From cef88deedb0e1c3978f3437aff44af708cdf3a0f Mon Sep 17 00:00:00 2001 From: Xu Jinyang <72930776+AuYang261@users.noreply.github.com> Date: Thu, 21 Mar 2024 21:05:34 +0800 Subject: [PATCH] Add `wasi_ephemeral_nn` module support (#3241) Add `wasi_ephemeral_nn` module support with optional cmake variable, which was mentioned in #3229. --- build-scripts/config_common.cmake | 4 ++ core/config.h | 4 ++ core/iwasm/common/wasm_native.c | 7 ++- .../wasi-nn/src/utils/wasi_nn_app_native.c | 50 ++++++++++++++++--- .../wasi-nn/src/utils/wasi_nn_app_native.h | 14 ++++++ core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 35 +++++++++++++ doc/build_wamr.md | 3 ++ 7 files changed, 109 insertions(+), 8 deletions(-) diff --git a/build-scripts/config_common.cmake b/build-scripts/config_common.cmake index 773ff2837..4569648a1 100644 --- a/build-scripts/config_common.cmake +++ b/build-scripts/config_common.cmake @@ -430,6 +430,10 @@ if (WAMR_BUILD_WASI_NN EQUAL 1) if (DEFINED WAMR_BUILD_WASI_NN_EXTERNAL_DELEGATE_PATH) add_definitions (-DWASM_WASI_NN_EXTERNAL_DELEGATE_PATH="${WAMR_BUILD_WASI_NN_EXTERNAL_DELEGATE_PATH}") endif () + if (WAMR_BUILD_WASI_EPHEMERAL_NN EQUAL 1) + message (" WASI-NN: WASI-Ephemeral-NN enabled") + add_definitions (-DWASM_ENABLE_WASI_EPHEMERAL_NN=1) + endif() endif () if (WAMR_BUILD_ALLOC_WITH_USER_DATA EQUAL 1) add_definitions(-DWASM_MEM_ALLOC_WITH_USER_DATA=1) diff --git a/core/config.h b/core/config.h index be9ebfc3a..616c1f6e7 100644 --- a/core/config.h +++ b/core/config.h @@ -152,6 +152,10 @@ #define WASM_ENABLE_WASI_NN_EXTERNAL_DELEGATE 0 #endif +#ifndef WASM_ENABLE_WASI_EPHEMERAL_NN +#define WASM_ENABLE_WASI_EPHEMERAL_NN 0 +#endif + /* Default disable libc emcc */ #ifndef WASM_ENABLE_LIBC_EMCC #define WASM_ENABLE_LIBC_EMCC 0 diff --git a/core/iwasm/common/wasm_native.c b/core/iwasm/common/wasm_native.c index 3cf7451e3..14b295ee7 100644 --- a/core/iwasm/common/wasm_native.c +++ b/core/iwasm/common/wasm_native.c @@ -567,7 +567,12 @@ wasm_native_init() #if WASM_ENABLE_WASI_NN != 0 n_native_symbols = get_wasi_nn_export_apis(&native_symbols); - if (!wasm_native_register_natives("wasi_nn", native_symbols, +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +#define wasi_nn_module_name "wasi_ephemeral_nn" +#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */ +#define wasi_nn_module_name "wasi_nn" +#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */ + if (!wasm_native_register_natives(wasi_nn_module_name, native_symbols, n_native_symbols)) goto fail; #endif diff --git a/core/iwasm/libraries/wasi-nn/src/utils/wasi_nn_app_native.c b/core/iwasm/libraries/wasi-nn/src/utils/wasi_nn_app_native.c index 28dfbad4e..44aef5359 100644 --- a/core/iwasm/libraries/wasi-nn/src/utils/wasi_nn_app_native.c +++ b/core/iwasm/libraries/wasi-nn/src/utils/wasi_nn_app_native.c @@ -23,24 +23,47 @@ graph_builder_app_native(wasm_module_inst_t instance, return success; } +/** + * builder_array_wasm is consisted of {builder_wasm, size} + */ +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +error +graph_builder_array_app_native(wasm_module_inst_t instance, + graph_builder_wasm *builder_wasm, uint32_t size, + graph_builder_array *builder_array) +#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */ error graph_builder_array_app_native(wasm_module_inst_t instance, graph_builder_array_wasm *builder_array_wasm, graph_builder_array *builder_array) +#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */ { +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +#define array_size size +#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */ +#define array_size builder_array_wasm->size + if (!wasm_runtime_validate_native_addr( instance, builder_array_wasm, (uint64)sizeof(graph_builder_array_wasm))) { NN_ERR_PRINTF("builder_array_wasm is invalid"); return invalid_argument; } +#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */ - NN_DBG_PRINTF("Graph builder array contains %d elements", - builder_array_wasm->size); + NN_DBG_PRINTF("Graph builder array contains %d elements", array_size); +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + if (!wasm_runtime_validate_native_addr(instance, builder_wasm, + (uint64)array_size + * sizeof(graph_builder_wasm))) { + NN_ERR_PRINTF("builder_wasm is invalid"); + return invalid_argument; + } +#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */ if (!wasm_runtime_validate_app_addr( instance, (uint64)builder_array_wasm->buf_offset, - (uint64)builder_array_wasm->size * sizeof(graph_builder_wasm))) { + (uint64)array_size * sizeof(graph_builder_wasm))) { NN_ERR_PRINTF("builder_array_wasm->buf_offset is invalid"); return invalid_argument; } @@ -48,13 +71,14 @@ graph_builder_array_app_native(wasm_module_inst_t instance, graph_builder_wasm *builder_wasm = (graph_builder_wasm *)wasm_runtime_addr_app_to_native( instance, (uint64)builder_array_wasm->buf_offset); +#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */ graph_builder *builder = (graph_builder *)wasm_runtime_malloc( - builder_array_wasm->size * sizeof(graph_builder)); + array_size * sizeof(graph_builder)); if (builder == NULL) return missing_memory; - for (uint32_t i = 0; i < builder_array_wasm->size; ++i) { + for (uint32_t i = 0; i < array_size; ++i) { error res; if (success != (res = graph_builder_app_native(instance, &builder_wasm[i], @@ -68,23 +92,31 @@ graph_builder_array_app_native(wasm_module_inst_t instance, } builder_array->buf = builder; - builder_array->size = builder_array_wasm->size; + builder_array->size = array_size; return success; +#undef array_size } static error tensor_data_app_native(wasm_module_inst_t instance, uint32_t total_elements, tensor_wasm *input_tensor_wasm, tensor_data *data) { +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +#define data_size input_tensor_wasm->data_size +#else +#define data_size total_elements +#endif + if (!wasm_runtime_validate_app_addr(instance, (uint64)input_tensor_wasm->data_offset, - (uint64)total_elements)) { + (uint64)data_size)) { NN_ERR_PRINTF("input_tensor_wasm->data_offset is invalid"); return invalid_argument; } *data = (tensor_data)wasm_runtime_addr_app_to_native( instance, (uint64)input_tensor_wasm->data_offset); return success; +#undef data_size } static error @@ -92,6 +124,9 @@ tensor_dimensions_app_native(wasm_module_inst_t instance, tensor_wasm *input_tensor_wasm, tensor_dimensions **dimensions) { +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + tensor_dimensions_wasm *dimensions_wasm = &input_tensor_wasm->dimensions; +#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */ if (!wasm_runtime_validate_app_addr( instance, (uint64)input_tensor_wasm->dimensions_offset, (uint64)sizeof(tensor_dimensions_wasm))) { @@ -102,6 +137,7 @@ tensor_dimensions_app_native(wasm_module_inst_t instance, tensor_dimensions_wasm *dimensions_wasm = (tensor_dimensions_wasm *)wasm_runtime_addr_app_to_native( instance, (uint64)input_tensor_wasm->dimensions_offset); +#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */ if (!wasm_runtime_validate_app_addr(instance, (uint64)dimensions_wasm->buf_offset, diff --git a/core/iwasm/libraries/wasi-nn/src/utils/wasi_nn_app_native.h b/core/iwasm/libraries/wasi-nn/src/utils/wasi_nn_app_native.h index 15154bd31..f0930a883 100644 --- a/core/iwasm/libraries/wasi-nn/src/utils/wasi_nn_app_native.h +++ b/core/iwasm/libraries/wasi-nn/src/utils/wasi_nn_app_native.h @@ -34,15 +34,29 @@ typedef struct { } tensor_dimensions_wasm; typedef struct { +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + tensor_dimensions_wasm dimensions; + tensor_type type; + uint32_t data_offset; + uint32_t data_size; +#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */ uint32_t dimensions_offset; tensor_type type; uint32_t data_offset; +#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */ } tensor_wasm; +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +error +graph_builder_array_app_native(wasm_module_inst_t instance, + graph_builder_wasm *builder_wasm, uint32_t size, + graph_builder_array *builder_array); +#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */ error graph_builder_array_app_native(wasm_module_inst_t instance, graph_builder_array_wasm *builder, graph_builder_array *builder_native); +#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */ error tensor_app_native(wasm_module_inst_t instance, tensor_wasm *input_tensor, diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 8e17deed4..1fbb94428 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -189,9 +189,16 @@ is_model_initialized(WASINNContext *wasi_nn_ctx) /* WASI-NN implementation */ +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +error +wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_wasm *builder, + uint32_t builder_wasm_size, graph_encoding encoding, + execution_target target, graph *g) +#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */ error wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder, graph_encoding encoding, execution_target target, graph *g) +#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */ { NN_DBG_PRINTF("Running wasi_nn_load [encoding=%d, target=%d]...", encoding, target); @@ -206,10 +213,17 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder, error res; graph_builder_array builder_native = { 0 }; +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + if (success + != (res = graph_builder_array_app_native( + instance, builder, builder_wasm_size, &builder_native))) + return res; +#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */ if (success != (res = graph_builder_array_app_native(instance, builder, &builder_native))) return res; +#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */ if (!wasm_runtime_validate_native_addr(instance, g, (uint64)sizeof(graph))) { @@ -315,10 +329,17 @@ wasi_nn_compute(wasm_exec_env_t exec_env, graph_execution_context ctx) return res; } +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +error +wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx, + uint32_t index, tensor_data output_tensor, + uint32_t output_tensor_len, uint32_t *output_tensor_size) +#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */ error wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx, uint32_t index, tensor_data output_tensor, uint32_t *output_tensor_size) +#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */ { NN_DBG_PRINTF("Running wasi_nn_get_output [ctx=%d, index=%d]...", ctx, index); @@ -337,8 +358,14 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx, return invalid_argument; } +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + res = lookup[wasi_nn_ctx->current_encoding].get_output( + wasi_nn_ctx->tflite_ctx, ctx, index, output_tensor, &output_tensor_len); + *output_tensor_size = output_tensor_len; +#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */ res = lookup[wasi_nn_ctx->current_encoding].get_output( wasi_nn_ctx->tflite_ctx, ctx, index, output_tensor, output_tensor_size); +#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */ NN_DBG_PRINTF("wasi_nn_get_output finished with status %d [data_size=%d]", res, *output_tensor_size); return res; @@ -352,11 +379,19 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx, /* clang-format on */ static NativeSymbol native_symbols_wasi_nn[] = { +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + REG_NATIVE_FUNC(load, "(*iii*)i"), + REG_NATIVE_FUNC(init_execution_context, "(i*)i"), + REG_NATIVE_FUNC(set_input, "(ii*)i"), + REG_NATIVE_FUNC(compute, "(i)i"), + REG_NATIVE_FUNC(get_output, "(ii*i*)i"), +#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */ REG_NATIVE_FUNC(load, "(*ii*)i"), REG_NATIVE_FUNC(init_execution_context, "(i*)i"), REG_NATIVE_FUNC(set_input, "(ii*)i"), REG_NATIVE_FUNC(compute, "(i)i"), REG_NATIVE_FUNC(get_output, "(ii**)i"), +#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */ }; uint32_t diff --git a/doc/build_wamr.md b/doc/build_wamr.md index 1331f9601..a9ffca204 100644 --- a/doc/build_wamr.md +++ b/doc/build_wamr.md @@ -107,6 +107,9 @@ cmake -DWAMR_BUILD_PLATFORM=linux -DWAMR_BUILD_TARGET=ARM - **WAMR_BUILD_WASI_NN_EXTERNAL_DELEGATE_PATH**=Path to the external delegate shared library (e.g. `libedgetpu.so.1.0` for Coral USB) +#### **Enable lib wasi-nn with `wasi_ephemeral_nn` module support** +- **WAMR_BUILD_WASI_EPHEMERAL_NN**=1/0, default to disable if not set + #### **Disable boundary check with hardware trap** - **WAMR_DISABLE_HW_BOUND_CHECK**=1/0, default to enable if not set and supported by platform > Note: by default only platform [linux/darwin/android/windows/vxworks 64-bit](https://github.com/bytecodealliance/wasm-micro-runtime/blob/5fb5119239220b0803e7045ca49b0a29fe65e70e/core/shared/platform/linux/platform_internal.h#L81) will enable the boundary check with hardware trap feature, for 32-bit platforms it's automatically disabled even when the flag is set to 0, and the wamrc tool will generate AOT code without boundary check instructions in all 64-bit targets except SGX to improve performance. The boundary check includes linear memory access boundary and native stack access boundary, if `WAMR_DISABLE_STACK_HW_BOUND_CHECK` below isn't set.