Compare commits

..

2 Commits

Author SHA1 Message Date
fab5bedf3d Implement two pass encoding
Some checks failed
Validate the build / validate-build (push) Failing after 50s
First pass:
 - collect information for numbers, registers and which instructions
   contain label references
 - encode all instructions that don't contain label references
 - Set (temporary) addresses for each instruction

Second pass:
 - Collect information about label references (address, offset, size)
 - encode all instructions that contain label references
 - Update (if necessary) addresses for each instruction

 The second pass is iterated 10 times or until no instructions change
 size, whichever comes first.
2025-04-22 02:08:38 +02:00
9a1570e3e5 Add more values to the ast to facilitate encoding
- Add a instruction value that contains the encoding, the address and a
  flag to indicate if this instruction contains label references
- Add label value that contains an address
- Add reference value that contains offset, an absolute address and an
  operand size
2025-04-22 00:54:12 +02:00
4 changed files with 103 additions and 336 deletions

View File

@ -5,7 +5,6 @@
#include "error.h" #include "error.h"
#include "lexer.h" #include "lexer.h"
#include "tokenlist.h" #include "tokenlist.h"
#include <assert.h>
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
@ -80,22 +79,6 @@ typedef struct opcode_encoding {
size_t len; size_t len;
} opcode_encoding_t; } opcode_encoding_t;
typedef struct instruction {
bool has_reference;
opcode_encoding_t encoding;
int64_t address;
} instruction_t;
typedef struct reference {
int64_t offset;
int64_t address;
operand_size_t size;
} reference_t;
typedef struct {
int64_t address;
} label_t;
struct ast_node { struct ast_node {
node_id_t id; node_id_t id;
tokenlist_entry_t *token_entry; tokenlist_entry_t *token_entry;
@ -106,37 +89,22 @@ struct ast_node {
union { union {
register_t reg; register_t reg;
number_t number; number_t number;
instruction_t instruction; struct {
reference_t reference; bool has_reference;
label_t label; opcode_encoding_t encoding;
int64_t address;
} instruction;
struct {
int64_t offset;
int64_t address;
operand_size_t size;
} reference;
struct {
int64_t address;
} label;
} value; } value;
}; };
static inline register_t *ast_node_register_value(ast_node_t *node) {
assert(node->id == NODE_REGISTER);
return &node->value.reg;
}
static inline number_t *ast_node_number_value(ast_node_t *node) {
assert(node->id == NODE_NUMBER);
return &node->value.number;
}
static inline instruction_t *ast_node_instruction_value(ast_node_t *node) {
assert(node->id == NODE_INSTRUCTION);
return &node->value.instruction;
}
static inline reference_t *ast_node_reference_value(ast_node_t *node) {
assert(node->id == NODE_LABEL_REFERENCE);
return &node->value.reference;
}
static inline label_t *ast_node_label_value(ast_node_t *node) {
assert(node->id == NODE_LABEL);
return &node->value.label;
}
/** /**
* @brief Allocates a new AST node * @brief Allocates a new AST node
* *

View File

@ -138,128 +138,8 @@ opcode_data_t *const opcodes[] = {
{ .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_64 }, { .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_64 },
}, },
}, },
// CALL rel32
&(opcode_data_t) {
.mnemonic = "call",
.opcode = 0xE8,
.opcode_extension = opcode_extension_none,
.encoding_class = ENCODING_DEFAULT,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_IMMEDIATE, .size = OPERAND_SIZE_32 },
},
},
// CALL reg64
&(opcode_data_t) {
.mnemonic = "call",
.opcode = 0xFF,
.opcode_extension = 2,
.encoding_class = ENCODING_DEFAULT,
.rex_w_prefix = true,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_64 },
},
},
// CALL mem64
&(opcode_data_t) {
.mnemonic = "call",
.opcode = 0xFF,
.opcode_extension = 2,
.encoding_class = ENCODING_DEFAULT,
.rex_w_prefix = true,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_MEMORY, .size = OPERAND_SIZE_64 },
},
},
// JMP rel8 (short jump)
&(opcode_data_t) {
.mnemonic = "jmp",
.opcode = 0xEB,
.opcode_extension = opcode_extension_none,
.encoding_class = ENCODING_DEFAULT,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_IMMEDIATE, .size = OPERAND_SIZE_8 },
},
},
// JMP rel16
&(opcode_data_t) {
.mnemonic = "jmp",
.opcode = 0xE9,
.opcode_extension = opcode_extension_none,
.encoding_class = ENCODING_DEFAULT,
.operand_size_prefix = true,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_IMMEDIATE, .size = OPERAND_SIZE_16 },
},
},
// JMP reg16
&(opcode_data_t) {
.mnemonic = "jmp",
.opcode = 0xFF,
.opcode_extension = 4,
.encoding_class = ENCODING_DEFAULT,
.operand_size_prefix = true,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_16 },
},
},
// JMP rel32 (near jump)
&(opcode_data_t) {
.mnemonic = "jmp",
.opcode = 0xE9,
.opcode_extension = opcode_extension_none,
.encoding_class = ENCODING_DEFAULT,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_IMMEDIATE, .size = OPERAND_SIZE_32 },
},
},
// JMP reg32
&(opcode_data_t) {
.mnemonic = "jmp",
.opcode = 0xFF,
.opcode_extension = 4,
.encoding_class = ENCODING_DEFAULT,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_32 },
},
},
// JMP reg64
&(opcode_data_t) {
.mnemonic = "jmp",
.opcode = 0xFF,
.opcode_extension = 4,
.encoding_class = ENCODING_DEFAULT,
.rex_w_prefix = true,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_64 },
},
},
// JMP mem64
&(opcode_data_t) {
.mnemonic = "jmp",
.opcode = 0xFF,
.opcode_extension = 4,
.encoding_class = ENCODING_DEFAULT,
.rex_w_prefix = true,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_MEMORY, .size = OPERAND_SIZE_64 },
},
},
nullptr, nullptr,
}; };

View File

@ -269,7 +269,7 @@ bool is_operand_match(operand_info_t *info, ast_node_t *operand) {
switch (info->kind) { switch (info->kind) {
case OPERAND_REGISTER: case OPERAND_REGISTER:
return operand->id == NODE_REGISTER && return operand->id == NODE_REGISTER &&
ast_node_register_value(operand)->size == info->size; operand->value.reg.size == info->size;
case OPERAND_MEMORY: case OPERAND_MEMORY:
return operand->id == NODE_MEMORY; return operand->id == NODE_MEMORY;
case OPERAND_IMMEDIATE: { case OPERAND_IMMEDIATE: {
@ -278,10 +278,13 @@ bool is_operand_match(operand_info_t *info, ast_node_t *operand) {
ast_node_t *child = operand->children[0]; ast_node_t *child = operand->children[0];
if (child->id == NODE_NUMBER) if (child->id == NODE_NUMBER)
return (ast_node_number_value(child)->size & info->size) > 0; return (child->value.number.size & info->size) > 0;
else if (child->id == NODE_LABEL_REFERENCE) { else if (child->id == NODE_LABEL_REFERENCE)
return info->size &= ast_node_reference_value(child)->size; return info->size == OPERAND_SIZE_32;
} // FIXME: first pass should give us information about the distance of
// the label reference so we can pick a size more appropriately instead
// of just defaulting to 32 bits
break;
} // end OPERAND_IMMEDIATE case } // end OPERAND_IMMEDIATE case
} }
assert(false && "unreachable"); assert(false && "unreachable");
@ -337,7 +340,7 @@ error_t *encode_one_register_in_opcode(encoder_t *encoder,
(void)encoder; (void)encoder;
(void)opcode; (void)opcode;
register_id_t id = ast_node_register_value(operands->children[0])->id; register_id_t id = operands->children[0]->value.reg.id;
encoding->buffer[encoding->len - 1] |= id & 0b111; encoding->buffer[encoding->len - 1] |= id & 0b111;
if ((id & 0b1000) > 0) { if ((id & 0b1000) > 0) {
*rex |= rex_prefix_r; *rex |= rex_prefix_r;
@ -352,7 +355,7 @@ error_t *encode_one_register(encoder_t *encoder, opcode_data_t *opcode,
assert(operands->len == 1); assert(operands->len == 1);
assert(operands->children[0]->id == NODE_REGISTER); assert(operands->children[0]->id == NODE_REGISTER);
register_id_t id = ast_node_register_value(operands->children[0])->id; register_id_t id = operands->children[0]->value.reg.id;
uint8_t modrm = modrm_mod_register; uint8_t modrm = modrm_mod_register;
@ -386,9 +389,9 @@ error_t *encode_one_immediate(encoder_t *encoder, opcode_data_t *opcode,
assert(immediate->id == NODE_NUMBER || assert(immediate->id == NODE_NUMBER ||
immediate->id == NODE_LABEL_REFERENCE); immediate->id == NODE_LABEL_REFERENCE);
operand_size_t size = opcode->operands[0].size;
if (immediate->id == NODE_NUMBER) { if (immediate->id == NODE_NUMBER) {
uint64_t value = ast_node_number_value(immediate)->value; uint64_t value = immediate->value.number.value;
operand_size_t size = opcode->operands[0].size;
error_t *err = nullptr; error_t *err = nullptr;
switch (size) { switch (size) {
case OPERAND_SIZE_8: case OPERAND_SIZE_8:
@ -408,21 +411,10 @@ error_t *encode_one_immediate(encoder_t *encoder, opcode_data_t *opcode,
} }
return err; return err;
} else { } else {
reference_t *reference = ast_node_reference_value(immediate); // FIXME: this still assumes references are always 32 bit
switch (size) { uint32_t value = 0xDEADBEEF;
case OPERAND_SIZE_64: return bytes_append_uint32(encoding, value);
return bytes_append_uint64(encoding, reference->address);
case OPERAND_SIZE_32:
return bytes_append_uint32(encoding, reference->offset);
case OPERAND_SIZE_16:
return bytes_append_uint16(encoding, reference->offset);
case OPERAND_SIZE_8:
return bytes_append_uint8(encoding, reference->offset);
default:
assert(false && "intentionally unhandled");
} }
}
__builtin_unreachable();
} }
error_t *encode_one_memory(encoder_t *encoder, opcode_data_t *opcode, error_t *encode_one_memory(encoder_t *encoder, opcode_data_t *opcode,
@ -491,8 +483,7 @@ error_t *encoder_encode_instruction(encoder_t *encoder,
return err; return err;
// produce the actual encoding output in the NODE_INSTRUCTION value // produce the actual encoding output in the NODE_INSTRUCTION value
instruction_t *instruction_value = ast_node_instruction_value(instruction); uint8_t *output = instruction->value.instruction.encoding.buffer;
uint8_t *output = instruction_value->encoding.buffer;
size_t output_len = 0; size_t output_len = 0;
// Handle prefixes // Handle prefixes
@ -511,17 +502,11 @@ error_t *encoder_encode_instruction(encoder_t *encoder,
memcpy(output + output_len, encoding->buffer, encoding->len); memcpy(output + output_len, encoding->buffer, encoding->len);
output_len += encoding->len; output_len += encoding->len;
instruction_value->encoding.len = output_len; instruction->value.instruction.encoding.len = output_len;
return nullptr; return nullptr;
} }
/**
* Initial guess for instruction size of instructions that contain a label
* reference
*/
constexpr size_t instruction_size_estimate = 10;
/** /**
* Perform the initial pass over the AST. * Perform the initial pass over the AST.
* *
@ -533,6 +518,7 @@ constexpr size_t instruction_size_estimate = 10;
* - determine estimated addresses of each statement * - determine estimated addresses of each statement
* *
*/ */
constexpr size_t instruction_size_estimate = 10;
error_t *encoder_first_pass(encoder_t *encoder) { error_t *encoder_first_pass(encoder_t *encoder) {
ast_node_t *root = encoder->ast; ast_node_t *root = encoder->ast;
assert(root->id == NODE_PROGRAM); assert(root->id == NODE_PROGRAM);
@ -546,21 +532,19 @@ error_t *encoder_first_pass(encoder_t *encoder) {
return err; return err;
if (statement->id == NODE_INSTRUCTION && if (statement->id == NODE_INSTRUCTION &&
ast_node_instruction_value(statement)->has_reference == false) { statement->value.instruction.has_reference == false) {
err = encoder_encode_instruction(encoder, statement); err = encoder_encode_instruction(encoder, statement);
if (err) if (err)
return err; return err;
instruction_t *instruction = ast_node_instruction_value(statement); statement->value.instruction.address = address;
instruction->address = address; address += statement->value.instruction.encoding.len;
address += instruction->encoding.len;
} else if (statement->id == NODE_INSTRUCTION) { } else if (statement->id == NODE_INSTRUCTION) {
instruction_t *instruction = ast_node_instruction_value(statement); statement->value.instruction.encoding.len =
instruction->encoding.len = instruction_size_estimate; instruction_size_estimate;
instruction->address = address; statement->value.instruction.address = address;
address += instruction_size_estimate; address += instruction_size_estimate;
} else if (statement->id == NODE_LABEL) { } else if (statement->id == NODE_LABEL) {
label_t *label = ast_node_label_value(statement); statement->value.instruction.address = address;
label->address = address;
} }
} }
@ -583,17 +567,17 @@ operand_size_t signed_to_size_mask(int64_t value) {
} }
int64_t statement_offset(ast_node_t *from, ast_node_t *to) { int64_t statement_offset(ast_node_t *from, ast_node_t *to) {
assert(from->id == NODE_INSTRUCTION); assert(from->id == NODE_LABEL || from->id == NODE_INSTRUCTION);
assert(to->id == NODE_LABEL); assert(to->id == NODE_LABEL || to->id == NODE_INSTRUCTION);
instruction_t *instruction = ast_node_instruction_value(from); int64_t from_addr =
int64_t from_addr = instruction->address + instruction->encoding.len; from->value.instruction.address + from->value.instruction.encoding.len;
int64_t to_addr = ast_node_label_value(to)->address; int64_t to_addr = to->value.instruction.address;
return to_addr - from_addr; return to_addr - from_addr;
} }
error_t *encoder_collect_reference_info(encoder_t *encoder, ast_node_t *node, error_t *encoder_collect_label_info(encoder_t *encoder, ast_node_t *node,
ast_node_t *statement) { ast_node_t *statement) {
assert(statement->id == NODE_INSTRUCTION); assert(statement->id == NODE_INSTRUCTION);
if (node->id == NODE_LABEL_REFERENCE) { if (node->id == NODE_LABEL_REFERENCE) {
@ -603,7 +587,7 @@ error_t *encoder_collect_reference_info(encoder_t *encoder, ast_node_t *node,
symbol->statement->id == NODE_LABEL); symbol->statement->id == NODE_LABEL);
int64_t offset = statement_offset(statement, symbol->statement); int64_t offset = statement_offset(statement, symbol->statement);
int64_t absolute = ast_node_label_value(symbol->statement)->address; int64_t absolute = symbol->statement->value.instruction.address;
operand_size_t size = signed_to_size_mask(offset); operand_size_t size = signed_to_size_mask(offset);
node->value.reference.address = absolute; node->value.reference.address = absolute;
@ -611,38 +595,9 @@ error_t *encoder_collect_reference_info(encoder_t *encoder, ast_node_t *node,
node->value.reference.size = size; node->value.reference.size = size;
} }
for (size_t i = 0; i < node->len; ++i) {
error_t *err = encoder_collect_reference_info(
encoder, node->children[i], statement);
if (err)
return err;
}
return nullptr; return nullptr;
} }
bool encoder_should_reencode(ast_node_t *statement) {
if (statement->id != NODE_INSTRUCTION)
return false;
instruction_t *instruction = ast_node_instruction_value(statement);
return instruction->has_reference;
}
void set_statement_address(ast_node_t *statement, int64_t address) {
if (statement->id == NODE_INSTRUCTION) {
ast_node_instruction_value(statement)->address = address;
} else if (statement->id == NODE_LABEL) {
ast_node_label_value(statement)->address = address;
}
}
size_t get_statement_length(ast_node_t *statement) {
if (statement->id != NODE_INSTRUCTION)
return 0;
return ast_node_instruction_value(statement)->encoding.len;
}
/** /**
* Perform the second pass. Updates the label info and encodes all instructions * Perform the second pass. Updates the label info and encodes all instructions
* that have a label reference.that performs actual encoding. * that have a label reference.that performs actual encoding.
@ -655,22 +610,28 @@ error_t *encoder_second_pass(encoder_t *encoder, bool *did_update) {
for (size_t i = 0; i < root->len; ++i) { for (size_t i = 0; i < root->len; ++i) {
ast_node_t *statement = root->children[i]; ast_node_t *statement = root->children[i];
set_statement_address(statement, address); if (statement->id == NODE_INSTRUCTION &&
size_t before = get_statement_length(statement); statement->value.instruction.has_reference) {
statement->value.instruction.address = address;
if (encoder_should_reencode(statement)) { size_t before = statement->value.instruction.encoding.len;
error_t *err = error_t *err =
encoder_collect_reference_info(encoder, statement, statement); encoder_collect_label_info(encoder, statement, statement);
if (err) if (err)
return err; return err;
err = encoder_encode_instruction(encoder, statement); err = encoder_encode_instruction(encoder, statement);
if (err) if (err)
return err; return err;
} size_t after = statement->value.instruction.encoding.len;
size_t after = get_statement_length(statement);
*did_update = *did_update || (before != after);
address += after; address += after;
*did_update = *did_update || (before != after);
} else if (statement->id == NODE_INSTRUCTION &&
statement->value.instruction.has_reference) {
statement->value.instruction.address = address;
address += statement->value.instruction.encoding.len;
} else if (statement->id == NODE_LABEL) {
statement->value.label.address = address;
}
address += statement->value.instruction.encoding.len;
} }
return nullptr; return nullptr;
} }

View File

@ -58,19 +58,17 @@ MunitResult test_symbol_add_reference(const MunitParameter params[], void *data)
symbol_table_alloc(&table); symbol_table_alloc(&table);
ast_node_t *reference = root->children[3]->children[1]->children[0]->children[0]; ast_node_t *reference = root->children[3]->children[1]->children[0]->children[0];
ast_node_t *statement = root->children[3]; // The containing statement
munit_assert_int(reference->id, ==, NODE_LABEL_REFERENCE); munit_assert_int(reference->id, ==, NODE_LABEL_REFERENCE);
munit_assert_size(table->len, ==, 0); munit_assert_size(table->len, ==, 0);
error_t *err = symbol_table_update(table, reference, statement); error_t *err = symbol_table_update(table, reference);
munit_assert_null(err); munit_assert_null(err);
munit_assert_size(table->len, ==, 1); munit_assert_size(table->len, ==, 1);
symbol_t *symbol = symbol_table_lookup(table, "test"); symbol_t *symbol = symbol_table_lookup(table, "test");
munit_assert_not_null(symbol); munit_assert_not_null(symbol);
munit_assert_int(SYMBOL_REFERENCE, ==, symbol->kind); munit_assert_int(SYMBOL_REFERENCE, ==, symbol->kind);
// For references, the statement should be nullptr munit_assert_ptr_equal(reference, symbol->node);
munit_assert_ptr_null(symbol->statement);
munit_assert_string_equal(symbol->name, "test"); munit_assert_string_equal(symbol->name, "test");
symbol_table_free(table); symbol_table_free(table);
@ -92,14 +90,14 @@ MunitResult test_symbol_add_label(const MunitParameter params[], void *data) {
munit_assert_int(label->id, ==, NODE_LABEL); munit_assert_int(label->id, ==, NODE_LABEL);
munit_assert_size(table->len, ==, 0); munit_assert_size(table->len, ==, 0);
error_t *err = symbol_table_update(table, label, label); error_t *err = symbol_table_update(table, label);
munit_assert_null(err); munit_assert_null(err);
munit_assert_size(table->len, ==, 1); munit_assert_size(table->len, ==, 1);
symbol_t *symbol = symbol_table_lookup(table, "test"); symbol_t *symbol = symbol_table_lookup(table, "test");
munit_assert_not_null(symbol); munit_assert_not_null(symbol);
munit_assert_int(SYMBOL_LOCAL, ==, symbol->kind); munit_assert_int(SYMBOL_LOCAL, ==, symbol->kind);
munit_assert_ptr_equal(label, symbol->statement); munit_assert_ptr_equal(label, symbol->node);
munit_assert_string_equal(symbol->name, "test"); munit_assert_string_equal(symbol->name, "test");
symbol_table_free(table); symbol_table_free(table);
@ -118,19 +116,17 @@ MunitResult test_symbol_add_import(const MunitParameter params[], void *data) {
symbol_table_alloc(&table); symbol_table_alloc(&table);
ast_node_t *import_directive = root->children[0]->children[1]; ast_node_t *import_directive = root->children[0]->children[1];
ast_node_t *statement = root->children[0]; // The containing statement
munit_assert_int(import_directive->id, ==, NODE_IMPORT_DIRECTIVE); munit_assert_int(import_directive->id, ==, NODE_IMPORT_DIRECTIVE);
munit_assert_size(table->len, ==, 0); munit_assert_size(table->len, ==, 0);
error_t *err = symbol_table_update(table, import_directive, statement); error_t *err = symbol_table_update(table, import_directive);
munit_assert_null(err); munit_assert_null(err);
munit_assert_size(table->len, ==, 1); munit_assert_size(table->len, ==, 1);
symbol_t *symbol = symbol_table_lookup(table, "test"); symbol_t *symbol = symbol_table_lookup(table, "test");
munit_assert_not_null(symbol); munit_assert_not_null(symbol);
munit_assert_int(SYMBOL_IMPORT, ==, symbol->kind); munit_assert_int(SYMBOL_IMPORT, ==, symbol->kind);
// For import directives, the statement should be nullptr munit_assert_ptr_equal(import_directive, symbol->node);
munit_assert_ptr_null(symbol->statement);
munit_assert_string_equal(symbol->name, "test"); munit_assert_string_equal(symbol->name, "test");
symbol_table_free(table); symbol_table_free(table);
@ -139,56 +135,42 @@ MunitResult test_symbol_add_import(const MunitParameter params[], void *data) {
return MUNIT_OK; return MUNIT_OK;
} }
void test_symbol_update(const char *name, ast_node_t *first, symbol_kind_t first_kind, ast_node_t *first_statement, void test_symbol_update(const char *name, ast_node_t *first, symbol_kind_t first_kind, ast_node_t *second,
ast_node_t *second, symbol_kind_t second_kind, ast_node_t *second_statement, symbol_kind_t second_kind, bool should_succeed, bool should_update) {
bool should_succeed, bool should_update, ast_node_t *expected_statement) {
symbol_table_t *table = nullptr; symbol_table_t *table = nullptr;
symbol_table_alloc(&table); symbol_table_alloc(&table);
// Add the first symbol munit_assert_size(table->len, ==, 0);
error_t *err = symbol_table_update(table, first, first_statement); error_t *err = symbol_table_update(table, first);
munit_assert_null(err); munit_assert_null(err);
munit_assert_size(table->len, ==, 1); munit_assert_size(table->len, ==, 1);
// Verify first symbol state
symbol_t *symbol = symbol_table_lookup(table, name); symbol_t *symbol = symbol_table_lookup(table, name);
munit_assert_not_null(symbol); munit_assert_not_null(symbol);
munit_assert_int(first_kind, ==, symbol->kind); munit_assert_int(first_kind, ==, symbol->kind);
munit_assert_ptr_equal(first, symbol->node);
munit_assert_string_equal(symbol->name, name); munit_assert_string_equal(symbol->name, name);
// Check statement based on symbol kind err = symbol_table_update(table, second);
if (first_kind == SYMBOL_LOCAL) { if (should_succeed)
munit_assert_ptr_equal(first_statement, symbol->statement);
} else {
munit_assert_ptr_null(symbol->statement);
}
// Attempt the second update
err = symbol_table_update(table, second, second_statement);
// Check if update succeeded as expected
if (should_succeed) {
munit_assert_null(err); munit_assert_null(err);
} else { else
munit_assert_ptr_equal(err, err_symbol_table_incompatible_symbols); munit_assert_ptr_equal(err, err_symbol_table_incompatible_symbols);
symbol_table_free(table); munit_assert_size(table->len, ==, 1);
return;
}
// Verify symbol after second update
symbol = symbol_table_lookup(table, name); symbol = symbol_table_lookup(table, name);
munit_assert_not_null(symbol);
// Check if kind updated as expected
if (should_update) { if (should_update) {
munit_assert_not_null(symbol);
munit_assert_int(second_kind, ==, symbol->kind); munit_assert_int(second_kind, ==, symbol->kind);
munit_assert_ptr_equal(second, symbol->node);
munit_assert_string_equal(symbol->name, name);
} else { } else {
munit_assert_not_null(symbol);
munit_assert_int(first_kind, ==, symbol->kind); munit_assert_int(first_kind, ==, symbol->kind);
munit_assert_ptr_equal(first, symbol->node);
munit_assert_string_equal(symbol->name, name);
} }
// Simply check against the expected statement value
munit_assert_ptr_equal(expected_statement, symbol->statement);
symbol_table_free(table); symbol_table_free(table);
} }
@ -199,43 +181,28 @@ MunitResult test_symbol_upgrade_valid(const MunitParameter params[], void *data)
symbols_setup_test(&root, &list, "tests/input/symbols.asm"); symbols_setup_test(&root, &list, "tests/input/symbols.asm");
ast_node_t *reference = root->children[3]->children[1]->children[0]->children[0]; ast_node_t *reference = root->children[3]->children[1]->children[0]->children[0];
ast_node_t *reference_statement = root->children[3];
ast_node_t *label = root->children[2]; ast_node_t *label = root->children[2];
ast_node_t *import_directive = root->children[0]->children[1]; ast_node_t *import_directive = root->children[0]->children[1];
ast_node_t *import_statement = root->children[0];
ast_node_t *export_directive = root->children[1]->children[1]; ast_node_t *export_directive = root->children[1]->children[1];
ast_node_t *export_statement = root->children[1];
// real upgrades // real upgrades
test_symbol_update("test", reference, SYMBOL_REFERENCE, reference_statement, label, SYMBOL_LOCAL, label, true, true, test_symbol_update("test", reference, SYMBOL_REFERENCE, label, SYMBOL_LOCAL, true, true);
label); test_symbol_update("test", reference, SYMBOL_REFERENCE, import_directive, SYMBOL_IMPORT, true, true);
test_symbol_update("test", reference, SYMBOL_REFERENCE, reference_statement, import_directive, SYMBOL_IMPORT, test_symbol_update("test", reference, SYMBOL_REFERENCE, export_directive, SYMBOL_EXPORT, true, true);
import_statement, true, true, nullptr); test_symbol_update("test", label, SYMBOL_LOCAL, export_directive, SYMBOL_EXPORT, true, true);
test_symbol_update("test", reference, SYMBOL_REFERENCE, reference_statement, export_directive, SYMBOL_EXPORT,
export_statement, true, true, nullptr);
test_symbol_update("test", label, SYMBOL_LOCAL, label, export_directive, SYMBOL_EXPORT, export_statement, true,
true, label);
// identity upgrades // identity upgrades
test_symbol_update("test", reference, SYMBOL_REFERENCE, reference_statement, reference, SYMBOL_REFERENCE, test_symbol_update("test", reference, SYMBOL_REFERENCE, reference, SYMBOL_REFERENCE, true, false);
reference_statement, true, false, nullptr); test_symbol_update("test", label, SYMBOL_LOCAL, label, SYMBOL_LOCAL, true, false);
test_symbol_update("test", label, SYMBOL_LOCAL, label, label, SYMBOL_LOCAL, label, true, false, label); test_symbol_update("test", import_directive, SYMBOL_IMPORT, import_directive, SYMBOL_IMPORT, true, false);
test_symbol_update("test", import_directive, SYMBOL_IMPORT, import_statement, import_directive, SYMBOL_IMPORT, test_symbol_update("test", export_directive, SYMBOL_EXPORT, export_directive, SYMBOL_EXPORT, true, false);
import_statement, true, false, nullptr);
test_symbol_update("test", export_directive, SYMBOL_EXPORT, export_statement, export_directive, SYMBOL_EXPORT,
export_statement, true, false, nullptr);
// downgrades that are allowed and ignored // downgrades that are allowed and ignored
test_symbol_update("test", label, SYMBOL_LOCAL, label, reference, SYMBOL_REFERENCE, reference_statement, true, test_symbol_update("test", label, SYMBOL_LOCAL, reference, SYMBOL_REFERENCE, true, false);
false, label); test_symbol_update("test", import_directive, SYMBOL_IMPORT, reference, SYMBOL_REFERENCE, true, false);
test_symbol_update("test", import_directive, SYMBOL_IMPORT, import_statement, reference, SYMBOL_REFERENCE, test_symbol_update("test", export_directive, SYMBOL_EXPORT, reference, SYMBOL_REFERENCE, true, false);
reference_statement, true, false, nullptr); test_symbol_update("test", export_directive, SYMBOL_EXPORT, label, SYMBOL_LOCAL, true, false);
test_symbol_update("test", export_directive, SYMBOL_EXPORT, export_statement, reference, SYMBOL_REFERENCE, test_symbol_update("test", import_directive, SYMBOL_IMPORT, label, SYMBOL_LOCAL, true, false);
reference_statement, true, false, nullptr);
test_symbol_update("test", export_directive, SYMBOL_EXPORT, export_statement, label, SYMBOL_LOCAL, label, true,
false, label);
test_symbol_update("test", import_directive, SYMBOL_IMPORT, import_statement, label, SYMBOL_LOCAL, label, true,
false, label);
ast_node_free(root); ast_node_free(root);
tokenlist_free(list); tokenlist_free(list);
@ -249,20 +216,14 @@ MunitResult test_symbol_upgrade_invalid(const MunitParameter params[], void *dat
symbols_setup_test(&root, &list, "tests/input/symbols.asm"); symbols_setup_test(&root, &list, "tests/input/symbols.asm");
ast_node_t *reference = root->children[3]->children[1]->children[0]->children[0]; ast_node_t *reference = root->children[3]->children[1]->children[0]->children[0];
ast_node_t *reference_statement = root->children[3];
ast_node_t *label = root->children[2]; ast_node_t *label = root->children[2];
ast_node_t *import_directive = root->children[0]->children[1]; ast_node_t *import_directive = root->children[0]->children[1];
ast_node_t *import_statement = root->children[0];
ast_node_t *export_directive = root->children[1]->children[1]; ast_node_t *export_directive = root->children[1]->children[1];
ast_node_t *export_statement = root->children[1];
// invalid upgrades // invalid upgrades
test_symbol_update("test", label, SYMBOL_LOCAL, label, import_directive, SYMBOL_IMPORT, import_statement, false, test_symbol_update("test", label, SYMBOL_LOCAL, import_directive, SYMBOL_IMPORT, false, false);
false, nullptr); test_symbol_update("test", export_directive, SYMBOL_EXPORT, import_directive, SYMBOL_IMPORT, false, false);
test_symbol_update("test", export_directive, SYMBOL_EXPORT, export_statement, import_directive, SYMBOL_IMPORT, test_symbol_update("test", import_directive, SYMBOL_IMPORT, export_directive, SYMBOL_EXPORT, false, false);
import_statement, false, false, nullptr);
test_symbol_update("test", import_directive, SYMBOL_IMPORT, import_statement, export_directive, SYMBOL_EXPORT,
export_statement, false, false, nullptr);
ast_node_free(root); ast_node_free(root);
tokenlist_free(list); tokenlist_free(list);
@ -279,19 +240,17 @@ MunitResult test_symbol_add_export(const MunitParameter params[], void *data) {
symbol_table_alloc(&table); symbol_table_alloc(&table);
ast_node_t *export_directive = root->children[1]->children[1]; ast_node_t *export_directive = root->children[1]->children[1];
ast_node_t *statement = root->children[1]; // The containing statement
munit_assert_int(export_directive->id, ==, NODE_EXPORT_DIRECTIVE); munit_assert_int(export_directive->id, ==, NODE_EXPORT_DIRECTIVE);
munit_assert_size(table->len, ==, 0); munit_assert_size(table->len, ==, 0);
error_t *err = symbol_table_update(table, export_directive, statement); error_t *err = symbol_table_update(table, export_directive);
munit_assert_null(err); munit_assert_null(err);
munit_assert_size(table->len, ==, 1); munit_assert_size(table->len, ==, 1);
symbol_t *symbol = symbol_table_lookup(table, "test"); symbol_t *symbol = symbol_table_lookup(table, "test");
munit_assert_not_null(symbol); munit_assert_not_null(symbol);
munit_assert_int(SYMBOL_EXPORT, ==, symbol->kind); munit_assert_int(SYMBOL_EXPORT, ==, symbol->kind);
// For export directives, the statement should be nullptr munit_assert_ptr_equal(export_directive, symbol->node);
munit_assert_ptr_null(symbol->statement);
munit_assert_string_equal(symbol->name, "test"); munit_assert_string_equal(symbol->name, "test");
symbol_table_free(table); symbol_table_free(table);
@ -321,7 +280,7 @@ MunitResult test_symbol_table_growth(const MunitParameter params[], void *data)
ast_node_t *label = root->children[i]; ast_node_t *label = root->children[i];
munit_assert_int(label->id, ==, NODE_LABEL); munit_assert_int(label->id, ==, NODE_LABEL);
error_t *err = symbol_table_update(table, label, label); error_t *err = symbol_table_update(table, label);
munit_assert_null(err); munit_assert_null(err);
munit_assert_size(table->len, ==, i + 1); munit_assert_size(table->len, ==, i + 1);
@ -333,7 +292,7 @@ MunitResult test_symbol_table_growth(const MunitParameter params[], void *data)
ast_node_t *final_label = root->children[64]; ast_node_t *final_label = root->children[64];
munit_assert_int(final_label->id, ==, NODE_LABEL); munit_assert_int(final_label->id, ==, NODE_LABEL);
error_t *err = symbol_table_update(table, final_label, final_label); error_t *err = symbol_table_update(table, final_label);
munit_assert_null(err); munit_assert_null(err);
munit_assert_size(table->len, ==, 65); munit_assert_size(table->len, ==, 65);
@ -349,7 +308,6 @@ MunitResult test_symbol_table_growth(const MunitParameter params[], void *data)
munit_assert_not_null(symbol); munit_assert_not_null(symbol);
munit_assert_int(SYMBOL_LOCAL, ==, symbol->kind); munit_assert_int(SYMBOL_LOCAL, ==, symbol->kind);
munit_assert_string_equal(symbol->name, name); munit_assert_string_equal(symbol->name, name);
munit_assert_ptr_equal(symbol->statement, root->children[i]);
} }
symbol_table_free(table); symbol_table_free(table);
@ -368,7 +326,7 @@ MunitResult test_symbol_invalid_node(const MunitParameter params[], void *data)
symbol_table_alloc(&table); symbol_table_alloc(&table);
munit_assert_size(table->len, ==, 0); munit_assert_size(table->len, ==, 0);
error_t *err = symbol_table_update(table, root, root); error_t *err = symbol_table_update(table, root);
munit_assert_ptr_equal(err, err_symbol_table_invalid_node); munit_assert_ptr_equal(err, err_symbol_table_invalid_node);
munit_assert_size(table->len, ==, 0); munit_assert_size(table->len, ==, 0);