Implement part of codegen, add asmjit and zydis (#1050)

Implement part of codegen and fix some frontend issues
Add asmjit to emit native code and add zydis to disassemble native code
Can successfully run some simple cases
This commit is contained in:
Wenyong Huang 2022-03-22 12:22:04 +08:00 committed by GitHub
parent 0f2885cd66
commit f7b6cd75c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 1787 additions and 209 deletions

File diff suppressed because it is too large Load Diff

View File

@ -118,6 +118,52 @@ fail:
return false;
}
static bool
load_block_results(JitCompContext *cc, JitBlock *block)
{
JitFrame *jit_frame = cc->jit_frame;
uint32 offset, i;
JitReg value = 0;
/* Restore jit frame's sp to block's sp begin */
jit_frame->sp = block->frame_sp_begin;
/* Load results to new block */
offset = (uint32)(jit_frame->sp - jit_frame->lp);
for (i = 0; i < block->result_count; i++) {
switch (block->result_types[i]) {
case VALUE_TYPE_I32:
#if WASM_ENABLE_REF_TYPES != 0
case VALUE_TYPE_EXTERNREF:
case VALUE_TYPE_FUNCREF:
#endif
value = gen_load_i32(jit_frame, offset);
offset++;
break;
case VALUE_TYPE_I64:
value = gen_load_i64(jit_frame, offset);
offset += 2;
break;
case VALUE_TYPE_F32:
value = gen_load_f32(jit_frame, offset);
offset++;
break;
case VALUE_TYPE_F64:
value = gen_load_f64(jit_frame, offset);
offset += 2;
break;
default:
bh_assert(0);
break;
}
PUSH(value, block->result_types[i]);
}
return true;
fail:
return false;
}
static bool
push_jit_block_to_stack_and_pass_params(JitCompContext *cc, JitBlock *block,
JitBasicBlock *basic_block, JitReg cond)
@ -133,7 +179,6 @@ push_jit_block_to_stack_and_pass_params(JitCompContext *cc, JitBlock *block,
we just move param values from current block's value stack to
the new block's value stack */
for (i = 0; i < block->param_count; i++) {
param_index = block->param_count - 1 - i;
jit_value = jit_value_stack_pop(
&cc->block_stack.block_list_end->value_stack);
if (!value_list_head) {
@ -296,8 +341,26 @@ handle_func_return(JitCompContext *cc, JitBlock *block)
NEW_CONST(I32, offsetof(WASMInterpFrame, sp)));
#endif
copy_block_arities(cc, prev_frame_sp, block->result_types,
block->result_count);
if (block->result_count) {
uint32 cell_num =
wasm_get_cell_num(block->result_types, block->result_count);
copy_block_arities(cc, prev_frame_sp, block->result_types,
block->result_count);
#if UINTPTR_MAX == UINT64_MAX
/* prev_frame->sp += cell_num */
GEN_INSN(ADD, prev_frame_sp, prev_frame_sp,
NEW_CONST(I64, cell_num * 4));
GEN_INSN(STI64, prev_frame_sp, prev_frame,
NEW_CONST(I32, offsetof(WASMInterpFrame, sp)));
#else
/* prev_frame->sp += cell_num */
GEN_INSN(ADD, prev_frame_sp, prev_frame_sp,
NEW_CONST(I32, cell_num * 4));
GEN_INSN(STI32, prev_frame_sp, prev_frame,
NEW_CONST(I32, offsetof(WASMInterpFrame, sp)));
#endif
}
/* Free stack space of the current frame:
exec_env->wasm_stack.s.top = cur_frame */
@ -320,14 +383,14 @@ handle_func_return(JitCompContext *cc, JitBlock *block)
/* fp_reg = prev_frame */
GEN_INSN(MOV, cc->fp_reg, prev_frame);
/* return 0 */
GEN_INSN(RETURNBC, NEW_CONST(I32, 0));
GEN_INSN(RETURNBC, NEW_CONST(I32, JIT_INTERP_ACTION_NORMAL), 0, 0);
}
static bool
handle_op_end(JitCompContext *cc, uint8 **p_frame_ip)
handle_op_end(JitCompContext *cc, uint8 **p_frame_ip, bool from_same_block)
{
JitFrame *jit_frame = cc->jit_frame;
JitBlock *block;
JitBlock *block, *block_prev;
JitIncomingInsn *incoming_insn;
JitInsn *insn;
@ -345,6 +408,42 @@ handle_op_end(JitCompContext *cc, uint8 **p_frame_ip)
handle_func_return(cc, block);
SET_BB_END_BCIP(cc->cur_basic_block, *p_frame_ip - 1);
}
else if (block->result_count > 0) {
JitValue *value_list_head = NULL, *value_list_end = NULL;
JitValue *jit_value;
uint32 i;
/* No need to change cc->jit_frame, just move result values
from current block's value stack to previous block's
value stack */
block_prev = block->prev;
for (i = 0; i < block->result_count; i++) {
jit_value = jit_value_stack_pop(&block->value_stack);
bh_assert(jit_value);
if (!value_list_head) {
value_list_head = value_list_end = jit_value;
jit_value->prev = jit_value->next = NULL;
}
else {
jit_value->prev = NULL;
jit_value->next = value_list_head;
value_list_head->prev = jit_value;
value_list_head = jit_value;
}
}
if (!block_prev->value_stack.value_list_head) {
block_prev->value_stack.value_list_head = value_list_head;
block_prev->value_stack.value_list_end = value_list_end;
}
else {
/* Link to the end of previous block's value stack */
block_prev->value_stack.value_list_end->next = value_list_head;
value_list_head->prev = block_prev->value_stack.value_list_end;
block_prev->value_stack.value_list_end = value_list_end;
}
}
/* Pop block and destroy the block */
block = jit_block_stack_pop(&cc->block_stack);
@ -361,8 +460,9 @@ handle_op_end(JitCompContext *cc, uint8 **p_frame_ip)
CREATE_BASIC_BLOCK(block->basic_block_end);
SET_BB_END_BCIP(cc->cur_basic_block, *p_frame_ip - 1);
SET_BB_BEGIN_BCIP(block->basic_block_end, *p_frame_ip);
/* Jump to the end basic block */
BUILD_BR(block->basic_block_end);
if (from_same_block)
/* Jump to the end basic block */
BUILD_BR(block->basic_block_end);
/* Patch the INSNs which jump to this basic block */
incoming_insn = block->incoming_insns_for_end_bb;
@ -384,13 +484,20 @@ handle_op_end(JitCompContext *cc, uint8 **p_frame_ip)
SET_BUILDER_POS(block->basic_block_end);
/* Pop block and load block results */
block = jit_block_stack_pop(&cc->block_stack);
if (block->label_type == LABEL_TYPE_FUNCTION) {
handle_func_return(cc, block);
SET_BB_END_BCIP(cc->cur_basic_block, *p_frame_ip - 1);
}
else {
if (!load_block_results(cc, block)) {
jit_block_destroy(block);
goto fail;
}
}
/* Pop block and destroy the block */
block = jit_block_stack_pop(&cc->block_stack);
jit_block_destroy(block);
return true;
}
@ -420,7 +527,7 @@ handle_op_else(JitCompContext *cc, uint8 **p_frame_ip)
/* The if branch is handled like OP_BLOCK (cond is const and != 0),
just skip the else branch and handle OP_END */
*p_frame_ip = block->wasm_code_end + 1;
return handle_op_end(cc, p_frame_ip);
return handle_op_end(cc, p_frame_ip, true);
}
else {
/* Has else branch and need to translate else branch */
@ -488,7 +595,7 @@ handle_next_reachable_block(JitCompContext *cc, uint8 **p_frame_ip)
}
else if (block->incoming_insns_for_end_bb) {
*p_frame_ip = block->wasm_code_end + 1;
return handle_op_end(cc, p_frame_ip);
return handle_op_end(cc, p_frame_ip, false);
}
else {
jit_block_stack_pop(&cc->block_stack);
@ -635,7 +742,7 @@ jit_compile_op_else(JitCompContext *cc, uint8 **p_frame_ip)
bool
jit_compile_op_end(JitCompContext *cc, uint8 **p_frame_ip)
{
return handle_op_end(cc, p_frame_ip);
return handle_op_end(cc, p_frame_ip, true);
}
#if 0
@ -716,7 +823,7 @@ bool
jit_compile_op_br(JitCompContext *cc, uint32 br_depth, uint8 **p_frame_ip)
{
JitFrame *jit_frame;
JitBlock *block_dst;
JitBlock *block_dst, *block;
JitReg frame_sp_dst;
JitValueSlot *frame_sp_src = NULL;
JitInsn *insn;
@ -733,6 +840,12 @@ jit_compile_op_br(JitCompContext *cc, uint32 br_depth, uint8 **p_frame_ip)
#endif
#endif
/* Check block stack */
if (!(block = cc->block_stack.block_list_end)) {
jit_set_last_error(cc, "WASM block stack underflow");
return false;
}
if (!(block_dst = get_target_block(cc, br_depth))) {
return false;
}
@ -761,15 +874,24 @@ jit_compile_op_br(JitCompContext *cc, uint32 br_depth, uint8 **p_frame_ip)
#endif
offset = offsetof(WASMInterpFrame, lp)
+ (block_dst->frame_sp_begin - jit_frame->lp) * 4;
#if UINTPTR_MAX == UINT64_MAX
GEN_INSN(ADD, frame_sp_dst, cc->fp_reg, NEW_CONST(I64, offset));
#else
GEN_INSN(ADD, frame_sp_dst, cc->fp_reg, NEW_CONST(I32, offset));
#endif
}
gen_commit_values(jit_frame, jit_frame->lp, block->frame_sp_begin);
if (block_dst->label_type == LABEL_TYPE_LOOP) {
if (copy_arities) {
/* Dest block is Loop block, copy loop parameters */
copy_block_arities(cc, frame_sp_dst, block_dst->param_types,
block_dst->param_count);
}
clear_values(jit_frame);
/* Jump to the begin basic block */
BUILD_BR(block_dst->basic_block_entry);
SET_BB_END_BCIP(cc->cur_basic_block, *p_frame_ip - 1);
@ -780,6 +902,9 @@ jit_compile_op_br(JitCompContext *cc, uint32 br_depth, uint8 **p_frame_ip)
copy_block_arities(cc, frame_sp_dst, block_dst->result_types,
block_dst->result_count);
}
clear_values(jit_frame);
/* Jump to the end basic block */
if (!(insn = GEN_INSN(JMP, 0))) {
jit_set_last_error(cc, "generate jmp insn failed");
@ -863,7 +988,11 @@ jit_compile_op_br_if(JitCompContext *cc, uint32 br_depth, uint8 **p_frame_ip)
#endif
offset = offsetof(WASMInterpFrame, lp)
+ (block_dst->frame_sp_begin - jit_frame->lp) * 4;
#if UINTPTR_MAX == UINT64_MAX
GEN_INSN(ADD, frame_sp_dst, cc->fp_reg, NEW_CONST(I64, offset));
#else
GEN_INSN(ADD, frame_sp_dst, cc->fp_reg, NEW_CONST(I32, offset));
#endif
}
if (block_dst->label_type == LABEL_TYPE_LOOP) {

View File

@ -56,6 +56,8 @@ jit_code_cache_free(void *ptr)
bool
jit_pass_register_jitted_code(JitCompContext *cc)
{
/* TODO */
return false;
cc->cur_wasm_func->fast_jit_jitted_code = cc->jitted_addr_begin;
cc->cur_wasm_module->fast_jit_func_ptrs[cc->cur_wasm_func_idx] =
cc->jitted_addr_begin;
return true;
}

View File

@ -15,12 +15,8 @@ jit_pass_lower_cg(JitCompContext *cc)
bool
jit_pass_codegen(JitCompContext *cc)
{
#if 0
bh_assert(jit_annl_is_enabled_next_label(cc));
if (!jit_annl_enable_jitted_addr(cc))
return false;
#endif
return jit_codegen_gen_native(cc);
}

View File

@ -50,7 +50,8 @@ static JitGlobals jit_globals = {
#else
.passes = compiler_passes_with_dump,
#endif
.code_cache_size = 10 * 1024 * 1024
.code_cache_size = 10 * 1024 * 1024,
.return_to_interp_from_jitted = NULL
};
/* clang-format on */
@ -99,7 +100,7 @@ jit_compiler_destroy()
jit_code_cache_destroy();
}
const JitGlobals *
JitGlobals *
jit_compiler_get_jit_globals()
{
return &jit_globals;
@ -153,7 +154,7 @@ jit_compiler_compile_all(WASMModule *module)
{
JitCompContext *cc;
char *last_error;
bool ret = false;
bool ret = true;
uint32 i;
/* Initialize compilation context. */

View File

@ -19,15 +19,50 @@ typedef struct JitGlobals {
const uint8 *passes;
/* Code cache size. */
uint32 code_cache_size;
char *return_to_interp_from_jitted;
} JitGlobals;
/**
* Actions the interpreter should do when JITed code returns to
* interpreter.
*/
typedef enum JitInterpAction {
JIT_INTERP_ACTION_NORMAL, /* normal execution */
JIT_INTERP_ACTION_THROWN, /* exception was thrown */
JIT_INTERP_ACTION_CALL /* call wasm function */
} JitInterpAction;
/**
* Information exchanged between JITed code and interpreter.
*/
typedef struct JitInterpSwitchInfo {
/* Points to the frame that is passed to JITed code and the frame
that is returned from JITed code. */
that is returned from JITed code */
void *frame;
/* Output values from JITed code of different actions */
union {
/* IP and SP offsets for NORMAL */
struct {
int32 ip;
int32 sp;
} normal;
/* Function called from JITed code for CALL */
struct {
void *function;
} call;
/* Returned integer and/or floating point values for RETURN. This
is also used to pass return values from interpreter to JITed
code if the caller is in JITed code and the callee is in
interpreter. */
struct {
uint32 ival[2];
uint32 fval[2];
uint32 last_return_type;
} ret;
} out;
} JitInterpSwitchInfo;
bool
@ -36,7 +71,7 @@ jit_compiler_init();
void
jit_compiler_destroy();
const JitGlobals *
JitGlobals *
jit_compiler_get_jit_globals();
const char *

View File

@ -146,11 +146,13 @@ jit_dump_insn(JitCompContext *cc, JitInsn *insn)
void
jit_dump_basic_block(JitCompContext *cc, JitBasicBlock *block)
{
unsigned i;
unsigned i, label_index;
void *begin_addr, *end_addr;
JitBasicBlock *block_next;
JitInsn *insn;
JitRegVec preds = jit_basic_block_preds(block);
JitRegVec succs = jit_basic_block_succs(block);
JitReg label = jit_basic_block_label(block);
JitReg label = jit_basic_block_label(block), label_next;
JitReg *reg;
jit_dump_reg(cc, label);
@ -176,16 +178,33 @@ jit_dump_basic_block(JitCompContext *cc, JitBasicBlock *block)
- (uint8 *)cc->cur_wasm_module->load_addr);
os_printf("\n");
if (jit_annl_is_enabled_jitted_addr(cc))
/* Dump assembly. */
jit_codegen_dump_native(
*(jit_annl_jitted_addr(cc, label)),
label != cc->exit_label
? *(jit_annl_jitted_addr(cc, *(jit_annl_next_label(cc, label))))
: cc->jitted_addr_end);
else
if (jit_annl_is_enabled_jitted_addr(cc)) {
begin_addr = *(jit_annl_jitted_addr(cc, label));
if (label == cc->entry_label) {
block_next = cc->_ann._label_basic_block[2];
label_next = jit_basic_block_label(block_next);
end_addr = *(jit_annl_jitted_addr(cc, label_next));
}
else if (label == cc->exit_label) {
end_addr = cc->jitted_addr_end;
}
else {
label_index = jit_reg_no(label);
if (label_index < jit_cc_label_num(cc) - 1)
block_next = cc->_ann._label_basic_block[label_index + 1];
else
block_next = cc->_ann._label_basic_block[1];
label_next = jit_basic_block_label(block_next);
end_addr = *(jit_annl_jitted_addr(cc, label_next));
}
jit_codegen_dump_native(begin_addr, end_addr);
}
else {
/* Dump IR. */
JIT_FOREACH_INSN(block, insn) jit_dump_insn(cc, insn);
}
os_printf(" ; SUCCS(");
@ -279,18 +298,17 @@ dump_cc_ir(JitCompContext *cc)
os_printf("\n\n");
if (jit_annl_is_enabled_next_label(cc))
if (jit_annl_is_enabled_next_label(cc)) {
/* Blocks have been reordered, use that order to dump. */
for (label = cc->entry_label; label;
label = *(jit_annl_next_label(cc, label)))
jit_dump_basic_block(cc, *(jit_annl_basic_block(cc, label)));
else
/* Otherwise, use the default order. */
{
}
else {
/* Otherwise, use the default order. */
jit_dump_basic_block(cc, jit_cc_entry_basic_block(cc));
JIT_FOREACH_BLOCK(cc, i, end, block)
jit_dump_basic_block(cc, block);
JIT_FOREACH_BLOCK(cc, i, end, block) jit_dump_basic_block(cc, block);
jit_dump_basic_block(cc, jit_cc_exit_basic_block(cc));
}

View File

@ -244,7 +244,7 @@ form_and_translate_func(JitCompContext *cc)
if (insn) {
*(jit_insn_opndv(insn, 2)) = NEW_CONST(I32, i);
}
GEN_INSN(RETURNBC, NEW_CONST(I32, i));
GEN_INSN(RETURN, NEW_CONST(I32, JIT_INTERP_ACTION_THROWN));
*(jit_annl_begin_bcip(cc,
jit_basic_block_label(cc->cur_basic_block))) =
@ -351,7 +351,7 @@ init_func_translation(JitCompContext *cc)
GEN_INSN(LDI64, top_boundary, cc->exec_env_reg,
NEW_CONST(I32, offsetof(WASMExecEnv, wasm_stack.s.top_boundary)));
/* frame_boundary = top + frame_size + outs_size */
GEN_INSN(ADD, frame_boundary, top, NEW_CONST(I32, frame_size + outs_size));
GEN_INSN(ADD, frame_boundary, top, NEW_CONST(I64, frame_size + outs_size));
/* if frame_boundary > top_boundary, throw stack overflow exception */
GEN_INSN(CMP, cc->cmp_reg, frame_boundary, top_boundary);
if (!jit_emit_exception(cc, EXCE_OPERAND_STACK_OVERFLOW, JIT_OP_BGTU,
@ -361,13 +361,13 @@ init_func_translation(JitCompContext *cc)
/* Add first and then sub to reduce one used register */
/* new_top = frame_boundary - outs_size = top + frame_size */
GEN_INSN(SUB, new_top, frame_boundary, NEW_CONST(I32, outs_size));
GEN_INSN(SUB, new_top, frame_boundary, NEW_CONST(I64, outs_size));
/* exec_env->wasm_stack.s.top = new_top */
GEN_INSN(STI64, new_top, cc->exec_env_reg,
NEW_CONST(I32, offsetof(WASMExecEnv, wasm_stack.s.top)));
/* frame_sp = frame->lp + local_size */
GEN_INSN(ADD, frame_sp, top,
NEW_CONST(I32, offsetof(WASMInterpFrame, lp) + local_size));
NEW_CONST(I64, offsetof(WASMInterpFrame, lp) + local_size));
/* frame->sp = frame_sp */
GEN_INSN(STI64, frame_sp, top,
NEW_CONST(I32, offsetof(WASMInterpFrame, sp)));

View File

@ -167,8 +167,9 @@ INSN(LOOKUPSWITCH, LookupSwitch, 1, 0)
/* Call and return instructions */
INSN(CALLNATIVE, VReg, 2, 1)
INSN(CALLBC, Reg, 3, 0)
INSN(RETURNBC, Reg, 1, 0)
INSN(CALLBC, Reg, 3, 2)
INSN(RETURNBC, Reg, 3, 0)
INSN(RETURN, Reg, 1, 0)
#if 0
/* Comparison instructions, can be translate to SELECTXXX */

View File

@ -255,7 +255,7 @@ struct WASMFunction {
uint32 const_cell_num;
#endif
#if WASM_ENABLE_FAST_JIT != 0
void *jitted_code;
void *fast_jit_jitted_code;
#endif
};
@ -447,7 +447,7 @@ struct WASMModule {
#if WASM_ENABLE_FAST_JIT != 0
/* point to JITed functions */
void **func_ptrs;
void **fast_jit_func_ptrs;
#endif
};

View File

@ -3767,10 +3767,13 @@ wasm_interp_call_wasm(WASMModuleInstance *module_inst, WASMExecEnv *exec_env,
#if WASM_ENABLE_FAST_JIT == 0
wasm_interp_call_func_bytecode(module_inst, exec_env, function, frame);
#else
JitGlobals *jit_globals = jit_compiler_get_jit_globals();
JitInterpSwitchInfo info;
info.frame = frame;
frame->jitted_return_addr =
(uint8 *)jit_globals->return_to_interp_from_jitted;
jit_interp_switch_to_jitted(exec_env, &info,
function->u.func->jitted_code);
function->u.func->fast_jit_jitted_code);
(void)wasm_interp_call_func_bytecode;
#endif
}

View File

@ -3230,6 +3230,11 @@ load_from_sections(WASMModule *module, WASMSection *sections,
}
#if WASM_ENABLE_FAST_JIT != 0
if (!(module->fast_jit_func_ptrs =
loader_malloc(sizeof(void *) * module->function_count, error_buf,
error_buf_size))) {
return false;
}
if (!jit_compiler_compile_all(module)) {
set_error_buf(error_buf, error_buf_size, "fast jit compilation failed");
return false;
@ -3719,6 +3724,7 @@ wasm_loader_unload(WASMModule *module)
}
}
#endif
#if WASM_ENABLE_DEBUG_INTERP != 0
WASMFastOPCodeNode *fast_opcode =
bh_list_first_elem(&module->fast_opcode_list);
@ -3729,6 +3735,12 @@ wasm_loader_unload(WASMModule *module)
}
os_mutex_destroy(&module->ref_count_lock);
#endif
#if WASM_ENABLE_FAST_JIT != 0
if (module->fast_jit_func_ptrs)
wasm_runtime_free(module->fast_jit_func_ptrs);
#endif
wasm_runtime_free(module);
}

View File

@ -120,6 +120,8 @@ set (CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -Wl,--gc-sections -pie -f
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wall -Wextra -Wformat -Wformat-security -Wshadow")
# set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wconversion -Wsign-conversion")
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -Wformat -Wformat-security")
if (WAMR_BUILD_TARGET MATCHES "X86_.*" OR WAMR_BUILD_TARGET STREQUAL "AMD_64")
if (NOT (CMAKE_C_COMPILER MATCHES ".*clang.*" OR CMAKE_C_COMPILER_ID MATCHES ".*Clang"))
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mindirect-branch-register")