diff --git a/core/iwasm/libraries/wasi-nn/include/wasi_nn.h b/core/iwasm/libraries/wasi-nn/include/wasi_nn.h index ad1f37deb..c8d1217a7 100644 --- a/core/iwasm/libraries/wasi-nn/include/wasi_nn.h +++ b/core/iwasm/libraries/wasi-nn/include/wasi_nn.h @@ -30,7 +30,7 @@ load(graph_builder_array *builder, graph_encoding encoding, __attribute__((import_module("wasi_nn"))); wasi_nn_error -load_by_name(const char *name, graph *g) +load_by_name(const char *name, uint32_t name_len, graph *g) __attribute__((import_module("wasi_nn"))); /** diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 4697e931b..75f362c76 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -697,6 +697,7 @@ static NativeSymbol native_symbols_wasi_nn[] = { REG_NATIVE_FUNC(get_output, "(ii*i*)i"), #else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */ REG_NATIVE_FUNC(load, "(*ii*)i"), + REG_NATIVE_FUNC(load_by_name, "(*i*)i"), REG_NATIVE_FUNC(init_execution_context, "(i*)i"), REG_NATIVE_FUNC(set_input, "(ii*)i"), REG_NATIVE_FUNC(compute, "(i)i"), diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp b/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp index f63d57e07..09e12f0d2 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp @@ -85,12 +85,8 @@ is_valid_graph(TFLiteContext *tfl_ctx, graph g) NN_ERR_PRINTF("Invalid graph: %d >= %d.", g, MAX_GRAPHS_PER_INST); return runtime_error; } - if (tfl_ctx->models[g].model_pointer == NULL) { - NN_ERR_PRINTF("Context (model) non-initialized."); - return runtime_error; - } if (tfl_ctx->models[g].model == NULL) { - NN_ERR_PRINTF("Context (tflite model) non-initialized."); + NN_ERR_PRINTF("Context (model) non-initialized."); return runtime_error; } return success; @@ -472,32 +468,31 @@ deinit_backend(void *tflite_ctx) NN_DBG_PRINTF("Freeing memory."); for (int i = 0; i < MAX_GRAPHS_PER_INST; ++i) { tfl_ctx->models[i].model.reset(); - if (tfl_ctx->models[i].model_pointer) { - if (tfl_ctx->delegate) { - switch (tfl_ctx->models[i].target) { - case gpu: - { + if (tfl_ctx->delegate) { + switch (tfl_ctx->models[i].target) { + case gpu: + { #if WASM_ENABLE_WASI_NN_GPU != 0 - TfLiteGpuDelegateV2Delete(tfl_ctx->delegate); + TfLiteGpuDelegateV2Delete(tfl_ctx->delegate); #else - NN_ERR_PRINTF("GPU delegate delete but not enabled."); + NN_ERR_PRINTF("GPU delegate delete but not enabled."); #endif - break; - } - case tpu: - { -#if WASM_ENABLE_WASI_NN_EXTERNAL_DELEGATE != 0 - TfLiteExternalDelegateDelete(tfl_ctx->delegate); -#else - NN_ERR_PRINTF( - "External delegate delete but not enabled."); -#endif - break; - } - default: - break; + break; } + case tpu: + { +#if WASM_ENABLE_WASI_NN_EXTERNAL_DELEGATE != 0 + TfLiteExternalDelegateDelete(tfl_ctx->delegate); +#else + NN_ERR_PRINTF("External delegate delete but not enabled."); +#endif + break; + } + default: + break; } + } + if (tfl_ctx->models[i].model_pointer) { wasm_runtime_free(tfl_ctx->models[i].model_pointer); } tfl_ctx->models[i].model_pointer = NULL; diff --git a/core/iwasm/libraries/wasi-nn/test/utils.c b/core/iwasm/libraries/wasi-nn/test/utils.c index 9e43ec985..690c37f0e 100644 --- a/core/iwasm/libraries/wasi-nn/test/utils.c +++ b/core/iwasm/libraries/wasi-nn/test/utils.c @@ -58,7 +58,7 @@ wasm_load(char *model_name, graph *g, execution_target target) wasi_nn_error wasm_load_by_name(const char *model_name, graph *g) { - wasi_nn_error res = load_by_name(model_name, g); + wasi_nn_error res = load_by_name(model_name, strlen(model_name), g); return res; } @@ -108,7 +108,8 @@ run_inference(execution_target target, float *input, uint32_t *input_size, uint32_t num_output_tensors) { graph graph; - if (wasm_load(model_name, &graph, target) != success) { + + if (wasm_load_by_name(model_name, &graph) != success) { NN_ERR_PRINTF("Error when loading model."); exit(1); }