Add wasi_ephemeral_nn module support (#3241)

Add `wasi_ephemeral_nn` module support with optional cmake variable,
which was mentioned in #3229.
This commit is contained in:
Xu Jinyang 2024-03-21 21:05:34 +08:00 committed by GitHub
parent e003ee1e29
commit cef88deedb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 109 additions and 8 deletions

View File

@ -430,6 +430,10 @@ if (WAMR_BUILD_WASI_NN EQUAL 1)
if (DEFINED WAMR_BUILD_WASI_NN_EXTERNAL_DELEGATE_PATH) if (DEFINED WAMR_BUILD_WASI_NN_EXTERNAL_DELEGATE_PATH)
add_definitions (-DWASM_WASI_NN_EXTERNAL_DELEGATE_PATH="${WAMR_BUILD_WASI_NN_EXTERNAL_DELEGATE_PATH}") add_definitions (-DWASM_WASI_NN_EXTERNAL_DELEGATE_PATH="${WAMR_BUILD_WASI_NN_EXTERNAL_DELEGATE_PATH}")
endif () endif ()
if (WAMR_BUILD_WASI_EPHEMERAL_NN EQUAL 1)
message (" WASI-NN: WASI-Ephemeral-NN enabled")
add_definitions (-DWASM_ENABLE_WASI_EPHEMERAL_NN=1)
endif()
endif () endif ()
if (WAMR_BUILD_ALLOC_WITH_USER_DATA EQUAL 1) if (WAMR_BUILD_ALLOC_WITH_USER_DATA EQUAL 1)
add_definitions(-DWASM_MEM_ALLOC_WITH_USER_DATA=1) add_definitions(-DWASM_MEM_ALLOC_WITH_USER_DATA=1)

View File

@ -152,6 +152,10 @@
#define WASM_ENABLE_WASI_NN_EXTERNAL_DELEGATE 0 #define WASM_ENABLE_WASI_NN_EXTERNAL_DELEGATE 0
#endif #endif
#ifndef WASM_ENABLE_WASI_EPHEMERAL_NN
#define WASM_ENABLE_WASI_EPHEMERAL_NN 0
#endif
/* Default disable libc emcc */ /* Default disable libc emcc */
#ifndef WASM_ENABLE_LIBC_EMCC #ifndef WASM_ENABLE_LIBC_EMCC
#define WASM_ENABLE_LIBC_EMCC 0 #define WASM_ENABLE_LIBC_EMCC 0

View File

@ -567,7 +567,12 @@ wasm_native_init()
#if WASM_ENABLE_WASI_NN != 0 #if WASM_ENABLE_WASI_NN != 0
n_native_symbols = get_wasi_nn_export_apis(&native_symbols); n_native_symbols = get_wasi_nn_export_apis(&native_symbols);
if (!wasm_native_register_natives("wasi_nn", native_symbols, #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
#define wasi_nn_module_name "wasi_ephemeral_nn"
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
#define wasi_nn_module_name "wasi_nn"
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
if (!wasm_native_register_natives(wasi_nn_module_name, native_symbols,
n_native_symbols)) n_native_symbols))
goto fail; goto fail;
#endif #endif

View File

@ -23,24 +23,47 @@ graph_builder_app_native(wasm_module_inst_t instance,
return success; return success;
} }
/**
* builder_array_wasm is consisted of {builder_wasm, size}
*/
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
error
graph_builder_array_app_native(wasm_module_inst_t instance,
graph_builder_wasm *builder_wasm, uint32_t size,
graph_builder_array *builder_array)
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
error error
graph_builder_array_app_native(wasm_module_inst_t instance, graph_builder_array_app_native(wasm_module_inst_t instance,
graph_builder_array_wasm *builder_array_wasm, graph_builder_array_wasm *builder_array_wasm,
graph_builder_array *builder_array) graph_builder_array *builder_array)
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
{ {
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
#define array_size size
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
#define array_size builder_array_wasm->size
if (!wasm_runtime_validate_native_addr( if (!wasm_runtime_validate_native_addr(
instance, builder_array_wasm, instance, builder_array_wasm,
(uint64)sizeof(graph_builder_array_wasm))) { (uint64)sizeof(graph_builder_array_wasm))) {
NN_ERR_PRINTF("builder_array_wasm is invalid"); NN_ERR_PRINTF("builder_array_wasm is invalid");
return invalid_argument; return invalid_argument;
} }
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
NN_DBG_PRINTF("Graph builder array contains %d elements", NN_DBG_PRINTF("Graph builder array contains %d elements", array_size);
builder_array_wasm->size);
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
if (!wasm_runtime_validate_native_addr(instance, builder_wasm,
(uint64)array_size
* sizeof(graph_builder_wasm))) {
NN_ERR_PRINTF("builder_wasm is invalid");
return invalid_argument;
}
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
if (!wasm_runtime_validate_app_addr( if (!wasm_runtime_validate_app_addr(
instance, (uint64)builder_array_wasm->buf_offset, instance, (uint64)builder_array_wasm->buf_offset,
(uint64)builder_array_wasm->size * sizeof(graph_builder_wasm))) { (uint64)array_size * sizeof(graph_builder_wasm))) {
NN_ERR_PRINTF("builder_array_wasm->buf_offset is invalid"); NN_ERR_PRINTF("builder_array_wasm->buf_offset is invalid");
return invalid_argument; return invalid_argument;
} }
@ -48,13 +71,14 @@ graph_builder_array_app_native(wasm_module_inst_t instance,
graph_builder_wasm *builder_wasm = graph_builder_wasm *builder_wasm =
(graph_builder_wasm *)wasm_runtime_addr_app_to_native( (graph_builder_wasm *)wasm_runtime_addr_app_to_native(
instance, (uint64)builder_array_wasm->buf_offset); instance, (uint64)builder_array_wasm->buf_offset);
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
graph_builder *builder = (graph_builder *)wasm_runtime_malloc( graph_builder *builder = (graph_builder *)wasm_runtime_malloc(
builder_array_wasm->size * sizeof(graph_builder)); array_size * sizeof(graph_builder));
if (builder == NULL) if (builder == NULL)
return missing_memory; return missing_memory;
for (uint32_t i = 0; i < builder_array_wasm->size; ++i) { for (uint32_t i = 0; i < array_size; ++i) {
error res; error res;
if (success if (success
!= (res = graph_builder_app_native(instance, &builder_wasm[i], != (res = graph_builder_app_native(instance, &builder_wasm[i],
@ -68,23 +92,31 @@ graph_builder_array_app_native(wasm_module_inst_t instance,
} }
builder_array->buf = builder; builder_array->buf = builder;
builder_array->size = builder_array_wasm->size; builder_array->size = array_size;
return success; return success;
#undef array_size
} }
static error static error
tensor_data_app_native(wasm_module_inst_t instance, uint32_t total_elements, tensor_data_app_native(wasm_module_inst_t instance, uint32_t total_elements,
tensor_wasm *input_tensor_wasm, tensor_data *data) tensor_wasm *input_tensor_wasm, tensor_data *data)
{ {
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
#define data_size input_tensor_wasm->data_size
#else
#define data_size total_elements
#endif
if (!wasm_runtime_validate_app_addr(instance, if (!wasm_runtime_validate_app_addr(instance,
(uint64)input_tensor_wasm->data_offset, (uint64)input_tensor_wasm->data_offset,
(uint64)total_elements)) { (uint64)data_size)) {
NN_ERR_PRINTF("input_tensor_wasm->data_offset is invalid"); NN_ERR_PRINTF("input_tensor_wasm->data_offset is invalid");
return invalid_argument; return invalid_argument;
} }
*data = (tensor_data)wasm_runtime_addr_app_to_native( *data = (tensor_data)wasm_runtime_addr_app_to_native(
instance, (uint64)input_tensor_wasm->data_offset); instance, (uint64)input_tensor_wasm->data_offset);
return success; return success;
#undef data_size
} }
static error static error
@ -92,6 +124,9 @@ tensor_dimensions_app_native(wasm_module_inst_t instance,
tensor_wasm *input_tensor_wasm, tensor_wasm *input_tensor_wasm,
tensor_dimensions **dimensions) tensor_dimensions **dimensions)
{ {
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
tensor_dimensions_wasm *dimensions_wasm = &input_tensor_wasm->dimensions;
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
if (!wasm_runtime_validate_app_addr( if (!wasm_runtime_validate_app_addr(
instance, (uint64)input_tensor_wasm->dimensions_offset, instance, (uint64)input_tensor_wasm->dimensions_offset,
(uint64)sizeof(tensor_dimensions_wasm))) { (uint64)sizeof(tensor_dimensions_wasm))) {
@ -102,6 +137,7 @@ tensor_dimensions_app_native(wasm_module_inst_t instance,
tensor_dimensions_wasm *dimensions_wasm = tensor_dimensions_wasm *dimensions_wasm =
(tensor_dimensions_wasm *)wasm_runtime_addr_app_to_native( (tensor_dimensions_wasm *)wasm_runtime_addr_app_to_native(
instance, (uint64)input_tensor_wasm->dimensions_offset); instance, (uint64)input_tensor_wasm->dimensions_offset);
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
if (!wasm_runtime_validate_app_addr(instance, if (!wasm_runtime_validate_app_addr(instance,
(uint64)dimensions_wasm->buf_offset, (uint64)dimensions_wasm->buf_offset,

View File

@ -34,15 +34,29 @@ typedef struct {
} tensor_dimensions_wasm; } tensor_dimensions_wasm;
typedef struct { typedef struct {
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
tensor_dimensions_wasm dimensions;
tensor_type type;
uint32_t data_offset;
uint32_t data_size;
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
uint32_t dimensions_offset; uint32_t dimensions_offset;
tensor_type type; tensor_type type;
uint32_t data_offset; uint32_t data_offset;
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
} tensor_wasm; } tensor_wasm;
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
error
graph_builder_array_app_native(wasm_module_inst_t instance,
graph_builder_wasm *builder_wasm, uint32_t size,
graph_builder_array *builder_array);
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
error error
graph_builder_array_app_native(wasm_module_inst_t instance, graph_builder_array_app_native(wasm_module_inst_t instance,
graph_builder_array_wasm *builder, graph_builder_array_wasm *builder,
graph_builder_array *builder_native); graph_builder_array *builder_native);
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
error error
tensor_app_native(wasm_module_inst_t instance, tensor_wasm *input_tensor, tensor_app_native(wasm_module_inst_t instance, tensor_wasm *input_tensor,

View File

@ -189,9 +189,16 @@ is_model_initialized(WASINNContext *wasi_nn_ctx)
/* WASI-NN implementation */ /* WASI-NN implementation */
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
error
wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_wasm *builder,
uint32_t builder_wasm_size, graph_encoding encoding,
execution_target target, graph *g)
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
error error
wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder, wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
graph_encoding encoding, execution_target target, graph *g) graph_encoding encoding, execution_target target, graph *g)
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
{ {
NN_DBG_PRINTF("Running wasi_nn_load [encoding=%d, target=%d]...", encoding, NN_DBG_PRINTF("Running wasi_nn_load [encoding=%d, target=%d]...", encoding,
target); target);
@ -206,10 +213,17 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
error res; error res;
graph_builder_array builder_native = { 0 }; graph_builder_array builder_native = { 0 };
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
if (success
!= (res = graph_builder_array_app_native(
instance, builder, builder_wasm_size, &builder_native)))
return res;
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
if (success if (success
!= (res = graph_builder_array_app_native(instance, builder, != (res = graph_builder_array_app_native(instance, builder,
&builder_native))) &builder_native)))
return res; return res;
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
if (!wasm_runtime_validate_native_addr(instance, g, if (!wasm_runtime_validate_native_addr(instance, g,
(uint64)sizeof(graph))) { (uint64)sizeof(graph))) {
@ -315,10 +329,17 @@ wasi_nn_compute(wasm_exec_env_t exec_env, graph_execution_context ctx)
return res; return res;
} }
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
error
wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
uint32_t index, tensor_data output_tensor,
uint32_t output_tensor_len, uint32_t *output_tensor_size)
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
error error
wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx, wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
uint32_t index, tensor_data output_tensor, uint32_t index, tensor_data output_tensor,
uint32_t *output_tensor_size) uint32_t *output_tensor_size)
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
{ {
NN_DBG_PRINTF("Running wasi_nn_get_output [ctx=%d, index=%d]...", ctx, NN_DBG_PRINTF("Running wasi_nn_get_output [ctx=%d, index=%d]...", ctx,
index); index);
@ -337,8 +358,14 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
return invalid_argument; return invalid_argument;
} }
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
res = lookup[wasi_nn_ctx->current_encoding].get_output(
wasi_nn_ctx->tflite_ctx, ctx, index, output_tensor, &output_tensor_len);
*output_tensor_size = output_tensor_len;
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
res = lookup[wasi_nn_ctx->current_encoding].get_output( res = lookup[wasi_nn_ctx->current_encoding].get_output(
wasi_nn_ctx->tflite_ctx, ctx, index, output_tensor, output_tensor_size); wasi_nn_ctx->tflite_ctx, ctx, index, output_tensor, output_tensor_size);
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
NN_DBG_PRINTF("wasi_nn_get_output finished with status %d [data_size=%d]", NN_DBG_PRINTF("wasi_nn_get_output finished with status %d [data_size=%d]",
res, *output_tensor_size); res, *output_tensor_size);
return res; return res;
@ -352,11 +379,19 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
/* clang-format on */ /* clang-format on */
static NativeSymbol native_symbols_wasi_nn[] = { static NativeSymbol native_symbols_wasi_nn[] = {
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
REG_NATIVE_FUNC(load, "(*iii*)i"),
REG_NATIVE_FUNC(init_execution_context, "(i*)i"),
REG_NATIVE_FUNC(set_input, "(ii*)i"),
REG_NATIVE_FUNC(compute, "(i)i"),
REG_NATIVE_FUNC(get_output, "(ii*i*)i"),
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
REG_NATIVE_FUNC(load, "(*ii*)i"), REG_NATIVE_FUNC(load, "(*ii*)i"),
REG_NATIVE_FUNC(init_execution_context, "(i*)i"), REG_NATIVE_FUNC(init_execution_context, "(i*)i"),
REG_NATIVE_FUNC(set_input, "(ii*)i"), REG_NATIVE_FUNC(set_input, "(ii*)i"),
REG_NATIVE_FUNC(compute, "(i)i"), REG_NATIVE_FUNC(compute, "(i)i"),
REG_NATIVE_FUNC(get_output, "(ii**)i"), REG_NATIVE_FUNC(get_output, "(ii**)i"),
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
}; };
uint32_t uint32_t

View File

@ -107,6 +107,9 @@ cmake -DWAMR_BUILD_PLATFORM=linux -DWAMR_BUILD_TARGET=ARM
- **WAMR_BUILD_WASI_NN_EXTERNAL_DELEGATE_PATH**=Path to the external delegate shared library (e.g. `libedgetpu.so.1.0` for Coral USB) - **WAMR_BUILD_WASI_NN_EXTERNAL_DELEGATE_PATH**=Path to the external delegate shared library (e.g. `libedgetpu.so.1.0` for Coral USB)
#### **Enable lib wasi-nn with `wasi_ephemeral_nn` module support**
- **WAMR_BUILD_WASI_EPHEMERAL_NN**=1/0, default to disable if not set
#### **Disable boundary check with hardware trap** #### **Disable boundary check with hardware trap**
- **WAMR_DISABLE_HW_BOUND_CHECK**=1/0, default to enable if not set and supported by platform - **WAMR_DISABLE_HW_BOUND_CHECK**=1/0, default to enable if not set and supported by platform
> Note: by default only platform [linux/darwin/android/windows/vxworks 64-bit](https://github.com/bytecodealliance/wasm-micro-runtime/blob/5fb5119239220b0803e7045ca49b0a29fe65e70e/core/shared/platform/linux/platform_internal.h#L81) will enable the boundary check with hardware trap feature, for 32-bit platforms it's automatically disabled even when the flag is set to 0, and the wamrc tool will generate AOT code without boundary check instructions in all 64-bit targets except SGX to improve performance. The boundary check includes linear memory access boundary and native stack access boundary, if `WAMR_DISABLE_STACK_HW_BOUND_CHECK` below isn't set. > Note: by default only platform [linux/darwin/android/windows/vxworks 64-bit](https://github.com/bytecodealliance/wasm-micro-runtime/blob/5fb5119239220b0803e7045ca49b0a29fe65e70e/core/shared/platform/linux/platform_internal.h#L81) will enable the boundary check with hardware trap feature, for 32-bit platforms it's automatically disabled even when the flag is set to 0, and the wamrc tool will generate AOT code without boundary check instructions in all 64-bit targets except SGX to improve performance. The boundary check includes linear memory access boundary and native stack access boundary, if `WAMR_DISABLE_STACK_HW_BOUND_CHECK` below isn't set.