diff --git a/src/encoder/encoder.c b/src/encoder/encoder.c index d3b999a..86c5263 100644 --- a/src/encoder/encoder.c +++ b/src/encoder/encoder.c @@ -23,13 +23,15 @@ error_t *const err_encoder_not_implemented = error_t *const err_encoder_unexpected_length = &(error_t){.message = "Unexpectedly long encoding"}; -error_t *encoder_alloc(encoder_t **output) { +error_t *encoder_alloc(encoder_t **output, ast_node_t *ast) { *output = nullptr; encoder_t *encoder = calloc(1, sizeof(encoder_t)); if (encoder == nullptr) return err_allocation_failed; + encoder->ast = ast; + error_t *err = symbol_table_alloc(&encoder->symbols); if (err) { free(encoder); @@ -213,15 +215,12 @@ static inline uint8_t modrm_rm(uint8_t modrm, register_id_t id) { return (modrm & ~modrm_rm_mask) | (id & 0b111); } -/** - * Perform the initial pass over the AST. Records all symbols and sets the - * values of registers and numbers. - */ -error_t *encoder_first_pass(encoder_t *encoder, ast_node_t *node) { +error_t *encoder_collect_info(encoder_t *encoder, ast_node_t *node, + ast_node_t *statement) { error_t *err = nullptr; if (encoder_is_symbols_node(node)) - err = symbol_table_update(encoder->symbols, node); + err = symbol_table_update(encoder->symbols, node, statement); else if (node->id == NODE_NUMBER) err = encoder_set_number_value(node); else if (node->id == NODE_REGISTER) @@ -230,7 +229,33 @@ error_t *encoder_first_pass(encoder_t *encoder, ast_node_t *node) { return err; for (size_t i = 0; i < node->len; ++i) { - error_t *err = encoder_first_pass(encoder, node->children[i]); + error_t *err = + encoder_collect_info(encoder, node->children[i], statement); + if (err) + return err; + } + + return nullptr; +} + +/** + * Perform the initial pass over the AST. + * + * - Collect information about the operands + * - parse and set number values + * - set the register values + * - determine if label references are used by an instruction + * - encode instructions that don't use label references + * - determine estimated addresses of each statement + * + */ +error_t *encoder_first_pass(encoder_t *encoder) { + ast_node_t *root = encoder->ast; + assert(root->id == NODE_PROGRAM); + + for (size_t i = 0; i < root->len; ++i) { + ast_node_t *statement = root->children[i]; + error_t *err = encoder_collect_info(encoder, statement, statement); if (err) return err; } @@ -485,7 +510,9 @@ error_t *encoder_encode_instruction(encoder_t *encoder, * placeholder values for label references because instruction size has not * yet been determined. */ -error_t *encoder_encoding_pass(encoder_t *encoder, ast_node_t *root) { +error_t *encoder_second_pass(encoder_t *encoder) { + ast_node_t *root = encoder->ast; + for (size_t i = 0; i < root->len; ++i) { if (root->children[i]->id != NODE_INSTRUCTION) continue; @@ -515,12 +542,12 @@ error_t *encoder_check_symbols(encoder_t *encoder) { return nullptr; } -error_t *encoder_encode(encoder_t *encoder, ast_node_t *ast) { - error_t *err = encoder_first_pass(encoder, ast); +error_t *encoder_encode(encoder_t *encoder) { + error_t *err = encoder_first_pass(encoder); if (err) return err; err = encoder_check_symbols(encoder); if (err) return err; - return encoder_encoding_pass(encoder, ast); + return encoder_second_pass(encoder); } diff --git a/src/encoder/encoder.h b/src/encoder/encoder.h index 45d34d7..39311bb 100644 --- a/src/encoder/encoder.h +++ b/src/encoder/encoder.h @@ -5,6 +5,7 @@ typedef struct encoder { symbol_table_t *symbols; + ast_node_t *ast; } encoder_t; constexpr uint8_t modrm_mod_memory = 0b00'000'000; @@ -16,8 +17,8 @@ constexpr uint8_t modrm_reg_mask = 0b00'111'000; constexpr uint8_t modrm_rm_mask = 0b00'000'111; constexpr uint8_t modrm_mod_mask = 0b11'000'000; -error_t *encoder_alloc(encoder_t **encoder); -error_t *encoder_encode(encoder_t *encoder, ast_node_t *ast); +error_t *encoder_alloc(encoder_t **encoder, ast_node_t *ast); +error_t *encoder_encode(encoder_t *encoder); void encoder_free(encoder_t *encoder); extern error_t *const err_encoder_invalid_register; diff --git a/src/encoder/symbols.c b/src/encoder/symbols.c index 29f5330..0095fce 100644 --- a/src/encoder/symbols.c +++ b/src/encoder/symbols.c @@ -92,7 +92,7 @@ EXPORT | | | ERR | | -------------|-----------|----------|----------|----------| */ -bool symbol_table_should_update(symbol_kind_t old, symbol_kind_t new) { +bool symbol_table_should_upgrade(symbol_kind_t old, symbol_kind_t new) { if (old == SYMBOL_REFERENCE) return new != SYMBOL_REFERENCE; if (old == SYMBOL_LOCAL) @@ -112,7 +112,7 @@ bool symbol_table_should_error(symbol_kind_t old, symbol_kind_t new) { * @pre The symbol _must not_ already be in the table. */ error_t *symbol_table_add(symbol_table_t *table, char *name, symbol_kind_t kind, - ast_node_t *node) { + ast_node_t *statement) { if (table->len >= table->cap) { error_t *err = symbol_table_grow_cap(table); if (err) @@ -122,7 +122,7 @@ error_t *symbol_table_add(symbol_table_t *table, char *name, symbol_kind_t kind, table->symbols[table->len] = (symbol_t){ .name = name, .kind = kind, - .node = node, + .statement = statement, }; table->len += 1; @@ -130,23 +130,29 @@ error_t *symbol_table_add(symbol_table_t *table, char *name, symbol_kind_t kind, return nullptr; } -error_t *symbol_table_update(symbol_table_t *table, ast_node_t *node) { +error_t *symbol_table_update(symbol_table_t *table, ast_node_t *node, + ast_node_t *statement) { char *name; symbol_kind_t kind; error_t *err = symbol_table_get_node_info(node, &kind, &name); if (err) return err; + if (kind != SYMBOL_LOCAL) + statement = nullptr; + symbol_t *symbol = symbol_table_lookup(table, name); if (!symbol) - return symbol_table_add(table, name, kind, node); + return symbol_table_add(table, name, kind, statement); if (symbol_table_should_error(symbol->kind, kind)) return err_symbol_table_incompatible_symbols; - if (symbol_table_should_update(symbol->kind, kind)) { - symbol->name = name; + if (symbol_table_should_upgrade(symbol->kind, kind)) { symbol->kind = kind; - symbol->node = node; } + + if (kind == SYMBOL_LOCAL && symbol->statement == nullptr) + symbol->statement = statement; + return nullptr; } diff --git a/src/encoder/symbols.h b/src/encoder/symbols.h index 9c4e7f7..ba0c144 100644 --- a/src/encoder/symbols.h +++ b/src/encoder/symbols.h @@ -29,7 +29,7 @@ typedef enum symbol_kind { typedef struct symbol { char *name; symbol_kind_t kind; - ast_node_t *node; + ast_node_t *statement; } symbol_t; typedef struct symbol_table { @@ -40,7 +40,8 @@ typedef struct symbol_table { error_t *symbol_table_alloc(symbol_table_t **table); void symbol_table_free(symbol_table_t *table); -error_t *symbol_table_update(symbol_table_t *table, ast_node_t *node); +error_t *symbol_table_update(symbol_table_t *table, ast_node_t *node, + ast_node_t *statement); symbol_t *symbol_table_lookup(symbol_table_t *table, const char *name); #endif // INCLUDE_ENCODER_SYMBOLS_H_ diff --git a/src/main.c b/src/main.c index fb00342..17d9335 100644 --- a/src/main.c +++ b/src/main.c @@ -74,11 +74,11 @@ error_t *print_encoding(tokenlist_t *list) { return result.err; encoder_t *encoder; - error_t *err = encoder_alloc(&encoder); + error_t *err = encoder_alloc(&encoder, result.node); if (err) goto cleanup_ast; - err = encoder_encode(encoder, result.node); + err = encoder_encode(encoder); if (err) goto cleanup_ast;