two pass encoding and resizing references #22
							
								
								
									
										48
									
								
								src/ast.h
									
									
									
									
									
								
							
							
						
						
									
										48
									
								
								src/ast.h
									
									
									
									
									
								
							@@ -5,6 +5,7 @@
 | 
			
		||||
#include "error.h"
 | 
			
		||||
#include "lexer.h"
 | 
			
		||||
#include "tokenlist.h"
 | 
			
		||||
#include <assert.h>
 | 
			
		||||
#include <stddef.h>
 | 
			
		||||
#include <stdint.h>
 | 
			
		||||
 | 
			
		||||
@@ -75,10 +76,26 @@ typedef struct register_ {
 | 
			
		||||
} register_t;
 | 
			
		||||
 | 
			
		||||
typedef struct opcode_encoding {
 | 
			
		||||
    uint8_t encoding[32];
 | 
			
		||||
    uint8_t buffer[32];
 | 
			
		||||
    size_t len;
 | 
			
		||||
} opcode_encoding_t;
 | 
			
		||||
 | 
			
		||||
typedef struct instruction {
 | 
			
		||||
    bool has_reference;
 | 
			
		||||
    opcode_encoding_t encoding;
 | 
			
		||||
    int64_t address;
 | 
			
		||||
} instruction_t;
 | 
			
		||||
 | 
			
		||||
typedef struct reference {
 | 
			
		||||
    int64_t offset;
 | 
			
		||||
    int64_t address;
 | 
			
		||||
    operand_size_t size;
 | 
			
		||||
} reference_t;
 | 
			
		||||
 | 
			
		||||
typedef struct {
 | 
			
		||||
    int64_t address;
 | 
			
		||||
} label_t;
 | 
			
		||||
 | 
			
		||||
struct ast_node {
 | 
			
		||||
    node_id_t id;
 | 
			
		||||
    tokenlist_entry_t *token_entry;
 | 
			
