Compare commits

...

6 Commits

Author SHA1 Message Date
7cefc3564d Implement one immediate label reference operand
All checks were successful
Validate the build / validate-build (push) Successful in 43s
Also adds opcode data for jmp and call
2025-04-24 14:45:57 +02:00
c848995ad6 Implement two pass encoding
First pass:
 - collect information for numbers, registers and which instructions
   contain label references
 - encode all instructions that don't contain label references
 - Set (temporary) addresses for each instruction

Second pass:
 - Collect information about label references (address, offset, size)
 - encode all instructions that contain label references
 - Update (if necessary) addresses for each instruction

 The second pass is iterated 10 times or until no instructions change
 size, whichever comes first.
2025-04-24 14:45:46 +02:00
5272fdb227 Add more values to the ast to facilitate encoding
- Add a instruction value that contains the encoding, the address and a
  flag to indicate if this instruction contains label references
- Add label value that contains an address
- Add reference value that contains offset, an absolute address and an
  operand size
- define types for all value options in the union
- define accessor functions for all the values in the union
2025-04-23 15:57:04 +02:00
0acc3f27f3 Update symbols tests for new API 2025-04-23 15:56:46 +02:00
9c6b69e187 Symbol table now keeps track of label statements
Before it kept track of a more specific node that referenced the symbol
in some way. Now it will only keep track of the actual label defining
statements. This is done to facilitate encoding. The encoder can now go
from a symbol name to the statement that defines the symbol.

Restructure the encoder to deal with this and pass the correct statement
to the symbol update function.
2025-04-18 14:00:08 +02:00
530e3fb423 Fix parse_memory_expression to use parse_label_reference
All checks were successful
Validate the build / validate-build (push) Successful in 37s
2025-04-17 23:28:44 +02:00
9 changed files with 506 additions and 105 deletions

View File

