Compare commits

..

6 Commits

Author SHA1 Message Date
ff1927a5c6 Add symbols tests
All checks were successful
Validate the build / validate-build (push) Successful in 30s
2025-04-09 00:14:53 +02:00
7223c31154 initial symbol table implementation 2025-04-09 00:14:44 +02:00
8025f7f8e8 Add .import and .export to the input test file 2025-04-09 00:14:44 +02:00
41867694e2 Make main properly return with failure on parsing errors 2025-04-09 00:14:44 +02:00
7596e54191 Add .import and .export directive to the grammar and parser 2025-04-09 00:14:44 +02:00
c43aab3a2d fix parse_immediate to accept label_reference instead of identifier 2025-04-09 00:14:44 +02:00
30 changed files with 106 additions and 1857 deletions

View File

@ -16,10 +16,8 @@ jobs:
echo "http://dl-cdn.alpinelinux.org/alpine/edge/main" >> /etc/apk/repositories
echo "http://dl-cdn.alpinelinux.org/alpine/edge/community" >> /etc/apk/repositories
# determine correct clang version and then install it
apk update
RT_VERSION=$(apk search -v compiler-rt | grep -o "compiler-rt-[0-9]*" | head -1 | grep -o "[0-9]*")
apk add --no-cache llvm${RT_VERSION} clang${RT_VERSION} clang${RT_VERSION}-analyzer compiler-rt valgrind
apk add --no-cache llvm19 clang19 clang19-analyzer compiler-rt valgrind
# Verify versions
echo "---------------------"

4
.gitignore vendored
View File

@ -1,5 +1,7 @@
*.o
*.d
/core
/build
/oas
/oas-asan
/oas-msan
/reports

View File

@ -21,7 +21,7 @@ asan:
msan:
make -rRf make/msan.mk all
validate: asan msan debug release
validate: asan msan debug
./validate.sh
analyze:

View File

@ -1,9 +1,9 @@
<program> ::= <statement>*
<statement> ::= <label> | <directive> | <instruction> | <newline>
<statement> ::= <label> | <directive> | <instruction>
<label> ::= <identifier> <colon>
<directive> ::= <dot> (<section_directive> | <export_directive> | <import_directive> ) <newline>
<directive> ::= <dot> (<section_directive> | <export_directive> | <import_directive> )
<section_directive> ::= "section" <identifier>
@ -11,7 +11,7 @@
<import_directive> ::= "import" <identifier>
<instruction> ::= <identifier> <operands> <newline>
<instruction> ::= <identifier> <operands>
<operands> ::= <operand> ( <comma> <operand> )*

View File

@ -1,4 +1,4 @@
CFLAGS?=-Wall -Wextra -Wpedantic -Werror -O2 -std=c23 -flto -fomit-frame-pointer -DNDEBUG -D_POSIX_C_SOURCE=200809L
CFLAGS?=-Wall -Wextra -Wpedantic -O2 -std=c23 -flto -fomit-frame-pointer -DNDEBUG -D_POSIX_C_SOURCE=200809L
LDFLAGS?=-flto -s -Wl,--gc-sections
BUILD_DIR?=build/release/

View File

@ -17,6 +17,10 @@ error_t *ast_node_alloc(ast_node_t **output) {
return nullptr;
}
void ast_node_free_value(ast_node_t *node) {
// TODO: decide how value ownership will work and clean it up here
}
void ast_node_free(ast_node_t *node) {
if (node == nullptr)
return;
@ -26,6 +30,8 @@ void ast_node_free(ast_node_t *node) {
free(node->children);
}
ast_node_free_value(node);
memset(node, 0, sizeof(ast_node_t));
free(node);
}
@ -155,8 +161,6 @@ const char *ast_node_id_to_cstr(node_id_t id) {
return "NODE_ASTERISK";
case NODE_DOT:
return "NODE_DOT";
case NODE_NEWLINE:
return "NODE_NEWLINE";
case NODE_IMPORT:
return "NODE_IMPORT";
case NODE_EXPORT:
@ -176,8 +180,7 @@ static void ast_node_print_internal(ast_node_t *node, int indent) {
}
printf("%s", ast_node_id_to_cstr(node->id));
if (node->token_entry && node->token_entry->token.value &&
node->id != NODE_NEWLINE) {
if (node->token_entry && node->token_entry->token.value) {
printf(" \"%s\"", node->token_entry->token.value);
}
printf("\n");
@ -190,18 +193,3 @@ static void ast_node_print_internal(ast_node_t *node, int indent) {
void ast_node_print(ast_node_t *node) {
ast_node_print_internal(node, 0);
}
void ast_node_prune(ast_node_t *node, node_id_t id) {
size_t new_len = 0;
for (size_t i = 0; i < node->len; i++) {
auto child = node->children[i];
if (child->id == id) {
ast_node_free(child);
continue;
}
ast_node_prune(child, id);
node->children[new_len] = child;
new_len++;
}
node->len = new_len;
}

View File

@ -1,11 +1,9 @@
#ifndef INCLUDE_SRC_AST_H_
#define INCLUDE_SRC_AST_H_
#include "data/registers.h"
#include "error.h"
#include "lexer.h"
#include "tokenlist.h"
#include <assert.h>
#include <stddef.h>
#include <stdint.h>
@ -56,7 +54,6 @@ typedef enum node_id {
NODE_MINUS,
NODE_ASTERISK,
NODE_DOT,
NODE_NEWLINE,
} node_id_t;
typedef struct ast_node ast_node_t;
@ -65,37 +62,6 @@ constexpr size_t node_default_children_cap = 8;
/* 65K ought to be enough for anybody */
constexpr size_t node_max_children_cap = 1 << 16;
typedef struct number {
uint64_t value;
operand_size_t size;
} number_t;
typedef struct register_ {
register_id_t id;
operand_size_t size;
} register_t;
typedef struct opcode_encoding {
uint8_t buffer[32];
size_t len;
} opcode_encoding_t;
typedef struct instruction {
bool has_reference;
opcode_encoding_t encoding;
int64_t address;
} instruction_t;
typedef struct reference {
int64_t offset;
int64_t address;
operand_size_t size;
} reference_t;
typedef struct {
int64_t address;
} label_t;
struct ast_node {
node_id_t id;
tokenlist_entry_t *token_entry;
@ -104,39 +70,14 @@ struct ast_node {
ast_node_t **children;
union {
register_t reg;
number_t number;
instruction_t instruction;
reference_t reference;
label_t label;
struct {
uint64_t value;
int size;
} integer;
char *name;
} value;
};
static inline register_t *ast_node_register_value(ast_node_t *node) {
assert(node->id == NODE_REGISTER);
return &node->value.reg;
}
static inline number_t *ast_node_number_value(ast_node_t *node) {
assert(node->id == NODE_NUMBER);
return &node->value.number;
}
static inline instruction_t *ast_node_instruction_value(ast_node_t *node) {
assert(node->id == NODE_INSTRUCTION);
return &node->value.instruction;
}
static inline reference_t *ast_node_reference_value(ast_node_t *node) {
assert(node->id == NODE_LABEL_REFERENCE);
return &node->value.reference;
}
static inline label_t *ast_node_label_value(ast_node_t *node) {
assert(node->id == NODE_LABEL);
return &node->value.label;
}
/**
* @brief Allocates a new AST node
*
@ -182,17 +123,4 @@ error_t *ast_node_add_child(ast_node_t *node, ast_node_t *child);
*/
void ast_node_print(ast_node_t *node);
/**
* Prune the children with a given id
*
* The tree is recursively visited and all child nodes of a given ID are pruned
* completely. If a node has the giver id, it will get removed along wih all its
* children, even if some of those children have different ids. The root node id
* is never checked so the tree is guaranteed to remain and allocated valid.
*
* @param node The root of the tree you want to prune
* @param id The id of the nodes you want to prune
*/
void ast_node_prune(ast_node_t *node, node_id_t id);
#endif // INCLUDE_SRC_AST_H_

View File

@ -1,6 +0,0 @@
#include "bytes.h"
#include "error.h"
error_t *const err_bytes_no_capacity = &(error_t){
.message = "Not enough capacity in bytes buffer",
};

View File

@ -1,60 +0,0 @@
#ifndef INCLUDE_SRC_BYTES_H_
#define INCLUDE_SRC_BYTES_H_
#include "error.h"
#include <stddef.h>
#include <stdint.h>
#include <string.h>
extern error_t *const err_bytes_no_capacity;
typedef struct bytes {
size_t len;
size_t cap;
uint8_t buffer[];
} bytes_t;
#define LOCAL_BYTES_ANONYMOUS(N) \
&(struct { \
size_t len; \
size_t cap; \
uint8_t buffer[(N)]; \
}) { \
0, (N), {} \
}
#define LOCAL_BYTES(N) (bytes_t *)LOCAL_BYTES_ANONYMOUS(N);
static inline error_t *bytes_append_uint8(bytes_t *bytes, uint8_t value) {
if (bytes->len >= bytes->cap)
return err_bytes_no_capacity;
bytes->buffer[bytes->len++] = value;
return nullptr;
}
static inline error_t *bytes_append_array(bytes_t *dst, size_t n,
uint8_t buffer[static n]) {
if (dst->len + n >= dst->cap)
return err_bytes_no_capacity;
memcpy(dst->buffer + dst->len, buffer, n);
dst->len += n;
return nullptr;
}
static inline error_t *bytes_append_bytes(bytes_t *dst, bytes_t *src) {
return bytes_append_array(dst, src->len, src->buffer);
}
static inline error_t *bytes_append_uint16(bytes_t *dst, uint16_t value) {
return bytes_append_array(dst, sizeof(value), (uint8_t *)&value);
}
static inline error_t *bytes_append_uint32(bytes_t *dst, uint32_t value) {
return bytes_append_array(dst, sizeof(value), (uint8_t *)&value);
}
static inline error_t *bytes_append_uint64(bytes_t *dst, uint64_t value) {
return bytes_append_array(dst, sizeof(value), (uint8_t *)&value);
}
#endif // INCLUDE_SRC_BYTES_H_

View File

@ -1,265 +0,0 @@
#include "opcodes.h"
// clang-format off
opcode_data_t *const opcodes[] = {
// RET
&(opcode_data_t) {
.mnemonic = "ret",
.opcode = 0xC3,
.opcode_extension = opcode_extension_none,
.operand_count = 0,
},
// RET imm16
&(opcode_data_t) {
.mnemonic = "ret",
.opcode = 0xC2,
.opcode_extension = opcode_extension_none,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_IMMEDIATE, .size = OPERAND_SIZE_16 },
},
},
// PUSH imm8
&(opcode_data_t) {
.mnemonic = "push",
.opcode = 0x6A,
.opcode_extension = opcode_extension_none,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_IMMEDIATE, .size = OPERAND_SIZE_8},
},
},
// PUSH imm16
&(opcode_data_t) {
.mnemonic = "push",
.opcode = 0x68,
.opcode_extension = opcode_extension_none,
.operand_size_prefix = true,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_IMMEDIATE, .size = OPERAND_SIZE_16},
},
},
// PUSH imm32
&(opcode_data_t) {
.mnemonic = "push",
.opcode = 0x68,
.opcode_extension = opcode_extension_none,
.operand_size_prefix = false,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_IMMEDIATE, .size = OPERAND_SIZE_32},
},
},
// PUSH reg16,
&(opcode_data_t) {
.mnemonic = "push",
.opcode = 0x50,
.opcode_extension = opcode_extension_none,
.encoding_class = ENCODING_OPCODE_REGISTER,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_16 },
},
},
// PUSH reg64
&(opcode_data_t) {
.mnemonic = "push",
.opcode = 0x50,
.opcode_extension = opcode_extension_none,
.encoding_class = ENCODING_OPCODE_REGISTER,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_64 },
},
},
// NOT reg16
&(opcode_data_t) {
.mnemonic = "not",
.opcode = 0xF7,
.opcode_extension = 2,
.operand_size_prefix = true,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_16 },
},
},
// NOT reg32
&(opcode_data_t) {
.mnemonic = "not",
.opcode = 0xF7,
.opcode_extension = 2,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_32 },
},
},
// NOT reg64
&(opcode_data_t) {
.mnemonic = "not",
.opcode = 0xF7,
.opcode_extension = 2,
.rex_w_prefix = true,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_64 },
},
},
// NEG reg16
&(opcode_data_t) {
.mnemonic = "neg",
.opcode = 0xF7,
.opcode_extension = 3,
.operand_size_prefix = true,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_16 },
},
},
// NEG reg32
&(opcode_data_t) {
.mnemonic = "neg",
.opcode = 0xF7,
.opcode_extension = 3,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_32 },
},
},
// NEG reg64
&(opcode_data_t) {
.mnemonic = "neg",
.opcode = 0xF7,
.opcode_extension = 3,
.rex_w_prefix = true,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_64 },
},
},
// CALL rel32
&(opcode_data_t) {
.mnemonic = "call",
.opcode = 0xE8,
.opcode_extension = opcode_extension_none,
.encoding_class = ENCODING_DEFAULT,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_IMMEDIATE, .size = OPERAND_SIZE_32 },
},
},
// CALL reg64
&(opcode_data_t) {
.mnemonic = "call",
.opcode = 0xFF,
.opcode_extension = 2,
.encoding_class = ENCODING_DEFAULT,
.rex_w_prefix = true,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_64 },
},
},
// CALL mem64
&(opcode_data_t) {
.mnemonic = "call",
.opcode = 0xFF,
.opcode_extension = 2,
.encoding_class = ENCODING_DEFAULT,
.rex_w_prefix = true,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_MEMORY, .size = OPERAND_SIZE_64 },
},
},
// JMP rel8 (short jump)
&(opcode_data_t) {
.mnemonic = "jmp",
.opcode = 0xEB,
.opcode_extension = opcode_extension_none,
.encoding_class = ENCODING_DEFAULT,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_IMMEDIATE, .size = OPERAND_SIZE_8 },
},
},
// JMP rel16
&(opcode_data_t) {
.mnemonic = "jmp",
.opcode = 0xE9,
.opcode_extension = opcode_extension_none,
.encoding_class = ENCODING_DEFAULT,
.operand_size_prefix = true,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_IMMEDIATE, .size = OPERAND_SIZE_16 },
},
},
// JMP reg16
&(opcode_data_t) {
.mnemonic = "jmp",
.opcode = 0xFF,
.opcode_extension = 4,
.encoding_class = ENCODING_DEFAULT,
.operand_size_prefix = true,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_16 },
},
},
// JMP rel32 (near jump)
&(opcode_data_t) {
.mnemonic = "jmp",
.opcode = 0xE9,
.opcode_extension = opcode_extension_none,
.encoding_class = ENCODING_DEFAULT,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_IMMEDIATE, .size = OPERAND_SIZE_32 },
},
},
// JMP reg32
&(opcode_data_t) {
.mnemonic = "jmp",
.opcode = 0xFF,
.opcode_extension = 4,
.encoding_class = ENCODING_DEFAULT,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_32 },
},
},
// JMP reg64
&(opcode_data_t) {
.mnemonic = "jmp",
.opcode = 0xFF,
.opcode_extension = 4,
.encoding_class = ENCODING_DEFAULT,
.rex_w_prefix = true,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_REGISTER, .size = OPERAND_SIZE_64 },
},
},
// JMP mem64
&(opcode_data_t) {
.mnemonic = "jmp",
.opcode = 0xFF,
.opcode_extension = 4,
.encoding_class = ENCODING_DEFAULT,
.rex_w_prefix = true,
.operand_count = 1,
.operands = {
{ .kind = OPERAND_MEMORY, .size = OPERAND_SIZE_64 },
},
},
nullptr,
};

