From cf7b01ad82140815754ebcddfaee520b0d43371f Mon Sep 17 00:00:00 2001 From: Wenyong Huang Date: Mon, 21 Nov 2022 10:42:18 +0800 Subject: [PATCH] Implement call Fast JIT function from LLVM JIT jitted code (#1714) Basically implement the Multi-tier JIT engine. And update document and wamr-test-suites script. --- build-scripts/config_common.cmake | 18 +- core/iwasm/common/wasm_runtime_common.c | 3 + core/iwasm/compilation/aot.h | 4 + core/iwasm/compilation/aot_emit_function.c | 48 +- .../fast-jit/cg/x86-64/jit_codegen_x86_64.cpp | 923 +++++++++++++++++- core/iwasm/fast-jit/fe/jit_emit_function.c | 4 +- core/iwasm/fast-jit/jit_codecache.c | 3 +- core/iwasm/fast-jit/jit_codegen.h | 11 +- core/iwasm/fast-jit/jit_compiler.c | 147 ++- core/iwasm/fast-jit/jit_compiler.h | 21 +- core/iwasm/fast-jit/jit_frontend.c | 6 + core/iwasm/fast-jit/jit_frontend.h | 11 + core/iwasm/fast-jit/jit_ir.def | 2 +- core/iwasm/interpreter/wasm.h | 71 +- core/iwasm/interpreter/wasm_interp_classic.c | 94 +- core/iwasm/interpreter/wasm_loader.c | 502 +++++++--- core/iwasm/interpreter/wasm_runtime.c | 79 +- core/iwasm/interpreter/wasm_runtime.h | 9 + doc/build_wamr.md | 10 +- tests/wamr-test-suites/test_wamr.sh | 30 +- 20 files changed, 1775 insertions(+), 221 deletions(-) diff --git a/build-scripts/config_common.cmake b/build-scripts/config_common.cmake index 1f340fc7e..b4ae5c668 100644 --- a/build-scripts/config_common.cmake +++ b/build-scripts/config_common.cmake @@ -83,6 +83,13 @@ if (NOT WAMR_BUILD_AOT EQUAL 1) endif () endif () +if (WAMR_BUILD_FAST_JIT EQUAL 1) + if (NOT WAMR_BUILD_LAZY_JIT EQUAL 0) + # Enable Lazy JIT by default + set (WAMR_BUILD_LAZY_JIT 1) + endif () +endif () + if (WAMR_BUILD_JIT EQUAL 1) if (NOT WAMR_BUILD_LAZY_JIT EQUAL 0) # Enable Lazy JIT by default @@ -134,7 +141,12 @@ else () message (" WAMR AOT disabled") endif () if (WAMR_BUILD_FAST_JIT EQUAL 1) - message (" WAMR Fast JIT enabled") + if (WAMR_BUILD_LAZY_JIT EQUAL 1) + add_definitions("-DWASM_ENABLE_LAZY_JIT=1") + message (" WAMR Fast JIT enabled with Lazy Compilation") + else () + message (" WAMR Fast JIT enabled with Eager Compilation") + endif () else () message (" WAMR Fast JIT disabled") endif () @@ -149,6 +161,10 @@ if (WAMR_BUILD_JIT EQUAL 1) else () message (" WAMR LLVM ORC JIT disabled") endif () +if (WAMR_BUILD_FAST_JIT EQUAL 1 AND WAMR_BUILD_JIT EQUAL 1 + AND WAMR_BUILD_LAZY_JIT EQUAL 1) + message (" Multi-tier JIT enabled") +endif () if (WAMR_BUILD_LIBC_BUILTIN EQUAL 1) message (" Libc builtin enabled") else () diff --git a/core/iwasm/common/wasm_runtime_common.c b/core/iwasm/common/wasm_runtime_common.c index 0cbaaf5c4..b5ab5eeae 100644 --- a/core/iwasm/common/wasm_runtime_common.c +++ b/core/iwasm/common/wasm_runtime_common.c @@ -2142,7 +2142,10 @@ static const char *exception_msgs[] = { "wasm auxiliary stack underflow", /* EXCE_AUX_STACK_UNDERFLOW */ "out of bounds table access", /* EXCE_OUT_OF_BOUNDS_TABLE_ACCESS */ "wasm operand stack overflow", /* EXCE_OPERAND_STACK_OVERFLOW */ +#if WASM_ENABLE_FAST_JIT != 0 + "failed to compile fast jit function", /* EXCE_FAILED_TO_COMPILE_FAST_JIT_FUNC */ "", /* EXCE_ALREADY_THROWN */ +#endif }; /* clang-format on */ diff --git a/core/iwasm/compilation/aot.h b/core/iwasm/compilation/aot.h index 8e8598c7b..c7989851c 100644 --- a/core/iwasm/compilation/aot.h +++ b/core/iwasm/compilation/aot.h @@ -309,6 +309,8 @@ aot_get_imp_tbl_data_slots(const AOTImportTable *tbl, bool is_jit_mode) #if WASM_ENABLE_MULTI_MODULE != 0 if (is_jit_mode) return tbl->table_max_size; +#else + (void)is_jit_mode; #endif return tbl->possible_grow ? tbl->table_max_size : tbl->table_init_size; } @@ -319,6 +321,8 @@ aot_get_tbl_data_slots(const AOTTable *tbl, bool is_jit_mode) #if WASM_ENABLE_MULTI_MODULE != 0 if (is_jit_mode) return tbl->table_max_size; +#else + (void)is_jit_mode; #endif return tbl->possible_grow ? tbl->table_max_size : tbl->table_init_size; } diff --git a/core/iwasm/compilation/aot_emit_function.c b/core/iwasm/compilation/aot_emit_function.c index cbcdebffc..c59fadc95 100644 --- a/core/iwasm/compilation/aot_emit_function.c +++ b/core/iwasm/compilation/aot_emit_function.c @@ -789,7 +789,53 @@ aot_compile_op_call(AOTCompContext *comp_ctx, AOTFuncContext *func_ctx, func = func_ctx->func; } else { - func = func_ctxes[func_idx - import_func_count]->func; + if (!comp_ctx->is_jit_mode) { + func = func_ctxes[func_idx - import_func_count]->func; + } + else { +#if !(WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_LAZY_JIT != 0) + func = func_ctxes[func_idx - import_func_count]->func; +#else + /* JIT tier-up, load func ptr from func_ptrs[func_idx] */ + LLVMValueRef func_ptr, func_idx_const; + LLVMTypeRef func_ptr_type; + + if (!(func_idx_const = I32_CONST(func_idx))) { + aot_set_last_error("llvm build const failed."); + goto fail; + } + + if (!(func_ptr = LLVMBuildInBoundsGEP2( + comp_ctx->builder, OPQ_PTR_TYPE, + func_ctx->func_ptrs, &func_idx_const, 1, + "func_ptr_tmp"))) { + aot_set_last_error("llvm build inbounds gep failed."); + goto fail; + } + + if (!(func_ptr = + LLVMBuildLoad2(comp_ctx->builder, OPQ_PTR_TYPE, + func_ptr, "func_ptr"))) { + aot_set_last_error("llvm build load failed."); + goto fail; + } + + if (!(func_ptr_type = LLVMPointerType( + func_ctxes[func_idx - import_func_count] + ->func_type, + 0))) { + aot_set_last_error("construct func ptr type failed."); + goto fail; + } + + if (!(func = LLVMBuildBitCast(comp_ctx->builder, func_ptr, + func_ptr_type, + "indirect_func"))) { + aot_set_last_error("llvm build bit cast failed."); + goto fail; + } +#endif /* end of !(WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_LAZY_JIT != 0) */ + } } } diff --git a/core/iwasm/fast-jit/cg/x86-64/jit_codegen_x86_64.cpp b/core/iwasm/fast-jit/cg/x86-64/jit_codegen_x86_64.cpp index d10be6f32..48d1486d5 100644 --- a/core/iwasm/fast-jit/cg/x86-64/jit_codegen_x86_64.cpp +++ b/core/iwasm/fast-jit/cg/x86-64/jit_codegen_x86_64.cpp @@ -6,6 +6,7 @@ #include "jit_codegen.h" #include "jit_codecache.h" #include "jit_compiler.h" +#include "jit_frontend.h" #include "jit_dump.h" #include @@ -21,6 +22,9 @@ using namespace asmjit; static char *code_block_switch_to_jitted_from_interp = NULL; static char *code_block_return_to_interp_from_jitted = NULL; +#if WASM_ENABLE_LAZY_JIT != 0 +static char *code_block_compile_fast_jit_and_then_call = NULL; +#endif typedef enum { REG_EBP_IDX = 0, @@ -107,16 +111,17 @@ x86::Xmm regs_float[] = { int jit_codegen_interp_jitted_glue(void *exec_env, JitInterpSwitchInfo *info, - void *target) + uint32 func_idx, void *target) { - typedef int32 (*F)(const void *exec_env, void *info, const void *target); + typedef int32 (*F)(const void *exec_env, void *info, uint32 func_idx, + const void *target); union { F f; void *v; } u; u.v = code_block_switch_to_jitted_from_interp; - return u.f(exec_env, info, target); + return u.f(exec_env, info, func_idx, target); } #define PRINT_LINE() LOG_VERBOSE("\n", __LINE__) @@ -5870,6 +5875,7 @@ lower_callbc(JitCompContext *cc, x86::Assembler &a, bh_list *jmp_info_list, JitReg xmm0_f64_hreg = jit_reg_new(JIT_REG_KIND_F64, 0); JitReg ret_reg = *(jit_insn_opnd(insn, 0)); JitReg func_reg = *(jit_insn_opnd(insn, 2)); + JitReg func_idx = *(jit_insn_opnd(insn, 3)); JitReg src_reg; int32 func_reg_no; @@ -5880,6 +5886,15 @@ lower_callbc(JitCompContext *cc, x86::Assembler &a, bh_list *jmp_info_list, func_reg_no = jit_reg_no(func_reg); CHECK_I64_REG_NO(func_reg_no); + CHECK_KIND(func_idx, JIT_REG_KIND_I32); + if (jit_reg_is_const(func_idx)) { + imm.setValue(jit_cc_get_const_I32(cc, func_idx)); + a.mov(regs_i64[REG_RDX_IDX], imm); + } + else { + a.movzx(regs_i64[REG_RDX_IDX], regs_i32[jit_reg_no(func_idx)]); + } + node = (JmpInfo *)jit_malloc(sizeof(JmpInfo)); if (!node) GOTO_FAIL; @@ -6762,6 +6777,761 @@ fail: return return_value; } +#if WASM_ENABLE_LAZY_JIT != 0 && WASM_ENABLE_JIT != 0 + +#define MAX_REG_INTS 6 +#define MAX_REG_FLOATS 8 + +void * +jit_codegen_compile_call_to_llvm_jit(const WASMType *func_type) +{ + const JitHardRegInfo *hreg_info = jit_codegen_get_hreg_info(); + x86::Gp reg_lp = x86::r10, reg_res = x86::r12; + x86::Gp reg_tmp_i64 = x86::r11, reg_tmp_i32 = x86::r11d; + /* the index of integer argument registers */ + uint8 reg_idx_of_int_args[] = { REG_RDI_IDX, REG_RSI_IDX, REG_RDX_IDX, + REG_RCX_IDX, REG_R8_IDX, REG_R9_IDX }; + uint32 n_ints = 0, n_fps = 0, n_stacks = 0, n_pushed; + uint32 int_reg_idx = 0, fp_reg_idx = 0, stack_arg_idx = 0; + uint32 off_to_lp = 0, off_to_res = 0, code_size, i; + uint32 param_count = func_type->param_count; + uint32 result_count = func_type->result_count; + uint32 ext_result_count; + char *code_buf, *stream; + Imm imm; + + JitErrorHandler err_handler; + Environment env(Arch::kX64); + CodeHolder code; + code.init(env); + code.setErrorHandler(&err_handler); + x86::Assembler a(&code); + + /* Load the llvm jit function pointer */ + { + /* r11 = exec_env->module_inst */ + x86::Mem m1(regs_i64[hreg_info->exec_env_hreg_index], + (uint32)offsetof(WASMExecEnv, module_inst)); + a.mov(reg_tmp_i64, m1); + /* r11 = module_inst->func_ptrs */ + x86::Mem m2(reg_tmp_i64, + (uint32)offsetof(WASMModuleInstance, func_ptrs)); + a.mov(reg_tmp_i64, m2); + /* rax = func_ptrs[func_idx] */ + x86::Mem m3(reg_tmp_i64, x86::rdx, 3, 0); + a.mov(x86::rax, m3); + } + + n_ints++; /* exec_env */ + + for (i = 0; i < param_count; i++) { + switch (func_type->types[i]) { + case VALUE_TYPE_I32: + case VALUE_TYPE_I64: +#if WASM_ENABLE_REF_TYPES != 0 + case VALUE_TYPE_FUNCREF: + case VALUE_TYPE_EXTERNREF: +#endif + if (n_ints < MAX_REG_INTS) + n_ints++; + else + n_stacks++; + break; + case VALUE_TYPE_F32: + case VALUE_TYPE_F64: + if (n_fps < MAX_REG_FLOATS) + n_fps++; + else + n_stacks++; + break; + } + } + + ext_result_count = result_count > 1 ? result_count - 1 : 0; + + if (ext_result_count > 0) { + if (n_ints + ext_result_count <= MAX_REG_INTS) { + /* extra result pointers can be stored into int registers */ + n_ints += ext_result_count; + } + else { + /* part or all extra result pointers must be stored into stack */ + n_stacks += n_ints + ext_result_count - MAX_REG_INTS; + n_ints = MAX_REG_INTS; + } + } + + n_pushed = n_stacks; + if (n_stacks & 1) { + /* Align stack on 16 bytes */ + n_pushed++; + } + if (n_pushed > 0) { + imm.setValue(n_pushed * 8); + a.sub(x86::rsp, imm); + } + + /* r10 = outs_area->lp */ + { + x86::Mem m(regs_i64[hreg_info->exec_env_hreg_index], + (uint32)offsetof(WASMExecEnv, wasm_stack.s.top)); + a.mov(reg_lp, m); + a.add(reg_lp, (uint32)offsetof(WASMInterpFrame, lp)); + } + + /* rdi = exec_env */ + a.mov(regs_i64[reg_idx_of_int_args[int_reg_idx++]], + regs_i64[hreg_info->exec_env_hreg_index]); + + for (i = 0; i < param_count; i++) { + x86::Mem m_src(reg_lp, off_to_lp); + + switch (func_type->types[i]) { + case VALUE_TYPE_I32: +#if WASM_ENABLE_REF_TYPES != 0 + case VALUE_TYPE_FUNCREF: + case VALUE_TYPE_EXTERNREF: +#endif + { + if (int_reg_idx < MAX_REG_INTS) { + a.mov(regs_i32[reg_idx_of_int_args[int_reg_idx]], m_src); + int_reg_idx++; + } + else { + a.mov(reg_tmp_i32, m_src); + x86::Mem m_dst(x86::rsp, stack_arg_idx * 8); + a.mov(m_dst, reg_tmp_i32); + stack_arg_idx++; + } + off_to_lp += 4; + break; + } + case VALUE_TYPE_I64: + { + if (int_reg_idx < MAX_REG_INTS) { + a.mov(regs_i64[reg_idx_of_int_args[int_reg_idx]], m_src); + int_reg_idx++; + } + else { + a.mov(reg_tmp_i64, m_src); + x86::Mem m_dst(x86::rsp, stack_arg_idx * 8); + a.mov(m_dst, reg_tmp_i64); + stack_arg_idx++; + } + off_to_lp += 8; + break; + } + case VALUE_TYPE_F32: + { + if (fp_reg_idx < MAX_REG_FLOATS) { + a.movss(regs_float[fp_reg_idx], m_src); + fp_reg_idx++; + } + else { + a.mov(reg_tmp_i32, m_src); + x86::Mem m_dst(x86::rsp, stack_arg_idx * 8); + a.mov(m_dst, reg_tmp_i32); + stack_arg_idx++; + } + off_to_lp += 4; + break; + } + case VALUE_TYPE_F64: + { + if (fp_reg_idx < MAX_REG_FLOATS) { + a.movsd(regs_float[fp_reg_idx], m_src); + fp_reg_idx++; + } + else { + a.mov(reg_tmp_i64, m_src); + x86::Mem m_dst(x86::rsp, stack_arg_idx * 8); + a.mov(m_dst, reg_tmp_i64); + stack_arg_idx++; + } + off_to_lp += 8; + break; + } + } + } + + if (result_count > 0) { + switch (func_type->types[param_count]) { + case VALUE_TYPE_I32: +#if WASM_ENABLE_REF_TYPES != 0 + case VALUE_TYPE_FUNCREF: + case VALUE_TYPE_EXTERNREF: +#endif + case VALUE_TYPE_F32: + off_to_res = 4; + break; + case VALUE_TYPE_I64: + case VALUE_TYPE_F64: + off_to_res = 8; + break; + } + + /* r12 = cur_frame->sp */ + x86::Mem m(x86::rbp, (uint32)offsetof(WASMInterpFrame, sp)); + a.mov(reg_res, m); + + for (i = 0; i < ext_result_count; i++) { + x86::Mem m(reg_res, off_to_res); + + if (int_reg_idx < MAX_REG_INTS) { + a.lea(regs_i64[reg_idx_of_int_args[int_reg_idx]], m); + int_reg_idx++; + } + else { + a.lea(reg_tmp_i64, m); + x86::Mem m_dst(x86::rsp, stack_arg_idx * 8); + a.mov(m_dst, reg_tmp_i64); + stack_arg_idx++; + } + + switch (func_type->types[param_count + 1 + i]) { + case VALUE_TYPE_I32: +#if WASM_ENABLE_REF_TYPES != 0 + case VALUE_TYPE_FUNCREF: + case VALUE_TYPE_EXTERNREF: +#endif + case VALUE_TYPE_F32: + off_to_res += 4; + break; + case VALUE_TYPE_I64: + case VALUE_TYPE_F64: + off_to_res += 8; + break; + } + } + } + + bh_assert(int_reg_idx == n_ints); + bh_assert(fp_reg_idx == n_fps); + bh_assert(stack_arg_idx == n_stacks); + + /* Call the llvm jit function */ + a.call(x86::rax); + + /* Check if there was exception thrown */ + { + /* r11 = exec_env->module_inst */ + x86::Mem m1(regs_i64[hreg_info->exec_env_hreg_index], + (uint32)offsetof(WASMExecEnv, module_inst)); + a.mov(reg_tmp_i64, m1); + /* module_inst->cur_exception */ + x86::Mem m2(reg_tmp_i64, + (uint32)offsetof(WASMModuleInstance, cur_exception)); + /* bl = module_inst->cur_exception[0] */ + a.mov(x86::bl, m2); + + /* cur_exception[0] == 0 ? */ + Imm imm((uint8)0); + a.cmp(x86::bl, imm); + /* If yes, jump to `Get function result and return` */ + imm.setValue(INT32_MAX); + a.je(imm); + + char *stream = (char *)a.code()->sectionById(0)->buffer().data() + + a.code()->sectionById(0)->buffer().size(); + + /* If no, set eax to JIT_INTERP_ACTION_THROWN, and + jump to code_block_return_to_interp_from_jitted to + return to interpreter */ + imm.setValue(JIT_INTERP_ACTION_THROWN); + a.mov(x86::eax, imm); + imm.setValue(code_block_return_to_interp_from_jitted); + a.mov(x86::rsi, imm); + a.jmp(x86::rsi); + + char *stream_new = (char *)a.code()->sectionById(0)->buffer().data() + + a.code()->sectionById(0)->buffer().size(); + + *(int32 *)(stream - 4) = (uint32)(stream_new - stream); + } + + /* Get function result and return */ + + if (result_count > 0 && func_type->types[param_count] != VALUE_TYPE_F32 + && func_type->types[param_count] != VALUE_TYPE_F64) { + a.mov(x86::rdx, x86::rax); + } + + if (off_to_res > 0) { + imm.setValue(off_to_res); + a.add(reg_res, imm); + /* cur_frame->sp = r12 */ + x86::Mem m(x86::rbp, (uint32)offsetof(WASMInterpFrame, sp)); + a.mov(m, reg_res); + } + + if (n_pushed > 0) { + imm.setValue(n_pushed * 8); + a.add(x86::rsp, imm); + } + + /* Return to the caller */ + { + /* eax = action = JIT_INTERP_ACTION_NORMAL */ + Imm imm(0); + a.mov(x86::eax, imm); + + uint32 jitted_return_addr_offset = + jit_frontend_get_jitted_return_addr_offset(); + x86::Mem m(x86::rbp, jitted_return_addr_offset); + a.jmp(m); + } + + if (err_handler.err) + return NULL; + + code_buf = (char *)code.sectionById(0)->buffer().data(); + code_size = code.sectionById(0)->buffer().size(); + stream = (char *)jit_code_cache_alloc(code_size); + if (!stream) + return NULL; + + bh_memcpy_s(stream, code_size, code_buf, code_size); + +#if 0 + dump_native(stream, code_size); +#endif + + return stream; +} + +static WASMInterpFrame * +fast_jit_alloc_frame(WASMExecEnv *exec_env, uint32 param_cell_num, + uint32 ret_cell_num) +{ + WASMModuleInstance *module_inst = + (WASMModuleInstance *)exec_env->module_inst; + WASMInterpFrame *frame; + uint32 size_frame1 = wasm_interp_interp_frame_size(ret_cell_num); + uint32 size_frame2 = wasm_interp_interp_frame_size(param_cell_num); + + /** + * Check whether we can allocate two frames: the first is an implied + * frame to store the function results from jit function to call, + * the second is the frame for the jit function + */ + if ((uint8 *)exec_env->wasm_stack.s.top + size_frame1 + size_frame2 + > exec_env->wasm_stack.s.top_boundary) { + wasm_set_exception(module_inst, "wasm operand stack overflow"); + return NULL; + } + + /* Allocate the frame */ + frame = (WASMInterpFrame *)exec_env->wasm_stack.s.top; + exec_env->wasm_stack.s.top += size_frame1; + + frame->function = NULL; + frame->ip = NULL; + frame->sp = frame->lp; + frame->prev_frame = wasm_exec_env_get_cur_frame(exec_env); + frame->jitted_return_addr = + (uint8 *)code_block_return_to_interp_from_jitted; + + wasm_exec_env_set_cur_frame(exec_env, frame); + + return frame; +} + +void * +jit_codegen_compile_call_to_fast_jit(const WASMModule *module, uint32 func_idx) +{ + uint32 func_idx_non_import = func_idx - module->import_function_count; + WASMType *func_type = module->functions[func_idx_non_import]->func_type; + /* the index of integer argument registers */ + uint8 reg_idx_of_int_args[] = { REG_RDI_IDX, REG_RSI_IDX, REG_RDX_IDX, + REG_RCX_IDX, REG_R8_IDX, REG_R9_IDX }; + uint32 int_reg_idx, fp_reg_idx, stack_arg_idx; + uint32 switch_info_offset, exec_env_offset, stack_arg_offset; + uint32 int_reg_offset, frame_lp_offset; + uint32 switch_info_size, code_size, i; + uint32 param_count = func_type->param_count; + uint32 result_count = func_type->result_count; + uint32 ext_result_count = result_count > 1 ? result_count - 1 : 0; + uint32 param_cell_num = func_type->param_cell_num; + uint32 ret_cell_num = + func_type->ret_cell_num > 2 ? func_type->ret_cell_num : 2; + char *code_buf, *stream; + Imm imm; + + JitErrorHandler err_handler; + Environment env(Arch::kX64); + CodeHolder code; + code.init(env); + code.setErrorHandler(&err_handler); + x86::Assembler a(&code); + + /** + * Push JitInterpSwitchInfo and make stack 16-byte aligned: + * the size pushed must be odd multiples of 8, as the stack pointer + * %rsp must be aligned to a 16-byte boundary before making a call, + * and when a function (including this llvm jit function) gets + * control, the %rsp is not 16-byte aligned (call instruction will + * push the ret address to stack). + */ + switch_info_size = align_uint((uint32)sizeof(JitInterpSwitchInfo), 16) + 8; + imm.setValue((uint64)switch_info_size); + a.sub(x86::rsp, imm); + + /* Push all integer argument registers since we will use them as + temporarily registers to load/store data */ + for (i = 0; i < MAX_REG_INTS; i++) { + a.push(regs_i64[reg_idx_of_int_args[MAX_REG_INTS - 1 - i]]); + } + + /* We don't push float/double register since we don't use them here */ + + /** + * Layout of the stack now: + * stack arguments + * ret address of the caller + * switch info + * int registers: r9, r8, rcx, rdx, rsi + * exec_env: rdi + */ + + /* offset of the first stack argument to the stack pointer, + add 8 to skip the ret address of the caller */ + stack_arg_offset = switch_info_size + 8 * MAX_REG_INTS + 8; + /* offset of jit interp switch info to the stack pointer */ + switch_info_offset = 8 * MAX_REG_INTS; + /* offset of the first int register to the stack pointer */ + int_reg_offset = 8; + /* offset of exec_env to the stack pointer */ + exec_env_offset = 0; + + /* Call fast_jit_alloc_frame to allocate the stack frame to + receive the results of the fast jit function to call */ + + /* rdi = exec_env, has been already set as exec_env is + the first argument of LLVM JIT function */ + /* rsi = param_cell_num */ + imm.setValue(param_cell_num); + a.mov(x86::rsi, imm); + /* rdx = ret_cell_num */ + imm.setValue(ret_cell_num); + a.mov(x86::rdx, imm); + /* call fast_jit_alloc_frame */ + imm.setValue((uint64)(uintptr_t)fast_jit_alloc_frame); + a.mov(x86::rax, imm); + a.call(x86::rax); + + /* Check the return value, note now rax is the allocated frame */ + { + /* Did fast_jit_alloc_frame return NULL? */ + Imm imm((uint64)0); + a.cmp(x86::rax, imm); + /* If no, jump to `Copy arguments to frame lp area` */ + imm.setValue(INT32_MAX); + a.jne(imm); + + char *stream = (char *)a.code()->sectionById(0)->buffer().data() + + a.code()->sectionById(0)->buffer().size(); + + /* If yes, set eax to 0, return to caller */ + + /* Pop all integer arument registers */ + for (i = 0; i < MAX_REG_INTS; i++) { + a.pop(regs_i64[reg_idx_of_int_args[i]]); + } + /* Pop jit interp switch info */ + imm.setValue((uint64)switch_info_size); + a.add(x86::rsp, imm); + + /* Return to the caller, don't use leave as we didn't + `push rbp` and `mov rbp, rsp` */ + a.ret(); + + /* Patch the offset of jne instruction */ + char *stream_new = (char *)a.code()->sectionById(0)->buffer().data() + + a.code()->sectionById(0)->buffer().size(); + *(int32 *)(stream - 4) = (int32)(stream_new - stream); + } + + int_reg_idx = 1; /* skip exec_env */ + fp_reg_idx = 0; + stack_arg_idx = 0; + + /* Offset of the dest arguments to outs area */ + frame_lp_offset = wasm_interp_interp_frame_size(ret_cell_num) + + (uint32)offsetof(WASMInterpFrame, lp); + + /* Copy arguments to frame lp area */ + for (i = 0; i < func_type->param_count; i++) { + x86::Mem m_dst(x86::rax, frame_lp_offset); + switch (func_type->types[i]) { + case VALUE_TYPE_I32: +#if WASM_ENABLE_REF_TYPES != 0 + case VALUE_TYPE_FUNCREF: + case VALUE_TYPE_EXTERNREF: +#endif + if (int_reg_idx < MAX_REG_INTS) { + /* Copy i32 argument from int register */ + x86::Mem m_src(x86::rsp, int_reg_offset); + a.mov(x86::esi, m_src); + a.mov(m_dst, x86::esi); + int_reg_offset += 8; + int_reg_idx++; + } + else { + /* Copy i32 argument from stack */ + x86::Mem m_src(x86::rsp, stack_arg_offset); + a.mov(x86::esi, m_src); + a.mov(m_dst, x86::esi); + stack_arg_offset += 8; + stack_arg_idx++; + } + frame_lp_offset += 4; + break; + case VALUE_TYPE_I64: + if (int_reg_idx < MAX_REG_INTS) { + /* Copy i64 argument from int register */ + x86::Mem m_src(x86::rsp, int_reg_offset); + a.mov(x86::rsi, m_src); + a.mov(m_dst, x86::rsi); + int_reg_offset += 8; + int_reg_idx++; + } + else { + /* Copy i64 argument from stack */ + x86::Mem m_src(x86::rsp, stack_arg_offset); + a.mov(x86::rsi, m_src); + a.mov(m_dst, x86::rsi); + stack_arg_offset += 8; + stack_arg_idx++; + } + frame_lp_offset += 8; + break; + case VALUE_TYPE_F32: + if (fp_reg_idx < MAX_REG_FLOATS) { + /* Copy f32 argument from fp register */ + a.movss(m_dst, regs_float[fp_reg_idx++]); + } + else { + /* Copy f32 argument from stack */ + x86::Mem m_src(x86::rsp, stack_arg_offset); + a.mov(x86::esi, m_src); + a.mov(m_dst, x86::esi); + stack_arg_offset += 8; + stack_arg_idx++; + } + frame_lp_offset += 4; + break; + case VALUE_TYPE_F64: + if (fp_reg_idx < MAX_REG_FLOATS) { + /* Copy f64 argument from fp register */ + a.movsd(m_dst, regs_float[fp_reg_idx++]); + } + else { + /* Copy f64 argument from stack */ + x86::Mem m_src(x86::rsp, stack_arg_offset); + a.mov(x86::rsi, m_src); + a.mov(m_dst, x86::rsi); + stack_arg_offset += 8; + stack_arg_idx++; + } + frame_lp_offset += 8; + break; + default: + bh_assert(0); + } + } + + /* Call the fast jit function */ + { + /* info = rsp + switch_info_offset */ + a.lea(x86::rsi, x86::ptr(x86::rsp, switch_info_offset)); + /* info.frame = frame = rax, or return of fast_jit_alloc_frame */ + x86::Mem m1(x86::rsi, (uint32)offsetof(JitInterpSwitchInfo, frame)); + a.mov(m1, x86::rax); + + /* Call code_block_switch_to_jitted_from_interp + with argument (exec_env, info, func_idx, pc) */ + /* rdi = exec_env */ + a.mov(x86::rdi, x86::ptr(x86::rsp, exec_env_offset)); + /* rsi = info, has been set */ + /* rdx = func_idx */ + imm.setValue(func_idx); + a.mov(x86::rdx, imm); + /* module_inst = exec_env->module_inst */ + a.mov(x86::rcx, + x86::ptr(x86::rdi, (uint32)offsetof(WASMExecEnv, module_inst))); + /* fast_jit_func_ptrs = module_inst->fast_jit_func_ptrs */ + a.mov(x86::rcx, + x86::ptr(x86::rcx, (uint32)offsetof(WASMModuleInstance, + fast_jit_func_ptrs))); + imm.setValue(func_idx_non_import); + a.mov(x86::rax, imm); + x86::Mem m3(x86::rcx, x86::rax, 3, 0); + /* rcx = module_inst->fast_jit_func_ptrs[func_idx_non_import] */ + a.mov(x86::rcx, m3); + + imm.setValue( + (uint64)(uintptr_t)code_block_switch_to_jitted_from_interp); + a.mov(x86::rax, imm); + a.call(x86::rax); + } + + /* No need to check exception thrown here as it will be checked + in the caller */ + + /* Copy function results */ + if (result_count > 0) { + frame_lp_offset = offsetof(WASMInterpFrame, lp); + + switch (func_type->types[param_count]) { + case VALUE_TYPE_I32: +#if WASM_ENABLE_REF_TYPES != 0 + case VALUE_TYPE_FUNCREF: + case VALUE_TYPE_EXTERNREF: +#endif + a.mov(x86::eax, x86::edx); + frame_lp_offset += 4; + break; + case VALUE_TYPE_I64: + a.mov(x86::rax, x86::rdx); + frame_lp_offset += 8; + break; + case VALUE_TYPE_F32: + /* The first result has been put to xmm0 */ + frame_lp_offset += 4; + break; + case VALUE_TYPE_F64: + /* The first result has been put to xmm0 */ + frame_lp_offset += 8; + break; + default: + bh_assert(0); + } + + /* Copy extra results from exec_env->cur_frame */ + if (ext_result_count > 0) { + /* rdi = exec_env */ + a.mov(x86::rdi, x86::ptr(x86::rsp, exec_env_offset)); + /* rsi = exec_env->cur_frame */ + a.mov(x86::rsi, + x86::ptr(x86::rdi, (uint32)offsetof(WASMExecEnv, cur_frame))); + + for (i = 0; i < ext_result_count; i++) { + switch (func_type->types[param_count + 1 + i]) { + case VALUE_TYPE_I32: +#if WASM_ENABLE_REF_TYPES != 0 + case VALUE_TYPE_FUNCREF: + case VALUE_TYPE_EXTERNREF: +#endif + case VALUE_TYPE_F32: + { + /* Copy 32-bit result */ + a.mov(x86::ecx, x86::ptr(x86::rsi, frame_lp_offset)); + if (int_reg_idx < MAX_REG_INTS) { + x86::Mem m1(x86::rsp, + exec_env_offset + int_reg_idx * 8); + a.mov(x86::rdx, m1); + x86::Mem m2(x86::rdx, 0); + a.mov(m2, x86::ecx); + int_reg_idx++; + } + else { + x86::Mem m1(x86::rsp, stack_arg_offset); + a.mov(x86::rdx, m1); + x86::Mem m2(x86::rdx, 0); + a.mov(m2, x86::ecx); + stack_arg_offset += 8; + stack_arg_idx++; + } + frame_lp_offset += 4; + break; + } + case VALUE_TYPE_I64: + case VALUE_TYPE_F64: + { + /* Copy 64-bit result */ + a.mov(x86::rcx, x86::ptr(x86::rsi, frame_lp_offset)); + if (int_reg_idx < MAX_REG_INTS) { + x86::Mem m1(x86::rsp, + exec_env_offset + int_reg_idx * 8); + a.mov(x86::rdx, m1); + x86::Mem m2(x86::rdx, 0); + a.mov(m2, x86::rcx); + int_reg_idx++; + } + else { + x86::Mem m1(x86::rsp, stack_arg_offset); + a.mov(x86::rdx, m1); + x86::Mem m2(x86::rdx, 0); + a.mov(m2, x86::rcx); + stack_arg_offset += 8; + stack_arg_idx++; + } + frame_lp_offset += 8; + break; + } + default: + bh_assert(0); + } + } + } + } + + /* Free the frame allocated */ + + /* rdi = exec_env */ + a.mov(x86::rdi, x86::ptr(x86::rsp, exec_env_offset)); + /* rsi = exec_env->cur_frame */ + a.mov(x86::rsi, + x86::ptr(x86::rdi, (uint32)offsetof(WASMExecEnv, cur_frame))); + /* rdx = exec_env->cur_frame->prev_frame */ + a.mov(x86::rdx, + x86::ptr(x86::rsi, (uint32)offsetof(WASMInterpFrame, prev_frame))); + /* exec_env->wasm_stack.s.top = cur_frame */ + { + x86::Mem m(x86::rdi, offsetof(WASMExecEnv, wasm_stack.s.top)); + a.mov(m, x86::rsi); + } + /* exec_env->cur_frame = prev_frame */ + { + x86::Mem m(x86::rdi, offsetof(WASMExecEnv, cur_frame)); + a.mov(m, x86::rdx); + } + + /* Pop all integer arument registers */ + for (i = 0; i < MAX_REG_INTS; i++) { + a.pop(regs_i64[reg_idx_of_int_args[i]]); + } + /* Pop jit interp switch info */ + imm.setValue((uint64)switch_info_size); + a.add(x86::rsp, imm); + + /* Return to the caller, don't use leave as we didn't + `push rbp` and `mov rbp, rsp` */ + a.ret(); + + if (err_handler.err) { + return NULL; + } + + code_buf = (char *)code.sectionById(0)->buffer().data(); + code_size = code.sectionById(0)->buffer().size(); + stream = (char *)jit_code_cache_alloc(code_size); + if (!stream) + return NULL; + + bh_memcpy_s(stream, code_size, code_buf, code_size); + +#if 0 + printf("Code of call to fast jit of func %u:\n", func_idx); + dump_native(stream, code_size); + printf("\n"); +#endif + + return stream; +} + +#endif /* end of WASM_ENABLE_LAZY_JIT != 0 && WASM_ENABLE_JIT != 0 */ + bool jit_codegen_lower(JitCompContext *cc) { @@ -6803,6 +7573,8 @@ jit_codegen_init() code.setErrorHandler(&err_handler); x86::Assembler a(&code); + /* Initialize code_block_switch_to_jitted_from_interp */ + /* push callee-save registers */ a.push(x86::rbp); a.push(x86::rbx); @@ -6822,9 +7594,11 @@ jit_codegen_init() /* exec_env_reg = exec_env */ a.mov(regs_i64[hreg_info->exec_env_hreg_index], x86::rdi); /* fp_reg = info->frame */ - a.mov(x86::rbp, x86::ptr(x86::rsi, 0)); - /* jmp target */ - a.jmp(x86::rdx); + a.mov(x86::rbp, x86::ptr(x86::rsi, offsetof(JitInterpSwitchInfo, frame))); + /* rdx = func_idx, is already set in the func_idx argument of + jit_codegen_interp_jitted_glue */ + /* jmp target, rcx = pc */ + a.jmp(x86::rcx); if (err_handler.err) return false; @@ -6842,26 +7616,25 @@ jit_codegen_init() dump_native(stream, code_size); #endif - a.setOffset(0); + /* Initialize code_block_return_to_interp_from_jitted */ - /* TODO: mask floating-point exception */ - /* TODO: floating-point parameters */ + a.setOffset(0); /* pop info */ a.pop(x86::rsi); /* info->frame = fp_reg */ { - x86::Mem m(x86::rsi, 0); + x86::Mem m(x86::rsi, offsetof(JitInterpSwitchInfo, frame)); a.mov(m, x86::rbp); } - /* info->out.ret.ival[0, 1] = rcx */ + /* info->out.ret.ival[0, 1] = rdx */ { - x86::Mem m(x86::rsi, 8); + x86::Mem m(x86::rsi, offsetof(JitInterpSwitchInfo, out.ret.ival)); a.mov(m, x86::rdx); } /* info->out.ret.fval[0, 1] = xmm0 */ { - x86::Mem m(x86::rsi, 16); + x86::Mem m(x86::rsi, offsetof(JitInterpSwitchInfo, out.ret.fval)); a.movsd(m, x86::xmm0); } @@ -6884,12 +7657,125 @@ jit_codegen_init() goto fail1; bh_memcpy_s(stream, code_size, code_buf, code_size); - code_block_return_to_interp_from_jitted = stream; + code_block_return_to_interp_from_jitted = + jit_globals->return_to_interp_from_jitted = stream; + +#if 0 + dump_native(stream, code_size); +#endif + +#if WASM_ENABLE_LAZY_JIT != 0 + /* Initialize code_block_compile_fast_jit_and_then_call */ + + a.setOffset(0); + + /* Use rbx, r12, r13 to save func_dix, module_inst and module, + as they are callee-save registers */ + + /* Backup func_idx: rbx = rdx = func_idx, note that rdx has + been prepared in the caller: + callbc or code_block_switch_to_jitted_from_interp */ + a.mov(x86::rbx, x86::rdx); + /* r12 = module_inst = exec_env->module_inst */ + { + x86::Mem m(regs_i64[hreg_info->exec_env_hreg_index], + (uint32)offsetof(WASMExecEnv, module_inst)); + a.mov(x86::r12, m); + } + /* rdi = r13 = module_inst->module */ + { + x86::Mem m(x86::r12, (uint32)offsetof(WASMModuleInstance, module)); + a.mov(x86::rdi, m); + a.mov(x86::r13, x86::rdi); + } + /* rsi = rdx = func_idx */ + a.mov(x86::rsi, x86::rdx); + /* Call jit_compiler_compile(module, func_idx) */ + { + Imm imm((uint64)(uintptr_t)jit_compiler_compile); + a.mov(x86::rax, imm); + a.call(x86::rax); + } + + /* Check if failed to compile the jit function */ + { + /* Did jit_compiler_compile return false? */ + Imm imm((uint8)0); + a.cmp(x86::al, imm); + /* If no, jump to `Load compiled func ptr and call it` */ + imm.setValue(INT32_MAX); + a.jne(imm); + + char *stream = (char *)a.code()->sectionById(0)->buffer().data() + + a.code()->sectionById(0)->buffer().size(); + + /* If yes, call jit_set_exception_with_id to throw exception, + and then set eax to JIT_INTERP_ACTION_THROWN, and jump to + code_block_return_to_interp_from_jitted to return */ + + /* rdi = module_inst */ + a.mov(x86::rdi, x86::r12); + /* rsi = EXCE_FAILED_TO_COMPILE_FAST_JIT_FUNC */ + imm.setValue(EXCE_FAILED_TO_COMPILE_FAST_JIT_FUNC); + a.mov(x86::rsi, imm); + /* Call jit_set_exception_with_id */ + imm.setValue((uint64)(uintptr_t)jit_set_exception_with_id); + a.mov(x86::rax, imm); + a.call(x86::rax); + /* Return to the caller */ + imm.setValue(JIT_INTERP_ACTION_THROWN); + a.mov(x86::eax, imm); + imm.setValue(code_block_return_to_interp_from_jitted); + a.mov(x86::rsi, imm); + a.jmp(x86::rsi); + + /* Patch the offset of jne instruction */ + char *stream_new = (char *)a.code()->sectionById(0)->buffer().data() + + a.code()->sectionById(0)->buffer().size(); + *(int32 *)(stream - 4) = (int32)(stream_new - stream); + } + + /* Load compiled func ptr and call it */ + { + /* rsi = module->import_function_count */ + x86::Mem m1(x86::r13, + (uint32)offsetof(WASMModule, import_function_count)); + a.movzx(x86::rsi, m1); + /* rbx = rbx - module->import_function_count */ + a.sub(x86::rbx, x86::rsi); + /* rax = module->fast_jit_func_ptrs */ + x86::Mem m2(x86::r13, (uint32)offsetof(WASMModule, fast_jit_func_ptrs)); + a.mov(x86::rax, m2); + /* rax = fast_jit_func_ptrs[rbx] */ + x86::Mem m3(x86::rax, x86::rbx, 3, 0); + a.mov(x86::rax, m3); + a.jmp(x86::rax); + } + + if (err_handler.err) + goto fail2; + + code_buf = (char *)code.sectionById(0)->buffer().data(); + code_size = code.sectionById(0)->buffer().size(); + stream = (char *)jit_code_cache_alloc(code_size); + if (!stream) + goto fail2; + + bh_memcpy_s(stream, code_size, code_buf, code_size); + code_block_compile_fast_jit_and_then_call = + jit_globals->compile_fast_jit_and_then_call = stream; + +#if 0 + dump_native(stream, code_size); +#endif +#endif /* end of WASM_ENABLE_LAZY_JIT != 0 */ - jit_globals->return_to_interp_from_jitted = - code_block_return_to_interp_from_jitted; return true; +#if WASM_ENABLE_LAZY_JIT != 0 +fail2: + jit_code_cache_free(code_block_return_to_interp_from_jitted); +#endif fail1: jit_code_cache_free(code_block_switch_to_jitted_from_interp); return false; @@ -6898,8 +7784,11 @@ fail1: void jit_codegen_destroy() { - jit_code_cache_free(code_block_switch_to_jitted_from_interp); +#if WASM_ENABLE_LAZY_JIT != 0 + jit_code_cache_free(code_block_compile_fast_jit_and_then_call); +#endif jit_code_cache_free(code_block_return_to_interp_from_jitted); + jit_code_cache_free(code_block_switch_to_jitted_from_interp); } /* clang-format off */ diff --git a/core/iwasm/fast-jit/fe/jit_emit_function.c b/core/iwasm/fast-jit/fe/jit_emit_function.c index 5b0728749..6724217d3 100644 --- a/core/iwasm/fast-jit/fe/jit_emit_function.c +++ b/core/iwasm/fast-jit/fe/jit_emit_function.c @@ -409,7 +409,7 @@ jit_compile_op_call(JitCompContext *cc, uint32 func_idx, bool tail_call) res = create_first_res_reg(cc, func_type); - GEN_INSN(CALLBC, res, 0, jitted_code); + GEN_INSN(CALLBC, res, 0, jitted_code, NEW_CONST(I32, func_idx)); if (!post_return(cc, func_type, res, true)) { goto fail; @@ -700,7 +700,7 @@ jit_compile_op_call_indirect(JitCompContext *cc, uint32 type_idx, goto fail; } } - GEN_INSN(CALLBC, res, 0, jitted_code); + GEN_INSN(CALLBC, res, 0, jitted_code, func_idx); /* Store res into current frame, so that post_return in block func_return can get the value */ n = cc->jit_frame->sp - cc->jit_frame->lp; diff --git a/core/iwasm/fast-jit/jit_codecache.c b/core/iwasm/fast-jit/jit_codecache.c index 4c899ad9d..66c2d033a 100644 --- a/core/iwasm/fast-jit/jit_codecache.c +++ b/core/iwasm/fast-jit/jit_codecache.c @@ -58,8 +58,7 @@ jit_pass_register_jitted_code(JitCompContext *cc) { uint32 jit_func_idx = cc->cur_wasm_func_idx - cc->cur_wasm_module->import_function_count; - cc->cur_wasm_func->fast_jit_jitted_code = cc->jitted_addr_begin; cc->cur_wasm_module->fast_jit_func_ptrs[jit_func_idx] = - cc->jitted_addr_begin; + cc->cur_wasm_func->fast_jit_jitted_code = cc->jitted_addr_begin; return true; } diff --git a/core/iwasm/fast-jit/jit_codegen.h b/core/iwasm/fast-jit/jit_codegen.h index 666a239a6..735cddab6 100644 --- a/core/iwasm/fast-jit/jit_codegen.h +++ b/core/iwasm/fast-jit/jit_codegen.h @@ -65,6 +65,14 @@ jit_codegen_gen_native(JitCompContext *cc); bool jit_codegen_lower(JitCompContext *cc); +#if WASM_ENABLE_LAZY_JIT != 0 && WASM_ENABLE_JIT != 0 +void * +jit_codegen_compile_call_to_llvm_jit(const WASMType *func_type); + +void * +jit_codegen_compile_call_to_fast_jit(const WASMModule *module, uint32 func_idx); +#endif + /** * Dump native code in the given range to assembly. * @@ -75,7 +83,8 @@ void jit_codegen_dump_native(void *begin_addr, void *end_addr); int -jit_codegen_interp_jitted_glue(void *self, JitInterpSwitchInfo *info, void *pc); +jit_codegen_interp_jitted_glue(void *self, JitInterpSwitchInfo *info, + uint32 func_idx, void *pc); #ifdef __cplusplus } diff --git a/core/iwasm/fast-jit/jit_compiler.c b/core/iwasm/fast-jit/jit_compiler.c index c10a40994..67dcb7b51 100644 --- a/core/iwasm/fast-jit/jit_compiler.c +++ b/core/iwasm/fast-jit/jit_compiler.c @@ -10,9 +10,9 @@ #include "../interpreter/wasm.h" typedef struct JitCompilerPass { - /* Name of the pass. */ + /* Name of the pass */ const char *name; - /* The entry of the compiler pass. */ + /* The entry of the compiler pass */ bool (*run)(JitCompContext *cc); } JitCompilerPass; @@ -30,7 +30,7 @@ static JitCompilerPass compiler_passes[] = { #undef REG_PASS }; -/* Number of compiler passes. */ +/* Number of compiler passes */ #define COMPILER_PASS_NUM (sizeof(compiler_passes) / sizeof(compiler_passes[0])) #if WASM_ENABLE_FAST_JIT_DUMP == 0 @@ -43,14 +43,17 @@ static const uint8 compiler_passes_with_dump[] = { }; #endif -/* The exported global data of JIT compiler. */ +/* The exported global data of JIT compiler */ static JitGlobals jit_globals = { #if WASM_ENABLE_FAST_JIT_DUMP == 0 .passes = compiler_passes_without_dump, #else .passes = compiler_passes_with_dump, #endif - .return_to_interp_from_jitted = NULL + .return_to_interp_from_jitted = NULL, +#if WASM_ENABLE_LAZY_JIT != 0 + .compile_fast_jit_and_then_call = NULL, +#endif }; /* clang-format on */ @@ -60,7 +63,7 @@ apply_compiler_passes(JitCompContext *cc) const uint8 *p = jit_globals.passes; for (; *p; p++) { - /* Set the pass NO. */ + /* Set the pass NO */ cc->cur_pass_no = p - jit_globals.passes; bh_assert(*p < COMPILER_PASS_NUM); @@ -120,37 +123,53 @@ jit_compiler_get_pass_name(unsigned i) bool jit_compiler_compile(WASMModule *module, uint32 func_idx) { - JitCompContext *cc; + JitCompContext *cc = NULL; char *last_error; - bool ret = true; + bool ret = false; + uint32 i = func_idx - module->import_function_count; + uint32 j = i % WASM_ORC_JIT_BACKEND_THREAD_NUM; - /* Initialize compilation context. */ - if (!(cc = jit_calloc(sizeof(*cc)))) - return false; + /* Lock to avoid duplicated compilation by other threads */ + os_mutex_lock(&module->fast_jit_thread_locks[j]); + + if (jit_compiler_is_compiled(module, func_idx)) { + /* Function has been compiled */ + os_mutex_unlock(&module->fast_jit_thread_locks[j]); + return true; + } + + /* Initialize the compilation context */ + if (!(cc = jit_calloc(sizeof(*cc)))) { + goto fail; + } if (!jit_cc_init(cc, 64)) { - jit_free(cc); - return false; + goto fail; } cc->cur_wasm_module = module; - cc->cur_wasm_func = - module->functions[func_idx - module->import_function_count]; + cc->cur_wasm_func = module->functions[i]; cc->cur_wasm_func_idx = func_idx; cc->mem_space_unchanged = (!cc->cur_wasm_func->has_op_memory_grow && !cc->cur_wasm_func->has_op_func_call) || (!module->possible_memory_grow); - /* Apply compiler passes. */ + /* Apply compiler passes */ if (!apply_compiler_passes(cc) || jit_get_last_error(cc)) { last_error = jit_get_last_error(cc); os_printf("fast jit compilation failed: %s\n", last_error ? last_error : "unknown error"); - ret = false; + goto fail; } - /* Delete the compilation context. */ - jit_cc_delete(cc); + ret = true; + +fail: + /* Destroy the compilation context */ + if (cc) + jit_cc_delete(cc); + + os_mutex_unlock(&module->fast_jit_thread_locks[j]); return ret; } @@ -169,8 +188,92 @@ jit_compiler_compile_all(WASMModule *module) return true; } -int -jit_interp_switch_to_jitted(void *exec_env, JitInterpSwitchInfo *info, void *pc) +bool +jit_compiler_is_compiled(const WASMModule *module, uint32 func_idx) { - return jit_codegen_interp_jitted_glue(exec_env, info, pc); + uint32 i = func_idx - module->import_function_count; + + bh_assert(func_idx >= module->import_function_count + && func_idx + < module->import_function_count + module->function_count); + +#if WASM_ENABLE_LAZY_JIT == 0 + return module->fast_jit_func_ptrs[i] ? true : false; +#else + return module->fast_jit_func_ptrs[i] + != jit_globals.compile_fast_jit_and_then_call + ? true + : false; +#endif +} + +#if WASM_ENABLE_LAZY_JIT != 0 && WASM_ENABLE_JIT != 0 +bool +jit_compiler_set_call_to_llvm_jit(WASMModule *module, uint32 func_idx) +{ + uint32 i = func_idx - module->import_function_count; + uint32 j = i % WASM_ORC_JIT_BACKEND_THREAD_NUM; + WASMType *func_type = module->functions[i]->func_type; + uint32 k = + ((uint32)(uintptr_t)func_type >> 3) % WASM_ORC_JIT_BACKEND_THREAD_NUM; + void *func_ptr = NULL; + + /* Compile code block of call_to_llvm_jit_from_fast_jit of + this kind of function type if it hasn't been compiled */ + if (!(func_ptr = func_type->call_to_llvm_jit_from_fast_jit)) { + os_mutex_lock(&module->fast_jit_thread_locks[k]); + if (!(func_ptr = func_type->call_to_llvm_jit_from_fast_jit)) { + if (!(func_ptr = func_type->call_to_llvm_jit_from_fast_jit = + jit_codegen_compile_call_to_llvm_jit(func_type))) { + os_mutex_unlock(&module->fast_jit_thread_locks[k]); + return false; + } + } + os_mutex_unlock(&module->fast_jit_thread_locks[k]); + } + + /* Switch current fast jit func ptr to the code block */ + os_mutex_lock(&module->fast_jit_thread_locks[j]); + module->fast_jit_func_ptrs[i] = func_ptr; + os_mutex_unlock(&module->fast_jit_thread_locks[j]); + return true; +} + +bool +jit_compiler_set_call_to_fast_jit(WASMModule *module, uint32 func_idx) +{ + void *func_ptr = NULL; + + func_ptr = jit_codegen_compile_call_to_fast_jit(module, func_idx); + if (func_ptr) { + jit_compiler_set_llvm_jit_func_ptr(module, func_idx, func_ptr); + } + + return func_ptr ? true : false; +} + +void +jit_compiler_set_llvm_jit_func_ptr(WASMModule *module, uint32 func_idx, + void *func_ptr) +{ + WASMModuleInstance *instance; + uint32 i = func_idx - module->import_function_count; + + module->functions[i]->llvm_jit_func_ptr = module->func_ptrs[i] = func_ptr; + + os_mutex_lock(&module->instance_list_lock); + instance = module->instance_list; + while (instance) { + instance->func_ptrs[func_idx] = func_ptr; + instance = instance->e->next; + } + os_mutex_unlock(&module->instance_list_lock); +} +#endif /* end of WASM_ENABLE_LAZY_JIT != 0 && WASM_ENABLE_JIT != 0 */ + +int +jit_interp_switch_to_jitted(void *exec_env, JitInterpSwitchInfo *info, + uint32 func_idx, void *pc) +{ + return jit_codegen_interp_jitted_glue(exec_env, info, func_idx, pc); } diff --git a/core/iwasm/fast-jit/jit_compiler.h b/core/iwasm/fast-jit/jit_compiler.h index 602494db9..9a49cffdd 100644 --- a/core/iwasm/fast-jit/jit_compiler.h +++ b/core/iwasm/fast-jit/jit_compiler.h @@ -18,6 +18,9 @@ typedef struct JitGlobals { /* Compiler pass sequence, the last element must be 0 */ const uint8 *passes; char *return_to_interp_from_jitted; +#if WASM_ENABLE_LAZY_JIT != 0 + char *compile_fast_jit_and_then_call; +#endif } JitGlobals; /** @@ -87,8 +90,24 @@ jit_compiler_compile(WASMModule *module, uint32 func_idx); bool jit_compiler_compile_all(WASMModule *module); +bool +jit_compiler_is_compiled(const WASMModule *module, uint32 func_idx); + +#if WASM_ENABLE_LAZY_JIT != 0 && WASM_ENABLE_JIT != 0 +bool +jit_compiler_set_call_to_llvm_jit(WASMModule *module, uint32 func_idx); + +bool +jit_compiler_set_call_to_fast_jit(WASMModule *module, uint32 func_idx); + +void +jit_compiler_set_llvm_jit_func_ptr(WASMModule *module, uint32 func_idx, + void *func_ptr); +#endif + int -jit_interp_switch_to_jitted(void *self, JitInterpSwitchInfo *info, void *pc); +jit_interp_switch_to_jitted(void *self, JitInterpSwitchInfo *info, + uint32 func_idx, void *pc); /* * Pass declarations: diff --git a/core/iwasm/fast-jit/jit_frontend.c b/core/iwasm/fast-jit/jit_frontend.c index 50cc3717f..e9c8925e8 100644 --- a/core/iwasm/fast-jit/jit_frontend.c +++ b/core/iwasm/fast-jit/jit_frontend.c @@ -2263,6 +2263,12 @@ jit_frontend_translate_func(JitCompContext *cc) return basic_block_entry; } +uint32 +jit_frontend_get_jitted_return_addr_offset() +{ + return (uint32)offsetof(WASMInterpFrame, jitted_return_addr); +} + #if 0 #if WASM_ENABLE_THREAD_MGR != 0 bool diff --git a/core/iwasm/fast-jit/jit_frontend.h b/core/iwasm/fast-jit/jit_frontend.h index d706c90b8..fce8ecfd2 100644 --- a/core/iwasm/fast-jit/jit_frontend.h +++ b/core/iwasm/fast-jit/jit_frontend.h @@ -13,6 +13,10 @@ #include "../aot/aot_runtime.h" #endif +#ifdef __cplusplus +extern "C" { +#endif + #if WASM_ENABLE_AOT == 0 typedef enum IntCond { INT_EQZ = 0, @@ -143,6 +147,9 @@ jit_frontend_translate_func(JitCompContext *cc); bool jit_frontend_lower(JitCompContext *cc); +uint32 +jit_frontend_get_jitted_return_addr_offset(); + uint32 jit_frontend_get_global_data_offset(const WASMModule *module, uint32 global_idx); @@ -483,4 +490,8 @@ set_local_f64(JitFrame *frame, int n, JitReg val) #define PUSH_FUNCREF(v) PUSH(v, VALUE_TYPE_FUNCREF) #define PUSH_EXTERNREF(v) PUSH(v, VALUE_TYPE_EXTERNREF) +#ifdef __cplusplus +} +#endif + #endif diff --git a/core/iwasm/fast-jit/jit_ir.def b/core/iwasm/fast-jit/jit_ir.def index d16843328..8a4396da5 100644 --- a/core/iwasm/fast-jit/jit_ir.def +++ b/core/iwasm/fast-jit/jit_ir.def @@ -196,7 +196,7 @@ INSN(LOOKUPSWITCH, LookupSwitch, 1, 0) /* Call and return instructions */ INSN(CALLNATIVE, VReg, 2, 1) -INSN(CALLBC, Reg, 3, 2) +INSN(CALLBC, Reg, 4, 2) INSN(RETURNBC, Reg, 3, 0) INSN(RETURN, Reg, 1, 0) diff --git a/core/iwasm/interpreter/wasm.h b/core/iwasm/interpreter/wasm.h index 9c6a2ce3d..d15f338fa 100644 --- a/core/iwasm/interpreter/wasm.h +++ b/core/iwasm/interpreter/wasm.h @@ -124,6 +124,12 @@ typedef struct WASMType { uint16 param_cell_num; uint16 ret_cell_num; uint16 ref_count; +#if WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_JIT != 0 \ + && WASM_ENABLE_LAZY_JIT != 0 + /* Code block to call llvm jit functions of this + kind of function type from fast jit jitted code */ + void *call_to_llvm_jit_from_fast_jit; +#endif /* types of params and results */ uint8 types[1]; } WASMType; @@ -256,13 +262,6 @@ struct WASMFunction { uint32 const_cell_num; #endif -#if WASM_ENABLE_FAST_JIT != 0 - void *fast_jit_jitted_code; -#endif -#if WASM_ENABLE_JIT != 0 - void *llvm_jit_func_ptr; -#endif - #if WASM_ENABLE_FAST_JIT != 0 || WASM_ENABLE_JIT != 0 \ || WASM_ENABLE_WAMR_COMPILER != 0 /* Whether function has opcode memory.grow */ @@ -278,6 +277,13 @@ struct WASMFunction { /* Whether function has opcode set_global_aux_stack */ bool has_op_set_global_aux_stack; #endif + +#if WASM_ENABLE_FAST_JIT != 0 + void *fast_jit_jitted_code; +#if WASM_ENABLE_JIT != 0 && WASM_ENABLE_LAZY_JIT != 0 + void *llvm_jit_func_ptr; +#endif +#endif }; struct WASMGlobal { @@ -378,18 +384,22 @@ typedef struct WASMCustomSection { } WASMCustomSection; #endif -#if WASM_ENABLE_JIT != 0 +#if WASM_ENABLE_FAST_JIT != 0 || WASM_ENABLE_JIT != 0 struct AOTCompData; struct AOTCompContext; /* Orc JIT thread arguments */ typedef struct OrcJitThreadArg { +#if WASM_ENABLE_JIT != 0 struct AOTCompContext *comp_ctx; +#endif struct WASMModule *module; uint32 group_idx; } OrcJitThreadArg; #endif +struct WASMModuleInstance; + struct WASMModule { /* Module type, for module loaded from WASM bytecode binary, this field is Wasm_Module_Bytecode; @@ -496,18 +506,22 @@ struct WASMModule { uint64 load_size; #endif -#if WASM_ENABLE_DEBUG_INTERP != 0 +#if WASM_ENABLE_DEBUG_INTERP != 0 \ + || (WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_JIT \ + && WASM_ENABLE_LAZY_JIT != 0) /** - * Count how many instances reference this module. When source - * debugging feature enabled, the debugger may modify the code - * section of the module, so we need to report a warning if user - * create several instances based on the same module + * List of instances referred to this module. When source debugging + * feature is enabled, the debugger may modify the code section of + * the module, so we need to report a warning if user create several + * instances based on the same module. Sub instances created by + * lib-pthread or spawn API won't be added into the list. * - * Sub_instances created by lib-pthread or spawn API will not - * influence or check the ref count + * Also add the instance to the list for Fast JIT to LLVM JIT + * tier-up, since we need to lazily update the LLVM func pointers + * in the instance. */ - uint32 ref_count; - korp_mutex ref_count_lock; + struct WASMModuleInstance *instance_list; + korp_mutex instance_list_lock; #endif #if WASM_ENABLE_CUSTOM_NAME_SECTION != 0 @@ -522,6 +536,9 @@ struct WASMModule { #if WASM_ENABLE_FAST_JIT != 0 /* func pointers of Fast JITed (un-imported) functions */ void **fast_jit_func_ptrs; + /* locks for Fast JIT lazy compilation */ + korp_mutex fast_jit_thread_locks[WASM_ORC_JIT_BACKEND_THREAD_NUM]; + bool fast_jit_thread_locks_inited[WASM_ORC_JIT_BACKEND_THREAD_NUM]; #endif #if WASM_ENABLE_JIT != 0 @@ -531,9 +548,27 @@ struct WASMModule { void **func_ptrs; /* whether the func pointers are compiled */ bool *func_ptrs_compiled; - bool orcjit_stop_compiling; +#endif + +#if WASM_ENABLE_FAST_JIT != 0 || WASM_ENABLE_JIT != 0 + /* backend compilation threads */ korp_tid orcjit_threads[WASM_ORC_JIT_BACKEND_THREAD_NUM]; + /* backend thread arguments */ OrcJitThreadArg orcjit_thread_args[WASM_ORC_JIT_BACKEND_THREAD_NUM]; + /* whether to stop the compilation of backend threads */ + bool orcjit_stop_compiling; +#endif + +#if WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_JIT != 0 \ + && WASM_ENABLE_LAZY_JIT != 0 + /* wait lock/cond for the synchronization of + the llvm jit initialization */ + korp_mutex tierup_wait_lock; + korp_cond tierup_wait_cond; + bool tierup_wait_lock_inited; + korp_tid llvm_jit_init_thread; + /* whether the llvm jit is initialized */ + bool llvm_jit_inited; #endif }; diff --git a/core/iwasm/interpreter/wasm_interp_classic.c b/core/iwasm/interpreter/wasm_interp_classic.c index 4726636c9..33fb9db14 100644 --- a/core/iwasm/interpreter/wasm_interp_classic.c +++ b/core/iwasm/interpreter/wasm_interp_classic.c @@ -3886,28 +3886,48 @@ wasm_interp_call_func_bytecode(WASMModuleInstance *module, #if WASM_ENABLE_FAST_JIT != 0 static void -fast_jit_call_func_bytecode(WASMExecEnv *exec_env, +fast_jit_call_func_bytecode(WASMModuleInstance *module_inst, + WASMExecEnv *exec_env, WASMFunctionInstance *function, WASMInterpFrame *frame) { JitGlobals *jit_globals = jit_compiler_get_jit_globals(); JitInterpSwitchInfo info; + WASMModule *module = module_inst->module; WASMType *func_type = function->u.func->func_type; uint8 type = func_type->result_count ? func_type->types[func_type->param_count] : VALUE_TYPE_VOID; + uint32 func_idx = (uint32)(function - module_inst->e->functions); + uint32 func_idx_non_import = func_idx - module->import_function_count; + int32 action; #if WASM_ENABLE_REF_TYPES != 0 if (type == VALUE_TYPE_EXTERNREF || type == VALUE_TYPE_FUNCREF) type = VALUE_TYPE_I32; #endif +#if WASM_ENABLE_LAZY_JIT != 0 + if (!jit_compiler_compile(module, func_idx)) { + wasm_set_exception(module_inst, "failed to compile fast jit function"); + return; + } +#endif + bh_assert(jit_compiler_is_compiled(module, func_idx)); + + /* Switch to jitted code to call the jit function */ info.out.ret.last_return_type = type; info.frame = frame; frame->jitted_return_addr = (uint8 *)jit_globals->return_to_interp_from_jitted; - jit_interp_switch_to_jitted(exec_env, &info, - function->u.func->fast_jit_jitted_code); + action = jit_interp_switch_to_jitted( + exec_env, &info, func_idx, + module_inst->fast_jit_func_ptrs[func_idx_non_import]); + bh_assert(action == JIT_INTERP_ACTION_NORMAL + || (action == JIT_INTERP_ACTION_THROWN + && wasm_runtime_get_exception(exec_env->module_inst))); + + /* Get the return values form info.out.ret */ if (func_type->result_count) { switch (type) { case VALUE_TYPE_I32: @@ -3931,8 +3951,10 @@ fast_jit_call_func_bytecode(WASMExecEnv *exec_env, break; } } + (void)action; + (void)func_idx; } -#endif +#endif /* end of WASM_ENABLE_FAST_JIT != 0 */ #if WASM_ENABLE_JIT != 0 static bool @@ -3962,11 +3984,14 @@ llvm_jit_call_func_bytecode(WASMModuleInstance *module_inst, WASMType *func_type = function->u.func->func_type; uint32 result_count = func_type->result_count; uint32 ext_ret_count = result_count > 1 ? result_count - 1 : 0; + uint32 func_idx = (uint32)(function - module_inst->e->functions); bool ret; #if (WASM_ENABLE_DUMP_CALL_STACK != 0) || (WASM_ENABLE_PERF_PROFILING != 0) if (!llvm_jit_alloc_frame(exec_env, function - module_inst->e->functions)) { - wasm_set_exception(module_inst, "wasm operand stack overflow"); + /* wasm operand stack overflow has been thrown, + no need to throw again */ + return false; } #endif @@ -4007,8 +4032,8 @@ llvm_jit_call_func_bytecode(WASMModuleInstance *module_inst, } ret = wasm_runtime_invoke_native( - exec_env, function->u.func->llvm_jit_func_ptr, func_type, NULL, - NULL, argv1, argc, argv); + exec_env, module_inst->func_ptrs[func_idx], func_type, NULL, NULL, + argv1, argc, argv); if (!ret || wasm_get_exception(module_inst)) { if (clear_wasi_proc_exit_exception(module_inst)) @@ -4058,8 +4083,8 @@ llvm_jit_call_func_bytecode(WASMModuleInstance *module_inst, } else { ret = wasm_runtime_invoke_native( - exec_env, function->u.func->llvm_jit_func_ptr, func_type, NULL, - NULL, argv, argc, argv); + exec_env, module_inst->func_ptrs[func_idx], func_type, NULL, NULL, + argv, argc, argv); if (clear_wasi_proc_exit_exception(module_inst)) ret = true; @@ -4067,7 +4092,7 @@ llvm_jit_call_func_bytecode(WASMModuleInstance *module_inst, return ret && !wasm_get_exception(module_inst) ? true : false; } } -#endif +#endif /* end of WASM_ENABLE_JIT != 0 */ void wasm_interp_call_wasm(WASMModuleInstance *module_inst, WASMExecEnv *exec_env, @@ -4134,18 +4159,63 @@ wasm_interp_call_wasm(WASMModuleInstance *module_inst, WASMExecEnv *exec_env, } } else { -#if WASM_ENABLE_JIT != 0 +#if WASM_ENABLE_LAZY_JIT != 0 + + /* Fast JIT to LLVM JIT tier-up is enabled */ +#if WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_JIT != 0 + /* Fast JIT and LLVM JIT are both enabled, call llvm jit function + if it is compiled, else call fast jit function */ + uint32 func_idx = (uint32)(function - module_inst->e->functions); + if (module_inst->module->func_ptrs_compiled + [func_idx - module_inst->module->import_function_count]) { + llvm_jit_call_func_bytecode(module_inst, exec_env, function, argc, + argv); + /* For llvm jit, the results have been stored in argv, + no need to copy them from stack frame again */ + copy_argv_from_frame = false; + } + else { + fast_jit_call_func_bytecode(module_inst, exec_env, function, frame); + } +#elif WASM_ENABLE_JIT != 0 + /* Only LLVM JIT is enabled */ llvm_jit_call_func_bytecode(module_inst, exec_env, function, argc, argv); /* For llvm jit, the results have been stored in argv, no need to copy them from stack frame again */ copy_argv_from_frame = false; #elif WASM_ENABLE_FAST_JIT != 0 - fast_jit_call_func_bytecode(exec_env, function, frame); + /* Only Fast JIT is enabled */ + fast_jit_call_func_bytecode(module_inst, exec_env, function, frame); #else + /* Both Fast JIT and LLVM JIT are disabled */ wasm_interp_call_func_bytecode(module_inst, exec_env, function, frame); #endif + +#else /* else of WASM_ENABLE_LAZY_JIT != 0 */ + + /* Fast JIT to LLVM JIT tier-up is enabled */ +#if WASM_ENABLE_JIT != 0 + /* LLVM JIT is enabled */ + llvm_jit_call_func_bytecode(module_inst, exec_env, function, argc, + argv); + /* For llvm jit, the results have been stored in argv, + no need to copy them from stack frame again */ + copy_argv_from_frame = false; +#elif WASM_ENABLE_FAST_JIT != 0 + /* Fast JIT is enabled */ + fast_jit_call_func_bytecode(module_inst, exec_env, function, frame); +#else + /* Both Fast JIT and LLVM JIT are disabled */ + wasm_interp_call_func_bytecode(module_inst, exec_env, function, frame); +#endif + +#endif /* end of WASM_ENABLE_LAZY_JIT != 0 */ + (void)wasm_interp_call_func_bytecode; +#if WASM_ENABLE_FAST_JIT != 0 + (void)fast_jit_call_func_bytecode; +#endif } /* Output the return value to the caller */ diff --git a/core/iwasm/interpreter/wasm_loader.c b/core/iwasm/interpreter/wasm_loader.c index c706b6c7d..978c66cd3 100644 --- a/core/iwasm/interpreter/wasm_loader.c +++ b/core/iwasm/interpreter/wasm_loader.c @@ -428,6 +428,12 @@ destroy_wasm_type(WASMType *type) return; } +#if WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_JIT != 0 \ + && WASM_ENABLE_LAZY_JIT != 0 + if (type->call_to_llvm_jit_from_fast_jit) + jit_code_cache_free(type->call_to_llvm_jit_from_fast_jit); +#endif + wasm_runtime_free(type); } @@ -2925,93 +2931,78 @@ calculate_global_data_offset(WASMModule *module) module->global_data_size = data_offset; } -#if WASM_ENABLE_JIT != 0 -static void * -orcjit_thread_callback(void *arg) -{ - LLVMOrcJITTargetAddress func_addr = 0; - OrcJitThreadArg *thread_arg = (OrcJitThreadArg *)arg; - AOTCompContext *comp_ctx = thread_arg->comp_ctx; - WASMModule *module = thread_arg->module; - uint32 group_idx = thread_arg->group_idx; - uint32 group_stride = WASM_ORC_JIT_BACKEND_THREAD_NUM; - uint32 func_count = module->function_count; - uint32 i, j; - typedef void (*F)(void); - LLVMErrorRef error; - char func_name[48]; - union { - F f; - void *v; - } u; - - /* Compile jit functions of this group */ - for (i = group_idx; i < func_count; - i += group_stride * WASM_ORC_JIT_COMPILE_THREAD_NUM) { - snprintf(func_name, sizeof(func_name), "%s%d%s", AOT_FUNC_PREFIX, i, - "_wrapper"); - LOG_DEBUG("compile func %s", func_name); - error = - LLVMOrcLLLazyJITLookup(comp_ctx->orc_jit, &func_addr, func_name); - if (error != LLVMErrorSuccess) { - char *err_msg = LLVMGetErrorMessage(error); - os_printf("failed to compile orc jit function: %s", err_msg); - LLVMDisposeErrorMessage(err_msg); - continue; - } - - /* Call the jit wrapper function to trigger its compilation, so as - to compile the actual jit functions, since we add the latter to - function list in the PartitionFunction callback */ - u.v = (void *)func_addr; - u.f(); - - for (j = 0; j < WASM_ORC_JIT_COMPILE_THREAD_NUM; j++) { - if (i + j * group_stride < func_count) - module->func_ptrs_compiled[i + j * group_stride] = true; - } - - if (module->orcjit_stop_compiling) { - break; - } - } - - return NULL; -} - -static void -orcjit_stop_compile_threads(WASMModule *module) -{ - uint32 i, thread_num = (uint32)(sizeof(module->orcjit_thread_args) - / sizeof(OrcJitThreadArg)); - - module->orcjit_stop_compiling = true; - for (i = 0; i < thread_num; i++) { - if (module->orcjit_threads[i]) - os_thread_join(module->orcjit_threads[i], NULL); - } -} - +#if WASM_ENABLE_FAST_JIT != 0 static bool -compile_llvm_jit_functions(WASMModule *module, char *error_buf, - uint32 error_buf_size) +init_fast_jit_functions(WASMModule *module, char *error_buf, + uint32 error_buf_size) +{ +#if WASM_ENABLE_LAZY_JIT != 0 + JitGlobals *jit_globals = jit_compiler_get_jit_globals(); +#endif + uint32 i; + + if (!module->function_count) + return true; + + if (!(module->fast_jit_func_ptrs = + loader_malloc(sizeof(void *) * module->function_count, error_buf, + error_buf_size))) { + return false; + } + +#if WASM_ENABLE_LAZY_JIT != 0 + for (i = 0; i < module->function_count; i++) { + module->fast_jit_func_ptrs[i] = + jit_globals->compile_fast_jit_and_then_call; + } +#endif + + for (i = 0; i < WASM_ORC_JIT_BACKEND_THREAD_NUM; i++) { + if (os_mutex_init(&module->fast_jit_thread_locks[i]) != 0) { + set_error_buf(error_buf, error_buf_size, + "init fast jit thread lock failed"); + return false; + } + module->fast_jit_thread_locks_inited[i] = true; + } + + return true; +} +#endif /* end of WASM_ENABLE_FAST_JIT != 0 */ + +#if WASM_ENABLE_JIT != 0 +static bool +init_llvm_jit_functions_stage1(WASMModule *module, char *error_buf, + uint32 error_buf_size) { AOTCompOption option = { 0 }; char *aot_last_error; uint64 size; - uint32 thread_num, i; - if (module->function_count > 0) { - size = sizeof(void *) * (uint64)module->function_count - + sizeof(bool) * (uint64)module->function_count; - if (!(module->func_ptrs = - loader_malloc(size, error_buf, error_buf_size))) { - return false; - } - module->func_ptrs_compiled = - (bool *)((uint8 *)module->func_ptrs - + sizeof(void *) * module->function_count); + if (module->function_count == 0) + return true; + +#if WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_LLVM_JIT != 0 + if (os_mutex_init(&module->tierup_wait_lock) != 0) { + set_error_buf(error_buf, error_buf_size, "init jit tierup lock failed"); + return false; } + if (os_cond_init(&module->tierup_wait_cond) != 0) { + set_error_buf(error_buf, error_buf_size, "init jit tierup cond failed"); + os_mutex_destroy(&module->tierup_wait_lock); + return false; + } + module->tierup_wait_lock_inited = true; +#endif + + size = sizeof(void *) * (uint64)module->function_count + + sizeof(bool) * (uint64)module->function_count; + if (!(module->func_ptrs = loader_malloc(size, error_buf, error_buf_size))) { + return false; + } + module->func_ptrs_compiled = + (bool *)((uint8 *)module->func_ptrs + + sizeof(void *) * module->function_count); module->comp_data = aot_create_comp_data(module); if (!module->comp_data) { @@ -3052,6 +3043,19 @@ compile_llvm_jit_functions(WASMModule *module, char *error_buf, return false; } + return true; +} + +static bool +init_llvm_jit_functions_stage2(WASMModule *module, char *error_buf, + uint32 error_buf_size) +{ + char *aot_last_error; + uint32 i; + + if (module->function_count == 0) + return true; + if (!aot_compile_wasm(module->comp_ctx)) { aot_last_error = aot_get_last_error(); bh_assert(aot_last_error != NULL); @@ -3059,7 +3063,7 @@ compile_llvm_jit_functions(WASMModule *module, char *error_buf, return false; } - bh_print_time("Begin to lookup jit functions"); + bh_print_time("Begin to lookup llvm jit functions"); for (i = 0; i < module->function_count; i++) { LLVMOrcJITTargetAddress func_addr = 0; @@ -3084,17 +3088,206 @@ compile_llvm_jit_functions(WASMModule *module, char *error_buf, * loading/storing at the same time. */ module->func_ptrs[i] = (void *)func_addr; - module->functions[i]->llvm_jit_func_ptr = (void *)func_addr; } + bh_print_time("End lookup llvm jit functions"); + + return true; +} +#endif /* end of WASM_ENABLE_JIT != 0 */ + +#if WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_JIT != 0 \ + && WASM_ENABLE_LAZY_JIT != 0 +static void * +init_llvm_jit_functions_stage2_callback(void *arg) +{ + WASMModule *module = (WASMModule *)arg; + char error_buf[128]; + uint32 error_buf_size = (uint32)sizeof(error_buf); + + if (!init_llvm_jit_functions_stage2(module, error_buf, error_buf_size)) { + module->orcjit_stop_compiling = true; + return NULL; + } + + os_mutex_lock(&module->tierup_wait_lock); + module->llvm_jit_inited = true; + os_cond_broadcast(&module->tierup_wait_cond); + os_mutex_unlock(&module->tierup_wait_lock); + + return NULL; +} +#endif + +#if WASM_ENABLE_FAST_JIT != 0 || WASM_ENABLE_JIT != 0 +/* The callback function to compile jit functions */ +static void * +orcjit_thread_callback(void *arg) +{ + OrcJitThreadArg *thread_arg = (OrcJitThreadArg *)arg; +#if WASM_ENABLE_JIT != 0 + AOTCompContext *comp_ctx = thread_arg->comp_ctx; +#endif + WASMModule *module = thread_arg->module; + uint32 group_idx = thread_arg->group_idx; + uint32 group_stride = WASM_ORC_JIT_BACKEND_THREAD_NUM; + uint32 func_count = module->function_count; + uint32 i; + +#if WASM_ENABLE_FAST_JIT != 0 + /* Compile fast jit funcitons of this group */ + for (i = group_idx; i < func_count; i += group_stride) { + if (!jit_compiler_compile(module, i + module->import_function_count)) { + os_printf("failed to compile fast jit function %u\n", i); + break; + } + + if (module->orcjit_stop_compiling) { + return NULL; + } + } +#endif + +#if WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_JIT != 0 \ + && WASM_ENABLE_LAZY_JIT != 0 + /* For JIT tier-up, set each llvm jit func to call_to_fast_jit */ + for (i = group_idx; i < func_count; + i += group_stride * WASM_ORC_JIT_COMPILE_THREAD_NUM) { + uint32 j; + + for (j = 0; j < WASM_ORC_JIT_COMPILE_THREAD_NUM; j++) { + if (i + j * group_stride < func_count) { + if (!jit_compiler_set_call_to_fast_jit( + module, + i + j * group_stride + module->import_function_count)) { + os_printf( + "failed to compile call_to_fast_jit for func %u\n", + i + j * group_stride + module->import_function_count); + module->orcjit_stop_compiling = true; + return NULL; + } + } + if (module->orcjit_stop_compiling) { + return NULL; + } + } + } + + /* Wait until init_llvm_jit_functions_stage2 finishes */ + os_mutex_lock(&module->tierup_wait_lock); + while (!module->llvm_jit_inited) { + os_cond_reltimedwait(&module->tierup_wait_cond, + &module->tierup_wait_lock, 10); + if (module->orcjit_stop_compiling) { + /* init_llvm_jit_functions_stage2 failed */ + os_mutex_unlock(&module->tierup_wait_lock); + return NULL; + } + } + os_mutex_unlock(&module->tierup_wait_lock); +#endif + +#if WASM_ENABLE_JIT != 0 + /* Compile llvm jit functions of this group */ + for (i = group_idx; i < func_count; + i += group_stride * WASM_ORC_JIT_COMPILE_THREAD_NUM) { + LLVMOrcJITTargetAddress func_addr = 0; + LLVMErrorRef error; + char func_name[48]; + typedef void (*F)(void); + union { + F f; + void *v; + } u; + uint32 j; + + snprintf(func_name, sizeof(func_name), "%s%d%s", AOT_FUNC_PREFIX, i, + "_wrapper"); + LOG_DEBUG("compile llvm jit func %s", func_name); + error = + LLVMOrcLLLazyJITLookup(comp_ctx->orc_jit, &func_addr, func_name); + if (error != LLVMErrorSuccess) { + char *err_msg = LLVMGetErrorMessage(error); + os_printf("failed to compile llvm jit function %u: %s", i, err_msg); + LLVMDisposeErrorMessage(err_msg); + break; + } + + /* Call the jit wrapper function to trigger its compilation, so as + to compile the actual jit functions, since we add the latter to + function list in the PartitionFunction callback */ + u.v = (void *)func_addr; + u.f(); + + for (j = 0; j < WASM_ORC_JIT_COMPILE_THREAD_NUM; j++) { + if (i + j * group_stride < func_count) { + module->func_ptrs_compiled[i + j * group_stride] = true; +#if WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_LAZY_JIT != 0 + snprintf(func_name, sizeof(func_name), "%s%d", AOT_FUNC_PREFIX, + i + j * group_stride); + error = LLVMOrcLLLazyJITLookup(comp_ctx->orc_jit, &func_addr, + func_name); + if (error != LLVMErrorSuccess) { + char *err_msg = LLVMGetErrorMessage(error); + os_printf("failed to compile llvm jit function %u: %s", i, + err_msg); + LLVMDisposeErrorMessage(err_msg); + /* Ignore current llvm jit func, as its func ptr is + previous set to call_to_fast_jit, which also works */ + continue; + } + + jit_compiler_set_llvm_jit_func_ptr( + module, + i + j * group_stride + module->import_function_count, + (void *)func_addr); + + /* Try to switch to call this llvm jit funtion instead of + fast jit function from fast jit jitted code */ + jit_compiler_set_call_to_llvm_jit( + module, + i + j * group_stride + module->import_function_count); +#endif + } + } + + if (module->orcjit_stop_compiling) { + break; + } + } +#endif + + return NULL; +} + +static void +orcjit_stop_compile_threads(WASMModule *module) +{ + uint32 i, thread_num = (uint32)(sizeof(module->orcjit_thread_args) + / sizeof(OrcJitThreadArg)); + + module->orcjit_stop_compiling = true; + for (i = 0; i < thread_num; i++) { + if (module->orcjit_threads[i]) + os_thread_join(module->orcjit_threads[i], NULL); + } +} + +static bool +compile_jit_functions(WASMModule *module, char *error_buf, + uint32 error_buf_size) +{ + uint32 thread_num = + (uint32)(sizeof(module->orcjit_thread_args) / sizeof(OrcJitThreadArg)); + uint32 i, j; + bh_print_time("Begin to compile jit functions"); - thread_num = - (uint32)(sizeof(module->orcjit_thread_args) / sizeof(OrcJitThreadArg)); - /* Create threads to compile the jit functions */ - for (i = 0; i < thread_num; i++) { + for (i = 0; i < thread_num && i < module->function_count; i++) { +#if WASM_ENABLE_JIT != 0 module->orcjit_thread_args[i].comp_ctx = module->comp_ctx; +#endif module->orcjit_thread_args[i].module = module; module->orcjit_thread_args[i].group_idx = i; @@ -3102,8 +3295,6 @@ compile_llvm_jit_functions(WASMModule *module, char *error_buf, (void *)&module->orcjit_thread_args[i], APP_THREAD_STACK_SIZE_DEFAULT) != 0) { - uint32 j; - set_error_buf(error_buf, error_buf_size, "create orcjit compile thread failed"); /* Terminate the threads created */ @@ -3118,15 +3309,39 @@ compile_llvm_jit_functions(WASMModule *module, char *error_buf, #if WASM_ENABLE_LAZY_JIT == 0 /* Wait until all jit functions are compiled for eager mode */ for (i = 0; i < thread_num; i++) { - os_thread_join(module->orcjit_threads[i], NULL); + if (module->orcjit_threads[i]) + os_thread_join(module->orcjit_threads[i], NULL); + } + +#if WASM_ENABLE_FAST_JIT != 0 + /* Ensure all the fast-jit functions are compiled */ + for (i = 0; i < module->function_count; i++) { + if (!jit_compiler_is_compiled(module, + i + module->import_function_count)) { + set_error_buf(error_buf, error_buf_size, + "failed to compile fast jit function"); + return false; + } } #endif +#if WASM_ENABLE_JIT != 0 + /* Ensure all the llvm-jit functions are compiled */ + for (i = 0; i < module->function_count; i++) { + if (!module->func_ptrs_compiled[i]) { + set_error_buf(error_buf, error_buf_size, + "failed to compile llvm jit function"); + return false; + } + } +#endif +#endif /* end of WASM_ENABLE_LAZY_JIT == 0 */ + bh_print_time("End compile jit functions"); return true; } -#endif /* end of WASM_ENABLE_JIT != 0 */ +#endif /* end of WASM_ENABLE_FAST_JIT != 0 || WASM_ENABLE_JIT != 0 */ static bool wasm_loader_prepare_bytecode(WASMModule *module, WASMFunction *func, @@ -3538,23 +3753,41 @@ load_from_sections(WASMModule *module, WASMSection *sections, calculate_global_data_offset(module); #if WASM_ENABLE_FAST_JIT != 0 - if (module->function_count - && !(module->fast_jit_func_ptrs = - loader_malloc(sizeof(void *) * module->function_count, - error_buf, error_buf_size))) { - return false; - } - if (!jit_compiler_compile_all(module)) { - set_error_buf(error_buf, error_buf_size, "fast jit compilation failed"); + if (!init_fast_jit_functions(module, error_buf, error_buf_size)) { return false; } #endif #if WASM_ENABLE_JIT != 0 - if (!compile_llvm_jit_functions(module, error_buf, error_buf_size)) { + if (!init_llvm_jit_functions_stage1(module, error_buf, error_buf_size)) { return false; } -#endif /* end of WASM_ENABLE_JIT != 0 */ +#if !(WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_LAZY_JIT != 0) + if (!init_llvm_jit_functions_stage2(module, error_buf, error_buf_size)) { + return false; + } +#else + /* Run aot_compile_wasm in a backend thread, so as not to block the main + thread fast jit execution, since applying llvm optimizations in + aot_compile_wasm may cost a lot of time. + Create thread with enough native stack to apply llvm optimizations */ + if (os_thread_create(&module->llvm_jit_init_thread, + init_llvm_jit_functions_stage2_callback, + (void *)module, APP_THREAD_STACK_SIZE_DEFAULT * 8) + != 0) { + set_error_buf(error_buf, error_buf_size, + "create orcjit compile thread failed"); + return false; + } +#endif +#endif + +#if WASM_ENABLE_FAST_JIT != 0 || WASM_ENABLE_JIT != 0 + /* Create threads to compile the jit functions */ + if (!compile_jit_functions(module, error_buf, error_buf_size)) { + return false; + } +#endif #if WASM_ENABLE_MEMORY_TRACING != 0 wasm_runtime_dump_module_mem_consumption((WASMModuleCommon *)module); @@ -3567,9 +3800,7 @@ create_module(char *error_buf, uint32 error_buf_size) { WASMModule *module = loader_malloc(sizeof(WASMModule), error_buf, error_buf_size); -#if WASM_ENABLE_FAST_INTERP == 0 bh_list_status ret; -#endif if (!module) { return NULL; @@ -3584,19 +3815,31 @@ create_module(char *error_buf, uint32 error_buf_size) module->br_table_cache_list = &module->br_table_cache_list_head; ret = bh_list_init(module->br_table_cache_list); bh_assert(ret == BH_LIST_SUCCESS); - (void)ret; #endif #if WASM_ENABLE_MULTI_MODULE != 0 module->import_module_list = &module->import_module_list_head; + ret = bh_list_init(module->import_module_list); + bh_assert(ret == BH_LIST_SUCCESS); #endif + #if WASM_ENABLE_DEBUG_INTERP != 0 - bh_list_init(&module->fast_opcode_list); - if (os_mutex_init(&module->ref_count_lock) != 0) { + ret = bh_list_init(&module->fast_opcode_list); + bh_assert(ret == BH_LIST_SUCCESS); +#endif + +#if WASM_ENABLE_DEBUG_INTERP != 0 \ + || (WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_JIT \ + && WASM_ENABLE_LAZY_JIT != 0) + if (os_mutex_init(&module->instance_list_lock) != 0) { + set_error_buf(error_buf, error_buf_size, + "init instance list lock failed"); wasm_runtime_free(module); return NULL; } #endif + + (void)ret; return module; } @@ -3964,10 +4207,19 @@ wasm_loader_unload(WASMModule *module) if (!module) return; -#if WASM_ENABLE_JIT != 0 - /* Stop LLVM JIT compilation firstly to avoid accessing +#if WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_JIT && WASM_ENABLE_LAZY_JIT != 0 + module->orcjit_stop_compiling = true; + if (module->llvm_jit_init_thread) + os_thread_join(module->llvm_jit_init_thread, NULL); +#endif + +#if WASM_ENABLE_FAST_JIT != 0 || WASM_ENABLE_JIT != 0 + /* Stop Fast/LLVM JIT compilation firstly to avoid accessing module internal data after they were freed */ orcjit_stop_compile_threads(module); +#endif + +#if WASM_ENABLE_JIT != 0 if (module->func_ptrs) wasm_runtime_free(module->func_ptrs); if (module->comp_ctx) @@ -3976,6 +4228,13 @@ wasm_loader_unload(WASMModule *module) aot_destroy_comp_data(module->comp_data); #endif +#if WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_JIT && WASM_ENABLE_LAZY_JIT != 0 + if (module->tierup_wait_lock_inited) { + os_mutex_destroy(&module->tierup_wait_lock); + os_cond_destroy(&module->tierup_wait_cond); + } +#endif + if (module->types) { for (i = 0; i < module->type_count; i++) { if (module->types[i]) @@ -3997,6 +4256,18 @@ wasm_loader_unload(WASMModule *module) wasm_runtime_free(module->functions[i]->code_compiled); if (module->functions[i]->consts) wasm_runtime_free(module->functions[i]->consts); +#endif +#if WASM_ENABLE_FAST_JIT != 0 + if (module->functions[i]->fast_jit_jitted_code) { + jit_code_cache_free( + module->functions[i]->fast_jit_jitted_code); + } +#if WASM_ENABLE_JIT != 0 && WASM_ENABLE_LAZY_JIT != 0 + if (module->functions[i]->llvm_jit_func_ptr) { + jit_code_cache_free( + module->functions[i]->llvm_jit_func_ptr); + } +#endif #endif wasm_runtime_free(module->functions[i]); } @@ -4084,7 +4355,12 @@ wasm_loader_unload(WASMModule *module) wasm_runtime_free(fast_opcode); fast_opcode = next; } - os_mutex_destroy(&module->ref_count_lock); +#endif + +#if WASM_ENABLE_DEBUG_INTERP != 0 \ + || (WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_JIT \ + && WASM_ENABLE_LAZY_JIT != 0) + os_mutex_destroy(&module->instance_list_lock); #endif #if WASM_ENABLE_LOAD_CUSTOM_SECTION != 0 @@ -4093,12 +4369,14 @@ wasm_loader_unload(WASMModule *module) #if WASM_ENABLE_FAST_JIT != 0 if (module->fast_jit_func_ptrs) { - for (i = 0; i < module->function_count; i++) { - if (module->fast_jit_func_ptrs[i]) - jit_code_cache_free(module->fast_jit_func_ptrs[i]); - } wasm_runtime_free(module->fast_jit_func_ptrs); } + + for (i = 0; i < WASM_ORC_JIT_BACKEND_THREAD_NUM; i++) { + if (module->fast_jit_thread_locks_inited[i]) { + os_mutex_destroy(&module->fast_jit_thread_locks[i]); + } + } #endif wasm_runtime_free(module); diff --git a/core/iwasm/interpreter/wasm_runtime.c b/core/iwasm/interpreter/wasm_runtime.c index f702a0e84..b7babf6b8 100644 --- a/core/iwasm/interpreter/wasm_runtime.c +++ b/core/iwasm/interpreter/wasm_runtime.c @@ -19,6 +19,9 @@ #if WASM_ENABLE_DEBUG_INTERP != 0 #include "../libraries/debug-engine/debug_engine.h" #endif +#if WASM_ENABLE_FAST_JIT != 0 +#include "../fast-jit/jit_compiler.h" +#endif #if WASM_ENABLE_JIT != 0 #include "../aot/aot_runtime.h" #endif @@ -1414,30 +1417,10 @@ wasm_instantiate(WASMModule *module, bool is_sub_inst, uint32 stack_size, extra_info_offset = (uint32)total_size; total_size += sizeof(WASMModuleInstanceExtra); -#if WASM_ENABLE_DEBUG_INTERP != 0 - if (!is_sub_inst) { - os_mutex_lock(&module->ref_count_lock); - if (module->ref_count != 0) { - LOG_WARNING( - "warning: multiple instances referencing the same module may " - "cause unexpected behaviour during debugging"); - } - module->ref_count++; - os_mutex_unlock(&module->ref_count_lock); - } -#endif - /* Allocate the memory for module instance with memory instances, global data, table data appended at the end */ if (!(module_inst = runtime_malloc(total_size, error_buf, error_buf_size))) { -#if WASM_ENABLE_DEBUG_INTERP != 0 - if (!is_sub_inst) { - os_mutex_lock(&module->ref_count_lock); - module->ref_count--; - os_mutex_unlock(&module->ref_count_lock); - } -#endif return NULL; } @@ -1827,6 +1810,33 @@ wasm_instantiate(WASMModule *module, bool is_sub_inst, uint32 stack_size, } #endif +#if WASM_ENABLE_DEBUG_INTERP != 0 \ + || (WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_JIT != 0 \ + && WASM_ENABLE_LAZY_JIT != 0) + if (!is_sub_inst) { + /* Add module instance into module's instance list */ + os_mutex_lock(&module->instance_list_lock); +#if WASM_ENABLE_DEBUG_INTERP != 0 + if (module->instance_list) { + LOG_WARNING( + "warning: multiple instances referencing to the same module " + "may cause unexpected behaviour during debugging"); + } +#endif +#if WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_JIT != 0 \ + && WASM_ENABLE_LAZY_JIT != 0 + /* Copy llvm func ptrs again in case that they were updated + after the module instance was created */ + bh_memcpy_s(module_inst->func_ptrs + module->import_function_count, + sizeof(void *) * module->function_count, module->func_ptrs, + sizeof(void *) * module->function_count); +#endif + module_inst->e->next = module->instance_list; + module->instance_list = module_inst; + os_mutex_unlock(&module->instance_list_lock); + } +#endif + if (module->start_function != (uint32)-1) { /* TODO: fix start function can be import function issue */ if (module->start_function >= module->import_function_count) @@ -1864,8 +1874,8 @@ wasm_instantiate(WASMModule *module, bool is_sub_inst, uint32 stack_size, wasm_runtime_dump_module_inst_mem_consumption( (WASMModuleInstanceCommon *)module_inst); #endif - (void)global_data_end; + (void)global_data_end; return module_inst; fail: @@ -1935,11 +1945,28 @@ wasm_deinstantiate(WASMModuleInstance *module_inst, bool is_sub_inst) } #endif -#if WASM_ENABLE_DEBUG_INTERP != 0 +#if WASM_ENABLE_DEBUG_INTERP != 0 \ + || (WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_JIT != 0 \ + && WASM_ENABLE_LAZY_JIT != 0) if (!is_sub_inst) { - os_mutex_lock(&module_inst->module->ref_count_lock); - module_inst->module->ref_count--; - os_mutex_unlock(&module_inst->module->ref_count_lock); + WASMModule *module = module_inst->module; + WASMModuleInstance *instance_prev = NULL, *instance; + os_mutex_lock(&module->instance_list_lock); + + instance = module->instance_list; + while (instance) { + if (instance == module_inst) { + if (!instance_prev) + module->instance_list = instance->e->next; + else + instance_prev->e->next = instance->e->next; + break; + } + instance_prev = instance; + instance = instance->e->next; + } + + os_mutex_unlock(&module->instance_list_lock); } #endif @@ -2803,7 +2830,7 @@ fast_jit_call_indirect(WASMExecEnv *exec_env, uint32 tbl_idx, uint32 elem_idx, return call_indirect(exec_env, tbl_idx, elem_idx, argc, argv, true, type_idx); } -#endif +#endif /* end of WASM_ENABLE_FAST_JIT != 0 */ #if WASM_ENABLE_JIT != 0 || WASM_ENABLE_WAMR_COMPILER != 0 diff --git a/core/iwasm/interpreter/wasm_runtime.h b/core/iwasm/interpreter/wasm_runtime.h index 3146b18fd..6b0f33fdb 100644 --- a/core/iwasm/interpreter/wasm_runtime.h +++ b/core/iwasm/interpreter/wasm_runtime.h @@ -59,7 +59,10 @@ typedef enum WASMExceptionID { EXCE_AUX_STACK_UNDERFLOW, EXCE_OUT_OF_BOUNDS_TABLE_ACCESS, EXCE_OPERAND_STACK_OVERFLOW, +#if WASM_ENABLE_FAST_JIT != 0 + EXCE_FAILED_TO_COMPILE_FAST_JIT_FUNC, EXCE_ALREADY_THROWN, +#endif EXCE_NUM, } WASMExceptionID; @@ -221,6 +224,12 @@ typedef struct WASMModuleInstanceExtra { #if WASM_ENABLE_MEMORY_PROFILING != 0 uint32 max_aux_stack_used; #endif + +#if WASM_ENABLE_DEBUG_INTERP != 0 \ + || (WASM_ENABLE_FAST_JIT != 0 && WASM_ENABLE_JIT \ + && WASM_ENABLE_LAZY_JIT != 0) + WASMModuleInstance *next; +#endif } WASMModuleInstanceExtra; struct AOTFuncPerfProfInfo; diff --git a/doc/build_wamr.md b/doc/build_wamr.md index 18404e778..7b89f9e7f 100644 --- a/doc/build_wamr.md +++ b/doc/build_wamr.md @@ -240,7 +240,7 @@ make By default in Linux, the `fast interpreter`, `AOT` and `Libc WASI` are enabled, and JIT is disabled. And the build target is set to X86_64 or X86_32 depending on the platform's bitwidth. -There are total 5 running modes supported: fast interpreter, classi interpreter, AOT, LLVM JIT and Fast JIT. +There are total 6 running modes supported: fast interpreter, classi interpreter, AOT, LLVM JIT, Fast JIT and Multi-tier JIT. (1) To run a wasm file with `fast interpreter` mode - build iwasm with default build and then: ```Bash @@ -301,6 +301,14 @@ make ``` The Fast JIT is a lightweight JIT engine with quick startup, small footprint and good portability, and gains ~50% performance of AOT. +(6) To enable the `Multi-tier JIT` mode: +``` Bash +mkdir build && cd build +cmake .. -DWAMR_BUILD_FAST_JTI=1 -DWAMR_BUILD_JIT=1 +make +``` +The Multi-tier JIT is a two level JIT tier-up engine, which launchs Fast JIT to run the wasm module as soon as possible and creates backend threads to compile the LLVM JIT functions at the same time, and when the LLVM JIT functions are compiled, the runtime will switch the extecution from the Fast JIT jitted code to LLVM JIT jitted code gradually, so as to gain the best performance. + Linux SGX (Intel Software Guard Extension) ------------------------- diff --git a/tests/wamr-test-suites/test_wamr.sh b/tests/wamr-test-suites/test_wamr.sh index 84185a748..a4743c561 100755 --- a/tests/wamr-test-suites/test_wamr.sh +++ b/tests/wamr-test-suites/test_wamr.sh @@ -16,7 +16,7 @@ function help() echo "-c clean previous test results, not start test" echo "-s {suite_name} test only one suite (spec)" echo "-m set compile target of iwasm(x86_64\x86_32\armv7_vfp\thumbv7_vfp\riscv64_lp64d\riscv64_lp64)" - echo "-t set compile type of iwasm(classic-interp\fast-interp\jit\aot\fast-jit)" + echo "-t set compile type of iwasm(classic-interp\fast-interp\jit\aot\fast-jit\multi-tier-jit)" echo "-M enable multi module feature" echo "-p enable multi thread feature" echo "-S enable SIMD feature" @@ -29,7 +29,7 @@ function help() OPT_PARSED="" WABT_BINARY_RELEASE="NO" #default type -TYPE=("classic-interp" "fast-interp" "jit" "aot" "fast-jit") +TYPE=("classic-interp" "fast-interp" "jit" "aot" "fast-jit" "multi-tier-jit") #default target TARGET="X86_64" ENABLE_MULTI_MODULE=0 @@ -80,7 +80,8 @@ do t) echo "set compile type of wamr " ${OPTARG} if [[ ${OPTARG} != "classic-interp" && ${OPTARG} != "fast-interp" \ - && ${OPTARG} != "jit" && ${OPTARG} != "aot" && ${OPTARG} != "fast-jit" ]]; then + && ${OPTARG} != "jit" && ${OPTARG} != "aot" + && ${OPTARG} != "fast-jit" && ${OPTARG} != "multi-tier-jit" ]]; then echo "*----- please varify a type of compile when using -t! -----*" help exit 1 @@ -201,6 +202,12 @@ readonly FAST_JIT_COMPILE_FLAGS="\ -DWAMR_BUILD_FAST_JIT=1 \ -DWAMR_BUILD_SPEC_TEST=1" +readonly MULTI_TIER_JIT_COMPILE_FLAGS="\ + -DWAMR_BUILD_TARGET=${TARGET} \ + -DWAMR_BUILD_INTERP=1 -DWAMR_BUILD_FAST_INTERP=0 \ + -DWAMR_BUILD_FAST_JIT=1 -DWAMR_BUILD_JIT=1 \ + -DWAMR_BUILD_SPEC_TEST=1" + readonly COMPILE_FLAGS=( "${CLASSIC_INTERP_COMPILE_FLAGS}" "${FAST_INTERP_COMPILE_FLAGS}" @@ -208,6 +215,7 @@ readonly COMPILE_FLAGS=( "${ORC_LAZY_JIT_COMPILE_FLAGS}" "${AOT_COMPILE_FLAGS}" "${FAST_JIT_COMPILE_FLAGS}" + "${MULTI_TIER_JIT_COMPILE_FLAGS}" ) # TODO: with libiwasm.so only @@ -397,6 +405,10 @@ function spec_test() echo "fast-jit doesn't support multi-thread feature yet, skip it" return fi + if [[ $1 == 'multi-tier-jit' ]]; then + echo "multi-tier-jit doesn't support multi-thread feature yet, skip it" + return + fi fi if [[ ${ENABLE_XIP} == 1 ]]; then @@ -641,7 +653,7 @@ function trigger() "fast-jit") echo "work in fast-jit mode" - # jit + # fast-jit BUILD_FLAGS="$FAST_JIT_COMPILE_FLAGS $EXTRA_COMPILE_FLAGS" build_iwasm_with_cfg $BUILD_FLAGS for suite in "${TEST_CASE_ARR[@]}"; do @@ -649,6 +661,16 @@ function trigger() done ;; + "multi-tier-jit") + echo "work in multi-tier-jit mode" + # multi-tier-jit + BUILD_FLAGS="$MULTI_TIER_JIT_COMPILE_FLAGS $EXTRA_COMPILE_FLAGS" + build_iwasm_with_cfg $BUILD_FLAGS + for suite in "${TEST_CASE_ARR[@]}"; do + $suite"_test" multi-tier-jit + done + ;; + *) echo "unexpected mode, do nothing" ;;