		||||
@@ -89,10 +106,37 @@ struct ast_node {
 | 
			
		||||
    union {
 | 
			
		||||
        register_t reg;
 | 
			
		||||
        number_t number;
 | 
			
		||||
        opcode_encoding_t encoding;
 | 
			
		||||
        instruction_t instruction;
 | 
			
		||||
        reference_t reference;
 | 
			
		||||
        label_t label;
 | 
			
		||||
    } value;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
static inline register_t *ast_node_register_value(ast_node_t *node) {
 | 
			
		||||
    assert(node->id == NODE_REGISTER);
 | 
			
		||||
    return &node->value.reg;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static inline number_t *ast_node_number_value(ast_node_t *node) {
 | 
			
		||||
    assert(node->id == NODE_NUMBER);
 | 
			
		||||
    return &node->value.number;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static inline instruction_t *ast_node_instruction_value(ast_node_t *node) {
 | 
			
		||||
    assert(node->id == NODE_INSTRUCTION);
 | 
			
		||||
    return &node->value.instruction;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static inline reference_t *ast_node_reference_value(ast_node_t *node) {
 | 
			
		||||
    assert(node->id == NODE_LABEL_REFERENCE);
 | 
			
		||||
    return &node->value.reference;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static inline label_t *ast_node_label_value(ast_node_t *node) {
 | 
			
		||||
    assert(node->id == NODE_LABEL);
 | 
			
		||||
    return &node->value.label;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * @brief Allocates a new AST node
 | 
			
		||||
 *
 | 
			
		||||
 
 | 
			
		||||
@@ -138,8 +138,128 @@ opcode_data_t *const opcodes[] = {
 | 
			
		||||
            { .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_64 },
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    // CALL rel32
 | 
			
		||||
    &(opcode_data_t) {
 | 
			
		||||
        .mnemonic = "call",
 | 
			
		||||
        .opcode = 0xE8,
 | 
			
		||||
        .opcode_extension = opcode_extension_none,
 | 
			
		||||
        .encoding_class = ENCODING_DEFAULT,
 | 
			
		||||
        .operand_count = 1,
 | 
			
		||||
        .operands = {
 | 
			
		||||
            { .kind = OPERAND_IMMEDIATE, .size = OPERAND_SIZE_32 },
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    // CALL reg64
 | 
			
		||||
    &(opcode_data_t) {
 | 
			
		||||
        .mnemonic = "call",
 | 
			
		||||
        .opcode = 0xFF,
 | 
			
		||||
        .opcode_extension = 2,
 | 
			
		||||
        .encoding_class = ENCODING_DEFAULT,
 | 
			
		||||
        .rex_w_prefix = true,
 | 
			
		||||
        .operand_count = 1,
 | 
			
		||||
        .operands = {
 | 
			
		||||
            { .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_64 },
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    // CALL mem64
 | 
			
		||||
    &(opcode_data_t) {
 | 
			
		||||
        .mnemonic = "call",
 | 
			
		||||
        .opcode = 0xFF,
 | 
			
		||||
        .opcode_extension = 2,
 | 
			
		||||
        .encoding_class = ENCODING_DEFAULT,
 | 
			
		||||
        .rex_w_prefix = true,
 | 
			
		||||
        .operand_count = 1,
 | 
			
		||||
        .operands = {
 | 
			
		||||
            { .kind = OPERAND_MEMORY, .size = OPERAND_SIZE_64 },
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    // JMP rel8 (short jump)
 | 
			
		||||
    &(opcode_data_t) {
 | 
			
		||||
        .mnemonic = "jmp",
 | 
			
		||||
        .opcode = 0xEB,
 | 
			
		||||
        .opcode_extension = opcode_extension_none,
 | 
			
		||||
        .encoding_class = ENCODING_DEFAULT,
 | 
			
		||||
        .operand_count = 1,
 | 
			
		||||
        .operands = {
 | 
			
		||||
            { .kind = OPERAND_IMMEDIATE, .size = OPERAND_SIZE_8 },
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
 | 
			
		||||
    // JMP rel16
 | 
			
		||||
    &(opcode_data_t) {
 | 
			
		||||
        .mnemonic = "jmp",
 | 
			
		||||
        .opcode = 0xE9,
 | 
			
		||||
        .opcode_extension = opcode_extension_none,
 | 
			
		||||
        .encoding_class = ENCODING_DEFAULT,
 | 
			
		||||
        .operand_size_prefix = true,
 | 
			
		||||
        .operand_count = 1,
 | 
			
		||||
        .operands = {
 | 
			
		||||
            { .kind = OPERAND_IMMEDIATE, .size = OPERAND_SIZE_16 },
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
 | 
			
		||||
    // JMP reg16
 | 
			
		||||
    &(opcode_data_t) {
 | 
			
		||||
        .mnemonic = "jmp",
 | 
			
		||||
        .opcode = 0xFF,
 | 
			
		||||
        .opcode_extension = 4,
 | 
			
		||||
        .encoding_class = ENCODING_DEFAULT,
 | 
			
		||||
        .operand_size_prefix = true,
 | 
			
		||||
        .operand_count = 1,
 | 
			
		||||
        .operands = {
 | 
			
		||||
            { .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_16 },
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
 | 
			
		||||
    // JMP rel32 (near jump)
 | 
			
		||||
    &(opcode_data_t) {
 | 
			
		||||
        .mnemonic = "jmp",
 | 
			
		||||
        .opcode = 0xE9,
 | 
			
		||||
        .opcode_extension = opcode_extension_none,
 | 
			
		||||
        .encoding_class = ENCODING_DEFAULT,
 | 
			
		||||
        .operand_count = 1,
 | 
			
		||||
        .operands = {
 | 
			
		||||
            { .kind = OPERAND_IMMEDIATE, .size = OPERAND_SIZE_32 },
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
 | 
			
		||||
    // JMP reg32
 | 
			
		||||
    &(opcode_data_t) {
 | 
			
		||||
        .mnemonic = "jmp",
 | 
			
		||||
        .opcode = 0xFF,
 | 
			
		||||
        .opcode_extension = 4,
 | 
			
		||||
        .encoding_class = ENCODING_DEFAULT,
 | 
			
		||||
        .operand_count = 1,
 | 
			
		||||
        .operands = {
 | 
			
		||||
            { .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_32 },
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
 | 
			
		||||
    // JMP reg64
 | 
			
		||||
    &(opcode_data_t) {
 | 
			
		||||
        .mnemonic = "jmp",
 | 
			
		||||
        .opcode = 0xFF,
 | 
			
		||||
        .opcode_extension = 4,
 | 
			
		||||
        .encoding_class = ENCODING_DEFAULT,
 | 
			
		||||
        .rex_w_prefix = true,
 | 
			
		||||
        .operand_count = 1,
 | 
			
		||||
        .operands = {
 | 
			
		||||
            { .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_64 },
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
 | 
			
		||||
    // JMP mem64
 | 
			
		||||
    &(opcode_data_t) {
 | 
			
		||||
        .mnemonic = "jmp",
 | 
			
		||||
        .opcode = 0xFF,
 | 
			
		||||
        .opcode_extension = 4,
 | 
			
		||||
        .encoding_class = ENCODING_DEFAULT,
 | 
			
		||||
        .rex_w_prefix = true,
 | 
			
		||||
        .operand_count = 1,
 | 
			
		||||
        .operands = {
 | 
			
		||||
            { .kind = OPERAND_MEMORY, .size = OPERAND_SIZE_64 },
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    nullptr,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -6,6 +6,31 @@
 | 
			
		||||
#include <errno.h>
 | 
			
		||||
#include <string.h>
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * General encoder flow:
 | 
			
		||||
 *
 | 
			
		||||
 * There are 2 major passes the encoder does:
 | 
			
		||||
 *
 | 
			
		||||
 * First pass:
 | 
			
		||||
 *   - Run through the AST and collect information:
 | 
			
		||||
 *     - Set register values
 | 
			
		||||
 *     - Parse/set number values
 | 
			
		||||
 *     - Mark all instructions that use label references
 | 
			
		||||
 *   - Encode all instructions that don't use label references
 | 
			
		||||
 *   - Update addresses of all labels and instructions. Use an estimated
 | 
			
		||||
 *     instruction size for those instructions that use label references.
 | 
			
		||||
 *
 | 
			
		||||
 * Second pass:
 | 
			
		||||
 *   - Run through the AST for all instructions that use label references and
 | 
			
		||||
 *     collect size information using the estimated addresses from pass 1
 | 
			
		||||
 *   - Encode label references with the estimated addresses, this fixes their
 | 
			
		||||
 *     size.
 | 
			
		||||
 *   - Update all addresses
 | 
			
		||||
 *
 | 
			
		||||
 * Iteration:
 | 
			
		||||
 *   - Repeat the second pass until addresses converge
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
error_t *const err_encoder_invalid_register =
 | 
			
		||||
    &(error_t){.message = "Invalid register"};
 | 
			
		||||
error_t *const err_encoder_number_overflow =
 | 
			
		||||
@@ -23,13 +48,15 @@ error_t *const err_encoder_not_implemented =
 | 
			
		||||
error_t *const err_encoder_unexpected_length =
 | 
			
		||||
    &(error_t){.message = "Unexpectedly long encoding"};
 | 
			
		||||
 | 
			
		||||
error_t *encoder_alloc(encoder_t **output) {
 | 
			
		||||
error_t *encoder_alloc(encoder_t **output, ast_node_t *ast) {
 | 
			
		||||
    *output = nullptr;
 | 
			
		||||
    encoder_t *encoder = calloc(1, sizeof(encoder_t));
 | 
			
		||||
 | 
			
		||||
    if (encoder == nullptr)
 | 
			
		||||
        return err_allocation_failed;
 | 
			
		||||
 | 
			
		||||
    encoder->ast = ast;
 | 
			
		||||
 | 
			
		||||
    error_t *err = symbol_table_alloc(&encoder->symbols);
 | 
			
		||||
    if (err) {
 | 
			
		||||
        free(encoder);
 | 
			
		||||
@@ -213,16 +240,15 @@ static inline uint8_t modrm_rm(uint8_t modrm, register_id_t id) {
 | 
			
		||||
    return (modrm & ~modrm_rm_mask) | (id & 0b111);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Perform the initial pass over the AST. Records all symbols and sets the
 | 
			
		||||
 * values of registers and numbers.
 | 
			
		||||
 */
 | 
			
		||||
error_t *encoder_first_pass(encoder_t *encoder, ast_node_t *node) {
 | 
			
		||||
error_t *encoder_collect_info(encoder_t *encoder, ast_node_t *node,
 | 
			
		||||
                              ast_node_t *statement) {
 | 
			
		||||
    error_t *err = nullptr;
 | 
			
		||||
 | 
			
		||||
    if (encoder_is_symbols_node(node))
 | 
			
		||||
        err = symbol_table_update(encoder->symbols, node);
 | 
			
		||||
    else if (node->id == NODE_NUMBER)
 | 
			
		||||
    if (encoder_is_symbols_node(node)) {
 | 
			
		||||
        err = symbol_table_update(encoder->symbols, node, statement);
 | 
			
		||||
        if (statement->id == NODE_INSTRUCTION)
 | 
			
		||||
            statement->value.instruction.has_reference = true;
 | 
			
		||||
    } else if (node->id == NODE_NUMBER)
 | 
			
		||||
        err = encoder_set_number_value(node);
 | 
			
		||||
    else if (node->id == NODE_REGISTER)
 | 
			
		||||
        err = encoder_set_register_value(node);
 | 
			
		||||
@@ -230,7 +256,8 @@ error_t *encoder_first_pass(encoder_t *encoder, ast_node_t *node) {
 | 
			
		||||
        return err;
 | 
			
		||||
 | 
			
		||||
    for (size_t i = 0; i < node->len; ++i) {
 | 
			
		||||
        error_t *err = encoder_first_pass(encoder, node->children[i]);
 | 
			
		||||
        error_t *err =
 | 
			
		||||
            encoder_collect_info(encoder, node->children[i], statement);
 | 
			
		||||
        if (err)
 | 
			
		||||
            return err;
 | 
			
		||||
    }
 | 
			
		||||
@@ -242,7 +269,7 @@ bool is_operand_match(operand_info_t *info, ast_node_t *operand) {
 | 
			
		||||
    switch (info->kind) {
 | 
			
		||||
    case OPERAND_REGISTER:
 | 
			
		||||
        return operand->id == NODE_REGISTER &&
 | 
			
		||||
               operand->value.reg.size == info->size;
 | 
			
		||||
               ast_node_register_value(operand)->size == info->size;
 | 
			
		||||
    case OPERAND_MEMORY:
 | 
			
		||||
        return operand->id == NODE_MEMORY;
 | 
			
		||||
    case OPERAND_IMMEDIATE: {
 | 
			
		||||
@@ -251,13 +278,10 @@ bool is_operand_match(operand_info_t *info, ast_node_t *operand) {
 | 
			
		||||
        ast_node_t *child = operand->children[0];
 | 
			
		||||
 | 
			
		||||
        if (child->id == NODE_NUMBER)
 | 
			
		||||
            return (child->value.number.size & info->size) > 0;
 | 
			
		||||
        else if (child->id == NODE_LABEL_REFERENCE)
 | 
			
		||||
            return info->size == OPERAND_SIZE_32;
 | 
			
		||||
        // FIXME: first pass should give us information about the distance of
 | 
			
		||||
        // the label reference so we can pick a size more appropriately instead
 | 
			
		||||
        // of just defaulting to 32 bits
 | 
			
		||||
        break;
 | 
			
		||||
            return (ast_node_number_value(child)->size & info->size) > 0;
 | 
			
		||||
        else if (child->id == NODE_LABEL_REFERENCE) {
 | 
			
		||||
            return info->size &= ast_node_reference_value(child)->size;
 | 
			
		||||
        }
 | 
			
		||||
    } // end OPERAND_IMMEDIATE case
 | 
			
		||||
    }
 | 
			
		||||
    assert(false && "unreachable");
 | 
			
		||||
@@ -313,7 +337,7 @@ error_t *encode_one_register_in_opcode(encoder_t *encoder,
 | 
			
		||||
    (void)encoder;
 | 
			
		||||
    (void)opcode;
 | 
			
		||||
 | 
			
		||||
    register_id_t id = operands->children[0]->value.reg.id;
 | 
			
		||||
    register_id_t id = ast_node_register_value(operands->children[0])->id;
 | 
			
		||||
    encoding->buffer[encoding->len - 1] |= id & 0b111;
 | 
			
		||||
    if ((id & 0b1000) > 0) {
 | 
			
		||||
        *rex |= rex_prefix_r;
 | 
			
		||||
@@ -328,7 +352,7 @@ error_t *encode_one_register(encoder_t *encoder, opcode_data_t *opcode,
 | 
			
		||||
    assert(operands->len == 1);
 | 
			
		||||
    assert(operands->children[0]->id == NODE_REGISTER);
 | 
			
		||||
 | 
			
		||||
    register_id_t id = operands->children[0]->value.reg.id;
 | 
			
		||||
    register_id_t id = ast_node_register_value(operands->children[0])->id;
 | 
			
		||||
 | 
			
		||||
    uint8_t modrm = modrm_mod_register;
 | 
			
		||||
 | 
			
		||||
@@ -362,9 +386,9 @@ error_t *encode_one_immediate(encoder_t *encoder, opcode_data_t *opcode,
 | 
			
		||||
    assert(immediate->id == NODE_NUMBER ||
 | 
			
		||||
           immediate->id == NODE_LABEL_REFERENCE);
 | 
			
		||||
 | 
			
		||||
    if (immediate->id == NODE_NUMBER) {
 | 
			
		||||
        uint64_t value = immediate->value.number.value;
 | 
			
		||||
    operand_size_t size = opcode->operands[0].size;
 | 
			
		||||
    if (immediate->id == NODE_NUMBER) {
 | 
			
		||||
        uint64_t value = ast_node_number_value(immediate)->value;
 | 
			
		||||
        error_t *err = nullptr;
 | 
			
		||||
        switch (size) {
 | 
			
		||||
        case OPERAND_SIZE_8:
 | 
			
		||||
@@ -384,11 +408,22 @@ error_t *encode_one_immediate(encoder_t *encoder, opcode_data_t *opcode,
 | 
			
		||||
        }
 | 
			
		||||
        return err;
 | 
			
		||||
    } else {
 | 
			
		||||
        // FIXME: this still assumes references are always 32 bit
 | 
			
		||||
        uint32_t value = 0xDEADBEEF;
 | 
			
		||||
        return bytes_append_uint32(encoding, value);
 | 
			
		||||
        reference_t *reference = ast_node_reference_value(immediate);
 | 
			
		||||
        switch (size) {
 | 
			
		||||
        case OPERAND_SIZE_64:
 | 
			
		||||
            return bytes_append_uint64(encoding, reference->address);
 | 
			
		||||
        case OPERAND_SIZE_32:
 | 
			
		||||
            return bytes_append_uint32(encoding, reference->offset);
 | 
			
		||||
        case OPERAND_SIZE_16:
 | 
			
		||||
            return bytes_append_uint16(encoding, reference->offset);
 | 
			
		||||
        case OPERAND_SIZE_8:
 | 
			
		||||
            return bytes_append_uint8(encoding, reference->offset);
 | 
			
		||||
        default:
 | 
			
		||||
            assert(false && "intentionally unhandled");
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    __builtin_unreachable();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
error_t *encode_one_memory(encoder_t *encoder, opcode_data_t *opcode,
 | 
			
		||||
                           ast_node_t *operands, bytes_t *encoding,
 | 
			
		||||
@@ -456,7 +491,8 @@ error_t *encoder_encode_instruction(encoder_t *encoder,
 | 
			
		||||
        return err;
 | 
			
		||||
 | 
			
		||||
    // produce the actual encoding output in the NODE_INSTRUCTION value
 | 
			
		||||
    uint8_t *output = instruction->value.encoding.encoding;
 | 
			
		||||
    instruction_t *instruction_value = ast_node_instruction_value(instruction);
 | 
			
		||||
    uint8_t *output = instruction_value->encoding.buffer;
 | 
			
		||||
    size_t output_len = 0;
 | 
			
		||||
 | 
			
		||||
    // Handle prefixes
 | 
			
		||||
@@ -475,24 +511,166 @@ error_t *encoder_encode_instruction(encoder_t *encoder,
 | 
			
		||||
    memcpy(output + output_len, encoding->buffer, encoding->len);
 | 
			
		||||
    output_len += encoding->len;
 | 
			
		||||
 | 
			
		||||
    instruction->value.encoding.len = output_len;
 | 
			
		||||
    instruction_value->encoding.len = output_len;
 | 
			
		||||
 | 
			
		||||
    return nullptr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Perform the second pass that performs actual encoding. Will use
 | 
			
		||||
 * placeholder values for label references because instruction size has not
 | 
			
		||||
 * yet been determined.
 | 
			
		||||
 * Initial guess for instruction size of instructions that contain a label
 | 
			
		||||
 * reference
 | 
			
		||||
 */
 | 
			
		||||
error_t *encoder_encoding_pass(encoder_t *encoder, ast_node_t *root) {
 | 
			
		||||
constexpr size_t instruction_size_estimate = 10;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Perform the initial pass over the AST.
 | 
			
		||||
 *
 | 
			
		||||
 * - Collect information about the operands
 | 
			
		||||
 *   - parse and set number values
 | 
			
		||||
 *   - set the register values
 | 
			
		||||
 *   - determine if label references are used by an instruction
 | 
			
		||||
 * - encode instructions that don't use label references
 | 
			
		||||
 * - determine estimated addresses of each statement
 | 
			
		||||
 *
 | 
			
		||||
 */
 | 
			
		||||
error_t *encoder_first_pass(encoder_t *encoder) {
 | 
			
		||||
    ast_node_t *root = encoder->ast;
 | 
			
		||||
    assert(root->id == NODE_PROGRAM);
 | 
			
		||||
 | 
			
		||||
    uintptr_t address = 0;
 | 
			
		||||
 | 
			
		||||
    for (size_t i = 0; i < root->len; ++i) {
 | 
			
		||||
        if (root->children[i]->id != NODE_INSTRUCTION)
 | 
			
		||||
            continue;
 | 
			
		||||
        ast_node_t *instruction = root->children[i];
 | 
			
		||||
        error_t *err = encoder_encode_instruction(encoder, instruction);
 | 
			
		||||
        ast_node_t *statement = root->children[i];
 | 
			
		||||
        error_t *err = encoder_collect_info(encoder, statement, statement);
 | 
			
		||||
        if (err)
 | 
			
		||||
            return err;
 | 
			
		||||
 | 
			
		||||
        if (statement->id == NODE_INSTRUCTION &&
 | 
			
		||||
            ast_node_instruction_value(statement)->has_reference == false) {
 | 
			
		||||
            err = encoder_encode_instruction(encoder, statement);
 | 
			
		||||
            if (err)
 | 
			
		||||
                return err;
 | 
			
		||||
            instruction_t *instruction = ast_node_instruction_value(statement);
 | 
			
		||||
            instruction->address = address;
 | 
			
		||||
            address += instruction->encoding.len;
 | 
			
		||||
        } else if (statement->id == NODE_INSTRUCTION) {
 | 
			
		||||
            instruction_t *instruction = ast_node_instruction_value(statement);
 | 
			
		||||
            instruction->encoding.len = instruction_size_estimate;
 | 
			
		||||
            instruction->address = address;
 | 
			
		||||
            address += instruction_size_estimate;
 | 
			
		||||
        } else if (statement->id == NODE_LABEL) {
 | 
			
		||||
            label_t *label = ast_node_label_value(statement);
 | 
			
		||||
            label->address = address;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return nullptr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
operand_size_t signed_to_size_mask(int64_t value) {
 | 
			
		||||
    operand_size_t size = OPERAND_SIZE_64;
 | 
			
		||||
 | 
			
		||||
    if (value >= INT8_MIN && value <= INT8_MAX)
 | 
			
		||||
        size |= OPERAND_SIZE_8;
 | 
			
		||||
 | 
			
		||||
    if (value >= INT16_MIN && value <= INT16_MAX)
 | 
			
		||||
        size |= OPERAND_SIZE_16;
 | 
			
		||||
 | 
			
		||||
    if (value >= INT32_MIN && value <= INT32_MAX)
 | 
			
		||||
        size |= OPERAND_SIZE_32;
 | 
			
		||||
 | 
			
		||||
    return size;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int64_t statement_offset(ast_node_t *from, ast_node_t *to) {
 | 
			
		||||
    assert(from->id == NODE_INSTRUCTION);
 | 
			
		||||
    assert(to->id == NODE_LABEL);
 | 
			
		||||
 | 
			
		||||
    instruction_t *instruction = ast_node_instruction_value(from);
 | 
			
		||||
    int64_t from_addr = instruction->address + instruction->encoding.len;
 | 
			
		||||
    int64_t to_addr = ast_node_label_value(to)->address;
 | 
			
		||||
 | 
			
		||||
    return to_addr - from_addr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
error_t *encoder_collect_reference_info(encoder_t *encoder, ast_node_t *node,
 | 
			
		||||
                                        ast_node_t *statement) {
 | 
			
		||||
    assert(statement->id == NODE_INSTRUCTION);
 | 
			
		||||
    if (node->id == NODE_LABEL_REFERENCE) {
 | 
			
		||||
        const char *name = node->token_entry->token.value;
 | 
			
		||||
        symbol_t *symbol = symbol_table_lookup(encoder->symbols, name);
 | 
			
		||||
        assert(symbol && symbol->statement &&
 | 
			
		||||
               symbol->statement->id == NODE_LABEL);
 | 
			
		||||
 | 
			
		||||
        int64_t offset = statement_offset(statement, symbol->statement);
 | 
			
		||||
        int64_t absolute = ast_node_label_value(symbol->statement)->address;
 | 
			
		||||
        operand_size_t size = signed_to_size_mask(offset);
 | 
			
		||||
 | 
			
		||||
        node->value.reference.address = absolute;
 | 
			
		||||
        node->value.reference.offset = offset;
 | 
			
		||||
        node->value.reference.size = size;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for (size_t i = 0; i < node->len; ++i) {
 | 
			
		||||
        error_t *err = encoder_collect_reference_info(
 | 
			
		||||
            encoder, node->children[i], statement);
 | 
			
		||||
        if (err)
 | 
			
		||||
            return err;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return nullptr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool encoder_should_reencode(ast_node_t *statement) {
 | 
			
		||||
    if (statement->id != NODE_INSTRUCTION)
 | 
			
		||||
        return false;
 | 
			
		||||
 | 
			
		||||
    instruction_t *instruction = ast_node_instruction_value(statement);
 | 
			
		||||
    return instruction->has_reference;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void set_statement_address(ast_node_t *statement, int64_t address) {
 | 
			
		||||
    if (statement->id == NODE_INSTRUCTION) {
 | 
			
		||||
        ast_node_instruction_value(statement)->address = address;
 | 
			
		||||
    } else if (statement->id == NODE_LABEL) {
 | 
			
		||||
        ast_node_label_value(statement)->address = address;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
size_t get_statement_length(ast_node_t *statement) {
 | 
			
		||||
    if (statement->id != NODE_INSTRUCTION)
 | 
			
		||||
        return 0;
 | 
			
		||||
    return ast_node_instruction_value(statement)->encoding.len;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Perform the second pass. Updates the label info and encodes all instructions
 | 
			
		||||
 * that have a label reference.that performs actual encoding.
 | 
			
		||||
 */
 | 
			
		||||
error_t *encoder_second_pass(encoder_t *encoder, bool *did_update) {
 | 
			
		||||
    ast_node_t *root = encoder->ast;
 | 
			
		||||
 | 
			
		||||
    *did_update = false;
 | 
			
		||||
    int64_t address = 0;
 | 
			
		||||
    for (size_t i = 0; i < root->len; ++i) {
 | 
			
		||||
        ast_node_t *statement = root->children[i];
 | 
			
		||||
 | 
			
		||||
        set_statement_address(statement, address);
 | 
			
		||||
        size_t before = get_statement_length(statement);
 | 
			
		||||
 | 
			
		||||
        if (encoder_should_reencode(statement)) {
 | 
			
		||||
            error_t *err =
 | 
			
		||||
                encoder_collect_reference_info(encoder, statement, statement);
 | 
			
		||||
            if (err)
 | 
			
		||||
                return err;
 | 
			
		||||
            err = encoder_encode_instruction(encoder, statement);
 | 
			
		||||
            if (err)
 | 
			
		||||
                return err;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        size_t after = get_statement_length(statement);
 | 
			
		||||
        *did_update = *did_update || (before != after);
 | 
			
		||||
        address += after;
 | 
			
		||||
    }
 | 
			
		||||
    return nullptr;
 | 
			
		||||
}
 | 
			
		||||
@@ -515,12 +693,19 @@ error_t *encoder_check_symbols(encoder_t *encoder) {
 | 
			
		||||
    return nullptr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
error_t *encoder_encode(encoder_t *encoder, ast_node_t *ast) {
 | 
			
		||||
    error_t *err = encoder_first_pass(encoder, ast);
 | 
			
		||||
error_t *encoder_encode(encoder_t *encoder) {
 | 
			
		||||
    error_t *err = encoder_first_pass(encoder);
 | 
			
		||||
    if (err)
 | 
			
		||||
        return err;
 | 
			
		||||
    err = encoder_check_symbols(encoder);
 | 
			
		||||
    if (err)
 | 
			
		||||
        return err;
 | 
			
		||||
    return encoder_encoding_pass(encoder, ast);
 | 
			
		||||
 | 
			
		||||
    bool did_update = true;
 | 
			
		||||
    for (int i = 0; i < 10 && did_update; ++i) {
 | 
			
		||||
        err = encoder_second_pass(encoder, &did_update);
 | 
			
		||||
        if (err)
 | 
			
		||||
            return err;
 | 
			
		||||
    }
 | 
			
		||||
    return nullptr;
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -5,6 +5,7 @@
 | 
			
		||||
 | 
			
		||||
typedef struct encoder {
 | 
			
		||||
    symbol_table_t *symbols;
 | 
			
		||||
    ast_node_t *ast;
 | 
			
		||||
} encoder_t;
 | 
			
		||||
 | 
			
		||||
constexpr uint8_t modrm_mod_memory = 0b00'000'000;
 | 
			
		||||
@@ -16,8 +17,8 @@ constexpr uint8_t modrm_reg_mask = 0b00'111'000;
 | 
			
		||||
constexpr uint8_t modrm_rm_mask = 0b00'000'111;
 | 
			
		||||
constexpr uint8_t modrm_mod_mask = 0b11'000'000;
 | 
			
		||||
 | 
			
		||||
error_t *encoder_alloc(encoder_t **encoder);
 | 
			
		||||
error_t *encoder_encode(encoder_t *encoder, ast_node_t *ast);
 | 
			
		||||
error_t *encoder_alloc(encoder_t **encoder, ast_node_t *ast);
 | 
			
		||||
error_t *encoder_encode(encoder_t *encoder);
 | 
			
		||||
void encoder_free(encoder_t *encoder);
 | 
			
		||||
 | 
			
		||||
extern error_t *const err_encoder_invalid_register;
 | 
			
		||||
 
 | 
			
		||||
@@ -92,7 +92,7 @@ EXPORT       |           |          |   ERR    |          |
 | 
			
		||||
-------------|-----------|----------|----------|----------|
 | 
			
		||||
*/
 | 
			
		||||
 | 
			
		||||
bool symbol_table_should_update(symbol_kind_t old, symbol_kind_t new) {
 | 
			
		||||
bool symbol_table_should_upgrade(symbol_kind_t old, symbol_kind_t new) {
 | 
			
		||||
    if (old == SYMBOL_REFERENCE)
 | 
			
		||||
        return new != SYMBOL_REFERENCE;
 | 
			
		||||
    if (old == SYMBOL_LOCAL)
 | 
			
		||||
@@ -112,7 +112,7 @@ bool symbol_table_should_error(symbol_kind_t old, symbol_kind_t new) {
 | 
			
		||||
 * @pre The symbol _must not_ already be in the table.
 | 
			
		||||
 */
 | 
			
		||||
error_t *symbol_table_add(symbol_table_t *table, char *name, symbol_kind_t kind,
 | 
			
		||||
                          ast_node_t *node) {
 | 
			
		||||
                          ast_node_t *statement) {
 | 
			
		||||
    if (table->len >= table->cap) {
 | 
			
		||||
        error_t *err = symbol_table_grow_cap(table);
 | 
			
		||||
        if (err)
 | 
			
		||||
@@ -122,7 +122,7 @@ error_t *symbol_table_add(symbol_table_t *table, char *name, symbol_kind_t kind,
 | 
			
		||||
    table->symbols[table->len] = (symbol_t){
 | 
			
		||||
        .name = name,
 | 
			
		||||
        .kind = kind,
 | 
			
		||||
        .node = node,
 | 
			
		||||
        .statement = statement,
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    table->len += 1;
 | 
			
		||||
@@ -130,23 +130,29 @@ error_t *symbol_table_add(symbol_table_t *table, char *name, symbol_kind_t kind,
 | 
			
		||||
    return nullptr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
error_t *symbol_table_update(symbol_table_t *table, ast_node_t *node) {
 | 
			
		||||
error_t *symbol_table_update(symbol_table_t *table, ast_node_t *node,
 | 
			
		||||
                             ast_node_t *statement) {
 | 
			
		||||
    char *name;
 | 
			
		||||
    symbol_kind_t kind;
 | 
			
		||||
    error_t *err = symbol_table_get_node_info(node, &kind, &name);
 | 
			
		||||
    if (err)
 | 
			
		||||
        return err;
 | 
			
		||||
 | 
			
		||||
    if (kind != SYMBOL_LOCAL)
 | 
			
		||||
        statement = nullptr;
 | 
			
		||||
 | 
			
		||||
    symbol_t *symbol = symbol_table_lookup(table, name);
 | 
			
		||||
    if (!symbol)
 | 
			
		||||
        return symbol_table_add(table, name, kind, node);
 | 
			
		||||
        return symbol_table_add(table, name, kind, statement);
 | 
			
		||||
    if (symbol_table_should_error(symbol->kind, kind))
 | 
			
		||||
        return err_symbol_table_incompatible_symbols;
 | 
			
		||||
    if (symbol_table_should_update(symbol->kind, kind)) {
 | 
			
		||||
        symbol->name = name;
 | 
			
		||||
    if (symbol_table_should_upgrade(symbol->kind, kind)) {
 | 
			
		||||
        symbol->kind = kind;
 | 
			
		||||
        symbol->node = node;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (kind == SYMBOL_LOCAL && symbol->statement == nullptr)
 | 
			
		||||
        symbol->statement = statement;
 | 
			
		||||
 | 
			
		||||
    return nullptr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -29,7 +29,7 @@ typedef enum symbol_kind {
 | 
			
		||||
typedef struct symbol {
 | 
			
		||||
    char *name;
 | 
			
		||||
    symbol_kind_t kind;
 | 
			
		||||
    ast_node_t *node;
 | 
			
		||||
    ast_node_t *statement;
 | 
			
		||||
} symbol_t;
 | 
			
		||||
 | 
			
		||||
typedef struct symbol_table {
 | 
			
		||||
@@ -40,7 +40,8 @@ typedef struct symbol_table {
 | 
			
		||||
 | 
			
		||||
error_t *symbol_table_alloc(symbol_table_t **table);
 | 
			
		||||
void symbol_table_free(symbol_table_t *table);
 | 
			
		||||
error_t *symbol_table_update(symbol_table_t *table, ast_node_t *node);
 | 
			
		||||
error_t *symbol_table_update(symbol_table_t *table, ast_node_t *node,
 | 
			
		||||
                             ast_node_t *statement);
 | 
			
		||||
symbol_t *symbol_table_lookup(symbol_table_t *table, const char *name);
 | 
			
		||||
 | 
			
		||||
#endif // INCLUDE_ENCODER_SYMBOLS_H_
 | 
			
		||||
 
 | 
			
		||||
@@ -74,11 +74,11 @@ error_t *print_encoding(tokenlist_t *list) {
 | 
			
		||||
        return result.err;
 | 
			
		||||
 | 
			
		||||
    encoder_t *encoder;
 | 
			
		||||
    error_t *err = encoder_alloc(&encoder);
 | 
			
		||||
    error_t *err = encoder_alloc(&encoder, result.node);
 | 
			
		||||
    if (err)
 | 
			
		||||
        goto cleanup_ast;
 | 
			
		||||
 | 
			
		||||
    err = encoder_encode(encoder, result.node);
 | 
			
		||||
    err = encoder_encode(encoder);
 | 
			
		||||
    if (err)
 | 
			
		||||
        goto cleanup_ast;
 | 
			
		||||
 | 
			
		||||
@@ -88,7 +88,8 @@ error_t *print_encoding(tokenlist_t *list) {
 | 
			
		||||
        if (node->id != NODE_INSTRUCTION)
 | 
			
		||||
            continue;
 | 
			
		||||
 | 
			
		||||
        print_hex(node->value.encoding.len, node->value.encoding.encoding);
 | 
			
		||||
        print_hex(node->value.instruction.encoding.len,
 | 
			
		||||
                  node->value.instruction.encoding.buffer);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    encoder_free(encoder);
 | 
			
		||||
 
 | 
			
		||||
@@ -89,7 +89,8 @@ parse_result_t parse_immediate(tokenlist_entry_t *current) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
parse_result_t parse_memory_expression(tokenlist_entry_t *current) {
 | 
			
		||||
    parser_t parsers[] = {parse_register_expression, parse_identifier, nullptr};
 | 
			
		||||
    parser_t parsers[] = {parse_register_expression, parse_label_reference,
 | 
			
		||||
                          nullptr};
 | 
			
		||||
    return parse_any(current, parsers);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										140
									
								
								tests/symbols.c
									
									
									
									
									
								
							
							
						
						
									
										140
									
								
								tests/symbols.c
									
									
									
									
									
								
							@@ -58,17 +58,19 @@ MunitResult test_symbol_add_reference(const MunitParameter params[], void *data)
 | 
			
		||||
    symbol_table_alloc(&table);
 | 
			
		||||
 | 
			
		||||
    ast_node_t *reference = root->children[3]->children[1]->children[0]->children[0];
 | 
			
		||||
    ast_node_t *statement = root->children[3]; // The containing statement
 | 
			
		||||
    munit_assert_int(reference->id, ==, NODE_LABEL_REFERENCE);
 | 
			
		||||
    munit_assert_size(table->len, ==, 0);
 | 
			
		||||
 | 
			
		||||
    error_t *err = symbol_table_update(table, reference);
 | 
			
		||||
    error_t *err = symbol_table_update(table, reference, statement);
 | 
			
		||||
    munit_assert_null(err);
 | 
			
		||||
    munit_assert_size(table->len, ==, 1);
 | 
			
		||||
 | 
			
		||||
    symbol_t *symbol = symbol_table_lookup(table, "test");
 | 
			
		||||
    munit_assert_not_null(symbol);
 | 
			
		||||
    munit_assert_int(SYMBOL_REFERENCE, ==, symbol->kind);
 | 
			
		||||
    munit_assert_ptr_equal(reference, symbol->node);
 | 
			
		||||
    // For references, the statement should be nullptr
 | 
			
		||||
    munit_assert_ptr_null(symbol->statement);
 | 
			
		||||
    munit_assert_string_equal(symbol->name, "test");
 | 
			
		||||
 | 
			
		||||
    symbol_table_free(table);
 | 
			
		||||
@@ -90,14 +92,14 @@ MunitResult test_symbol_add_label(const MunitParameter params[], void *data) {
 | 
			
		||||
    munit_assert_int(label->id, ==, NODE_LABEL);
 | 
			
		||||
    munit_assert_size(table->len, ==, 0);
 | 
			
		||||
 | 
			
		||||
    error_t *err = symbol_table_update(table, label);
 | 
			
		||||
    error_t *err = symbol_table_update(table, label, label);
 | 
			
		||||
    munit_assert_null(err);
 | 
			
		||||
    munit_assert_size(table->len, ==, 1);
 | 
			
		||||
 | 
			
		||||
    symbol_t *symbol = symbol_table_lookup(table, "test");
 | 
			
		||||
    munit_assert_not_null(symbol);
 | 
			
		||||
    munit_assert_int(SYMBOL_LOCAL, ==, symbol->kind);
 | 
			
		||||
    munit_assert_ptr_equal(label, symbol->node);
 | 
			
		||||
    munit_assert_ptr_equal(label, symbol->statement);
 | 
			
		||||
    munit_assert_string_equal(symbol->name, "test");
 | 
			
		||||
 | 
			
		||||
    symbol_table_free(table);
 | 
			
		||||
@@ -116,17 +118,19 @@ MunitResult test_symbol_add_import(const MunitParameter params[], void *data) {
 | 
			
		||||
    symbol_table_alloc(&table);
 | 
			
		||||
 | 
			
		||||
    ast_node_t *import_directive = root->children[0]->children[1];
 | 
			
		||||
    ast_node_t *statement = root->children[0]; // The containing statement
 | 
			
		||||
    munit_assert_int(import_directive->id, ==, NODE_IMPORT_DIRECTIVE);
 | 
			
		||||
    munit_assert_size(table->len, ==, 0);
 | 
			
		||||
 | 
			
		||||
    error_t *err = symbol_table_update(table, import_directive);
 | 
			
		||||
    error_t *err = symbol_table_update(table, import_directive, statement);
 | 
			
		||||
    munit_assert_null(err);
 | 
			
		||||
    munit_assert_size(table->len, ==, 1);
 | 
			
		||||
 | 
			
		||||
    symbol_t *symbol = symbol_table_lookup(table, "test");
 | 
			
		||||
    munit_assert_not_null(symbol);
 | 
			
		||||
    munit_assert_int(SYMBOL_IMPORT, ==, symbol->kind);
 | 
			
		||||
    munit_assert_ptr_equal(import_directive, symbol->node);
 | 
			
		||||
    // For import directives, the statement should be nullptr
 | 
			
		||||
    munit_assert_ptr_null(symbol->statement);
 | 
			
		||||
    munit_assert_string_equal(symbol->name, "test");
 | 
			
		||||
 | 
			
		||||
    symbol_table_free(table);
 | 
			
		||||
@@ -135,42 +139,56 @@ MunitResult test_symbol_add_import(const MunitParameter params[], void *data) {
 | 
			
		||||
    return MUNIT_OK;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void test_symbol_update(const char *name, ast_node_t *first, symbol_kind_t first_kind, ast_node_t *second,
 | 
			
		||||
                        symbol_kind_t second_kind, bool should_succeed, bool should_update) {
 | 
			
		||||
void test_symbol_update(const char *name, ast_node_t *first, symbol_kind_t first_kind, ast_node_t *first_statement,
 | 
			
		||||
                        ast_node_t *second, symbol_kind_t second_kind, ast_node_t *second_statement,
 | 
			
		||||
                        bool should_succeed, bool should_update, ast_node_t *expected_statement) {
 | 
			
		||||
    symbol_table_t *table = nullptr;
 | 
			
		||||
    symbol_table_alloc(&table);
 | 
			
		||||
 | 
			
		||||
    munit_assert_size(table->len, ==, 0);
 | 
			
		||||
    error_t *err = symbol_table_update(table, first);
 | 
			
		||||
    // Add the first symbol
 | 
			
		||||
    error_t *err = symbol_table_update(table, first, first_statement);
 | 
			
		||||
    munit_assert_null(err);
 | 
			
		||||
    munit_assert_size(table->len, ==, 1);
 | 
			
		||||
 | 
			
		||||
    // Verify first symbol state
 | 
			
		||||
    symbol_t *symbol = symbol_table_lookup(table, name);
 | 
			
		||||
    munit_assert_not_null(symbol);
 | 
			
		||||
    munit_assert_int(first_kind, ==, symbol->kind);
 | 
			
		||||
    munit_assert_ptr_equal(first, symbol->node);
 | 
			
		||||
    munit_assert_string_equal(symbol->name, name);
 | 
			
		||||
 | 
			
		||||
    err = symbol_table_update(table, second);
 | 
			
		||||
    if (should_succeed)
 | 
			
		||||
        munit_assert_null(err);
 | 
			
		||||
    else
 | 
			
		||||
        munit_assert_ptr_equal(err, err_symbol_table_incompatible_symbols);
 | 
			
		||||
    munit_assert_size(table->len, ==, 1);
 | 
			
		||||
 | 
			
		||||
    symbol = symbol_table_lookup(table, name);
 | 
			
		||||
    if (should_update) {
 | 
			
		||||
        munit_assert_not_null(symbol);
 | 
			
		||||
        munit_assert_int(second_kind, ==, symbol->kind);
 | 
			
		||||
        munit_assert_ptr_equal(second, symbol->node);
 | 
			
		||||
        munit_assert_string_equal(symbol->name, name);
 | 
			
		||||
    // Check statement based on symbol kind
 | 
			
		||||
    if (first_kind == SYMBOL_LOCAL) {
 | 
			
		||||
        munit_assert_ptr_equal(first_statement, symbol->statement);
 | 
			
		||||
    } else {
 | 
			
		||||
        munit_assert_not_null(symbol);
 | 
			
		||||
        munit_assert_int(first_kind, ==, symbol->kind);
 | 
			
		||||
        munit_assert_ptr_equal(first, symbol->node);
 | 
			
		||||
        munit_assert_string_equal(symbol->name, name);
 | 
			
		||||
        munit_assert_ptr_null(symbol->statement);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Attempt the second update
 | 
			
		||||
    err = symbol_table_update(table, second, second_statement);
 | 
			
		||||
 | 
			
		||||
    // Check if update succeeded as expected
 | 
			
		||||
    if (should_succeed) {
 | 
			
		||||
        munit_assert_null(err);
 | 
			
		||||
    } else {
 | 
			
		||||
        munit_assert_ptr_equal(err, err_symbol_table_incompatible_symbols);
 | 
			
		||||
        symbol_table_free(table);
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Verify symbol after second update
 | 
			
		||||
    symbol = symbol_table_lookup(table, name);
 | 
			
		||||
    munit_assert_not_null(symbol);
 | 
			
		||||
 | 
			
		||||
    // Check if kind updated as expected
 | 
			
		||||
    if (should_update) {
 | 
			
		||||
        munit_assert_int(second_kind, ==, symbol->kind);
 | 
			
		||||
    } else {
 | 
			
		||||
        munit_assert_int(first_kind, ==, symbol->kind);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Simply check against the expected statement value
 | 
			
		||||
    munit_assert_ptr_equal(expected_statement, symbol->statement);
 | 
			
		||||
 | 
			
		||||
    symbol_table_free(table);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -181,28 +199,43 @@ MunitResult test_symbol_upgrade_valid(const MunitParameter params[], void *data)
 | 
			
		||||
    symbols_setup_test(&root, &list, "tests/input/symbols.asm");
 | 
			
		||||
 | 
			
		||||
    ast_node_t *reference = root->children[3]->children[1]->children[0]->children[0];
 | 
			
		||||
    ast_node_t *reference_statement = root->children[3];
 | 
			
		||||
    ast_node_t *label = root->children[2];
 | 
			
		||||
    ast_node_t *import_directive = root->children[0]->children[1];
 | 
			
		||||
    ast_node_t *import_statement = root->children[0];
 | 
			
		||||
    ast_node_t *export_directive = root->children[1]->children[1];
 | 
			
		||||
    ast_node_t *export_statement = root->children[1];
 | 
			
		||||
 | 
			
		||||
    // real upgrades
 | 
			
		||||
    test_symbol_update("test", reference, SYMBOL_REFERENCE, label, SYMBOL_LOCAL, true, true);
 | 
			
		||||
    test_symbol_update("test", reference, SYMBOL_REFERENCE, import_directive, SYMBOL_IMPORT, true, true);
 | 
			
		||||
    test_symbol_update("test", reference, SYMBOL_REFERENCE, export_directive, SYMBOL_EXPORT, true, true);
 | 
			
		||||
    test_symbol_update("test", label, SYMBOL_LOCAL, export_directive, SYMBOL_EXPORT, true, true);
 | 
			
		||||
    test_symbol_update("test", reference, SYMBOL_REFERENCE, reference_statement, label, SYMBOL_LOCAL, label, true, true,
 | 
			
		||||
                       label);
 | 
			
		||||
    test_symbol_update("test", reference, SYMBOL_REFERENCE, reference_statement, import_directive, SYMBOL_IMPORT,
 | 
			
		||||
                       import_statement, true, true, nullptr);
 | 
			
		||||
    test_symbol_update("test", reference, SYMBOL_REFERENCE, reference_statement, export_directive, SYMBOL_EXPORT,
 | 
			
		||||
                       export_statement, true, true, nullptr);
 | 
			
		||||
    test_symbol_update("test", label, SYMBOL_LOCAL, label, export_directive, SYMBOL_EXPORT, export_statement, true,
 | 
			
		||||
                       true, label);
 | 
			
		||||
 | 
			
		||||
    // identity upgrades
 | 
			
		||||
    test_symbol_update("test", reference, SYMBOL_REFERENCE, reference, SYMBOL_REFERENCE, true, false);
 | 
			
		||||
    test_symbol_update("test", label, SYMBOL_LOCAL, label, SYMBOL_LOCAL, true, false);
 | 
			
		||||
    test_symbol_update("test", import_directive, SYMBOL_IMPORT, import_directive, SYMBOL_IMPORT, true, false);
 | 
			
		||||
    test_symbol_update("test", export_directive, SYMBOL_EXPORT, export_directive, SYMBOL_EXPORT, true, false);
 | 
			
		||||
    test_symbol_update("test", reference, SYMBOL_REFERENCE, reference_statement, reference, SYMBOL_REFERENCE,
 | 
			
		||||
                       reference_statement, true, false, nullptr);
 | 
			
		||||
    test_symbol_update("test", label, SYMBOL_LOCAL, label, label, SYMBOL_LOCAL, label, true, false, label);
 | 
			
		||||
    test_symbol_update("test", import_directive, SYMBOL_IMPORT, import_statement, import_directive, SYMBOL_IMPORT,
 | 
			
		||||
                       import_statement, true, false, nullptr);
 | 
			
		||||
    test_symbol_update("test", export_directive, SYMBOL_EXPORT, export_statement, export_directive, SYMBOL_EXPORT,
 | 
			
		||||
                       export_statement, true, false, nullptr);
 | 
			
		||||
 | 
			
		||||
    // downgrades that are allowed and ignored
 | 
			
		||||
    test_symbol_update("test", label, SYMBOL_LOCAL, reference, SYMBOL_REFERENCE, true, false);
 | 
			
		||||
    test_symbol_update("test", import_directive, SYMBOL_IMPORT, reference, SYMBOL_REFERENCE, true, false);
 | 
			
		||||
    test_symbol_update("test", export_directive, SYMBOL_EXPORT, reference, SYMBOL_REFERENCE, true, false);
 | 
			
		||||
    test_symbol_update("test", export_directive, SYMBOL_EXPORT, label, SYMBOL_LOCAL, true, false);
 | 
			
		||||
    test_symbol_update("test", import_directive, SYMBOL_IMPORT, label, SYMBOL_LOCAL, true, false);
 | 
			
		||||
    test_symbol_update("test", label, SYMBOL_LOCAL, label, reference, SYMBOL_REFERENCE, reference_statement, true,
 | 
			
		||||
                       false, label);
 | 
			
		||||
    test_symbol_update("test", import_directive, SYMBOL_IMPORT, import_statement, reference, SYMBOL_REFERENCE,
 | 
			
		||||
                       reference_statement, true, false, nullptr);
 | 
			
		||||
    test_symbol_update("test", export_directive, SYMBOL_EXPORT, export_statement, reference, SYMBOL_REFERENCE,
 | 
			
		||||
                       reference_statement, true, false, nullptr);
 | 
			
		||||
    test_symbol_update("test", export_directive, SYMBOL_EXPORT, export_statement, label, SYMBOL_LOCAL, label, true,
 | 
			
		||||
                       false, label);
 | 
			
		||||
    test_symbol_update("test", import_directive, SYMBOL_IMPORT, import_statement, label, SYMBOL_LOCAL, label, true,
 | 
			
		||||
                       false, label);
 | 
			
		||||
 | 
			
		||||
    ast_node_free(root);
 | 
			
		||||
    tokenlist_free(list);
 | 
			
		||||
@@ -216,14 +249,20 @@ MunitResult test_symbol_upgrade_invalid(const MunitParameter params[], void *dat
 | 
			
		||||
    symbols_setup_test(&root, &list, "tests/input/symbols.asm");
 | 
			
		||||
 | 
			
		||||
    ast_node_t *reference = root->children[3]->children[1]->children[0]->children[0];
 | 
			
		||||
    ast_node_t *reference_statement = root->children[3];
 | 
			
		||||
    ast_node_t *label = root->children[2];
 | 
			
		||||
    ast_node_t *import_directive = root->children[0]->children[1];
 | 
			
		||||
    ast_node_t *import_statement = root->children[0];
 | 
			
		||||
    ast_node_t *export_directive = root->children[1]->children[1];
 | 
			
		||||
    ast_node_t *export_statement = root->children[1];
 | 
			
		||||
 | 
			
		||||
    // invalid upgrades
 | 
			
		||||
    test_symbol_update("test", label, SYMBOL_LOCAL, import_directive, SYMBOL_IMPORT, false, false);
 | 
			
		||||
    test_symbol_update("test", export_directive, SYMBOL_EXPORT, import_directive, SYMBOL_IMPORT, false, false);
 | 
			
		||||
    test_symbol_update("test", import_directive, SYMBOL_IMPORT, export_directive, SYMBOL_EXPORT, false, false);
 | 
			
		||||
    test_symbol_update("test", label, SYMBOL_LOCAL, label, import_directive, SYMBOL_IMPORT, import_statement, false,
 | 
			
		||||
                       false, nullptr);
 | 
			
		||||
    test_symbol_update("test", export_directive, SYMBOL_EXPORT, export_statement, import_directive, SYMBOL_IMPORT,
 | 
			
		||||
                       import_statement, false, false, nullptr);
 | 
			
		||||
    test_symbol_update("test", import_directive, SYMBOL_IMPORT, import_statement, export_directive, SYMBOL_EXPORT,
 | 
			
		||||
                       export_statement, false, false, nullptr);
 | 
			
		||||
 | 
			
		||||
    ast_node_free(root);
 | 
			
		||||
    tokenlist_free(list);
 | 
			
		||||
@@ -240,17 +279,19 @@ MunitResult test_symbol_add_export(const MunitParameter params[], void *data) {
 | 
			
		||||
    symbol_table_alloc(&table);
 | 
			
		||||
 | 
			
		||||
    ast_node_t *export_directive = root->children[1]->children[1];
 | 
			
		||||
    ast_node_t *statement = root->children[1]; // The containing statement
 | 
			
		||||
    munit_assert_int(export_directive->id, ==, NODE_EXPORT_DIRECTIVE);
 | 
			
		||||
    munit_assert_size(table->len, ==, 0);
 | 
			
		||||
 | 
			
		||||
    error_t *err = symbol_table_update(table, export_directive);
 | 
			
		||||
    error_t *err = symbol_table_update(table, export_directive, statement);
 | 
			
		||||
    munit_assert_null(err);
 | 
			
		||||
    munit_assert_size(table->len, ==, 1);
 | 
			
		||||
 | 
			
		||||
    symbol_t *symbol = symbol_table_lookup(table, "test");
 | 
			
		||||
    munit_assert_not_null(symbol);
 | 
			
		||||
    munit_assert_int(SYMBOL_EXPORT, ==, symbol->kind);
 | 
			
		||||
    munit_assert_ptr_equal(export_directive, symbol->node);
 | 
			
		||||
    // For export directives, the statement should be nullptr
 | 
			
		||||
    munit_assert_ptr_null(symbol->statement);
 | 
			
		||||
    munit_assert_string_equal(symbol->name, "test");
 | 
			
		||||
 | 
			
		||||
    symbol_table_free(table);
 | 
			
		||||
@@ -280,7 +321,7 @@ MunitResult test_symbol_table_growth(const MunitParameter params[], void *data)
 | 
			
		||||
        ast_node_t *label = root->children[i];
 | 
			
		||||
        munit_assert_int(label->id, ==, NODE_LABEL);
 | 
			
		||||
 | 
			
		||||
        error_t *err = symbol_table_update(table, label);
 | 
			
		||||
        error_t *err = symbol_table_update(table, label, label);
 | 
			
		||||
        munit_assert_null(err);
 | 
			
		||||
        munit_assert_size(table->len, ==, i + 1);
 | 
			
		||||
 | 
			
		||||
@@ -292,7 +333,7 @@ MunitResult test_symbol_table_growth(const MunitParameter params[], void *data)
 | 
			
		||||
    ast_node_t *final_label = root->children[64];
 | 
			
		||||
    munit_assert_int(final_label->id, ==, NODE_LABEL);
 | 
			
		||||
 | 
			
		||||
    error_t *err = symbol_table_update(table, final_label);
 | 
			
		||||
    error_t *err = symbol_table_update(table, final_label, final_label);
 | 
			
		||||
    munit_assert_null(err);
 | 
			
		||||
    munit_assert_size(table->len, ==, 65);
 | 
			
		||||
 | 
			
		||||
@@ -308,6 +349,7 @@ MunitResult test_symbol_table_growth(const MunitParameter params[], void *data)
 | 
			
		||||
        munit_assert_not_null(symbol);
 | 
			
		||||
        munit_assert_int(SYMBOL_LOCAL, ==, symbol->kind);
 | 
			
		||||
        munit_assert_string_equal(symbol->name, name);
 | 
			
		||||
        munit_assert_ptr_equal(symbol->statement, root->children[i]);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    symbol_table_free(table);
 | 
			
		||||
@@ -326,7 +368,7 @@ MunitResult test_symbol_invalid_node(const MunitParameter params[], void *data)
 | 
			
		||||
    symbol_table_alloc(&table);
 | 
			
		||||
 | 
			
		||||
    munit_assert_size(table->len, ==, 0);
 | 
			
		||||
    error_t *err = symbol_table_update(table, root);
 | 
			
		||||
    error_t *err = symbol_table_update(table, root, root);
 | 
			
		||||
    munit_assert_ptr_equal(err, err_symbol_table_invalid_node);
 | 
			
		||||
    munit_assert_size(table->len, ==, 0);
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user