View File

@ -1,56 +0,0 @@
#ifndef INCLUDE_DATA_OPCODES_H_
#define INCLUDE_DATA_OPCODES_H_
#include "../data/registers.h"
#include <stddef.h>
#include <stdint.h>
constexpr uint8_t rex_prefix = 0x40;
constexpr uint8_t rex_prefix_w = 0x48;
constexpr uint8_t rex_prefix_r = 0x44;
constexpr uint8_t rex_prefix_x = 0x42;
constexpr uint8_t rex_prefix_b = 0x41;
constexpr uint8_t operand_size_prefix = 0x66;
constexpr uint8_t memory_size_prefix = 0x67;
constexpr uint8_t lock_prefix = 0xF0;
constexpr uint8_t repne_prefix = 0xF2;
constexpr uint8_t rep_prefix = 0xF3;
typedef enum encoding_class {
ENCODING_DEFAULT, // use modrm+sib for registers and memory, append
// immediates
ENCODING_OPCODE_REGISTER, // encode the register in the last 3 bits of the
// opcode
} encoding_class_t;
typedef enum operand_kind {
OPERAND_REGISTER,
OPERAND_MEMORY,
OPERAND_IMMEDIATE,
} operand_kind_t;
typedef struct operand_info {
operand_kind_t kind;
operand_size_t size;
} operand_info_t;
constexpr uint8_t opcode_extension_none = 0xFF;
typedef struct opcode_data {
const char *mnemonic;
uint16_t opcode;
uint8_t opcode_extension; // 3 bits for the opcode extension in the reg
// field of a modr/m byte
encoding_class_t encoding_class;
bool operand_size_prefix;
bool address_size_prefix;
bool rex_w_prefix;
size_t operand_count;
operand_info_t operands[3];
} opcode_data_t;
extern opcode_data_t *const opcodes[];
#endif // INCLUDE_DATA_OPCODES_H_

View File

@ -1,92 +0,0 @@
#include "registers.h"
register_data_t *const registers[] = {
// Instruction pointer
&(register_data_t){"rip", REG_RIP, OPERAND_SIZE_64},
&(register_data_t){"eip", REG_RIP, OPERAND_SIZE_32},
&(register_data_t){"ip", REG_RIP, OPERAND_SIZE_16},
// 64-bit general purpose registers
&(register_data_t){"rax", REG_A, OPERAND_SIZE_64},
&(register_data_t){"rcx", REG_C, OPERAND_SIZE_64},
&(register_data_t){"rdx", REG_D, OPERAND_SIZE_64},
&(register_data_t){"rbx", REG_B, OPERAND_SIZE_64},
&(register_data_t){"rsp", REG_SP, OPERAND_SIZE_64},
&(register_data_t){"rbp", REG_BP, OPERAND_SIZE_64},
&(register_data_t){"rsi", REG_SI, OPERAND_SIZE_64},
&(register_data_t){"rdi", REG_DI, OPERAND_SIZE_64},
&(register_data_t){"r8", REG_8, OPERAND_SIZE_64},
&(register_data_t){"r9", REG_9, OPERAND_SIZE_64},
&(register_data_t){"r10", REG_10, OPERAND_SIZE_64},
&(register_data_t){"r11", REG_11, OPERAND_SIZE_64},
&(register_data_t){"r12", REG_12, OPERAND_SIZE_64},
&(register_data_t){"r13", REG_13, OPERAND_SIZE_64},
&(register_data_t){"r14", REG_14, OPERAND_SIZE_64},
&(register_data_t){"r15", REG_15, OPERAND_SIZE_64},
// 32-bit general purpose registers
&(register_data_t){"eax", REG_A, OPERAND_SIZE_32},
&(register_data_t){"ecx", REG_C, OPERAND_SIZE_32},
&(register_data_t){"edx", REG_D, OPERAND_SIZE_32},
&(register_data_t){"ebx", REG_B, OPERAND_SIZE_32},
&(register_data_t){"esp", REG_SP, OPERAND_SIZE_32},
&(register_data_t){"ebp", REG_BP, OPERAND_SIZE_32},
&(register_data_t){"esi", REG_SI, OPERAND_SIZE_32},
&(register_data_t){"edi", REG_DI, OPERAND_SIZE_32},
&(register_data_t){"r8d", REG_8, OPERAND_SIZE_32},
&(register_data_t){"r9d", REG_9, OPERAND_SIZE_32},
&(register_data_t){"r10d", REG_10, OPERAND_SIZE_32},
&(register_data_t){"r11d", REG_11, OPERAND_SIZE_32},
&(register_data_t){"r12d", REG_12, OPERAND_SIZE_32},
&(register_data_t){"r13d", REG_13, OPERAND_SIZE_32},
&(register_data_t){"r14d", REG_14, OPERAND_SIZE_32},
&(register_data_t){"r15d", REG_15, OPERAND_SIZE_32},
// 16-bit general purpose registers
&(register_data_t){"ax", REG_A, OPERAND_SIZE_16},
&(register_data_t){"cx", REG_C, OPERAND_SIZE_16},
&(register_data_t){"dx", REG_D, OPERAND_SIZE_16},
&(register_data_t){"bx", REG_B, OPERAND_SIZE_16},
&(register_data_t){"sp", REG_SP, OPERAND_SIZE_16},
&(register_data_t){"bp", REG_BP, OPERAND_SIZE_16},
&(register_data_t){"si", REG_SI, OPERAND_SIZE_16},
&(register_data_t){"di", REG_DI, OPERAND_SIZE_16},
&(register_data_t){"r8w", REG_8, OPERAND_SIZE_16},
&(register_data_t){"r9w", REG_9, OPERAND_SIZE_16},
&(register_data_t){"r10w", REG_10, OPERAND_SIZE_16},
&(register_data_t){"r11w", REG_11, OPERAND_SIZE_16},
&(register_data_t){"r12w", REG_12, OPERAND_SIZE_16},
&(register_data_t){"r13w", REG_13, OPERAND_SIZE_16},
&(register_data_t){"r14w", REG_14, OPERAND_SIZE_16},
&(register_data_t){"r15w", REG_15, OPERAND_SIZE_16},
// 8-bit general purpose registers (low byte)
&(register_data_t){"al", REG_A, OPERAND_SIZE_8 },
&(register_data_t){"cl", REG_C, OPERAND_SIZE_8 },
&(register_data_t){"dl", REG_D, OPERAND_SIZE_8 },
&(register_data_t){"bl", REG_B, OPERAND_SIZE_8 },
&(register_data_t){"spl", REG_SP, OPERAND_SIZE_8 },
&(register_data_t){"bpl", REG_BP, OPERAND_SIZE_8 },
&(register_data_t){"sil", REG_SI, OPERAND_SIZE_8 },
&(register_data_t){"dil", REG_DI, OPERAND_SIZE_8 },
&(register_data_t){"r8b", REG_8, OPERAND_SIZE_8 },
&(register_data_t){"r9b", REG_9, OPERAND_SIZE_8 },
&(register_data_t){"r10b", REG_10, OPERAND_SIZE_8 },
&(register_data_t){"r11b", REG_11, OPERAND_SIZE_8 },
&(register_data_t){"r12b", REG_12, OPERAND_SIZE_8 },
&(register_data_t){"r13b", REG_13, OPERAND_SIZE_8 },
&(register_data_t){"r14b", REG_14, OPERAND_SIZE_8 },
&(register_data_t){"r15b", REG_15, OPERAND_SIZE_8 },
// x87 floating point registers
&(register_data_t){"st0", REG_ST0, OPERAND_SIZE_80},
&(register_data_t){"st1", REG_ST1, OPERAND_SIZE_80},
&(register_data_t){"st2", REG_ST2, OPERAND_SIZE_80},
&(register_data_t){"st3", REG_ST3, OPERAND_SIZE_80},
&(register_data_t){"st4", REG_ST4, OPERAND_SIZE_80},
&(register_data_t){"st5", REG_ST5, OPERAND_SIZE_80},
&(register_data_t){"st6", REG_ST6, OPERAND_SIZE_80},
&(register_data_t){"st7", REG_ST7, OPERAND_SIZE_80},
nullptr,
};