@ -5,6 +5,7 @@
#include "error.h"
#include "lexer.h"
#include "tokenlist.h"
#include <assert.h>
#include <stddef.h>
#include <stdint.h>
@ -75,10 +76,26 @@ typedef struct register_ {
} register_t;
typedef struct opcode_encoding {
uint8_t encoding[32];
uint8_t buffer[32];
size_t len;
} opcode_encoding_t;
typedef struct instruction {
bool has_reference;
opcode_encoding_t encoding;
int64_t address;
} instruction_t;
typedef struct reference {
int64_t offset;
int64_t address;
operand_size_t size;
} reference_t;
typedef struct {
int64_t address;
} label_t;
struct ast_node {
node_id_t id;
tokenlist_entry_t *token_entry;
@ -89,10 +106,37 @@ struct ast_node {
union {
register_t reg;
number_t number;
opcode_encoding_t encoding;
instruction_t instruction;
reference_t reference;
label_t label;
} value;
};
static inline register_t *ast_node_register_value(ast_node_t *node) {
assert(node->id == NODE_REGISTER);
return &node->value.reg;
}
static inline number_t *ast_node_number_value(ast_node_t *node) {
assert(node->id == NODE_NUMBER);
return &node->value.number;
}
static inline instruction_t *ast_node_instruction_value(ast_node_t *node) {
assert(node->id == NODE_INSTRUCTION);
return &node->value.instruction;
}
static inline reference_t *ast_node_reference_value(ast_node_t *node) {
assert(node->id == NODE_LABEL_REFERENCE);
return &node->value.reference;
}
static inline label_t *ast_node_label_value(ast_node_t *node) {
assert(node->id == NODE_LABEL);
return &node->value.label;
}
/**
* @brief Allocates a new AST node
*

View File

@ -138,8 +138,128 @@ opcode_data_t *const opcodes[] = {
{ .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_64 },
},
},
// CALL rel32
&(opcode_data_t) {
.mnemonic = "call",
.opcode = 0xE8,
.opcode_extension = opcode_extension_none,
.encoding_class = ENCODING_DEFAULT,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_IMMEDIATE, .size = OPERAND_SIZE_32 },
},
},
// CALL reg64
&(opcode_data_t) {
.mnemonic = "call",
.opcode = 0xFF,
.opcode_extension = 2,
.encoding_class = ENCODING_DEFAULT,
.rex_w_prefix = true,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_64 },
},
},
// CALL mem64
&(opcode_data_t) {
.mnemonic = "call",
.opcode = 0xFF,
.opcode_extension = 2,
.encoding_class = ENCODING_DEFAULT,
.rex_w_prefix = true,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_MEMORY, .size = OPERAND_SIZE_64 },
},
},
// JMP rel8 (short jump)
&(opcode_data_t) {
.mnemonic = "jmp",
.opcode = 0xEB,
.opcode_extension = opcode_extension_none,
.encoding_class = ENCODING_DEFAULT,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_IMMEDIATE, .size = OPERAND_SIZE_8 },
},
},
// JMP rel16
&(opcode_data_t) {
.mnemonic = "jmp",
.opcode = 0xE9,
.opcode_extension = opcode_extension_none,
.encoding_class = ENCODING_DEFAULT,
.operand_size_prefix = true,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_IMMEDIATE, .size = OPERAND_SIZE_16 },
},
},
// JMP reg16
&(opcode_data_t) {
.mnemonic = "jmp",
.opcode = 0xFF,
.opcode_extension = 4,
.encoding_class = ENCODING_DEFAULT,
.operand_size_prefix = true,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_16 },
},
},
// JMP rel32 (near jump)
&(opcode_data_t) {
.mnemonic = "jmp",
.opcode = 0xE9,
.opcode_extension = opcode_extension_none,
.encoding_class = ENCODING_DEFAULT,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_IMMEDIATE, .size = OPERAND_SIZE_32 },
},
},
// JMP reg32
&(opcode_data_t) {
.mnemonic = "jmp",
.opcode = 0xFF,
.opcode_extension = 4,
.encoding_class = ENCODING_DEFAULT,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_32 },
},
},
// JMP reg64
&(opcode_data_t) {
.mnemonic = "jmp",
.opcode = 0xFF,
.opcode_extension = 4,
.encoding_class = ENCODING_DEFAULT,
.rex_w_prefix = true,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_64 },
},
},
// JMP mem64
&(opcode_data_t) {
.mnemonic = "jmp",
.opcode = 0xFF,
.opcode_extension = 4,
.encoding_class = ENCODING_DEFAULT,
.rex_w_prefix = true,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_MEMORY, .size = OPERAND_SIZE_64 },
},
},
nullptr,
};

View File

@ -6,6 +6,31 @@
#include <errno.h>
#include <string.h>
/**
* General encoder flow:
*
* There are 2 major passes the encoder does:
*
* First pass:
* - Run through the AST and collect information:
* - Set register values
* - Parse/set number values
* - Mark all instructions that use label references
* - Encode all instructions that don't use label references
* - Update addresses of all labels and instructions. Use an estimated
* instruction size for those instructions that use label references.
*
* Second pass:
* - Run through the AST for all instructions that use label references and
* collect size information using the estimated addresses from pass 1
* - Encode label references with the estimated addresses, this fixes their
* size.
* - Update all addresses
*
* Iteration:
* - Repeat the second pass until addresses converge
*/
error_t *const err_encoder_invalid_register =
&(error_t){.message = "Invalid register"};
error_t *const err_encoder_number_overflow =
@ -23,13 +48,15 @@ error_t *const err_encoder_not_implemented =
error_t *const err_encoder_unexpected_length =
&(error_t){.message = "Unexpectedly long encoding"};
error_t *encoder_alloc(encoder_t **output) {
error_t *encoder_alloc(encoder_t **output, ast_node_t *ast) {
*output = nullptr;
encoder_t *encoder = calloc(1, sizeof(encoder_t));
if (encoder == nullptr)
return err_allocation_failed;
encoder->ast = ast;
error_t *err = symbol_table_alloc(&encoder->symbols);
if (err) {
free(encoder);
@ -213,16 +240,15 @@ static inline uint8_t modrm_rm(uint8_t modrm, register_id_t id) {
return (modrm & ~modrm_rm_mask) | (id & 0b111);
}
/**
* Perform the initial pass over the AST. Records all symbols and sets the
* values of registers and numbers.
*/
error_t *encoder_first_pass(encoder_t *encoder, ast_node_t *node) {
error_t *encoder_collect_info(encoder_t *encoder, ast_node_t *node,
ast_node_t *statement) {
error_t *err = nullptr;
if (encoder_is_symbols_node(node))
err = symbol_table_update(encoder->symbols, node);
else if (node->id == NODE_NUMBER)
if (encoder_is_symbols_node(node)) {
err = symbol_table_update(encoder->symbols, node, statement);
if (statement->id == NODE_INSTRUCTION)
statement->value.instruction.has_reference = true;
} else if (node->id == NODE_NUMBER)
err = encoder_set_number_value(node);
else if (node->id == NODE_REGISTER)
err = encoder_set_register_value(node);
@ -230,7 +256,8 @@ error_t *encoder_first_pass(encoder_t *encoder, ast_node_t *node) {
return err;
for (size_t i = 0; i < node->len; ++i) {
error_t *err = encoder_first_pass(encoder, node->children[i]);
error_t *err =
encoder_collect_info(encoder, node->children[i], statement);
if (err)
return err;
}
@ -242,7 +269,7 @@ bool is_operand_match(operand_info_t *info, ast_node_t *operand) {
switch (info->kind) {
case OPERAND_REGISTER:
return operand->id == NODE_REGISTER &&
operand->value.reg.size == info->size;
ast_node_register_value(operand)->size == info->size;
case OPERAND_MEMORY:
return operand->id == NODE_MEMORY;
case OPERAND_IMMEDIATE: {
@ -251,13 +278,10 @@ bool is_operand_match(operand_info_t *info, ast_node_t *operand) {
ast_node_t *child = operand->children[0];
if (child->id == NODE_NUMBER)
return (child->value.number.size & info->size) > 0;
else if (child->id == NODE_LABEL_REFERENCE)
return info->size == OPERAND_SIZE_32;
// FIXME: first pass should give us information about the distance of
// the label reference so we can pick a size more appropriately instead
// of just defaulting to 32 bits
break;
return (ast_node_number_value(child)->size & info->size) > 0;
else if (child->id == NODE_LABEL_REFERENCE) {
return info->size &= ast_node_reference_value(child)->size;
}
} // end OPERAND_IMMEDIATE case
}
assert(false && "unreachable");
@ -313,7 +337,7 @@ error_t *encode_one_register_in_opcode(encoder_t *encoder,
(void)encoder;
(void)opcode;
register_id_t id = operands->children[0]->value.reg.id;
register_id_t id = ast_node_register_value(operands->children[0])->id;
encoding->buffer[encoding->len - 1] |= id & 0b111;
if ((id & 0b1000) > 0) {
*rex |= rex_prefix_r;
@ -328,7 +352,7 @@ error_t *encode_one_register(encoder_t *encoder, opcode_data_t *opcode,
assert(operands->len == 1);
assert(operands->children[0]->id == NODE_REGISTER);
register_id_t id = operands->children[0]->value.reg.id;
register_id_t id = ast_node_register_value(operands->children[0])->id;
uint8_t modrm = modrm_mod_register;
@ -362,9 +386,9 @@ error_t *encode_one_immediate(encoder_t *encoder, opcode_data_t *opcode,
assert(immediate->id == NODE_NUMBER ||
immediate->id == NODE_LABEL_REFERENCE);
if (immediate->id == NODE_NUMBER) {
uint64_t value = immediate->value.number.value;
operand_size_t size = opcode->operands[0].size;
if (immediate->id == NODE_NUMBER) {
uint64_t value = ast_node_number_value(immediate)->value;
error_t *err = nullptr;
switch (size) {
case OPERAND_SIZE_8:
@ -384,11 +408,22 @@ error_t *encode_one_immediate(encoder_t *encoder, opcode_data_t *opcode,
}
return err;
} else {
// FIXME: this still assumes references are always 32 bit
uint32_t value = 0xDEADBEEF;
return bytes_append_uint32(encoding, value);
reference_t *reference = ast_node_reference_value(immediate);
switch (size) {
case OPERAND_SIZE_64:
return bytes_append_uint64(encoding, reference->address);
case OPERAND_SIZE_32:
return bytes_append_uint32(encoding, reference->offset);
case OPERAND_SIZE_16:
return bytes_append_uint16(encoding, reference->offset);
case OPERAND_SIZE_8:
return bytes_append_uint8(encoding, reference->offset);
default:
assert(false && "intentionally unhandled");
}
}
__builtin_unreachable();
}
error_t *encode_one_memory(encoder_t *encoder, opcode_data_t *opcode,
ast_node_t *operands, bytes_t *encoding,
@ -456,7 +491,8 @@ error_t *encoder_encode_instruction(encoder_t *encoder,
return err;
// produce the actual encoding output in the NODE_INSTRUCTION value
uint8_t *output = instruction->value.encoding.encoding;
instruction_t *instruction_value = ast_node_instruction_value(instruction);
uint8_t *output = instruction_value->encoding.buffer;
size_t output_len = 0;
// Handle prefixes
@ -475,24 +511,166 @@ error_t *encoder_encode_instruction(encoder_t *encoder,
memcpy(output + output_len, encoding->buffer, encoding->len);
output_len += encoding->len;
instruction->value.encoding.len = output_len;
instruction_value->encoding.len = output_len;
return nullptr;
}
/**
* Perform the second pass that performs actual encoding. Will use
* placeholder values for label references because instruction size has not
* yet been determined.
* Initial guess for instruction size of instructions that contain a label
* reference
*/
error_t *encoder_encoding_pass(encoder_t *encoder, ast_node_t *root) {
constexpr size_t instruction_size_estimate = 10;
/**
* Perform the initial pass over the AST.
*
* - Collect information about the operands
* - parse and set number values
* - set the register values
* - determine if label references are used by an instruction
* - encode instructions that don't use label references
* - determine estimated addresses of each statement
*
*/
error_t *encoder_first_pass(encoder_t *encoder) {
ast_node_t *root = encoder->ast;
assert(root->id == NODE_PROGRAM);
uintptr_t address = 0;
for (size_t i = 0; i < root->len; ++i) {
if (root->children[i]->id != NODE_INSTRUCTION)
continue;
ast_node_t *instruction = root->children[i];
error_t *err = encoder_encode_instruction(encoder, instruction);
ast_node_t *statement = root->children[i];
error_t *err = encoder_collect_info(encoder, statement, statement);
if (err)
return err;
if (statement->id == NODE_INSTRUCTION &&
ast_node_instruction_value(statement)->has_reference == false) {
err = encoder_encode_instruction(encoder, statement);
if (err)
return err;
instruction_t *instruction = ast_node_instruction_value(statement);
instruction->address = address;
address += instruction->encoding.len;
} else if (statement->id == NODE_INSTRUCTION) {
instruction_t *instruction = ast_node_instruction_value(statement);
instruction->encoding.len = instruction_size_estimate;
instruction->address = address;
address += instruction_size_estimate;
} else if (statement->id == NODE_LABEL) {
label_t *label = ast_node_label_value(statement);
label->address = address;
}
}
return nullptr;
}
operand_size_t signed_to_size_mask(int64_t value) {
operand_size_t size = OPERAND_SIZE_64;
if (value >= INT8_MIN && value <= INT8_MAX)
size |= OPERAND_SIZE_8;
if (value >= INT16_MIN && value <= INT16_MAX)
size |= OPERAND_SIZE_16;
if (value >= INT32_MIN && value <= INT32_MAX)
size |= OPERAND_SIZE_32;
return size;
}
int64_t statement_offset(ast_node_t *from, ast_node_t *to) {
assert(from->id == NODE_INSTRUCTION);
assert(to->id == NODE_LABEL);
instruction_t *instruction = ast_node_instruction_value(from);
int64_t from_addr = instruction->address + instruction->encoding.len;
int64_t to_addr = ast_node_label_value(to)->address;
return to_addr - from_addr;
}
error_t *encoder_collect_reference_info(encoder_t *encoder, ast_node_t *node,
ast_node_t *statement) {
assert(statement->id == NODE_INSTRUCTION);
if (node->id == NODE_LABEL_REFERENCE) {
const char *name = node->token_entry->token.value;
symbol_t *symbol = symbol_table_lookup(encoder->symbols, name);
assert(symbol && symbol->statement &&
symbol->statement->id == NODE_LABEL);
int64_t offset = statement_offset(statement, symbol->statement);
int64_t absolute = ast_node_label_value(symbol->statement)->address;
operand_size_t size = signed_to_size_mask(offset);
node->value.reference.address = absolute;
node->value.reference.offset = offset;
node->value.reference.size = size;
}
for (size_t i = 0; i < node->len; ++i) {
error_t *err = encoder_collect_reference_info(
encoder, node->children[i], statement);
if (err)
return err;
}
return nullptr;
}
bool encoder_should_reencode(ast_node_t *statement) {
if (statement->id != NODE_INSTRUCTION)
return false;
instruction_t *instruction = ast_node_instruction_value(statement);
return instruction->has_reference;
}
void set_statement_address(ast_node_t *statement, int64_t address) {
if (statement->id == NODE_INSTRUCTION) {
ast_node_instruction_value(statement)->address = address;
} else if (statement->id == NODE_LABEL) {
ast_node_label_value(statement)->address = address;
}
}
size_t get_statement_length(ast_node_t *statement) {
if (statement->id != NODE_INSTRUCTION)
return 0;
return ast_node_instruction_value(statement)->encoding.len;
}
/**
* Perform the second pass. Updates the label info and encodes all instructions
* that have a label reference.that performs actual encoding.
*/
error_t *encoder_second_pass(encoder_t *encoder, bool *did_update) {
ast_node_t *root = encoder->ast;
*did_update = false;
int64_t address = 0;
for (size_t i = 0; i < root->len; ++i) {
ast_node_t *statement = root->children[i];
set_statement_address(statement, address);
size_t before = get_statement_length(statement);
if (encoder_should_reencode(statement)) {
error_t *err =
encoder_collect_reference_info(encoder, statement, statement);
if (err)
return err;
err = encoder_encode_instruction(encoder, statement);
if (err)
return err;
}
size_t after = get_statement_length(statement);
*did_update = *did_update || (before != after);
address += after;
}
return nullptr;
}
@ -515,12 +693,19 @@ error_t *encoder_check_symbols(encoder_t *encoder) {
return nullptr;
}
error_t *encoder_encode(encoder_t *encoder, ast_node_t *ast) {
error_t *err = encoder_first_pass(encoder, ast);
error_t *encoder_encode(encoder_t *encoder) {
error_t *err = encoder_first_pass(encoder);
if (err)
return err;
err = encoder_check_symbols(encoder);
if (err)
return err;
return encoder_encoding_pass(encoder, ast);
bool did_update = true;
for (int i = 0; i < 10 && did_update; ++i) {
err = encoder_second_pass(encoder, &did_update);
if (err)
return err;
}
return nullptr;
}

View File

@ -5,6 +5,7 @@
typedef struct encoder {
symbol_table_t *symbols;
ast_node_t *ast;
} encoder_t;
constexpr uint8_t modrm_mod_memory = 0b00'000'000;
@ -16,8 +17,8 @@ constexpr uint8_t modrm_reg_mask = 0b00'111'000;
constexpr uint8_t modrm_rm_mask = 0b00'000'111;
constexpr uint8_t modrm_mod_mask = 0b11'000'000;
error_t *encoder_alloc(encoder_t **encoder);
error_t *encoder_encode(encoder_t *encoder, ast_node_t *ast);
error_t *encoder_alloc(encoder_t **encoder, ast_node_t *ast);
error_t *encoder_encode(encoder_t *encoder);
void encoder_free(encoder_t *encoder);
extern error_t *const err_encoder_invalid_register;

View File

@ -92,7 +92,7 @@ EXPORT | | | ERR | |
-------------|-----------|----------|----------|----------|
*/
bool symbol_table_should_update(symbol_kind_t old, symbol_kind_t new) {
bool symbol_table_should_upgrade(symbol_kind_t old, symbol_kind_t new) {
if (old == SYMBOL_REFERENCE)
return new != SYMBOL_REFERENCE;
if (old == SYMBOL_LOCAL)
@ -112,7 +112,7 @@ bool symbol_table_should_error(symbol_kind_t old, symbol_kind_t new) {
* @pre The symbol _must not_ already be in the table.
*/
error_t *symbol_table_add(symbol_table_t *table, char *name, symbol_kind_t kind,
ast_node_t *node) {
ast_node_t *statement) {
if (table->len >= table->cap) {
error_t *err = symbol_table_grow_cap(table);
if (err)
@ -122,7 +122,7 @@ error_t *symbol_table_add(symbol_table_t *table, char *name, symbol_kind_t kind,
table->symbols[table->len] = (symbol_t){
.name = name,
.kind = kind,
.node = node,
.statement = statement,
};
table->len += 1;
@ -130,23 +130,29 @@ error_t *symbol_table_add(symbol_table_t *table, char *name, symbol_kind_t kind,
return nullptr;
}
error_t *symbol_table_update(symbol_table_t *table, ast_node_t *node) {
error_t *symbol_table_update(symbol_table_t *table, ast_node_t *node,
ast_node_t *statement) {
char *name;
symbol_kind_t kind;
error_t *err = symbol_table_get_node_info(node, &kind, &name);
if (err)
return err;
if (kind != SYMBOL_LOCAL)
statement = nullptr;
symbol_t *symbol = symbol_table_lookup(table, name);
if (!symbol)
return symbol_table_add(table, name, kind, node);
return symbol_table_add(table, name, kind, statement);
if (symbol_table_should_error(symbol->kind, kind))
return err_symbol_table_incompatible_symbols;
if (symbol_table_should_update(symbol->kind, kind)) {
symbol->name = name;
if (symbol_table_should_upgrade(symbol->kind, kind)) {
symbol->kind = kind;
symbol->node = node;
}
if (kind == SYMBOL_LOCAL && symbol->statement == nullptr)
symbol->statement = statement;
return nullptr;
}

View File

@ -29,7 +29,7 @@ typedef enum symbol_kind {
typedef struct symbol {
char *name;
symbol_kind_t kind;
ast_node_t *node;
ast_node_t *statement;
} symbol_t;
typedef struct symbol_table {
@ -40,7 +40,8 @@ typedef struct symbol_table {
error_t *symbol_table_alloc(symbol_table_t **table);
void symbol_table_free(symbol_table_t *table);
error_t *symbol_table_update(symbol_table_t *table, ast_node_t *node);
error_t *symbol_table_update(symbol_table_t *table, ast_node_t *node,
ast_node_t *statement);
symbol_t *symbol_table_lookup(symbol_table_t *table, const char *name);
#endif // INCLUDE_ENCODER_SYMBOLS_H_

View File

@ -74,11 +74,11 @@ error_t *print_encoding(tokenlist_t *list) {
return result.err;
encoder_t *encoder;
error_t *err = encoder_alloc(&encoder);
error_t *err = encoder_alloc(&encoder, result.node);
if (err)
goto cleanup_ast;
err = encoder_encode(encoder, result.node);
err = encoder_encode(encoder);
if (err)
goto cleanup_ast;
@ -88,7 +88,8 @@ error_t *print_encoding(tokenlist_t *list) {
if (node->id != NODE_INSTRUCTION)
continue;
print_hex(node->value.encoding.len, node->value.encoding.encoding);
print_hex(node->value.instruction.encoding.len,
node->value.instruction.encoding.buffer);
}
encoder_free(encoder);

View File

@ -89,7 +89,8 @@ parse_result_t parse_immediate(tokenlist_entry_t *current) {
}
parse_result_t parse_memory_expression(tokenlist_entry_t *current) {
parser_t parsers[] = {parse_register_expression, parse_identifier, nullptr};
parser_t parsers[] = {parse_register_expression, parse_label_reference,
nullptr};
return parse_any(current, parsers);
}

View File

@ -58,17 +58,19 @@ MunitResult test_symbol_add_reference(const MunitParameter params[], void *data)
symbol_table_alloc(&table);
ast_node_t *reference = root->children[3]->children[1]->children[0]->children[0];
ast_node_t *statement = root->children[3]; // The containing statement
munit_assert_int(reference->id, ==, NODE_LABEL_REFERENCE);
munit_assert_size(table->len, ==, 0);
error_t *err = symbol_table_update(table, reference);
error_t *err = symbol_table_update(table, reference, statement);
munit_assert_null(err);
munit_assert_size(table->len, ==, 1);
symbol_t *symbol = symbol_table_lookup(table, "test");
munit_assert_not_null(symbol);
munit_assert_int(SYMBOL_REFERENCE, ==, symbol->kind);
munit_assert_ptr_equal(reference, symbol->node);
// For references, the statement should be nullptr
munit_assert_ptr_null(symbol->statement);
munit_assert_string_equal(symbol->name, "test");
symbol_table_free(table);
@ -90,14 +92,14 @@ MunitResult test_symbol_add_label(const MunitParameter params[], void *data) {
munit_assert_int(label->id, ==, NODE_LABEL);
munit_assert_size(table->len, ==, 0);
error_t *err = symbol_table_update(table, label);
error_t *err = symbol_table_update(table, label, label);
munit_assert_null(err);
munit_assert_size(table->len, ==, 1);
symbol_t *symbol = symbol_table_lookup(table, "test");
munit_assert_not_null(symbol);
munit_assert_int(SYMBOL_LOCAL, ==, symbol->kind);
munit_assert_ptr_equal(label, symbol->node);
munit_assert_ptr_equal(label, symbol->statement);
munit_assert_string_equal(symbol->name, "test");
symbol_table_free(table);
@ -116,17 +118,19 @@ MunitResult test_symbol_add_import(const MunitParameter params[], void *data) {
symbol_table_alloc(&table);
ast_node_t *import_directive = root->children[0]->children[1];
ast_node_t *statement = root->children[0]; // The containing statement
munit_assert_int(import_directive->id, ==, NODE_IMPORT_DIRECTIVE);
munit_assert_size(table->len, ==, 0);
error_t *err = symbol_table_update(table, import_directive);
error_t *err = symbol_table_update(table, import_directive, statement);
munit_assert_null(err);
munit_assert_size(table->len, ==, 1);
symbol_t *symbol = symbol_table_lookup(table, "test");
munit_assert_not_null(symbol);
munit_assert_int(SYMBOL_IMPORT, ==, symbol->kind);
munit_assert_ptr_equal(import_directive, symbol->node);
// For import directives, the statement should be nullptr
munit_assert_ptr_null(symbol->statement);
munit_assert_string_equal(symbol->name, "test");
symbol_table_free(table);
@ -135,42 +139,56 @@ MunitResult test_symbol_add_import(const MunitParameter params[], void *data) {
return MUNIT_OK;
}
void test_symbol_update(const char *name, ast_node_t *first, symbol_kind_t first_kind, ast_node_t *second,
symbol_kind_t second_kind, bool should_succeed, bool should_update) {
void test_symbol_update(const char *name, ast_node_t *first, symbol_kind_t first_kind, ast_node_t *first_statement,
ast_node_t *second, symbol_kind_t second_kind, ast_node_t *second_statement,
bool should_succeed, bool should_update, ast_node_t *expected_statement) {
symbol_table_t *table = nullptr;
symbol_table_alloc(&table);
munit_assert_size(table->len, ==, 0);
error_t *err = symbol_table_update(table, first);
// Add the first symbol
error_t *err = symbol_table_update(table, first, first_statement);
munit_assert_null(err);
munit_assert_size(table->len, ==, 1);
// Verify first symbol state
symbol_t *symbol = symbol_table_lookup(table, name);
munit_assert_not_null(symbol);
munit_assert_int(first_kind, ==, symbol->kind);
munit_assert_ptr_equal(first, symbol->node);
munit_assert_string_equal(symbol->name, name);
err = symbol_table_update(table, second);
if (should_succeed)
munit_assert_null(err);
else
munit_assert_ptr_equal(err, err_symbol_table_incompatible_symbols);
munit_assert_size(table->len, ==, 1);
symbol = symbol_table_lookup(table, name);
if (should_update) {
munit_assert_not_null(symbol);
munit_assert_int(second_kind, ==, symbol->kind);
munit_assert_ptr_equal(second, symbol->node);
munit_assert_string_equal(symbol->name, name);
// Check statement based on symbol kind
if (first_kind == SYMBOL_LOCAL) {
munit_assert_ptr_equal(first_statement, symbol->statement);
} else {
munit_assert_not_null(symbol);
munit_assert_int(first_kind, ==, symbol->kind);
munit_assert_ptr_equal(first, symbol->node);
munit_assert_string_equal(symbol->name, name);
munit_assert_ptr_null(symbol->statement);
}
// Attempt the second update
err = symbol_table_update(table, second, second_statement);
// Check if update succeeded as expected
if (should_succeed) {
munit_assert_null(err);
} else {
munit_assert_ptr_equal(err, err_symbol_table_incompatible_symbols);
symbol_table_free(table);
return;
}
// Verify symbol after second update
symbol = symbol_table_lookup(table, name);
munit_assert_not_null(symbol);
// Check if kind updated as expected
if (should_update) {
munit_assert_int(second_kind, ==, symbol->kind);
} else {
munit_assert_int(first_kind, ==, symbol->kind);
}
// Simply check against the expected statement value
munit_assert_ptr_equal(expected_statement, symbol->statement);
symbol_table_free(table);
}
@ -181,28 +199,43 @@ MunitResult test_symbol_upgrade_valid(const MunitParameter params[], void *data)
symbols_setup_test(&root, &list, "tests/input/symbols.asm");
ast_node_t *reference = root->children[3]->children[1]->children[0]->children[0];
ast_node_t *reference_statement = root->children[3];
ast_node_t *label = root->children[2];
ast_node_t *import_directive = root->children[0]->children[1];
ast_node_t *import_statement = root->children[0];
ast_node_t *export_directive = root->children[1]->children[1];
ast_node_t *export_statement = root->children[1];
// real upgrades
test_symbol_update("test", reference, SYMBOL_REFERENCE, label, SYMBOL_LOCAL, true, true);
test_symbol_update("test", reference, SYMBOL_REFERENCE, import_directive, SYMBOL_IMPORT, true, true);
test_symbol_update("test", reference, SYMBOL_REFERENCE, export_directive, SYMBOL_EXPORT, true, true);
test_symbol_update("test", label, SYMBOL_LOCAL, export_directive, SYMBOL_EXPORT, true, true);
test_symbol_update("test", reference, SYMBOL_REFERENCE, reference_statement, label, SYMBOL_LOCAL, label, true, true,
label);
test_symbol_update("test", reference, SYMBOL_REFERENCE, reference_statement, import_directive, SYMBOL_IMPORT,
import_statement, true, true, nullptr);
test_symbol_update("test", reference, SYMBOL_REFERENCE, reference_statement, export_directive, SYMBOL_EXPORT,
export_statement, true, true, nullptr);
test_symbol_update("test", label, SYMBOL_LOCAL, label, export_directive, SYMBOL_EXPORT, export_statement, true,
true, label);
// identity upgrades
test_symbol_update("test", reference, SYMBOL_REFERENCE, reference, SYMBOL_REFERENCE, true, false);
test_symbol_update("test", label, SYMBOL_LOCAL, label, SYMBOL_LOCAL, true, false);
test_symbol_update("test", import_directive, SYMBOL_IMPORT, import_directive, SYMBOL_IMPORT, true, false);
test_symbol_update("test", export_directive, SYMBOL_EXPORT, export_directive, SYMBOL_EXPORT, true, false);
test_symbol_update("test", reference, SYMBOL_REFERENCE, reference_statement, reference, SYMBOL_REFERENCE,
reference_statement, true, false, nullptr);
test_symbol_update("test", label, SYMBOL_LOCAL, label, label, SYMBOL_LOCAL, label, true, false, label);
test_symbol_update("test", import_directive, SYMBOL_IMPORT, import_statement, import_directive, SYMBOL_IMPORT,
import_statement, true, false, nullptr);
test_symbol_update("test", export_directive, SYMBOL_EXPORT, export_statement, export_directive, SYMBOL_EXPORT,
export_statement, true, false, nullptr);
// downgrades that are allowed and ignored
test_symbol_update("test", label, SYMBOL_LOCAL, reference, SYMBOL_REFERENCE, true, false);
test_symbol_update("test", import_directive, SYMBOL_IMPORT, reference, SYMBOL_REFERENCE, true, false);
test_symbol_update("test", export_directive, SYMBOL_EXPORT, reference, SYMBOL_REFERENCE, true, false);
test_symbol_update("test", export_directive, SYMBOL_EXPORT, label, SYMBOL_LOCAL, true, false);
test_symbol_update("test", import_directive, SYMBOL_IMPORT, label, SYMBOL_LOCAL, true, false);
test_symbol_update("test", label, SYMBOL_LOCAL, label, reference, SYMBOL_REFERENCE, reference_statement, true,
false, label);
test_symbol_update("test", import_directive, SYMBOL_IMPORT, import_statement, reference, SYMBOL_REFERENCE,
reference_statement, true, false, nullptr);
test_symbol_update("test", export_directive, SYMBOL_EXPORT, export_statement, reference, SYMBOL_REFERENCE,
reference_statement, true, false, nullptr);
test_symbol_update("test", export_directive, SYMBOL_EXPORT, export_statement, label, SYMBOL_LOCAL, label, true,
false, label);
test_symbol_update("test", import_directive, SYMBOL_IMPORT, import_statement, label, SYMBOL_LOCAL, label, true,
false, label);
ast_node_free(root);
tokenlist_free(list);
@ -216,14 +249,20 @@ MunitResult test_symbol_upgrade_invalid(const MunitParameter params[], void *dat
symbols_setup_test(&root, &list, "tests/input/symbols.asm");
ast_node_t *reference = root->children[3]->children[1]->children[0]->children[0];
ast_node_t *reference_statement = root->children[3];
ast_node_t *label = root->children[2];
ast_node_t *import_directive = root->children[0]->children[1];
ast_node_t *import_statement = root->children[0];
ast_node_t *export_directive = root->children[1]->children[1];
ast_node_t *export_statement = root->children[1];
// invalid upgrades
test_symbol_update("test", label, SYMBOL_LOCAL, import_directive, SYMBOL_IMPORT, false, false);
test_symbol_update("test", export_directive, SYMBOL_EXPORT, import_directive, SYMBOL_IMPORT, false, false);
test_symbol_update("test", import_directive, SYMBOL_IMPORT, export_directive, SYMBOL_EXPORT, false, false);
test_symbol_update("test", label, SYMBOL_LOCAL, label, import_directive, SYMBOL_IMPORT, import_statement, false,
false, nullptr);
test_symbol_update("test", export_directive, SYMBOL_EXPORT, export_statement, import_directive, SYMBOL_IMPORT,
import_statement, false, false, nullptr);
test_symbol_update("test", import_directive, SYMBOL_IMPORT, import_statement, export_directive, SYMBOL_EXPORT,
export_statement, false, false, nullptr);
ast_node_free(root);
tokenlist_free(list);
@ -240,17 +279,19 @@ MunitResult test_symbol_add_export(const MunitParameter params[], void *data) {
symbol_table_alloc(&table);
ast_node_t *export_directive = root->children[1]->children[1];
ast_node_t *statement = root->children[1]; // The containing statement
munit_assert_int(export_directive->id, ==, NODE_EXPORT_DIRECTIVE);
munit_assert_size(table->len, ==, 0);
error_t *err = symbol_table_update(table, export_directive);
error_t *err = symbol_table_update(table, export_directive, statement);
munit_assert_null(err);
munit_assert_size(table->len, ==, 1);
symbol_t *symbol = symbol_table_lookup(table, "test");
munit_assert_not_null(symbol);
munit_assert_int(SYMBOL_EXPORT, ==, symbol->kind);
munit_assert_ptr_equal(export_directive, symbol->node);
// For export directives, the statement should be nullptr
munit_assert_ptr_null(symbol->statement);
munit_assert_string_equal(symbol->name, "test");
symbol_table_free(table);
@ -280,7 +321,7 @@ MunitResult test_symbol_table_growth(const MunitParameter params[], void *data)
ast_node_t *label = root->children[i];
munit_assert_int(label->id, ==, NODE_LABEL);
error_t *err = symbol_table_update(table, label);
error_t *err = symbol_table_update(table, label, label);
munit_assert_null(err);
munit_assert_size(table->len, ==, i + 1);
@ -292,7 +333,7 @@ MunitResult test_symbol_table_growth(const MunitParameter params[], void *data)
ast_node_t *final_label = root->children[64];
munit_assert_int(final_label->id, ==, NODE_LABEL);
error_t *err = symbol_table_update(table, final_label);
error_t *err = symbol_table_update(table, final_label, final_label);
munit_assert_null(err);
munit_assert_size(table->len, ==, 65);
@ -308,6 +349,7 @@ MunitResult test_symbol_table_growth(const MunitParameter params[], void *data)
munit_assert_not_null(symbol);
munit_assert_int(SYMBOL_LOCAL, ==, symbol->kind);
munit_assert_string_equal(symbol->name, name);
munit_assert_ptr_equal(symbol->statement, root->children[i]);
}
symbol_table_free(table);
@ -326,7 +368,7 @@ MunitResult test_symbol_invalid_node(const MunitParameter params[], void *data)
symbol_table_alloc(&table);
munit_assert_size(table->len, ==, 0);
error_t *err = symbol_table_update(table, root);
error_t *err = symbol_table_update(table, root, root);
munit_assert_ptr_equal(err, err_symbol_table_invalid_node);
munit_assert_size(table->len, ==, 0);