diff --git a/src/ast.h b/src/ast.h index 12c9b79..8e4fc42 100644 --- a/src/ast.h +++ b/src/ast.h @@ -5,6 +5,7 @@ #include "error.h" #include "lexer.h" #include "tokenlist.h" +#include #include #include @@ -75,10 +76,26 @@ typedef struct register_ { } register_t; typedef struct opcode_encoding { - uint8_t encoding[32]; + uint8_t buffer[32]; size_t len; } opcode_encoding_t; +typedef struct instruction { + bool has_reference; + opcode_encoding_t encoding; + int64_t address; +} instruction_t; + +typedef struct reference { + int64_t offset; + int64_t address; + operand_size_t size; +} reference_t; + +typedef struct { + int64_t address; +} label_t; + struct ast_node { node_id_t id; tokenlist_entry_t *token_entry; @@ -89,10 +106,37 @@ struct ast_node { union { register_t reg; number_t number; - opcode_encoding_t encoding; + instruction_t instruction; + reference_t reference; + label_t label; } value; }; +static inline register_t *ast_node_register_value(ast_node_t *node) { + assert(node->id == NODE_REGISTER); + return &node->value.reg; +} + +static inline number_t *ast_node_number_value(ast_node_t *node) { + assert(node->id == NODE_NUMBER); + return &node->value.number; +} + +static inline instruction_t *ast_node_instruction_value(ast_node_t *node) { + assert(node->id == NODE_INSTRUCTION); + return &node->value.instruction; +} + +static inline reference_t *ast_node_reference_value(ast_node_t *node) { + assert(node->id == NODE_LABEL_REFERENCE); + return &node->value.reference; +} + +static inline label_t *ast_node_label_value(ast_node_t *node) { + assert(node->id == NODE_LABEL); + return &node->value.label; +} + /** * @brief Allocates a new AST node * diff --git a/src/data/opcodes.c b/src/data/opcodes.c index f74f68a..d793d69 100644 --- a/src/data/opcodes.c +++ b/src/data/opcodes.c @@ -138,8 +138,128 @@ opcode_data_t *const opcodes[] = { { .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_64 }, }, }, + // CALL rel32 + &(opcode_data_t) { + .mnemonic = "call", + .opcode = 0xE8, + .opcode_extension = opcode_extension_none, + .encoding_class = ENCODING_DEFAULT, + .operand_count = 1, + .operands = { + { .kind = OPERAND_IMMEDIATE, .size = OPERAND_SIZE_32 }, + }, + }, + // CALL reg64 + &(opcode_data_t) { + .mnemonic = "call", + .opcode = 0xFF, + .opcode_extension = 2, + .encoding_class = ENCODING_DEFAULT, + .rex_w_prefix = true, + .operand_count = 1, + .operands = { + { .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_64 }, + }, + }, + // CALL mem64 + &(opcode_data_t) { + .mnemonic = "call", + .opcode = 0xFF, + .opcode_extension = 2, + .encoding_class = ENCODING_DEFAULT, + .rex_w_prefix = true, + .operand_count = 1, + .operands = { + { .kind = OPERAND_MEMORY, .size = OPERAND_SIZE_64 }, + }, + }, + // JMP rel8 (short jump) + &(opcode_data_t) { + .mnemonic = "jmp", + .opcode = 0xEB, + .opcode_extension = opcode_extension_none, + .encoding_class = ENCODING_DEFAULT, + .operand_count = 1, + .operands = { + { .kind = OPERAND_IMMEDIATE, .size = OPERAND_SIZE_8 }, + }, + }, + // JMP rel16 + &(opcode_data_t) { + .mnemonic = "jmp", + .opcode = 0xE9, + .opcode_extension = opcode_extension_none, + .encoding_class = ENCODING_DEFAULT, + .operand_size_prefix = true, + .operand_count = 1, + .operands = { + { .kind = OPERAND_IMMEDIATE, .size = OPERAND_SIZE_16 }, + }, + }, + // JMP reg16 + &(opcode_data_t) { + .mnemonic = "jmp", + .opcode = 0xFF, + .opcode_extension = 4, + .encoding_class = ENCODING_DEFAULT, + .operand_size_prefix = true, + .operand_count = 1, + .operands = { + { .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_16 }, + }, + }, + + // JMP rel32 (near jump) + &(opcode_data_t) { + .mnemonic = "jmp", + .opcode = 0xE9, + .opcode_extension = opcode_extension_none, + .encoding_class = ENCODING_DEFAULT, + .operand_count = 1, + .operands = { + { .kind = OPERAND_IMMEDIATE, .size = OPERAND_SIZE_32 }, + }, + }, + + // JMP reg32 + &(opcode_data_t) { + .mnemonic = "jmp", + .opcode = 0xFF, + .opcode_extension = 4, + .encoding_class = ENCODING_DEFAULT, + .operand_count = 1, + .operands = { + { .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_32 }, + }, + }, + + // JMP reg64 + &(opcode_data_t) { + .mnemonic = "jmp", + .opcode = 0xFF, + .opcode_extension = 4, + .encoding_class = ENCODING_DEFAULT, + .rex_w_prefix = true, + .operand_count = 1, + .operands = { + { .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_64 }, + }, + }, + + // JMP mem64 + &(opcode_data_t) { + .mnemonic = "jmp", + .opcode = 0xFF, + .opcode_extension = 4, + .encoding_class = ENCODING_DEFAULT, + .rex_w_prefix = true, + .operand_count = 1, + .operands = { + { .kind = OPERAND_MEMORY, .size = OPERAND_SIZE_64 }, + }, + }, nullptr, }; diff --git a/src/encoder/encoder.c b/src/encoder/encoder.c index d3b999a..ea462c4 100644 --- a/src/encoder/encoder.c +++ b/src/encoder/encoder.c @@ -6,6 +6,31 @@ #include #include +/** + * General encoder flow: + * + * There are 2 major passes the encoder does: + * + * First pass: + * - Run through the AST and collect information: + * - Set register values + * - Parse/set number values + * - Mark all instructions that use label references + * - Encode all instructions that don't use label references + * - Update addresses of all labels and instructions. Use an estimated + * instruction size for those instructions that use label references. + * + * Second pass: + * - Run through the AST for all instructions that use label references and + * collect size information using the estimated addresses from pass 1 + * - Encode label references with the estimated addresses, this fixes their + * size. + * - Update all addresses + * + * Iteration: + * - Repeat the second pass until addresses converge + */ + error_t *const err_encoder_invalid_register = &(error_t){.message = "Invalid register"}; error_t *const err_encoder_number_overflow = @@ -23,13 +48,15 @@ error_t *const err_encoder_not_implemented = error_t *const err_encoder_unexpected_length = &(error_t){.message = "Unexpectedly long encoding"}; -error_t *encoder_alloc(encoder_t **output) { +error_t *encoder_alloc(encoder_t **output, ast_node_t *ast) { *output = nullptr; encoder_t *encoder = calloc(1, sizeof(encoder_t)); if (encoder == nullptr) return err_allocation_failed; + encoder->ast = ast; + error_t *err = symbol_table_alloc(&encoder->symbols); if (err) { free(encoder); @@ -213,16 +240,15 @@ static inline uint8_t modrm_rm(uint8_t modrm, register_id_t id) { return (modrm & ~modrm_rm_mask) | (id & 0b111); } -/** - * Perform the initial pass over the AST. Records all symbols and sets the - * values of registers and numbers. - */ -error_t *encoder_first_pass(encoder_t *encoder, ast_node_t *node) { +error_t *encoder_collect_info(encoder_t *encoder, ast_node_t *node, + ast_node_t *statement) { error_t *err = nullptr; - if (encoder_is_symbols_node(node)) - err = symbol_table_update(encoder->symbols, node); - else if (node->id == NODE_NUMBER) + if (encoder_is_symbols_node(node)) { + err = symbol_table_update(encoder->symbols, node, statement); + if (statement->id == NODE_INSTRUCTION) + statement->value.instruction.has_reference = true; + } else if (node->id == NODE_NUMBER) err = encoder_set_number_value(node); else if (node->id == NODE_REGISTER) err = encoder_set_register_value(node); @@ -230,7 +256,8 @@ error_t *encoder_first_pass(encoder_t *encoder, ast_node_t *node) { return err; for (size_t i = 0; i < node->len; ++i) { - error_t *err = encoder_first_pass(encoder, node->children[i]); + error_t *err = + encoder_collect_info(encoder, node->children[i], statement); if (err) return err; } @@ -242,7 +269,7 @@ bool is_operand_match(operand_info_t *info, ast_node_t *operand) { switch (info->kind) { case OPERAND_REGISTER: return operand->id == NODE_REGISTER && - operand->value.reg.size == info->size; + ast_node_register_value(operand)->size == info->size; case OPERAND_MEMORY: return operand->id == NODE_MEMORY; case OPERAND_IMMEDIATE: { @@ -251,13 +278,10 @@ bool is_operand_match(operand_info_t *info, ast_node_t *operand) { ast_node_t *child = operand->children[0]; if (child->id == NODE_NUMBER) - return (child->value.number.size & info->size) > 0; - else if (child->id == NODE_LABEL_REFERENCE) - return info->size == OPERAND_SIZE_32; - // FIXME: first pass should give us information about the distance of - // the label reference so we can pick a size more appropriately instead - // of just defaulting to 32 bits - break; + return (ast_node_number_value(child)->size & info->size) > 0; + else if (child->id == NODE_LABEL_REFERENCE) { + return info->size &= ast_node_reference_value(child)->size; + } } // end OPERAND_IMMEDIATE case } assert(false && "unreachable"); @@ -313,7 +337,7 @@ error_t *encode_one_register_in_opcode(encoder_t *encoder, (void)encoder; (void)opcode; - register_id_t id = operands->children[0]->value.reg.id; + register_id_t id = ast_node_register_value(operands->children[0])->id; encoding->buffer[encoding->len - 1] |= id & 0b111; if ((id & 0b1000) > 0) { *rex |= rex_prefix_r; @@ -328,7 +352,7 @@ error_t *encode_one_register(encoder_t *encoder, opcode_data_t *opcode, assert(operands->len == 1); assert(operands->children[0]->id == NODE_REGISTER); - register_id_t id = operands->children[0]->value.reg.id; + register_id_t id = ast_node_register_value(operands->children[0])->id; uint8_t modrm = modrm_mod_register; @@ -362,9 +386,9 @@ error_t *encode_one_immediate(encoder_t *encoder, opcode_data_t *opcode, assert(immediate->id == NODE_NUMBER || immediate->id == NODE_LABEL_REFERENCE); + operand_size_t size = opcode->operands[0].size; if (immediate->id == NODE_NUMBER) { - uint64_t value = immediate->value.number.value; - operand_size_t size = opcode->operands[0].size; + uint64_t value = ast_node_number_value(immediate)->value; error_t *err = nullptr; switch (size) { case OPERAND_SIZE_8: @@ -384,10 +408,21 @@ error_t *encode_one_immediate(encoder_t *encoder, opcode_data_t *opcode, } return err; } else { - // FIXME: this still assumes references are always 32 bit - uint32_t value = 0xDEADBEEF; - return bytes_append_uint32(encoding, value); + reference_t *reference = ast_node_reference_value(immediate); + switch (size) { + case OPERAND_SIZE_64: + return bytes_append_uint64(encoding, reference->address); + case OPERAND_SIZE_32: + return bytes_append_uint32(encoding, reference->offset); + case OPERAND_SIZE_16: + return bytes_append_uint16(encoding, reference->offset); + case OPERAND_SIZE_8: + return bytes_append_uint8(encoding, reference->offset); + default: + assert(false && "intentionally unhandled"); + } } + __builtin_unreachable(); } error_t *encode_one_memory(encoder_t *encoder, opcode_data_t *opcode, @@ -456,7 +491,8 @@ error_t *encoder_encode_instruction(encoder_t *encoder, return err; // produce the actual encoding output in the NODE_INSTRUCTION value - uint8_t *output = instruction->value.encoding.encoding; + instruction_t *instruction_value = ast_node_instruction_value(instruction); + uint8_t *output = instruction_value->encoding.buffer; size_t output_len = 0; // Handle prefixes @@ -475,24 +511,166 @@ error_t *encoder_encode_instruction(encoder_t *encoder, memcpy(output + output_len, encoding->buffer, encoding->len); output_len += encoding->len; - instruction->value.encoding.len = output_len; + instruction_value->encoding.len = output_len; return nullptr; } /** - * Perform the second pass that performs actual encoding. Will use - * placeholder values for label references because instruction size has not - * yet been determined. + * Initial guess for instruction size of instructions that contain a label + * reference */ -error_t *encoder_encoding_pass(encoder_t *encoder, ast_node_t *root) { +constexpr size_t instruction_size_estimate = 10; + +/** + * Perform the initial pass over the AST. + * + * - Collect information about the operands + * - parse and set number values + * - set the register values + * - determine if label references are used by an instruction + * - encode instructions that don't use label references + * - determine estimated addresses of each statement + * + */ +error_t *encoder_first_pass(encoder_t *encoder) { + ast_node_t *root = encoder->ast; + assert(root->id == NODE_PROGRAM); + + uintptr_t address = 0; + for (size_t i = 0; i < root->len; ++i) { - if (root->children[i]->id != NODE_INSTRUCTION) - continue; - ast_node_t *instruction = root->children[i]; - error_t *err = encoder_encode_instruction(encoder, instruction); + ast_node_t *statement = root->children[i]; + error_t *err = encoder_collect_info(encoder, statement, statement); if (err) return err; + + if (statement->id == NODE_INSTRUCTION && + ast_node_instruction_value(statement)->has_reference == false) { + err = encoder_encode_instruction(encoder, statement); + if (err) + return err; + instruction_t *instruction = ast_node_instruction_value(statement); + instruction->address = address; + address += instruction->encoding.len; + } else if (statement->id == NODE_INSTRUCTION) { + instruction_t *instruction = ast_node_instruction_value(statement); + instruction->encoding.len = instruction_size_estimate; + instruction->address = address; + address += instruction_size_estimate; + } else if (statement->id == NODE_LABEL) { + label_t *label = ast_node_label_value(statement); + label->address = address; + } + } + + return nullptr; +} + +operand_size_t signed_to_size_mask(int64_t value) { + operand_size_t size = OPERAND_SIZE_64; + + if (value >= INT8_MIN && value <= INT8_MAX) + size |= OPERAND_SIZE_8; + + if (value >= INT16_MIN && value <= INT16_MAX) + size |= OPERAND_SIZE_16; + + if (value >= INT32_MIN && value <= INT32_MAX) + size |= OPERAND_SIZE_32; + + return size; +} + +int64_t statement_offset(ast_node_t *from, ast_node_t *to) { + assert(from->id == NODE_INSTRUCTION); + assert(to->id == NODE_LABEL); + + instruction_t *instruction = ast_node_instruction_value(from); + int64_t from_addr = instruction->address + instruction->encoding.len; + int64_t to_addr = ast_node_label_value(to)->address; + + return to_addr - from_addr; +} + +error_t *encoder_collect_reference_info(encoder_t *encoder, ast_node_t *node, + ast_node_t *statement) { + assert(statement->id == NODE_INSTRUCTION); + if (node->id == NODE_LABEL_REFERENCE) { + const char *name = node->token_entry->token.value; + symbol_t *symbol = symbol_table_lookup(encoder->symbols, name); + assert(symbol && symbol->statement && + symbol->statement->id == NODE_LABEL); + + int64_t offset = statement_offset(statement, symbol->statement); + int64_t absolute = ast_node_label_value(symbol->statement)->address; + operand_size_t size = signed_to_size_mask(offset); + + node->value.reference.address = absolute; + node->value.reference.offset = offset; + node->value.reference.size = size; + } + + for (size_t i = 0; i < node->len; ++i) { + error_t *err = encoder_collect_reference_info( + encoder, node->children[i], statement); + if (err) + return err; + } + + return nullptr; +} + +bool encoder_should_reencode(ast_node_t *statement) { + if (statement->id != NODE_INSTRUCTION) + return false; + + instruction_t *instruction = ast_node_instruction_value(statement); + return instruction->has_reference; +} + +void set_statement_address(ast_node_t *statement, int64_t address) { + if (statement->id == NODE_INSTRUCTION) { + ast_node_instruction_value(statement)->address = address; + } else if (statement->id == NODE_LABEL) { + ast_node_label_value(statement)->address = address; + } +} + +size_t get_statement_length(ast_node_t *statement) { + if (statement->id != NODE_INSTRUCTION) + return 0; + return ast_node_instruction_value(statement)->encoding.len; +} + +/** + * Perform the second pass. Updates the label info and encodes all instructions + * that have a label reference.that performs actual encoding. + */ +error_t *encoder_second_pass(encoder_t *encoder, bool *did_update) { + ast_node_t *root = encoder->ast; + + *did_update = false; + int64_t address = 0; + for (size_t i = 0; i < root->len; ++i) { + ast_node_t *statement = root->children[i]; + + set_statement_address(statement, address); + size_t before = get_statement_length(statement); + + if (encoder_should_reencode(statement)) { + error_t *err = + encoder_collect_reference_info(encoder, statement, statement); + if (err) + return err; + err = encoder_encode_instruction(encoder, statement); + if (err) + return err; + } + + size_t after = get_statement_length(statement); + *did_update = *did_update || (before != after); + address += after; } return nullptr; } @@ -515,12 +693,19 @@ error_t *encoder_check_symbols(encoder_t *encoder) { return nullptr; } -error_t *encoder_encode(encoder_t *encoder, ast_node_t *ast) { - error_t *err = encoder_first_pass(encoder, ast); +error_t *encoder_encode(encoder_t *encoder) { + error_t *err = encoder_first_pass(encoder); if (err) return err; err = encoder_check_symbols(encoder); if (err) return err; - return encoder_encoding_pass(encoder, ast); + + bool did_update = true; + for (int i = 0; i < 10 && did_update; ++i) { + err = encoder_second_pass(encoder, &did_update); + if (err) + return err; + } + return nullptr; } diff --git a/src/encoder/encoder.h b/src/encoder/encoder.h index 45d34d7..39311bb 100644 --- a/src/encoder/encoder.h +++ b/src/encoder/encoder.h @@ -5,6 +5,7 @@ typedef struct encoder { symbol_table_t *symbols; + ast_node_t *ast; } encoder_t; constexpr uint8_t modrm_mod_memory = 0b00'000'000; @@ -16,8 +17,8 @@ constexpr uint8_t modrm_reg_mask = 0b00'111'000; constexpr uint8_t modrm_rm_mask = 0b00'000'111; constexpr uint8_t modrm_mod_mask = 0b11'000'000; -error_t *encoder_alloc(encoder_t **encoder); -error_t *encoder_encode(encoder_t *encoder, ast_node_t *ast); +error_t *encoder_alloc(encoder_t **encoder, ast_node_t *ast); +error_t *encoder_encode(encoder_t *encoder); void encoder_free(encoder_t *encoder); extern error_t *const err_encoder_invalid_register; diff --git a/src/encoder/symbols.c b/src/encoder/symbols.c index 29f5330..0095fce 100644 --- a/src/encoder/symbols.c +++ b/src/encoder/symbols.c @@ -92,7 +92,7 @@ EXPORT | | | ERR | | -------------|-----------|----------|----------|----------| */ -bool symbol_table_should_update(symbol_kind_t old, symbol_kind_t new) { +bool symbol_table_should_upgrade(symbol_kind_t old, symbol_kind_t new) { if (old == SYMBOL_REFERENCE) return new != SYMBOL_REFERENCE; if (old == SYMBOL_LOCAL) @@ -112,7 +112,7 @@ bool symbol_table_should_error(symbol_kind_t old, symbol_kind_t new) { * @pre The symbol _must not_ already be in the table. */ error_t *symbol_table_add(symbol_table_t *table, char *name, symbol_kind_t kind, - ast_node_t *node) { + ast_node_t *statement) { if (table->len >= table->cap) { error_t *err = symbol_table_grow_cap(table); if (err) @@ -122,7 +122,7 @@ error_t *symbol_table_add(symbol_table_t *table, char *name, symbol_kind_t kind, table->symbols[table->len] = (symbol_t){ .name = name, .kind = kind, - .node = node, + .statement = statement, }; table->len += 1; @@ -130,23 +130,29 @@ error_t *symbol_table_add(symbol_table_t *table, char *name, symbol_kind_t kind, return nullptr; } -error_t *symbol_table_update(symbol_table_t *table, ast_node_t *node) { +error_t *symbol_table_update(symbol_table_t *table, ast_node_t *node, + ast_node_t *statement) { char *name; symbol_kind_t kind; error_t *err = symbol_table_get_node_info(node, &kind, &name); if (err) return err; + if (kind != SYMBOL_LOCAL) + statement = nullptr; + symbol_t *symbol = symbol_table_lookup(table, name); if (!symbol) - return symbol_table_add(table, name, kind, node); + return symbol_table_add(table, name, kind, statement); if (symbol_table_should_error(symbol->kind, kind)) return err_symbol_table_incompatible_symbols; - if (symbol_table_should_update(symbol->kind, kind)) { - symbol->name = name; + if (symbol_table_should_upgrade(symbol->kind, kind)) { symbol->kind = kind; - symbol->node = node; } + + if (kind == SYMBOL_LOCAL && symbol->statement == nullptr) + symbol->statement = statement; + return nullptr; } diff --git a/src/encoder/symbols.h b/src/encoder/symbols.h index 9c4e7f7..ba0c144 100644 --- a/src/encoder/symbols.h +++ b/src/encoder/symbols.h @@ -29,7 +29,7 @@ typedef enum symbol_kind { typedef struct symbol { char *name; symbol_kind_t kind; - ast_node_t *node; + ast_node_t *statement; } symbol_t; typedef struct symbol_table { @@ -40,7 +40,8 @@ typedef struct symbol_table { error_t *symbol_table_alloc(symbol_table_t **table); void symbol_table_free(symbol_table_t *table); -error_t *symbol_table_update(symbol_table_t *table, ast_node_t *node); +error_t *symbol_table_update(symbol_table_t *table, ast_node_t *node, + ast_node_t *statement); symbol_t *symbol_table_lookup(symbol_table_t *table, const char *name); #endif // INCLUDE_ENCODER_SYMBOLS_H_ diff --git a/src/main.c b/src/main.c index fb00342..93e3aff 100644 --- a/src/main.c +++ b/src/main.c @@ -74,11 +74,11 @@ error_t *print_encoding(tokenlist_t *list) { return result.err; encoder_t *encoder; - error_t *err = encoder_alloc(&encoder); + error_t *err = encoder_alloc(&encoder, result.node); if (err) goto cleanup_ast; - err = encoder_encode(encoder, result.node); + err = encoder_encode(encoder); if (err) goto cleanup_ast; @@ -88,7 +88,8 @@ error_t *print_encoding(tokenlist_t *list) { if (node->id != NODE_INSTRUCTION) continue; - print_hex(node->value.encoding.len, node->value.encoding.encoding); + print_hex(node->value.instruction.encoding.len, + node->value.instruction.encoding.buffer); } encoder_free(encoder); diff --git a/src/parser/parser.c b/src/parser/parser.c index 05dcbf0..827e518 100644 --- a/src/parser/parser.c +++ b/src/parser/parser.c @@ -89,7 +89,8 @@ parse_result_t parse_immediate(tokenlist_entry_t *current) { } parse_result_t parse_memory_expression(tokenlist_entry_t *current) { - parser_t parsers[] = {parse_register_expression, parse_identifier, nullptr}; + parser_t parsers[] = {parse_register_expression, parse_label_reference, + nullptr}; return parse_any(current, parsers); } diff --git a/tests/symbols.c b/tests/symbols.c index f844e34..3808f03 100644 --- a/tests/symbols.c +++ b/tests/symbols.c @@ -58,17 +58,19 @@ MunitResult test_symbol_add_reference(const MunitParameter params[], void *data) symbol_table_alloc(&table); ast_node_t *reference = root->children[3]->children[1]->children[0]->children[0]; + ast_node_t *statement = root->children[3]; // The containing statement munit_assert_int(reference->id, ==, NODE_LABEL_REFERENCE); munit_assert_size(table->len, ==, 0); - error_t *err = symbol_table_update(table, reference); + error_t *err = symbol_table_update(table, reference, statement); munit_assert_null(err); munit_assert_size(table->len, ==, 1); symbol_t *symbol = symbol_table_lookup(table, "test"); munit_assert_not_null(symbol); munit_assert_int(SYMBOL_REFERENCE, ==, symbol->kind); - munit_assert_ptr_equal(reference, symbol->node); + // For references, the statement should be nullptr + munit_assert_ptr_null(symbol->statement); munit_assert_string_equal(symbol->name, "test"); symbol_table_free(table); @@ -90,14 +92,14 @@ MunitResult test_symbol_add_label(const MunitParameter params[], void *data) { munit_assert_int(label->id, ==, NODE_LABEL); munit_assert_size(table->len, ==, 0); - error_t *err = symbol_table_update(table, label); + error_t *err = symbol_table_update(table, label, label); munit_assert_null(err); munit_assert_size(table->len, ==, 1); symbol_t *symbol = symbol_table_lookup(table, "test"); munit_assert_not_null(symbol); munit_assert_int(SYMBOL_LOCAL, ==, symbol->kind); - munit_assert_ptr_equal(label, symbol->node); + munit_assert_ptr_equal(label, symbol->statement); munit_assert_string_equal(symbol->name, "test"); symbol_table_free(table); @@ -116,17 +118,19 @@ MunitResult test_symbol_add_import(const MunitParameter params[], void *data) { symbol_table_alloc(&table); ast_node_t *import_directive = root->children[0]->children[1]; + ast_node_t *statement = root->children[0]; // The containing statement munit_assert_int(import_directive->id, ==, NODE_IMPORT_DIRECTIVE); munit_assert_size(table->len, ==, 0); - error_t *err = symbol_table_update(table, import_directive); + error_t *err = symbol_table_update(table, import_directive, statement); munit_assert_null(err); munit_assert_size(table->len, ==, 1); symbol_t *symbol = symbol_table_lookup(table, "test"); munit_assert_not_null(symbol); munit_assert_int(SYMBOL_IMPORT, ==, symbol->kind); - munit_assert_ptr_equal(import_directive, symbol->node); + // For import directives, the statement should be nullptr + munit_assert_ptr_null(symbol->statement); munit_assert_string_equal(symbol->name, "test"); symbol_table_free(table); @@ -135,42 +139,56 @@ MunitResult test_symbol_add_import(const MunitParameter params[], void *data) { return MUNIT_OK; } -void test_symbol_update(const char *name, ast_node_t *first, symbol_kind_t first_kind, ast_node_t *second, - symbol_kind_t second_kind, bool should_succeed, bool should_update) { +void test_symbol_update(const char *name, ast_node_t *first, symbol_kind_t first_kind, ast_node_t *first_statement, + ast_node_t *second, symbol_kind_t second_kind, ast_node_t *second_statement, + bool should_succeed, bool should_update, ast_node_t *expected_statement) { symbol_table_t *table = nullptr; symbol_table_alloc(&table); - munit_assert_size(table->len, ==, 0); - error_t *err = symbol_table_update(table, first); + // Add the first symbol + error_t *err = symbol_table_update(table, first, first_statement); munit_assert_null(err); munit_assert_size(table->len, ==, 1); + // Verify first symbol state symbol_t *symbol = symbol_table_lookup(table, name); munit_assert_not_null(symbol); munit_assert_int(first_kind, ==, symbol->kind); - munit_assert_ptr_equal(first, symbol->node); munit_assert_string_equal(symbol->name, name); - err = symbol_table_update(table, second); - if (should_succeed) - munit_assert_null(err); - else - munit_assert_ptr_equal(err, err_symbol_table_incompatible_symbols); - munit_assert_size(table->len, ==, 1); - - symbol = symbol_table_lookup(table, name); - if (should_update) { - munit_assert_not_null(symbol); - munit_assert_int(second_kind, ==, symbol->kind); - munit_assert_ptr_equal(second, symbol->node); - munit_assert_string_equal(symbol->name, name); + // Check statement based on symbol kind + if (first_kind == SYMBOL_LOCAL) { + munit_assert_ptr_equal(first_statement, symbol->statement); } else { - munit_assert_not_null(symbol); - munit_assert_int(first_kind, ==, symbol->kind); - munit_assert_ptr_equal(first, symbol->node); - munit_assert_string_equal(symbol->name, name); + munit_assert_ptr_null(symbol->statement); } + // Attempt the second update + err = symbol_table_update(table, second, second_statement); + + // Check if update succeeded as expected + if (should_succeed) { + munit_assert_null(err); + } else { + munit_assert_ptr_equal(err, err_symbol_table_incompatible_symbols); + symbol_table_free(table); + return; + } + + // Verify symbol after second update + symbol = symbol_table_lookup(table, name); + munit_assert_not_null(symbol); + + // Check if kind updated as expected + if (should_update) { + munit_assert_int(second_kind, ==, symbol->kind); + } else { + munit_assert_int(first_kind, ==, symbol->kind); + } + + // Simply check against the expected statement value + munit_assert_ptr_equal(expected_statement, symbol->statement); + symbol_table_free(table); } @@ -181,28 +199,43 @@ MunitResult test_symbol_upgrade_valid(const MunitParameter params[], void *data) symbols_setup_test(&root, &list, "tests/input/symbols.asm"); ast_node_t *reference = root->children[3]->children[1]->children[0]->children[0]; + ast_node_t *reference_statement = root->children[3]; ast_node_t *label = root->children[2]; ast_node_t *import_directive = root->children[0]->children[1]; + ast_node_t *import_statement = root->children[0]; ast_node_t *export_directive = root->children[1]->children[1]; + ast_node_t *export_statement = root->children[1]; // real upgrades - test_symbol_update("test", reference, SYMBOL_REFERENCE, label, SYMBOL_LOCAL, true, true); - test_symbol_update("test", reference, SYMBOL_REFERENCE, import_directive, SYMBOL_IMPORT, true, true); - test_symbol_update("test", reference, SYMBOL_REFERENCE, export_directive, SYMBOL_EXPORT, true, true); - test_symbol_update("test", label, SYMBOL_LOCAL, export_directive, SYMBOL_EXPORT, true, true); + test_symbol_update("test", reference, SYMBOL_REFERENCE, reference_statement, label, SYMBOL_LOCAL, label, true, true, + label); + test_symbol_update("test", reference, SYMBOL_REFERENCE, reference_statement, import_directive, SYMBOL_IMPORT, + import_statement, true, true, nullptr); + test_symbol_update("test", reference, SYMBOL_REFERENCE, reference_statement, export_directive, SYMBOL_EXPORT, + export_statement, true, true, nullptr); + test_symbol_update("test", label, SYMBOL_LOCAL, label, export_directive, SYMBOL_EXPORT, export_statement, true, + true, label); // identity upgrades - test_symbol_update("test", reference, SYMBOL_REFERENCE, reference, SYMBOL_REFERENCE, true, false); - test_symbol_update("test", label, SYMBOL_LOCAL, label, SYMBOL_LOCAL, true, false); - test_symbol_update("test", import_directive, SYMBOL_IMPORT, import_directive, SYMBOL_IMPORT, true, false); - test_symbol_update("test", export_directive, SYMBOL_EXPORT, export_directive, SYMBOL_EXPORT, true, false); + test_symbol_update("test", reference, SYMBOL_REFERENCE, reference_statement, reference, SYMBOL_REFERENCE, + reference_statement, true, false, nullptr); + test_symbol_update("test", label, SYMBOL_LOCAL, label, label, SYMBOL_LOCAL, label, true, false, label); + test_symbol_update("test", import_directive, SYMBOL_IMPORT, import_statement, import_directive, SYMBOL_IMPORT, + import_statement, true, false, nullptr); + test_symbol_update("test", export_directive, SYMBOL_EXPORT, export_statement, export_directive, SYMBOL_EXPORT, + export_statement, true, false, nullptr); // downgrades that are allowed and ignored - test_symbol_update("test", label, SYMBOL_LOCAL, reference, SYMBOL_REFERENCE, true, false); - test_symbol_update("test", import_directive, SYMBOL_IMPORT, reference, SYMBOL_REFERENCE, true, false); - test_symbol_update("test", export_directive, SYMBOL_EXPORT, reference, SYMBOL_REFERENCE, true, false); - test_symbol_update("test", export_directive, SYMBOL_EXPORT, label, SYMBOL_LOCAL, true, false); - test_symbol_update("test", import_directive, SYMBOL_IMPORT, label, SYMBOL_LOCAL, true, false); + test_symbol_update("test", label, SYMBOL_LOCAL, label, reference, SYMBOL_REFERENCE, reference_statement, true, + false, label); + test_symbol_update("test", import_directive, SYMBOL_IMPORT, import_statement, reference, SYMBOL_REFERENCE, + reference_statement, true, false, nullptr); + test_symbol_update("test", export_directive, SYMBOL_EXPORT, export_statement, reference, SYMBOL_REFERENCE, + reference_statement, true, false, nullptr); + test_symbol_update("test", export_directive, SYMBOL_EXPORT, export_statement, label, SYMBOL_LOCAL, label, true, + false, label); + test_symbol_update("test", import_directive, SYMBOL_IMPORT, import_statement, label, SYMBOL_LOCAL, label, true, + false, label); ast_node_free(root); tokenlist_free(list); @@ -216,14 +249,20 @@ MunitResult test_symbol_upgrade_invalid(const MunitParameter params[], void *dat symbols_setup_test(&root, &list, "tests/input/symbols.asm"); ast_node_t *reference = root->children[3]->children[1]->children[0]->children[0]; + ast_node_t *reference_statement = root->children[3]; ast_node_t *label = root->children[2]; ast_node_t *import_directive = root->children[0]->children[1]; + ast_node_t *import_statement = root->children[0]; ast_node_t *export_directive = root->children[1]->children[1]; + ast_node_t *export_statement = root->children[1]; // invalid upgrades - test_symbol_update("test", label, SYMBOL_LOCAL, import_directive, SYMBOL_IMPORT, false, false); - test_symbol_update("test", export_directive, SYMBOL_EXPORT, import_directive, SYMBOL_IMPORT, false, false); - test_symbol_update("test", import_directive, SYMBOL_IMPORT, export_directive, SYMBOL_EXPORT, false, false); + test_symbol_update("test", label, SYMBOL_LOCAL, label, import_directive, SYMBOL_IMPORT, import_statement, false, + false, nullptr); + test_symbol_update("test", export_directive, SYMBOL_EXPORT, export_statement, import_directive, SYMBOL_IMPORT, + import_statement, false, false, nullptr); + test_symbol_update("test", import_directive, SYMBOL_IMPORT, import_statement, export_directive, SYMBOL_EXPORT, + export_statement, false, false, nullptr); ast_node_free(root); tokenlist_free(list); @@ -240,17 +279,19 @@ MunitResult test_symbol_add_export(const MunitParameter params[], void *data) { symbol_table_alloc(&table); ast_node_t *export_directive = root->children[1]->children[1]; + ast_node_t *statement = root->children[1]; // The containing statement munit_assert_int(export_directive->id, ==, NODE_EXPORT_DIRECTIVE); munit_assert_size(table->len, ==, 0); - error_t *err = symbol_table_update(table, export_directive); + error_t *err = symbol_table_update(table, export_directive, statement); munit_assert_null(err); munit_assert_size(table->len, ==, 1); symbol_t *symbol = symbol_table_lookup(table, "test"); munit_assert_not_null(symbol); munit_assert_int(SYMBOL_EXPORT, ==, symbol->kind); - munit_assert_ptr_equal(export_directive, symbol->node); + // For export directives, the statement should be nullptr + munit_assert_ptr_null(symbol->statement); munit_assert_string_equal(symbol->name, "test"); symbol_table_free(table); @@ -280,7 +321,7 @@ MunitResult test_symbol_table_growth(const MunitParameter params[], void *data) ast_node_t *label = root->children[i]; munit_assert_int(label->id, ==, NODE_LABEL); - error_t *err = symbol_table_update(table, label); + error_t *err = symbol_table_update(table, label, label); munit_assert_null(err); munit_assert_size(table->len, ==, i + 1); @@ -292,7 +333,7 @@ MunitResult test_symbol_table_growth(const MunitParameter params[], void *data) ast_node_t *final_label = root->children[64]; munit_assert_int(final_label->id, ==, NODE_LABEL); - error_t *err = symbol_table_update(table, final_label); + error_t *err = symbol_table_update(table, final_label, final_label); munit_assert_null(err); munit_assert_size(table->len, ==, 65); @@ -308,6 +349,7 @@ MunitResult test_symbol_table_growth(const MunitParameter params[], void *data) munit_assert_not_null(symbol); munit_assert_int(SYMBOL_LOCAL, ==, symbol->kind); munit_assert_string_equal(symbol->name, name); + munit_assert_ptr_equal(symbol->statement, root->children[i]); } symbol_table_free(table); @@ -326,7 +368,7 @@ MunitResult test_symbol_invalid_node(const MunitParameter params[], void *data) symbol_table_alloc(&table); munit_assert_size(table->len, ==, 0); - error_t *err = symbol_table_update(table, root); + error_t *err = symbol_table_update(table, root, root); munit_assert_ptr_equal(err, err_symbol_table_invalid_node); munit_assert_size(table->len, ==, 0);