View File

@ -1,82 +0,0 @@
#ifndef INCLUDE_DATA_REGISTERS_H_
#define INCLUDE_DATA_REGISTERS_H_
typedef enum operand_size {
OPERAND_SIZE_INVALID = 0,
OPERAND_SIZE_8 = 1 << 0,
OPERAND_SIZE_16 = 1 << 1,
OPERAND_SIZE_32 = 1 << 2,
OPERAND_SIZE_64 = 1 << 3,
OPERAND_SIZE_80 = 1 << 4,
OPERAND_SIZE_128 = 1 << 5,
OPERAND_SIZE_256 = 1 << 6,
OPERAND_SIZE_512 = 1 << 7,
} operand_size_t;
static inline operand_size_t bits_to_operand_size(int bits) {
switch (bits) {
case 8:
return OPERAND_SIZE_8;
case 16:
return OPERAND_SIZE_16;
case 32:
return OPERAND_SIZE_32;
case 64:
return OPERAND_SIZE_64;
case 80:
return OPERAND_SIZE_80;
case 128:
return OPERAND_SIZE_128;
case 256:
return OPERAND_SIZE_256;
case 512:
return OPERAND_SIZE_512;
default:
return OPERAND_SIZE_INVALID;
}
}
typedef enum register_id {
// Special registers
REG_RIP = -1,
// General purpose registers
REG_A = 0x0000,
REG_C,
REG_D,
REG_B,
REG_SP,
REG_BP,
REG_SI,
REG_DI,
REG_8,
REG_9,
REG_10,
REG_11,
REG_12,
REG_13,
REG_14,
REG_15,
REG_ST0 = 0x1000,
REG_ST1,
REG_ST2,
REG_ST3,
REG_ST4,
REG_ST5,
REG_ST6,
REG_ST7,
} register_id_t;
typedef struct register_data {
const char *name;
register_id_t id;
operand_size_t size;
} register_data_t;
extern register_data_t *const registers[];
#endif // INCLUDE_DATA_REGISTERS_H_

View File

