From db025e457afd3846815e34b385ecd79bc3262f8d Mon Sep 17 00:00:00 2001 From: "liang.he" Date: Mon, 17 Jun 2024 14:58:09 +0800 Subject: [PATCH] sync up with latest wasi-nn spec (#3530) --- core/iwasm/libraries/wasi-nn/README.md | 21 ++- .../iwasm/libraries/wasi-nn/include/wasi_nn.h | 4 + .../libraries/wasi-nn/include/wasi_nn_types.h | 40 ++++- .../wasi-nn/src/utils/wasi_nn_app_native.c | 4 +- core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 169 +++++++++++------- .../libraries/wasi-nn/src/wasi_nn_private.h | 1 + .../wasi-nn/src/wasi_nn_tensorflowlite.cpp | 52 ++++-- .../wasi-nn/src/wasi_nn_tensorflowlite.hpp | 4 +- core/iwasm/libraries/wasi-nn/test/utils.c | 15 +- 9 files changed, 209 insertions(+), 101 deletions(-) diff --git a/core/iwasm/libraries/wasi-nn/README.md b/core/iwasm/libraries/wasi-nn/README.md index 757088237..e08c6457b 100644 --- a/core/iwasm/libraries/wasi-nn/README.md +++ b/core/iwasm/libraries/wasi-nn/README.md @@ -21,9 +21,21 @@ $ cmake -DWAMR_BUILD_WASI_NN=1 ... ### Wasm -The definition of functions provided by WASI-NN (Wasm imports) is in the header file _core/iwasm/libraries/wasi-nn/wasi_nn.h_. +The definition of functions provided by WASI-NN (Wasm imports) is in the header file [wasi_nn.h](_core/iwasm/libraries/wasi-nn/wasi_nn.h_). By only including this file in a WASM application you will bind WASI-NN into your module. -By only including this file in a WASM application you will bind WASI-NN into your module. +For some historical reasons, there are two sets of functions in the header file. The first set is the original one, and the second set is the new one. The new set is recommended to use. In code, `WASM_ENABLE_WASI_EPHEMERAL_NN` is used to control which set of functions to use. If `WASM_ENABLE_WASI_EPHEMERAL_NN` is defined, the new set of functions will be used. Otherwise, the original set of functions will be used. + +There is a big difference between the two sets of functions, `tensor_type`. + +``` c +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +typedef enum { fp16 = 0, fp32, fp64, bf16, u8, i32, i64 } tensor_type; +#else +typedef enum { fp16 = 0, fp32, up8, ip32 } tensor_type; +#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */ +``` + +It is required to recompile the Wasm application if you want to switch between the two sets of functions. ## Tests @@ -41,7 +53,10 @@ Build the runtime image for your execution target type. - `tpu` ```bash -EXECUTION_TYPE=cpu docker build -t wasi-nn-${EXECUTION_TYPE} -f core/iwasm/libraries/wasi-nn/test/Dockerfile.${EXECUTION_TYPE} . +$ pwd +/wasm-micro-runtime + +$ EXECUTION_TYPE=cpu docker build -t wasi-nn-${EXECUTION_TYPE} -f core/iwasm/libraries/wasi-nn/test/Dockerfile.${EXECUTION_TYPE} . ``` ### Build wasm app diff --git a/core/iwasm/libraries/wasi-nn/include/wasi_nn.h b/core/iwasm/libraries/wasi-nn/include/wasi_nn.h index 66e2ee02a..ad1f37deb 100644 --- a/core/iwasm/libraries/wasi-nn/include/wasi_nn.h +++ b/core/iwasm/libraries/wasi-nn/include/wasi_nn.h @@ -29,6 +29,10 @@ load(graph_builder_array *builder, graph_encoding encoding, execution_target target, graph *g) __attribute__((import_module("wasi_nn"))); +wasi_nn_error +load_by_name(const char *name, graph *g) + __attribute__((import_module("wasi_nn"))); + /** * INFERENCE * diff --git a/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h b/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h index 2b9759057..75f14eb70 100644 --- a/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h +++ b/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h @@ -18,7 +18,9 @@ extern "C" { * */ -// Error codes returned by functions in this API. +// sync up with +// https://github.com/WebAssembly/wasi-nn/blob/main/wit/wasi-nn.wit#L136 Error +// codes returned by functions in this API. typedef enum { // No error occurred. success = 0, @@ -26,12 +28,21 @@ typedef enum { invalid_argument, // Invalid encoding. invalid_encoding, - // Caller module is missing a memory export. - missing_memory, - // Device or resource busy. - busy, + // The operation timed out. + timeout, // Runtime Error. runtime_error, + // Unsupported operation. + unsupported_operation, + // Graph is too large. + too_large, + // Graph not found. + not_found, + // The operation is insecure or has insufficient privilege to be performed. + // e.g., cannot access a hardware feature requested + security, + // The operation failed for an unspecified reason. + unknown, } wasi_nn_error; /** @@ -48,8 +59,14 @@ typedef struct { uint32_t size; } tensor_dimensions; +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +// sync up with +// https://github.com/WebAssembly/wasi-nn/blob/main/wit/wasi-nn.wit#L27 // The type of the elements in a tensor. +typedef enum { fp16 = 0, fp32, fp64, bf16, u8, i32, i64 } tensor_type; +#else typedef enum { fp16 = 0, fp32, up8, ip32 } tensor_type; +#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */ // The tensor data. // @@ -96,6 +113,8 @@ typedef struct { // An execution graph for performing inference (i.e., a model). typedef uint32_t graph; +// sync up with +// https://github.com/WebAssembly/wasi-nn/blob/main/wit/wasi-nn.wit#L75 // Describes the encoding of the graph. This allows the API to be implemented by // various backends that encode (i.e., serialize) their graph IR with different // formats. @@ -105,7 +124,8 @@ typedef enum { tensorflow, pytorch, tensorflowlite, - backend_amount + ggml, + autodetect, } graph_encoding; // Define where the graph should be executed. @@ -118,6 +138,7 @@ typedef uint32_t graph_execution_context; typedef wasi_nn_error (*LOAD)(void *, graph_builder_array *, graph_encoding, execution_target, graph *); +typedef wasi_nn_error (*LOAD_BY_NAME)(void *, const char *, uint32_t, graph *); typedef wasi_nn_error (*INIT_EXECUTION_CONTEXT)(void *, graph, graph_execution_context *); typedef wasi_nn_error (*SET_INPUT)(void *, graph_execution_context, uint32_t, @@ -126,11 +147,12 @@ typedef wasi_nn_error (*COMPUTE)(void *, graph_execution_context); typedef wasi_nn_error (*GET_OUTPUT)(void *, graph_execution_context, uint32_t, tensor_data, uint32_t *); /* wasi-nn general APIs */ -typedef void (*BACKEND_INITIALIZE)(void **); -typedef void (*BACKEND_DEINITIALIZE)(void *); +typedef wasi_nn_error (*BACKEND_INITIALIZE)(void **); +typedef wasi_nn_error (*BACKEND_DEINITIALIZE)(void *); typedef struct { LOAD load; + LOAD_BY_NAME load_by_name; INIT_EXECUTION_CONTEXT init_execution_context; SET_INPUT set_input; COMPUTE compute; @@ -140,7 +162,7 @@ typedef struct { } api_function; bool -wasi_nn_register_backend(graph_encoding backend_code, api_function apis); +wasi_nn_register_backend(api_function apis); #ifdef __cplusplus } diff --git a/core/iwasm/libraries/wasi-nn/src/utils/wasi_nn_app_native.c b/core/iwasm/libraries/wasi-nn/src/utils/wasi_nn_app_native.c index b1a7c327d..07516f34d 100644 --- a/core/iwasm/libraries/wasi-nn/src/utils/wasi_nn_app_native.c +++ b/core/iwasm/libraries/wasi-nn/src/utils/wasi_nn_app_native.c @@ -76,7 +76,7 @@ graph_builder_array_app_native(wasm_module_inst_t instance, graph_builder *builder = (graph_builder *)wasm_runtime_malloc( array_size * sizeof(graph_builder)); if (builder == NULL) - return missing_memory; + return too_large; for (uint32_t i = 0; i < array_size; ++i) { wasi_nn_error res; @@ -149,7 +149,7 @@ tensor_dimensions_app_native(wasm_module_inst_t instance, *dimensions = (tensor_dimensions *)wasm_runtime_malloc(sizeof(tensor_dimensions)); if (dimensions == NULL) - return missing_memory; + return too_large; (*dimensions)->size = dimensions_wasm->size; (*dimensions)->buf = (uint32_t *)wasm_runtime_addr_app_to_native( diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 3e6bd853f..de931b41b 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -16,12 +16,26 @@ #include "logger.h" #include "bh_platform.h" +#include "wasi_nn_types.h" #include "wasm_export.h" #define HASHMAP_INITIAL_SIZE 20 /* Global variables */ -static api_function lookup[backend_amount] = { 0 }; +// if using `load_by_name`, there is no known `encoding` at the time of loading +// so, just keep one `api_function` is enough +static api_function lookup = { 0 }; + +#define call_wasi_nn_func(wasi_error, func, ...) \ + do { \ + if (lookup.func) { \ + wasi_error = lookup.func(__VA_ARGS__); \ + } \ + else { \ + NN_ERR_PRINTF("Error: %s is not registered", #func); \ + wasi_error = unsupported_operation; \ + } \ + } while (0) static HashMap *hashmap; @@ -73,16 +87,16 @@ wasi_nn_initialize_context() return NULL; } wasi_nn_ctx->is_model_loaded = false; + /* only one backend can be registered */ - { - unsigned i; - for (i = 0; i < sizeof(lookup) / sizeof(lookup[0]); i++) { - if (lookup[i].init) { - lookup[i].init(&wasi_nn_ctx->backend_ctx); - break; - } - } + wasi_nn_error res; + call_wasi_nn_func(res, init, &wasi_nn_ctx->backend_ctx); + if (res != success) { + NN_ERR_PRINTF("Error while initializing backend"); + wasm_runtime_free(wasi_nn_ctx); + return NULL; } + return wasi_nn_ctx; } @@ -90,6 +104,7 @@ static bool wasi_nn_initialize() { NN_DBG_PRINTF("Initializing wasi-nn"); + // hashmap { instance: wasi_nn_ctx } hashmap = bh_hash_map_create(HASHMAP_INITIAL_SIZE, true, hash_func, key_equal_func, key_destroy_func, value_destroy_func); @@ -133,42 +148,26 @@ wasi_nn_ctx_destroy(WASINNContext *wasi_nn_ctx) NN_DBG_PRINTF("Freeing wasi-nn"); NN_DBG_PRINTF("-> is_model_loaded: %d", wasi_nn_ctx->is_model_loaded); NN_DBG_PRINTF("-> current_encoding: %d", wasi_nn_ctx->current_encoding); - /* only one backend can be registered */ - { - unsigned i; - for (i = 0; i < sizeof(lookup) / sizeof(lookup[0]); i++) { - if (lookup[i].deinit) { - lookup[i].deinit(wasi_nn_ctx->backend_ctx); - break; - } - } - } - wasm_runtime_free(wasi_nn_ctx); -} -static void -wasi_nn_ctx_destroy_helper(void *instance, void *wasi_nn_ctx, void *user_data) -{ - wasi_nn_ctx_destroy((WASINNContext *)wasi_nn_ctx); + /* only one backend can be registered */ + wasi_nn_error res; + call_wasi_nn_func(res, deinit, wasi_nn_ctx->backend_ctx); + if (res != success) { + NN_ERR_PRINTF("Error while destroyging backend"); + } + + wasm_runtime_free(wasi_nn_ctx); } void wasi_nn_destroy() { - bh_hash_map_traverse(hashmap, wasi_nn_ctx_destroy_helper, NULL); + // destroy hashmap will destroy keys and values bh_hash_map_destroy(hashmap); } /* Utils */ -static bool -is_encoding_implemented(graph_encoding encoding) -{ - return lookup[encoding].load && lookup[encoding].init_execution_context - && lookup[encoding].set_input && lookup[encoding].compute - && lookup[encoding].get_output; -} - static wasi_nn_error is_model_initialized(WASINNContext *wasi_nn_ctx) { @@ -195,13 +194,9 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder, NN_DBG_PRINTF("Running wasi_nn_load [encoding=%d, target=%d]...", encoding, target); - if (!is_encoding_implemented(encoding)) { - NN_ERR_PRINTF("Encoding not supported."); - return invalid_encoding; - } - wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env); - bh_assert(instance); + if (!instance) + return runtime_error; wasi_nn_error res; graph_builder_array builder_native = { 0 }; @@ -225,10 +220,11 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder, } WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance); - res = lookup[encoding].load(wasi_nn_ctx->backend_ctx, &builder_native, - encoding, target, g); - + call_wasi_nn_func(res, load, wasi_nn_ctx->backend_ctx, &builder_native, + encoding, target, g); NN_DBG_PRINTF("wasi_nn_load finished with status %d [graph=%d]", res, *g); + if (res != success) + goto fail; wasi_nn_ctx->current_encoding = encoding; wasi_nn_ctx->is_model_loaded = true; @@ -241,6 +237,39 @@ fail: return res; } +wasi_nn_error +wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, + graph *g) +{ + NN_DBG_PRINTF("Running wasi_nn_load_by_name ..."); + + wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env); + if (!instance) { + return runtime_error; + } + + if (!wasm_runtime_validate_native_addr(instance, name, name_len)) { + return invalid_argument; + } + + if (!wasm_runtime_validate_native_addr(instance, g, + (uint64)sizeof(graph))) { + return invalid_argument; + } + + WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance); + wasi_nn_error res; + call_wasi_nn_func(res, load_by_name, wasi_nn_ctx->backend_ctx, name, + name_len, g); + NN_DBG_PRINTF("wasi_nn_load_by_name finished with status %d", *g); + if (res != success) + return res; + + wasi_nn_ctx->current_encoding = autodetect; + wasi_nn_ctx->is_model_loaded = true; + return success; +} + wasi_nn_error wasi_nn_init_execution_context(wasm_exec_env_t exec_env, graph g, graph_execution_context *ctx) @@ -248,7 +277,10 @@ wasi_nn_init_execution_context(wasm_exec_env_t exec_env, graph g, NN_DBG_PRINTF("Running wasi_nn_init_execution_context [graph=%d]...", g); wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env); - bh_assert(instance); + if (!instance) { + return runtime_error; + } + WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance); wasi_nn_error res; @@ -261,9 +293,8 @@ wasi_nn_init_execution_context(wasm_exec_env_t exec_env, graph g, return invalid_argument; } - res = lookup[wasi_nn_ctx->current_encoding].init_execution_context( - wasi_nn_ctx->backend_ctx, g, ctx); - + call_wasi_nn_func(res, init_execution_context, wasi_nn_ctx->backend_ctx, g, + ctx); NN_DBG_PRINTF( "wasi_nn_init_execution_context finished with status %d [ctx=%d]", res, *ctx); @@ -278,7 +309,10 @@ wasi_nn_set_input(wasm_exec_env_t exec_env, graph_execution_context ctx, index); wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env); - bh_assert(instance); + if (!instance) { + return runtime_error; + } + WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance); wasi_nn_error res; @@ -291,9 +325,8 @@ wasi_nn_set_input(wasm_exec_env_t exec_env, graph_execution_context ctx, &input_tensor_native))) return res; - res = lookup[wasi_nn_ctx->current_encoding].set_input( - wasi_nn_ctx->backend_ctx, ctx, index, &input_tensor_native); - + call_wasi_nn_func(res, set_input, wasi_nn_ctx->backend_ctx, ctx, index, + &input_tensor_native); // XXX: Free intermediate structure pointers if (input_tensor_native.dimensions) wasm_runtime_free(input_tensor_native.dimensions); @@ -308,15 +341,17 @@ wasi_nn_compute(wasm_exec_env_t exec_env, graph_execution_context ctx) NN_DBG_PRINTF("Running wasi_nn_compute [ctx=%d]...", ctx); wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env); - bh_assert(instance); + if (!instance) { + return runtime_error; + } + WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance); wasi_nn_error res; if (success != (res = is_model_initialized(wasi_nn_ctx))) return res; - res = lookup[wasi_nn_ctx->current_encoding].compute( - wasi_nn_ctx->backend_ctx, ctx); + call_wasi_nn_func(res, compute, wasi_nn_ctx->backend_ctx, ctx); NN_DBG_PRINTF("wasi_nn_compute finished with status %d", res); return res; } @@ -337,7 +372,10 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx, index); wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env); - bh_assert(instance); + if (!instance) { + return runtime_error; + } + WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance); wasi_nn_error res; @@ -351,14 +389,12 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx, } #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - res = lookup[wasi_nn_ctx->current_encoding].get_output( - wasi_nn_ctx->backend_ctx, ctx, index, output_tensor, - &output_tensor_len); + call_wasi_nn_func(res, get_output, wasi_nn_ctx->backend_ctx, ctx, index, + output_tensor, &output_tensor_len); *output_tensor_size = output_tensor_len; #else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */ - res = lookup[wasi_nn_ctx->current_encoding].get_output( - wasi_nn_ctx->backend_ctx, ctx, index, output_tensor, - output_tensor_size); + call_wasi_nn_func(res, get_output, wasi_nn_ctx->backend_ctx, ctx, index, + output_tensor, output_tensor_size); #endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */ NN_DBG_PRINTF("wasi_nn_get_output finished with status %d [data_size=%d]", res, *output_tensor_size); @@ -375,6 +411,7 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx, static NativeSymbol native_symbols_wasi_nn[] = { #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 REG_NATIVE_FUNC(load, "(*iii*)i"), + REG_NATIVE_FUNC(load_by_name, "(*i*)i"), REG_NATIVE_FUNC(init_execution_context, "(i*)i"), REG_NATIVE_FUNC(set_input, "(ii*)i"), REG_NATIVE_FUNC(compute, "(i)i"), @@ -429,15 +466,9 @@ deinit_native_lib() } __attribute__((used)) bool -wasi_nn_register_backend(graph_encoding backend_code, api_function apis) +wasi_nn_register_backend(api_function apis) { NN_DBG_PRINTF("--|> wasi_nn_register_backend"); - - if (backend_code >= sizeof(lookup) / sizeof(lookup[0])) { - NN_ERR_PRINTF("Invalid backend code"); - return false; - } - - lookup[backend_code] = apis; + lookup = apis; return true; } \ No newline at end of file diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h b/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h index 53902807f..df5080dea 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h @@ -11,6 +11,7 @@ typedef struct { bool is_model_loaded; + // Optional graph_encoding current_encoding; void *backend_ctx; } WASINNContext; diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp b/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp index 2618128eb..606aca243 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp @@ -7,6 +7,7 @@ #include "logger.h" #include "bh_platform.h" +#include "wasi_nn_types.h" #include "wasm_export.h" #include @@ -144,7 +145,7 @@ tensorflowlite_load(void *tflite_ctx, graph_builder_array *builder, tfl_ctx->models[*g].model_pointer = (char *)wasm_runtime_malloc(size); if (tfl_ctx->models[*g].model_pointer == NULL) { NN_ERR_PRINTF("Error when allocating memory for model."); - return missing_memory; + return too_large; } bh_memcpy_s(tfl_ctx->models[*g].model_pointer, size, builder->buf[0].buf, @@ -159,7 +160,7 @@ tensorflowlite_load(void *tflite_ctx, graph_builder_array *builder, NN_ERR_PRINTF("Loading model error."); wasm_runtime_free(tfl_ctx->models[*g].model_pointer); tfl_ctx->models[*g].model_pointer = NULL; - return missing_memory; + return too_large; } // Save target @@ -167,6 +168,30 @@ tensorflowlite_load(void *tflite_ctx, graph_builder_array *builder, return success; } +wasi_nn_error +tensorflowlite_load_by_name(void *tflite_ctx, const char *filename, + uint32_t filename_len, graph *g) +{ + TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx; + + wasi_nn_error res = initialize_g(tfl_ctx, g); + if (success != res) + return res; + + // Load model + tfl_ctx->models[*g].model = + std::move(tflite::FlatBufferModel::BuildFromFile(filename, NULL)); + + if (tfl_ctx->models[*g].model == NULL) { + NN_ERR_PRINTF("Loading model error."); + return too_large; + } + + // Use CPU as default + tfl_ctx->models[*g].target = cpu; + return success; +} + wasi_nn_error tensorflowlite_init_execution_context(void *tflite_ctx, graph g, graph_execution_context *ctx) @@ -187,7 +212,7 @@ tensorflowlite_init_execution_context(void *tflite_ctx, graph g, tflite_builder(&tfl_ctx->interpreters[*ctx].interpreter); if (tfl_ctx->interpreters[*ctx].interpreter == NULL) { NN_ERR_PRINTF("Error when generating the interpreter."); - return missing_memory; + return too_large; } bool use_default = false; @@ -207,7 +232,7 @@ tensorflowlite_init_execution_context(void *tflite_ctx, graph g, if (tfl_ctx->delegate == NULL) { NN_ERR_PRINTF("Error when generating GPU delegate."); use_default = true; - return missing_memory; + return too_large; } if (tfl_ctx->interpreters[*ctx] .interpreter->ModifyGraphWithDelegate(tfl_ctx->delegate) @@ -232,7 +257,7 @@ tensorflowlite_init_execution_context(void *tflite_ctx, graph g, if (tfl_ctx->delegate == NULL) { NN_ERR_PRINTF("Error when generating External delegate."); use_default = true; - return missing_memory; + return too_large; } if (tfl_ctx->interpreters[*ctx] .interpreter->ModifyGraphWithDelegate(tfl_ctx->delegate) @@ -276,7 +301,7 @@ tensorflowlite_set_input(void *tflite_ctx, graph_execution_context ctx, auto tensor = tfl_ctx->interpreters[ctx].interpreter->input_tensor(index); if (tensor == NULL) { NN_ERR_PRINTF("Missing memory"); - return missing_memory; + return too_large; } uint32_t model_tensor_size = 1; @@ -363,7 +388,7 @@ tensorflowlite_get_output(void *tflite_ctx, graph_execution_context ctx, auto tensor = tfl_ctx->interpreters[ctx].interpreter->output_tensor(index); if (tensor == NULL) { NN_ERR_PRINTF("Missing memory"); - return missing_memory; + return too_large; } uint32_t model_tensor_size = 1; @@ -372,7 +397,7 @@ tensorflowlite_get_output(void *tflite_ctx, graph_execution_context ctx, if (*output_tensor_size < model_tensor_size) { NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index); - return missing_memory; + return too_large; } if (tensor->quantization.type == kTfLiteNoQuantization) { @@ -409,13 +434,13 @@ tensorflowlite_get_output(void *tflite_ctx, graph_execution_context ctx, return success; } -void +wasi_nn_error tensorflowlite_initialize(void **tflite_ctx) { TFLiteContext *tfl_ctx = new TFLiteContext(); if (tfl_ctx == NULL) { NN_ERR_PRINTF("Error when allocating memory for tensorflowlite."); - return; + return runtime_error; } NN_DBG_PRINTF("Initializing models."); @@ -433,9 +458,10 @@ tensorflowlite_initialize(void **tflite_ctx) tfl_ctx->delegate = NULL; *tflite_ctx = (void *)tfl_ctx; + return success; } -void +wasi_nn_error tensorflowlite_destroy(void *tflite_ctx) { /* @@ -485,6 +511,7 @@ tensorflowlite_destroy(void *tflite_ctx) os_mutex_destroy(&tfl_ctx->g_lock); delete tfl_ctx; NN_DBG_PRINTF("Memory free'd."); + return success; } __attribute__((constructor(200))) void @@ -492,6 +519,7 @@ tflite_register_backend() { api_function apis = { .load = tensorflowlite_load, + .load_by_name = tensorflowlite_load_by_name, .init_execution_context = tensorflowlite_init_execution_context, .set_input = tensorflowlite_set_input, .compute = tensorflowlite_compute, @@ -499,5 +527,5 @@ tflite_register_backend() .init = tensorflowlite_initialize, .deinit = tensorflowlite_destroy, }; - wasi_nn_register_backend(tensorflowlite, apis); + wasi_nn_register_backend(apis); } \ No newline at end of file diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.hpp b/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.hpp index 6eea38be9..630e741c0 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.hpp +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.hpp @@ -32,10 +32,10 @@ tensorflowlite_get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index, tensor_data output_tensor, uint32_t *output_tensor_size); -void +wasi_nn_error tensorflowlite_initialize(void **tflite_ctx); -void +wasi_nn_error tensorflowlite_destroy(void *tflite_ctx); #ifdef __cplusplus diff --git a/core/iwasm/libraries/wasi-nn/test/utils.c b/core/iwasm/libraries/wasi-nn/test/utils.c index f19ec0f8e..c499adc5b 100644 --- a/core/iwasm/libraries/wasi-nn/test/utils.c +++ b/core/iwasm/libraries/wasi-nn/test/utils.c @@ -23,14 +23,14 @@ wasm_load(char *model_name, graph *g, execution_target target) buffer = (uint8_t *)malloc(sizeof(uint8_t) * MAX_MODEL_SIZE); if (buffer == NULL) { fclose(pFile); - return missing_memory; + return too_large; } result = fread(buffer, 1, MAX_MODEL_SIZE, pFile); if (result <= 0) { fclose(pFile); free(buffer); - return missing_memory; + return too_large; } graph_builder_array arr; @@ -40,7 +40,7 @@ wasm_load(char *model_name, graph *g, execution_target target) if (arr.buf == NULL) { fclose(pFile); free(buffer); - return missing_memory; + return too_large; } arr.buf[0].size = result; @@ -54,6 +54,13 @@ wasm_load(char *model_name, graph *g, execution_target target) return res; } +wasi_nn_error +wasm_load_by_name(const char *model_name, graph *g) +{ + wasm_nn_error res = load_by_name(model_name, g); + return res; +} + wasi_nn_error wasm_init_execution_context(graph g, graph_execution_context *ctx) { @@ -67,7 +74,7 @@ wasm_set_input(graph_execution_context ctx, float *input_tensor, uint32_t *dim) dims.size = INPUT_TENSOR_DIMS; dims.buf = (uint32_t *)malloc(dims.size * sizeof(uint32_t)); if (dims.buf == NULL) - return missing_memory; + return too_large; tensor tensor; tensor.dimensions = &dims;