@ -1,711 +0,0 @@
#include "encoder.h"
#include "../bytes.h"
#include "../data/opcodes.h"
#include "symbols.h"
#include <assert.h>
#include <errno.h>
#include <string.h>
/**
* General encoder flow:
*
* There are 2 major passes the encoder does:
*
* First pass:
* - Run through the AST and collect information:
* - Set register values
* - Parse/set number values
* - Mark all instructions that use label references
* - Encode all instructions that don't use label references
* - Update addresses of all labels and instructions. Use an estimated
* instruction size for those instructions that use label references.
*
* Second pass:
* - Run through the AST for all instructions that use label references and
* collect size information using the estimated addresses from pass 1
* - Encode label references with the estimated addresses, this fixes their
* size.
* - Update all addresses
*
* Iteration:
* - Repeat the second pass until addresses converge
*/
error_t *const err_encoder_invalid_register =
&(error_t){.message = "Invalid register"};
error_t *const err_encoder_number_overflow =
&(error_t){.message = "Number overflows the storage"};
error_t *const err_encoder_invalid_number_format =
&(error_t){.message = "Invalid number format"};
error_t *const err_encoder_invalid_size_suffix =
&(error_t){.message = "Invalid number size suffix"};
error_t *const err_encoder_unknown_symbol_reference =
&(error_t){.message = "Referenced an unknown symbol"};
error_t *const err_encoder_no_encoding_found =
&(error_t){.message = "No encoding found for instruction"};
error_t *const err_encoder_not_implemented =
&(error_t){.message = "Implementation for this opcode is missing"};
error_t *const err_encoder_unexpected_length =
&(error_t){.message = "Unexpectedly long encoding"};
error_t *encoder_alloc(encoder_t **output, ast_node_t *ast) {
*output = nullptr;
encoder_t *encoder = calloc(1, sizeof(encoder_t));
if (encoder == nullptr)
return err_allocation_failed;
encoder->ast = ast;
error_t *err = symbol_table_alloc(&encoder->symbols);
if (err) {
free(encoder);
return err;
}
*output = encoder;
return nullptr;
}
void encoder_free(encoder_t *encoder) {
if (encoder == nullptr)
return;
symbol_table_free(encoder->symbols);
free(encoder);
}
bool encoder_is_symbols_node(ast_node_t *node) {
switch (node->id) {
case NODE_LABEL:
case NODE_LABEL_REFERENCE:
case NODE_EXPORT_DIRECTIVE:
case NODE_IMPORT_DIRECTIVE:
return true;
default:
return false;
}
}
int encoder_get_number_base(ast_node_t *number) {
switch (number->children[0]->id) {
case NODE_BINARY:
return 2;
case NODE_OCTAL:
return 8;
case NODE_DECIMAL:
return 10;
case NODE_HEXADECIMAL:
return 16;
default:
assert(false);
}
__builtin_unreachable();
}
bool is_valid_size_suffix(int bits) {
switch (bits) {
case 0:
case 8:
case 16:
case 32:
case 64:
return true;
default:
return false;
}
}
bool is_overflow(uint64_t value, int bits) {
if (bits == 0 || bits >= 64)
return false;
uint64_t max_value = (1ULL << bits) - 1;
return value > max_value;
}
operand_size_t encoder_get_size_mask(uint64_t value, int bits) {
if (bits != 0)
return bits_to_operand_size(bits);
operand_size_t mask = OPERAND_SIZE_64;
if (value < (1ULL << 8))
mask |= OPERAND_SIZE_8;
if (value < (1ULL << 16))
mask |= OPERAND_SIZE_16;
if (value < (1ULL << 32))
mask |= OPERAND_SIZE_32;
return mask;
}
error_t *encoder_set_number_value(ast_node_t *node) {
assert(node->id == NODE_NUMBER);
assert(node->children[0]);
const char *number = node->children[0]->token_entry->token.value;
int base = encoder_get_number_base(node);
if (base != 10)
number += 2; // all except base 10 use a 0x, 0o or 0b prefix
char *endptr;
errno = 0;
uint64_t value = strtoull(number, &endptr, base);
if (errno == ERANGE)
return err_encoder_number_overflow;
if (endptr == number)
return err_encoder_invalid_number_format;
int bits = 0;
if (*endptr == ':') {
const char *suffix = endptr + 1;
bits = strtol(suffix, &endptr, 10);
if (endptr == suffix)
return err_encoder_invalid_number_format;
}
if (*endptr != '\0')
return err_encoder_invalid_number_format;
if (!is_valid_size_suffix(bits))
return err_encoder_invalid_size_suffix;
if (is_overflow(value, bits))
return err_encoder_number_overflow;
node->value.number.value = value;
node->value.number.size = encoder_get_size_mask(value, bits);
return nullptr;
}
error_t *encoder_set_register_value(ast_node_t *node) {
assert(node->id == NODE_REGISTER);
const char *value = node->token_entry->token.value;
for (size_t i = 0; registers[i] != nullptr; ++i) {
if (strcmp(value, registers[i]->name) == 0) {
node->value.reg.id = registers[i]->id;
node->value.reg.size = registers[i]->size;
return nullptr;
}
}
return err_encoder_invalid_register;
}
/**
* Set the opcode extension in the modrm field
*/
static inline uint8_t modrm_extension(uint8_t modrm, uint8_t extension) {
assert(extension != opcode_extension_none);
assert((extension & 0b111) == extension);
return (modrm & ~modrm_reg_mask) | extension << 3;
}
/**
* Return the rex bit for reg field in modrm
*/
static inline uint8_t modrm_reg_rex(uint8_t rex, register_id_t id) {
if (id & 0b1000)
rex |= rex_prefix_r;
return rex;
}
/**
* update modrm reg field with the given register, must be used alongside
* modrm_reg_rex
*/
static inline uint8_t modrm_reg(uint8_t modrm, register_id_t id) {
return (modrm & ~modrm_reg_mask) | (id & 0b111) << 3;
}
/**
* Return the rex bit for rm field in modrm
*/
static inline uint8_t modrm_rm_rex(uint8_t rex, register_id_t id) {
if (id & 0b1000)
rex |= rex_prefix_b;
return rex;
}
/**
* update modrm rm field with the given register, must be used alongside
* modrm_rm_rex
*/
static inline uint8_t modrm_rm(uint8_t modrm, register_id_t id) {
assert((modrm & modrm_mod_mask) == modrm_mod_register);
return (modrm & ~modrm_rm_mask) | (id & 0b111);
}
error_t *encoder_collect_info(encoder_t *encoder, ast_node_t *node,
ast_node_t *statement) {
error_t *err = nullptr;
if (encoder_is_symbols_node(node)) {
err = symbol_table_update(encoder->symbols, node, statement);
if (statement->id == NODE_INSTRUCTION)
statement->value.instruction.has_reference = true;
} else if (node->id == NODE_NUMBER)
err = encoder_set_number_value(node);
else if (node->id == NODE_REGISTER)
err = encoder_set_register_value(node);
if (err)
return err;
for (size_t i = 0; i < node->len; ++i) {
error_t *err =
encoder_collect_info(encoder, node->children[i], statement);
if (err)
return err;
}
return nullptr;
}
bool is_operand_match(operand_info_t *info, ast_node_t *operand) {
switch (info->kind) {
case OPERAND_REGISTER:
return operand->id == NODE_REGISTER &&
ast_node_register_value(operand)->size == info->size;
case OPERAND_MEMORY:
return operand->id == NODE_MEMORY;
case OPERAND_IMMEDIATE: {
if (operand->id != NODE_IMMEDIATE)
return false;
ast_node_t *child = operand->children[0];
if (child->id == NODE_NUMBER)
return (ast_node_number_value(child)->size & info->size) > 0;
else if (child->id == NODE_LABEL_REFERENCE) {
return info->size &= ast_node_reference_value(child)->size;
}
} // end OPERAND_IMMEDIATE case
}
assert(false && "unreachable");
__builtin_unreachable();
}
bool is_opcode_match(opcode_data_t *opcode, const char *mnemonic,
ast_node_t *operands) {
if (strcmp(opcode->mnemonic, mnemonic) != 0)
return false;
if (opcode->operand_count != operands->len)
return false;
for (size_t i = 0; i < operands->len; ++i) {
if (!is_operand_match(&opcode->operands[i], operands->children[i]))
return false;
}
return true;
}
error_t *encoder_get_opcode_data(ast_node_t *instruction, ast_node_t *operands,
opcode_data_t **opcode_out) {
const char *mnemonic = instruction->children[0]->token_entry->token.value;
for (size_t i = 0; opcodes[i]; ++i) {
opcode_data_t *opcode = opcodes[i];
if (is_opcode_match(opcode, mnemonic, operands)) {
*opcode_out = opcode;
return nullptr;
}
}
return err_encoder_no_encoding_found;
}
error_t *encode_two_operand(encoder_t *encoder, opcode_data_t *opcode,
ast_node_t *operands, bytes_t *encoding,
uint8_t *rex) {
(void)encoder;
(void)opcode;
(void)operands;
(void)encoding;
(void)rex;
assert(encoding->len >= 1 && "must have 1+ opcode byte in buffer already");
return err_encoder_not_implemented;
}
error_t *encode_one_register_in_opcode(encoder_t *encoder,
opcode_data_t *opcode,
ast_node_t *operands, bytes_t *encoding,
uint8_t *rex) {
(void)encoder;
(void)opcode;
register_id_t id = ast_node_register_value(operands->children[0])->id;
encoding->buffer[encoding->len - 1] |= id & 0b111;
if ((id & 0b1000) > 0) {
*rex |= rex_prefix_r;
}
return nullptr;
}
error_t *encode_one_register(encoder_t *encoder, opcode_data_t *opcode,
ast_node_t *operands, bytes_t *encoding,
uint8_t *rex) {
(void)encoder;
assert(operands->len == 1);
assert(operands->children[0]->id == NODE_REGISTER);
register_id_t id = ast_node_register_value(operands->children[0])->id;
uint8_t modrm = modrm_mod_register;
if (opcode->opcode_extension != opcode_extension_none) {
// register goes in rm field, extension goes in mod field
modrm = modrm_extension(modrm, opcode->opcode_extension);
modrm = modrm_rm(modrm, id);
*rex = modrm_rm_rex(*rex, id);
} else {
// register goes in reg field
// NOTE:
// it's actually likely this case just doesn't exist at all and all
// opcodes that take one register in modr/m _all_ have extended opcdes
modrm = modrm_reg(modrm, id);
*rex = modrm_reg_rex(*rex, id);
}
return bytes_append_uint8(encoding, modrm);
}
error_t *encode_one_immediate(encoder_t *encoder, opcode_data_t *opcode,
ast_node_t *operands, bytes_t *encoding,
uint8_t *rex) {
(void)encoder;
(void)opcode;
(void)rex;
assert(operands->len == 1);
assert(operands->children[0]->id == NODE_IMMEDIATE);
assert(operands->children[0]->len == 1);
ast_node_t *immediate = operands->children[0]->children[0];
assert(immediate->id == NODE_NUMBER ||
immediate->id == NODE_LABEL_REFERENCE);
operand_size_t size = opcode->operands[0].size;
if (immediate->id == NODE_NUMBER) {
uint64_t value = ast_node_number_value(immediate)->value;
error_t *err = nullptr;
switch (size) {
case OPERAND_SIZE_8:
err = bytes_append_uint8(encoding, value);
break;
case OPERAND_SIZE_16:
err = bytes_append_uint16(encoding, value);
break;
case OPERAND_SIZE_32:
err = bytes_append_uint32(encoding, value);
break;
case OPERAND_SIZE_64:
err = bytes_append_uint64(encoding, value);
break;
default:
assert(false && "intentionally unhandled");
}
return err;
} else {
reference_t *reference = ast_node_reference_value(immediate);
switch (size) {
case OPERAND_SIZE_64:
return bytes_append_uint64(encoding, reference->address);
case OPERAND_SIZE_32:
return bytes_append_uint32(encoding, reference->offset);
case OPERAND_SIZE_16:
return bytes_append_uint16(encoding, reference->offset);
case OPERAND_SIZE_8:
return bytes_append_uint8(encoding, reference->offset);
default:
assert(false && "intentionally unhandled");
}
}
__builtin_unreachable();
}
error_t *encode_one_memory(encoder_t *encoder, opcode_data_t *opcode,
ast_node_t *operands, bytes_t *encoding,
uint8_t *rex) {
(void)encoder;
(void)opcode;
(void)operands;
(void)encoding;
(void)rex;
return err_encoder_not_implemented;
}
error_t *encode_one_operand(encoder_t *encoder, opcode_data_t *opcode,
ast_node_t *operands, bytes_t *encoding,
uint8_t *rex) {
switch (opcode->operands[0].kind) {
case OPERAND_REGISTER:
if (opcode->encoding_class == ENCODING_OPCODE_REGISTER)
return encode_one_register_in_opcode(encoder, opcode, operands,
encoding, rex);
else
return encode_one_register(encoder, opcode, operands, encoding,
rex);
case OPERAND_MEMORY:
return encode_one_memory(encoder, opcode, operands, encoding, rex);
case OPERAND_IMMEDIATE:
return encode_one_immediate(encoder, opcode, operands, encoding, rex);
}
}
error_t *encoder_encode_instruction(encoder_t *encoder,
ast_node_t *instruction) {
ast_node_t *operands = instruction->children[1];
opcode_data_t *opcode = nullptr;
error_t *err = encoder_get_opcode_data(instruction, operands, &opcode);
if (err)
return err;
uint8_t rex = 0;
bytes_t *encoding = LOCAL_BYTES(32);
if (opcode->opcode > 0xFF &&
(err = bytes_append_uint8(encoding, opcode->opcode >> 8)))
return err;
if ((err = bytes_append_uint8(encoding, opcode->opcode & 0xFF)))
return err;
// NOTE:operand encoders all expect the opcode to be in the buffer already.
// Some of them rely on this to encode the register value in the opcode
// byte.
switch (opcode->operand_count) {
case 0:
break;
case 1:
err = encode_one_operand(encoder, opcode, operands, encoding, &rex);
break;
case 2:
err = encode_two_operand(encoder, opcode, operands, encoding, &rex);
break;
default:
err = err_encoder_not_implemented;
}
if (err)
return err;
// produce the actual encoding output in the NODE_INSTRUCTION value
instruction_t *instruction_value = ast_node_instruction_value(instruction);
uint8_t *output = instruction_value->encoding.buffer;
size_t output_len = 0;
// Handle prefixes
if (opcode->rex_w_prefix)
rex = rex_prefix_w;
if (opcode->address_size_prefix)
output[output_len++] = memory_size_prefix;
if (opcode->operand_size_prefix)
output[output_len++] = operand_size_prefix;
if (rex > 0)
output[output_len++] = rex;
// copy the encoded opcode and operands
if (encoding->len > 20)
return err_encoder_unexpected_length;
memcpy(output + output_len, encoding->buffer, encoding->len);
output_len += encoding->len;
instruction_value->encoding.len = output_len;
return nullptr;
}
/**
* Initial guess for instruction size of instructions that contain a label
* reference
*/
constexpr size_t instruction_size_estimate = 10;
/**
* Perform the initial pass over the AST.
*
* - Collect information about the operands
* - parse and set number values
* - set the register values
* - determine if label references are used by an instruction
* - encode instructions that don't use label references
* - determine estimated addresses of each statement
*
*/
error_t *encoder_first_pass(encoder_t *encoder) {
ast_node_t *root = encoder->ast;
assert(root->id == NODE_PROGRAM);
uintptr_t address = 0;
for (size_t i = 0; i < root->len; ++i) {
ast_node_t *statement = root->children[i];
error_t *err = encoder_collect_info(encoder, statement, statement);
if (err)
return err;
if (statement->id == NODE_INSTRUCTION &&
ast_node_instruction_value(statement)->has_reference == false) {
err = encoder_encode_instruction(encoder, statement);
if (err)
return err;
instruction_t *instruction = ast_node_instruction_value(statement);
instruction->address = address;
address += instruction->encoding.len;
} else if (statement->id == NODE_INSTRUCTION) {
instruction_t *instruction = ast_node_instruction_value(statement);
instruction->encoding.len = instruction_size_estimate;
instruction->address = address;
address += instruction_size_estimate;
} else if (statement->id == NODE_LABEL) {
label_t *label = ast_node_label_value(statement);
label->address = address;
}
}
return nullptr;
}
operand_size_t signed_to_size_mask(int64_t value) {
operand_size_t size = OPERAND_SIZE_64;
if (value >= INT8_MIN && value <= INT8_MAX)
size |= OPERAND_SIZE_8;
if (value >= INT16_MIN && value <= INT16_MAX)
size |= OPERAND_SIZE_16;
if (value >= INT32_MIN && value <= INT32_MAX)
size |= OPERAND_SIZE_32;
return size;
}
int64_t statement_offset(ast_node_t *from, ast_node_t *to) {
assert(from->id == NODE_INSTRUCTION);
assert(to->id == NODE_LABEL);
instruction_t *instruction = ast_node_instruction_value(from);
int64_t from_addr = instruction->address + instruction->encoding.len;
int64_t to_addr = ast_node_label_value(to)->address;
return to_addr - from_addr;
}
error_t *encoder_collect_reference_info(encoder_t *encoder, ast_node_t *node,
ast_node_t *statement) {
assert(statement->id == NODE_INSTRUCTION);
if (node->id == NODE_LABEL_REFERENCE) {
const char *name = node->token_entry->token.value;
symbol_t *symbol = symbol_table_lookup(encoder->symbols, name);
assert(symbol && symbol->statement &&
symbol->statement->id == NODE_LABEL);
int64_t offset = statement_offset(statement, symbol->statement);
int64_t absolute = ast_node_label_value(symbol->statement)->address;
operand_size_t size = signed_to_size_mask(offset);
node->value.reference.address = absolute;
node->value.reference.offset = offset;
node->value.reference.size = size;
}
for (size_t i = 0; i < node->len; ++i) {
error_t *err = encoder_collect_reference_info(
encoder, node->children[i], statement);
if (err)
return err;
}
return nullptr;
}
bool encoder_should_reencode(ast_node_t *statement) {
if (statement->id != NODE_INSTRUCTION)
return false;
instruction_t *instruction = ast_node_instruction_value(statement);
return instruction->has_reference;
}
void set_statement_address(ast_node_t *statement, int64_t address) {
if (statement->id == NODE_INSTRUCTION) {
ast_node_instruction_value(statement)->address = address;
} else if (statement->id == NODE_LABEL) {
ast_node_label_value(statement)->address = address;
}
}
size_t get_statement_length(ast_node_t *statement) {
if (statement->id != NODE_INSTRUCTION)
return 0;
return ast_node_instruction_value(statement)->encoding.len;
}
/**
* Perform the second pass. Updates the label info and encodes all instructions
* that have a label reference.that performs actual encoding.
*/
error_t *encoder_second_pass(encoder_t *encoder, bool *did_update) {
ast_node_t *root = encoder->ast;
*did_update = false;
int64_t address = 0;
for (size_t i = 0; i < root->len; ++i) {
ast_node_t *statement = root->children[i];
set_statement_address(statement, address);
size_t before = get_statement_length(statement);
if (encoder_should_reencode(statement)) {
error_t *err =
encoder_collect_reference_info(encoder, statement, statement);
if (err)
return err;
err = encoder_encode_instruction(encoder, statement);
if (err)
return err;
}
size_t after = get_statement_length(statement);
*did_update = *did_update || (before != after);
address += after;
}
return nullptr;
}
opcode_data_t *encoder_find_opcode(ast_node_t *instruction) {
for (size_t i = 0; opcodes[i] != nullptr; ++i) {
const char *mnemonic =
instruction->children[0]->token_entry->token.value;
ast_node_t *operands = instruction->children[1];
if (is_opcode_match(opcodes[i], mnemonic, operands))
return opcodes[i];
}
return nullptr;
}
error_t *encoder_check_symbols(encoder_t *encoder) {
for (size_t i = 0; i < encoder->symbols->len; ++i)
if (encoder->symbols->symbols[i].kind == SYMBOL_REFERENCE)
return err_encoder_unknown_symbol_reference;
return nullptr;
}
error_t *encoder_encode(encoder_t *encoder) {
error_t *err = encoder_first_pass(encoder);
if (err)
return err;
err = encoder_check_symbols(encoder);
if (err)
return err;
bool did_update = true;
for (int i = 0; i < 10 && did_update; ++i) {
err = encoder_second_pass(encoder, &did_update);
if (err)
return err;
}
return nullptr;
}

View File

@ -1,33 +0,0 @@
#ifndef INCLUDE_ENCODER_ENCODER_H_
#define INCLUDE_ENCODER_ENCODER_H_
#include "symbols.h"
typedef struct encoder {
symbol_table_t *symbols;
ast_node_t *ast;
} encoder_t;
constexpr uint8_t modrm_mod_memory = 0b00'000'000;
constexpr uint8_t modrm_mod_memory_displacement8 = 0b01'000'000;
constexpr uint8_t modrm_mod_memory_displacement32 = 0b10'000'000;
constexpr uint8_t modrm_mod_register = 0b11'000'000;
constexpr uint8_t modrm_reg_mask = 0b00'111'000;
constexpr uint8_t modrm_rm_mask = 0b00'000'111;
constexpr uint8_t modrm_mod_mask = 0b11'000'000;
error_t *encoder_alloc(encoder_t **encoder, ast_node_t *ast);
error_t *encoder_encode(encoder_t *encoder);
void encoder_free(encoder_t *encoder);
extern error_t *const err_encoder_invalid_register;
extern error_t *const err_encoder_number_overflow;
extern error_t *const err_encoder_invalid_number_format;
extern error_t *const err_encoder_invalid_size_suffix;
extern error_t *const err_encoder_unknown_symbol_reference;
extern error_t *const err_encoder_no_encoding_found;
extern error_t *const err_encoder_not_implemented;
extern error_t *const err_encoder_unexpected_length;
#endif // INCLUDE_ENCODER_ENCODER_H_

View File

@ -92,7 +92,7 @@ EXPORT | | | ERR | |
-------------|-----------|----------|----------|----------|
*/
bool symbol_table_should_upgrade(symbol_kind_t old, symbol_kind_t new) {
bool symbol_table_should_update(symbol_kind_t old, symbol_kind_t new) {
if (old == SYMBOL_REFERENCE)
return new != SYMBOL_REFERENCE;
if (old == SYMBOL_LOCAL)
@ -112,7 +112,7 @@ bool symbol_table_should_error(symbol_kind_t old, symbol_kind_t new) {
* @pre The symbol _must not_ already be in the table.
*/
error_t *symbol_table_add(symbol_table_t *table, char *name, symbol_kind_t kind,
ast_node_t *statement) {
ast_node_t *node) {
if (table->len >= table->cap) {
error_t *err = symbol_table_grow_cap(table);
if (err)
@ -122,7 +122,7 @@ error_t *symbol_table_add(symbol_table_t *table, char *name, symbol_kind_t kind,
table->symbols[table->len] = (symbol_t){
.name = name,
.kind = kind,
.statement = statement,
.node = node,
};
table->len += 1;
@ -130,29 +130,23 @@ error_t *symbol_table_add(symbol_table_t *table, char *name, symbol_kind_t kind,
return nullptr;
}
error_t *symbol_table_update(symbol_table_t *table, ast_node_t *node,
ast_node_t *statement) {
error_t *symbol_table_update(symbol_table_t *table, ast_node_t *node) {
char *name;
symbol_kind_t kind;
error_t *err = symbol_table_get_node_info(node, &kind, &name);
if (err)
return err;
if (kind != SYMBOL_LOCAL)
statement = nullptr;
symbol_t *symbol = symbol_table_lookup(table, name);
if (!symbol)
return symbol_table_add(table, name, kind, statement);
return symbol_table_add(table, name, kind, node);
if (symbol_table_should_error(symbol->kind, kind))
return err_symbol_table_incompatible_symbols;
if (symbol_table_should_upgrade(symbol->kind, kind)) {
if (symbol_table_should_update(symbol->kind, kind)) {
symbol->name = name;
symbol->kind = kind;
symbol->node = node;
}
if (kind == SYMBOL_LOCAL && symbol->statement == nullptr)
symbol->statement = statement;
return nullptr;
}

View File

@ -29,7 +29,7 @@ typedef enum symbol_kind {
typedef struct symbol {
char *name;
symbol_kind_t kind;
ast_node_t *statement;
ast_node_t *node;
} symbol_t;
typedef struct symbol_table {
@ -40,8 +40,7 @@ typedef struct symbol_table {
error_t *symbol_table_alloc(symbol_table_t **table);
void symbol_table_free(symbol_table_t *table);
error_t *symbol_table_update(symbol_table_t *table, ast_node_t *node,
ast_node_t *statement);
error_t *symbol_table_update(symbol_table_t *table, ast_node_t *node);
symbol_t *symbol_table_lookup(symbol_table_t *table, const char *name);
#endif // INCLUDE_ENCODER_SYMBOLS_H_

View File

@ -136,7 +136,7 @@ error_t *lexer_open(lexer_t *lex, char *path) {
*
* @pre There must be at least n characters in the input buffer
*/
void lexer_shift_buffer(lexer_t *lex, size_t n) {
void lexer_shift_buffer(lexer_t *lex, int n) {
assert(lex->buffer_count >= n);
lex->buffer_count -= n;
memmove(lex->buffer, lex->buffer + n, lex->buffer_count);

View File

@ -1,5 +1,3 @@
#include "ast.h"
#include "encoder/encoder.h"
#include "error.h"
#include "lexer.h"
#include "parser/parser.h"
@ -10,13 +8,7 @@
#include <stdlib.h>
#include <string.h>
typedef enum mode {
MODE_INVALID = -1,
MODE_AST,
MODE_TEXT,
MODE_TOKENS,
MODE_ENCODING,
} mode_t;
typedef enum mode { MODE_AST, MODE_TEXT, MODE_TOKENS } mode_t;
void print_tokens(tokenlist_t *list) {
for (auto entry = list->head; entry; entry = entry->next) {
@ -58,62 +50,18 @@ error_t *print_ast(tokenlist_t *list) {
return nullptr;
}
void print_hex(size_t len, uint8_t bytes[static len]) {
for (size_t i = 0; i < len; i++) {
printf("%02x", bytes[i]);
if (i < len - 1) {
printf(" ");
}
}
printf("\n");
}
error_t *print_encoding(tokenlist_t *list) {
parse_result_t result = parse(list->head);
if (result.err)
return result.err;
encoder_t *encoder;
error_t *err = encoder_alloc(&encoder, result.node);
if (err)
goto cleanup_ast;
err = encoder_encode(encoder);
if (err)
goto cleanup_ast;
ast_node_t *root = result.node;
for (size_t i = 0; i < root->len; ++i) {
ast_node_t *node = root->children[i];
if (node->id != NODE_INSTRUCTION)
continue;
print_hex(node->value.instruction.encoding.len,
node->value.instruction.encoding.buffer);
}
encoder_free(encoder);
ast_node_free(result.node);
return nullptr;
cleanup_ast:
ast_node_free(result.node);
return err;
}
int get_execution_mode(int argc, char *argv[]) {
if (argc != 3)
return MODE_INVALID;
if (argc != 3 || (strcmp(argv[1], "tokens") != 0 &&
strcmp(argv[1], "text") != 0 && strcmp(argv[1], "ast"))) {
puts("Usage: oas [tokens|text|ast] <filename>");
exit(1);
}
if (strcmp(argv[1], "tokens") == 0)
return MODE_TOKENS;
if (strcmp(argv[1], "text") == 0)
return MODE_TEXT;
if (strcmp(argv[1], "ast") == 0)
return MODE_AST;
if (strcmp(argv[1], "encoding") == 0)
return MODE_ENCODING;
return MODE_INVALID;
return MODE_AST;
}
error_t *do_action(mode_t mode, tokenlist_t *list) {
@ -126,20 +74,12 @@ error_t *do_action(mode_t mode, tokenlist_t *list) {
return nullptr;
case MODE_AST:
return print_ast(list);
case MODE_ENCODING:
return print_encoding(list);
case MODE_INVALID:
/* can't happen */
}
__builtin_unreachable();
}
int main(int argc, char *argv[]) {
mode_t mode = get_execution_mode(argc, argv);
if (mode == MODE_INVALID) {
puts("Usage: oas [tokens|text|ast|encoding] <filename>");
exit(1);
}
char *filename = argv[2];
lexer_t *lex = &(lexer_t){};

View File

@ -1,5 +1,4 @@
#include "combinators.h"
#include "util.h"
// Parse a list of the given parser delimited by the given token id. Does not
// store the delimiters in the parent node
@ -123,12 +122,5 @@ parse_result_t parse_consecutive(tokenlist_entry_t *current, node_id_t id,
}
current = result.next;
}
// token stream ended before we matched all parsers
if (parser != nullptr) {
ast_node_free(all);
return parse_no_match();
}
return parse_success(all, current);
}

View File

@ -89,8 +89,7 @@ parse_result_t parse_immediate(tokenlist_entry_t *current) {
}
parse_result_t parse_memory_expression(tokenlist_entry_t *current) {
parser_t parsers[] = {parse_register_expression, parse_label_reference,
nullptr};
parser_t parsers[] = {parse_register_expression, parse_identifier, nullptr};
return parse_any(current, parsers);
}
@ -137,28 +136,21 @@ parse_result_t parse_directive_options(tokenlist_entry_t *current) {
}
parse_result_t parse_directive(tokenlist_entry_t *current) {
parser_t parsers[] = {parse_dot, parse_directive_options, parse_newline,
nullptr};
parser_t parsers[] = {parse_dot, parse_directive_options, nullptr};
return parse_consecutive(current, NODE_DIRECTIVE, parsers);
}
parse_result_t parse_instruction(tokenlist_entry_t *current) {
parser_t parsers[] = {parse_identifier, parse_operands, parse_newline,
nullptr};
parser_t parsers[] = {parse_identifier, parse_operands, nullptr};
return parse_consecutive(current, NODE_INSTRUCTION, parsers);
}
parse_result_t parse_statement(tokenlist_entry_t *current) {
parser_t parsers[] = {parse_label, parse_directive, parse_instruction,
parse_newline, nullptr};
nullptr};
return parse_any(current, parsers);
}
parse_result_t parse(tokenlist_entry_t *current) {
current = tokenlist_skip_trivia(current);
parse_result_t result =
parse_many(current, NODE_PROGRAM, true, parse_statement);
if (result.node != nullptr)
ast_node_prune(result.node, NODE_NEWLINE);
return result;
return parse_many(current, NODE_PROGRAM, true, parse_statement);
}

View File

@ -1,6 +1,5 @@
#include "primitives.h"
#include "../ast.h"
#include "../data/registers.h"
#include <string.h>
parse_result_t parse_identifier(tokenlist_entry_t *current) {
@ -63,18 +62,28 @@ parse_result_t parse_dot(tokenlist_entry_t *current) {
return parse_token(current, TOKEN_DOT, NODE_DOT, nullptr);
}
parse_result_t parse_newline(tokenlist_entry_t *current) {
return parse_token(current, TOKEN_NEWLINE, NODE_NEWLINE, nullptr);
}
parse_result_t parse_label_reference(tokenlist_entry_t *current) {
return parse_token(current, TOKEN_IDENTIFIER, NODE_LABEL_REFERENCE,
nullptr);
}
const char *registers[] = {
// 64-bit registers
"rax", "rcx", "rdx", "rbx", "rsp", "rbp", "rsi", "rdi", "r8", "r9", "r10",
"r11", "r12", "r13", "r14", "r15",
// 32-bit registers
"eax", "ecx", "edx", "ebx", "esp", "ebp", "esi", "edi", "r8d", "r9d",
"r10d", "r11d", "r12d", "r13d", "r14d", "r15d",
// 16-bit registers
"ax", "cx", "dx", "bx", "sp", "bp", "si", "di", "r8w", "r9w", "r10w",
"r11w", "r12w", "r13w", "r14w", "r15w",
// 8-bit low registers
"al", "cl", "dl", "bl", "spl", "bpl", "sil", "dil", "r8b", "r9b", "r10b",
"r11b", "r12b", "r13b", "r14b", "r15b", nullptr};
bool is_register_token(lexer_token_t *token) {
for (size_t i = 0; registers[i] != nullptr; ++i)
if (strcmp(token->value, registers[i]->name) == 0)
if (strcmp(token->value, registers[i]) == 0)
return true;
return false;
}

View File

@ -18,7 +18,6 @@ parse_result_t parse_plus(tokenlist_entry_t *current);
parse_result_t parse_minus(tokenlist_entry_t *current);
parse_result_t parse_asterisk(tokenlist_entry_t *current);
parse_result_t parse_dot(tokenlist_entry_t *current);
parse_result_t parse_newline(tokenlist_entry_t *current);
parse_result_t parse_label_reference(tokenlist_entry_t *current);
/* These are "primitives" with a different name and some extra validation on top

View File

@ -86,6 +86,7 @@ bool is_trivia(tokenlist_entry_t *trivia) {
switch (trivia->token.id) {
case TOKEN_WHITESPACE:
case TOKEN_COMMENT:
case TOKEN_NEWLINE:
return true;
default:
return false;

View File

@ -1,164 +0,0 @@
#include "../src/bytes.h"
#include "munit.h"
MunitResult test_bytes_initializer(const MunitParameter params[], void *data) {
(void)params;
(void)data;
bytes_t *bytes = LOCAL_BYTES(16);
munit_assert_size(bytes->len, ==, 0);
munit_assert_size(bytes->cap, ==, 16);
for (size_t i = 0; i < 16; ++i)
munit_assert_uint8(bytes->buffer[i], ==, 0);
return MUNIT_OK;
}
MunitResult test_bytes_append_uint8(const MunitParameter params[], void *data) {
(void)params;
(void)data;
bytes_t *bytes = LOCAL_BYTES(16);
munit_assert_size(bytes->len, ==, 0);
munit_assert_size(bytes->cap, ==, 16);
for (size_t i = 0; i < 16; ++i) {
error_t *err = bytes_append_uint8(bytes, (uint8_t)i);
munit_assert_null(err);
munit_assert_uint8(bytes->buffer[i], ==, (uint8_t)i);
}
error_t *err = bytes_append_uint8(bytes, 0xFF);
munit_assert_ptr(err, ==, err_bytes_no_capacity);
return MUNIT_OK;
}
MunitResult test_bytes_append_array(const MunitParameter params[], void *data) {
(void)params;
(void)data;
bytes_t *bytes = LOCAL_BYTES(16);
munit_assert_size(bytes->len, ==, 0);
munit_assert_size(bytes->cap, ==, 16);
uint8_t test_array[] = {0x01, 0x02, 0x03, 0x04, 0x05};
size_t array_len = sizeof(test_array) / sizeof(test_array[0]);
error_t *err = bytes_append_array(bytes, array_len, test_array);
munit_assert_null(err);
munit_assert_size(bytes->len, ==, array_len);
for (size_t i = 0; i < array_len; ++i) {
munit_assert_uint8(bytes->buffer[i], ==, test_array[i]);
}
uint8_t second_array[] = {0x06, 0x07, 0x08};
size_t second_len = sizeof(second_array) / sizeof(second_array[0]);
err = bytes_append_array(bytes, second_len, second_array);
munit_assert_null(err);
munit_assert_size(bytes->len, ==, array_len + second_len);
for (size_t i = 0; i < second_len; ++i) {
munit_assert_uint8(bytes->buffer[array_len + i], ==, second_array[i]);
}
uint8_t overflow_array[10] = {0}; // Array that would exceed capacity
err = bytes_append_array(bytes, sizeof(overflow_array), overflow_array);
munit_assert_ptr(err, ==, err_bytes_no_capacity);
munit_assert_size(bytes->len, ==, array_len + second_len);
return MUNIT_OK;
}
MunitResult test_bytes_append_bytes(const MunitParameter params[], void *data) {
(void)params;
(void)data;
bytes_t *src = LOCAL_BYTES(8);
bytes_t *dst = LOCAL_BYTES(16);
// Fill source bytes with test data
for (uint8_t i = 0; i < 5; ++i) {
error_t *err = bytes_append_uint8(src, i + 1);
munit_assert_null(err);
}
munit_assert_size(src->len, ==, 5);
// Append source to destination
error_t *err = bytes_append_bytes(dst, src);
munit_assert_null(err);
munit_assert_size(dst->len, ==, src->len);
// Verify destination contents match source
for (size_t i = 0; i < src->len; ++i) {
munit_assert_uint8(dst->buffer[i], ==, src->buffer[i]);
}
// Fill source with more data and append again
for (uint8_t i = 0; i < 3; ++i) {
err = bytes_append_uint8(src, i + 6);
munit_assert_null(err);
}
munit_assert_size(src->len, ==, 8);
// Append updated source
err = bytes_append_bytes(dst, src);
munit_assert_null(err);
munit_assert_size(dst->len, ==, 13); // 5 + 8
// Test capacity boundary
src->len = 4; // manually set length to barely not fit
err = bytes_append_bytes(dst, src);
munit_assert_ptr(err, ==, err_bytes_no_capacity);
munit_assert_size(dst->len, ==, 13); // Length unchanged after error
return MUNIT_OK;
}
MunitResult test_bytes_append_uint16(const MunitParameter params[], void *data) {
bytes_t *bytes = LOCAL_BYTES(16);
munit_assert_size(bytes->len, ==, 0);
munit_assert_size(bytes->cap, ==, 16);
bytes_append_uint16(bytes, 0xFFAA);
munit_assert_size(bytes->len, ==, 2);
munit_assert_uint8(bytes->buffer[0], ==, 0xAA);
munit_assert_uint8(bytes->buffer[1], ==, 0xFF);
return MUNIT_OK;
}
MunitResult test_bytes_append_uint32(const MunitParameter params[], void *data) {
bytes_t *bytes = LOCAL_BYTES(16);
munit_assert_size(bytes->len, ==, 0);
munit_assert_size(bytes->cap, ==, 16);
bytes_append_uint32(bytes, 0xAABBCCDD);
munit_assert_size(bytes->len, ==, 4);
munit_assert_uint8(bytes->buffer[0], ==, 0xDD);
munit_assert_uint8(bytes->buffer[1], ==, 0xCC);
munit_assert_uint8(bytes->buffer[2], ==, 0xBB);
munit_assert_uint8(bytes->buffer[3], ==, 0xAA);
return MUNIT_OK;
}
MunitResult test_bytes_append_uint64(const MunitParameter params[], void *data) {
bytes_t *bytes = LOCAL_BYTES(16);
munit_assert_size(bytes->len, ==, 0);
munit_assert_size(bytes->cap, ==, 16);
bytes_append_uint64(bytes, 0xAABBCCDDEEFF9988);
munit_assert_size(bytes->len, ==, 8);
munit_assert_uint8(bytes->buffer[0], ==, 0x88);
munit_assert_uint8(bytes->buffer[1], ==, 0x99);
munit_assert_uint8(bytes->buffer[2], ==, 0xFF);
munit_assert_uint8(bytes->buffer[3], ==, 0xEE);
munit_assert_uint8(bytes->buffer[4], ==, 0xDD);
munit_assert_uint8(bytes->buffer[5], ==, 0xCC);
munit_assert_uint8(bytes->buffer[6], ==, 0xBB);
munit_assert_uint8(bytes->buffer[7], ==, 0xAA);
return MUNIT_OK;
}
MunitTest bytes_tests[] = {
{"/initializer", test_bytes_initializer, nullptr, nullptr, MUNIT_TEST_OPTION_NONE, nullptr},
{"/append_uint8", test_bytes_append_uint8, nullptr, nullptr, MUNIT_TEST_OPTION_NONE, nullptr},
{"/append_array", test_bytes_append_array, nullptr, nullptr, MUNIT_TEST_OPTION_NONE, nullptr},
{"/append_bytes", test_bytes_append_bytes, nullptr, nullptr, MUNIT_TEST_OPTION_NONE, nullptr},
{"/append_uint16", test_bytes_append_uint16, nullptr, nullptr, MUNIT_TEST_OPTION_NONE, nullptr},
{"/append_uint32", test_bytes_append_uint32, nullptr, nullptr, MUNIT_TEST_OPTION_NONE, nullptr},
{"/append_uint64", test_bytes_append_uint64, nullptr, nullptr, MUNIT_TEST_OPTION_NONE, nullptr},
{nullptr, nullptr, nullptr, nullptr, MUNIT_TEST_OPTION_NONE, nullptr}
};

View File

@ -1,5 +0,0 @@
; regression test for two issues:
; - parsing two zero operand instructions in a row
; - a zero operand instruction just before eof
syscall
ret

View File

@ -1,5 +0,0 @@
; sample program with trivia on the head of the tokenlist
_start:
xor rax, rax
call exit

View File

@ -2,18 +2,14 @@
extern MunitTest ast_tests[];
extern MunitTest lexer_tests[];
extern MunitTest regression_tests[];
extern MunitTest symbols_tests[];
extern MunitTest bytes_tests[];
int main(int argc, char *argv[MUNIT_ARRAY_PARAM(argc + 1)]) {
MunitSuite suites[] = {
{"/regression", regression_tests, nullptr, 1, MUNIT_SUITE_OPTION_NONE},
{"/ast", ast_tests, nullptr, 1, MUNIT_SUITE_OPTION_NONE},
{"/lexer", lexer_tests, nullptr, 1, MUNIT_SUITE_OPTION_NONE},
{"/symbols", symbols_tests, nullptr, 1, MUNIT_SUITE_OPTION_NONE},
{"/bytes", bytes_tests, nullptr, 1, MUNIT_SUITE_OPTION_NONE},
{nullptr, nullptr, nullptr, 0, MUNIT_SUITE_OPTION_NONE},
{"/ast", ast_tests, nullptr, 1, MUNIT_SUITE_OPTION_NONE},
{"/lexer", lexer_tests, nullptr, 1, MUNIT_SUITE_OPTION_NONE},
{"/symbols", symbols_tests, nullptr, 1, MUNIT_SUITE_OPTION_NONE},
{nullptr, nullptr, nullptr, 0, MUNIT_SUITE_OPTION_NONE},
};
MunitSuite master_suite = {"/oas", nullptr, suites, 1, MUNIT_SUITE_OPTION_NONE};

View File

@ -1,68 +0,0 @@
#include "../src/ast.h"
#include "../src/parser/parser.h"
#include "munit.h"
MunitResult test_regression_trivia_head(const MunitParameter params[], void *data) {
(void)params;
(void)data;
lexer_t *lex = &(lexer_t){};
error_t *err = lexer_open(lex, "tests/input/regression/test_trivia_head.asm");
munit_assert_null(err);
tokenlist_t *list;
err = tokenlist_alloc(&list);
munit_assert_null(err);
err = tokenlist_fill(list, lex);
munit_assert_null(err);
parse_result_t result = parse(list->head);
munit_assert_null(result.err);
munit_assert_null(result.next);
ast_node_free(result.node);
tokenlist_free(list);
return MUNIT_OK;
}
MunitResult test_no_operands_eof(const MunitParameter params[], void *data) {
(void)params;
(void)data;
lexer_t *lex = &(lexer_t){};
error_t *err = lexer_open(lex, "tests/input/regression/test_no_operands_eof.asm");
munit_assert_null(err);
tokenlist_t *list;
err = tokenlist_alloc(&list);
munit_assert_null(err);
err = tokenlist_fill(list, lex);
munit_assert_null(err);
parse_result_t result = parse(list->head);
munit_assert_null(result.err);
munit_assert_null(result.next);
// Both children should be instructions
munit_assert_size(result.node->len, ==, 2);
munit_assert_int(result.node->children[0]->id, ==, NODE_INSTRUCTION);
munit_assert_int(result.node->children[1]->id, ==, NODE_INSTRUCTION);
// And they should have empty operands
munit_assert_size(result.node->children[0]->len, ==, 2);
munit_assert_size(result.node->children[1]->len, ==, 2);
munit_assert_size(result.node->children[0]->children[1]->len, ==, 0);
munit_assert_size(result.node->children[1]->children[1]->len, ==, 0);
ast_node_free(result.node);
tokenlist_free(list);
return MUNIT_OK;
}
MunitTest regression_tests[] = {
{"/trivia_head", test_regression_trivia_head, nullptr, nullptr, MUNIT_TEST_OPTION_NONE, nullptr},
{"/no_operands_eof", test_no_operands_eof, nullptr, nullptr, MUNIT_TEST_OPTION_NONE, nullptr},
{nullptr, nullptr, nullptr, nullptr, MUNIT_TEST_OPTION_NONE, nullptr}
};

View File

@ -58,19 +58,17 @@ MunitResult test_symbol_add_reference(const MunitParameter params[], void *data)
symbol_table_alloc(&table);
ast_node_t *reference = root->children[3]->children[1]->children[0]->children[0];
ast_node_t *statement = root->children[3]; // The containing statement
munit_assert_int(reference->id, ==, NODE_LABEL_REFERENCE);
munit_assert_size(table->len, ==, 0);
error_t *err = symbol_table_update(table, reference, statement);
error_t *err = symbol_table_update(table, reference);
munit_assert_null(err);
munit_assert_size(table->len, ==, 1);
symbol_t *symbol = symbol_table_lookup(table, "test");
munit_assert_not_null(symbol);
munit_assert_int(SYMBOL_REFERENCE, ==, symbol->kind);
// For references, the statement should be nullptr
munit_assert_ptr_null(symbol->statement);
munit_assert_ptr_equal(reference, symbol->node);
munit_assert_string_equal(symbol->name, "test");
symbol_table_free(table);
@ -92,14 +90,14 @@ MunitResult test_symbol_add_label(const MunitParameter params[], void *data) {
munit_assert_int(label->id, ==, NODE_LABEL);
munit_assert_size(table->len, ==, 0);
error_t *err = symbol_table_update(table, label, label);
error_t *err = symbol_table_update(table, label);
munit_assert_null(err);
munit_assert_size(table->len, ==, 1);
symbol_t *symbol = symbol_table_lookup(table, "test");
munit_assert_not_null(symbol);
munit_assert_int(SYMBOL_LOCAL, ==, symbol->kind);
munit_assert_ptr_equal(label, symbol->statement);
munit_assert_ptr_equal(label, symbol->node);
munit_assert_string_equal(symbol->name, "test");
symbol_table_free(table);
@ -118,19 +116,17 @@ MunitResult test_symbol_add_import(const MunitParameter params[], void *data) {
symbol_table_alloc(&table);
ast_node_t *import_directive = root->children[0]->children[1];
ast_node_t *statement = root->children[0]; // The containing statement
munit_assert_int(import_directive->id, ==, NODE_IMPORT_DIRECTIVE);
munit_assert_size(table->len, ==, 0);
error_t *err = symbol_table_update(table, import_directive, statement);
error_t *err = symbol_table_update(table, import_directive);
munit_assert_null(err);
munit_assert_size(table->len, ==, 1);
symbol_t *symbol = symbol_table_lookup(table, "test");
munit_assert_not_null(symbol);
munit_assert_int(SYMBOL_IMPORT, ==, symbol->kind);
// For import directives, the statement should be nullptr
munit_assert_ptr_null(symbol->statement);
munit_assert_ptr_equal(import_directive, symbol->node);
munit_assert_string_equal(symbol->name, "test");
symbol_table_free(table);
@ -139,56 +135,42 @@ MunitResult test_symbol_add_import(const MunitParameter params[], void *data) {
return MUNIT_OK;
}
void test_symbol_update(const char *name, ast_node_t *first, symbol_kind_t first_kind, ast_node_t *first_statement,
ast_node_t *second, symbol_kind_t second_kind, ast_node_t *second_statement,
bool should_succeed, bool should_update, ast_node_t *expected_statement) {
void test_symbol_update(const char *name, ast_node_t *first, symbol_kind_t first_kind, ast_node_t *second,
symbol_kind_t second_kind, bool should_succeed, bool should_update) {
symbol_table_t *table = nullptr;
symbol_table_alloc(&table);
// Add the first symbol
error_t *err = symbol_table_update(table, first, first_statement);
munit_assert_size(table->len, ==, 0);
error_t *err = symbol_table_update(table, first);
munit_assert_null(err);
munit_assert_size(table->len, ==, 1);
// Verify first symbol state
symbol_t *symbol = symbol_table_lookup(table, name);
munit_assert_not_null(symbol);
munit_assert_int(first_kind, ==, symbol->kind);
munit_assert_ptr_equal(first, symbol->node);
munit_assert_string_equal(symbol->name, name);
// Check statement based on symbol kind
if (first_kind == SYMBOL_LOCAL) {
munit_assert_ptr_equal(first_statement, symbol->statement);
} else {
munit_assert_ptr_null(symbol->statement);
}
// Attempt the second update
err = symbol_table_update(table, second, second_statement);
// Check if update succeeded as expected
if (should_succeed) {
err = symbol_table_update(table, second);
if (should_succeed)
munit_assert_null(err);
} else {
else
munit_assert_ptr_equal(err, err_symbol_table_incompatible_symbols);
symbol_table_free(table);
return;
}
munit_assert_size(table->len, ==, 1);
// Verify symbol after second update
symbol = symbol_table_lookup(table, name);
munit_assert_not_null(symbol);
// Check if kind updated as expected
if (should_update) {
munit_assert_not_null(symbol);
munit_assert_int(second_kind, ==, symbol->kind);
munit_assert_ptr_equal(second, symbol->node);
munit_assert_string_equal(symbol->name, name);
} else {
munit_assert_not_null(symbol);
munit_assert_int(first_kind, ==, symbol->kind);
munit_assert_ptr_equal(first, symbol->node);
munit_assert_string_equal(symbol->name, name);
}
// Simply check against the expected statement value
munit_assert_ptr_equal(expected_statement, symbol->statement);
symbol_table_free(table);
}
@ -199,43 +181,28 @@ MunitResult test_symbol_upgrade_valid(const MunitParameter params[], void *data)
symbols_setup_test(&root, &list, "tests/input/symbols.asm");
ast_node_t *reference = root->children[3]->children[1]->children[0]->children[0];
ast_node_t *reference_statement = root->children[3];
ast_node_t *label = root->children[2];
ast_node_t *import_directive = root->children[0]->children[1];
ast_node_t *import_statement = root->children[0];
ast_node_t *export_directive = root->children[1]->children[1];
ast_node_t *export_statement = root->children[1];
// real upgrades
test_symbol_update("test", reference, SYMBOL_REFERENCE, reference_statement, label, SYMBOL_LOCAL, label, true, true,
label);
test_symbol_update("test", reference, SYMBOL_REFERENCE, reference_statement, import_directive, SYMBOL_IMPORT,
import_statement, true, true, nullptr);
test_symbol_update("test", reference, SYMBOL_REFERENCE, reference_statement, export_directive, SYMBOL_EXPORT,
export_statement, true, true, nullptr);
test_symbol_update("test", label, SYMBOL_LOCAL, label, export_directive, SYMBOL_EXPORT, export_statement, true,
true, label);
test_symbol_update("test", reference, SYMBOL_REFERENCE, label, SYMBOL_LOCAL, true, true);
test_symbol_update("test", reference, SYMBOL_REFERENCE, import_directive, SYMBOL_IMPORT, true, true);
test_symbol_update("test", reference, SYMBOL_REFERENCE, export_directive, SYMBOL_EXPORT, true, true);
test_symbol_update("test", label, SYMBOL_LOCAL, export_directive, SYMBOL_EXPORT, true, true);
// identity upgrades
test_symbol_update("test", reference, SYMBOL_REFERENCE, reference_statement, reference, SYMBOL_REFERENCE,
reference_statement, true, false, nullptr);
test_symbol_update("test", label, SYMBOL_LOCAL, label, label, SYMBOL_LOCAL, label, true, false, label);
test_symbol_update("test", import_directive, SYMBOL_IMPORT, import_statement, import_directive, SYMBOL_IMPORT,
import_statement, true, false, nullptr);
test_symbol_update("test", export_directive, SYMBOL_EXPORT, export_statement, export_directive, SYMBOL_EXPORT,
export_statement, true, false, nullptr);
test_symbol_update("test", reference, SYMBOL_REFERENCE, reference, SYMBOL_REFERENCE, true, false);
test_symbol_update("test", label, SYMBOL_LOCAL, label, SYMBOL_LOCAL, true, false);
test_symbol_update("test", import_directive, SYMBOL_IMPORT, import_directive, SYMBOL_IMPORT, true, false);
test_symbol_update("test", export_directive, SYMBOL_EXPORT, export_directive, SYMBOL_EXPORT, true, false);
// downgrades that are allowed and ignored
test_symbol_update("test", label, SYMBOL_LOCAL, label, reference, SYMBOL_REFERENCE, reference_statement, true,
false, label);
test_symbol_update("test", import_directive, SYMBOL_IMPORT, import_statement, reference, SYMBOL_REFERENCE,
reference_statement, true, false, nullptr);
test_symbol_update("test", export_directive, SYMBOL_EXPORT, export_statement, reference, SYMBOL_REFERENCE,
reference_statement, true, false, nullptr);
test_symbol_update("test", export_directive, SYMBOL_EXPORT, export_statement, label, SYMBOL_LOCAL, label, true,
false, label);
test_symbol_update("test", import_directive, SYMBOL_IMPORT, import_statement, label, SYMBOL_LOCAL, label, true,
false, label);
test_symbol_update("test", label, SYMBOL_LOCAL, reference, SYMBOL_REFERENCE, true, false);
test_symbol_update("test", import_directive, SYMBOL_IMPORT, reference, SYMBOL_REFERENCE, true, false);
test_symbol_update("test", export_directive, SYMBOL_EXPORT, reference, SYMBOL_REFERENCE, true, false);
test_symbol_update("test", export_directive, SYMBOL_EXPORT, label, SYMBOL_LOCAL, true, false);
test_symbol_update("test", import_directive, SYMBOL_IMPORT, label, SYMBOL_LOCAL, true, false);
ast_node_free(root);
tokenlist_free(list);
@ -249,20 +216,14 @@ MunitResult test_symbol_upgrade_invalid(const MunitParameter params[], void *dat
symbols_setup_test(&root, &list, "tests/input/symbols.asm");
ast_node_t *reference = root->children[3]->children[1]->children[0]->children[0];
ast_node_t *reference_statement = root->children[3];
ast_node_t *label = root->children[2];
ast_node_t *import_directive = root->children[0]->children[1];
ast_node_t *import_statement = root->children[0];
ast_node_t *export_directive = root->children[1]->children[1];
ast_node_t *export_statement = root->children[1];
// invalid upgrades
test_symbol_update("test", label, SYMBOL_LOCAL, label, import_directive, SYMBOL_IMPORT, import_statement, false,
false, nullptr);
test_symbol_update("test", export_directive, SYMBOL_EXPORT, export_statement, import_directive, SYMBOL_IMPORT,
import_statement, false, false, nullptr);
test_symbol_update("test", import_directive, SYMBOL_IMPORT, import_statement, export_directive, SYMBOL_EXPORT,
export_statement, false, false, nullptr);
test_symbol_update("test", label, SYMBOL_LOCAL, import_directive, SYMBOL_IMPORT, false, false);
test_symbol_update("test", export_directive, SYMBOL_EXPORT, import_directive, SYMBOL_IMPORT, false, false);
test_symbol_update("test", import_directive, SYMBOL_IMPORT, export_directive, SYMBOL_EXPORT, false, false);
ast_node_free(root);
tokenlist_free(list);
@ -279,19 +240,17 @@ MunitResult test_symbol_add_export(const MunitParameter params[], void *data) {
symbol_table_alloc(&table);
ast_node_t *export_directive = root->children[1]->children[1];
ast_node_t *statement = root->children[1]; // The containing statement
munit_assert_int(export_directive->id, ==, NODE_EXPORT_DIRECTIVE);
munit_assert_size(table->len, ==, 0);
error_t *err = symbol_table_update(table, export_directive, statement);
error_t *err = symbol_table_update(table, export_directive);
munit_assert_null(err);
munit_assert_size(table->len, ==, 1);
symbol_t *symbol = symbol_table_lookup(table, "test");
munit_assert_not_null(symbol);
munit_assert_int(SYMBOL_EXPORT, ==, symbol->kind);
// For export directives, the statement should be nullptr
munit_assert_ptr_null(symbol->statement);
munit_assert_ptr_equal(export_directive, symbol->node);
munit_assert_string_equal(symbol->name, "test");
symbol_table_free(table);
@ -321,7 +280,7 @@ MunitResult test_symbol_table_growth(const MunitParameter params[], void *data)
ast_node_t *label = root->children[i];
munit_assert_int(label->id, ==, NODE_LABEL);
error_t *err = symbol_table_update(table, label, label);
error_t *err = symbol_table_update(table, label);
munit_assert_null(err);
munit_assert_size(table->len, ==, i + 1);
@ -333,7 +292,7 @@ MunitResult test_symbol_table_growth(const MunitParameter params[], void *data)
ast_node_t *final_label = root->children[64];
munit_assert_int(final_label->id, ==, NODE_LABEL);
error_t *err = symbol_table_update(table, final_label, final_label);
error_t *err = symbol_table_update(table, final_label);
munit_assert_null(err);
munit_assert_size(table->len, ==, 65);
@ -349,7 +308,6 @@ MunitResult test_symbol_table_growth(const MunitParameter params[], void *data)
munit_assert_not_null(symbol);
munit_assert_int(SYMBOL_LOCAL, ==, symbol->kind);
munit_assert_string_equal(symbol->name, name);
munit_assert_ptr_equal(symbol->statement, root->children[i]);
}
symbol_table_free(table);
@ -368,7 +326,7 @@ MunitResult test_symbol_invalid_node(const MunitParameter params[], void *data)
symbol_table_alloc(&table);
munit_assert_size(table->len, ==, 0);
error_t *err = symbol_table_update(table, root, root);
error_t *err = symbol_table_update(table, root);
munit_assert_ptr_equal(err, err_symbol_table_invalid_node);
munit_assert_size(table->len, ==, 0);