diff --git a/core/iwasm/common/wasm_native.c b/core/iwasm/common/wasm_native.c index 2ba4a5778..2b49c052e 100644 --- a/core/iwasm/common/wasm_native.c +++ b/core/iwasm/common/wasm_native.c @@ -486,7 +486,7 @@ wasm_runtime_get_wasi_nn_registry(WASMModuleInstanceCommon *module_inst_comm) void wasm_runtime_set_wasi_nn_registry(WASMModuleInstanceCommon *module_inst_comm, - WASINNRegistry *wasi_nn_ctx) + WASINNRegistry *wasi_nn_ctx) { wasm_native_set_context(module_inst_comm, g_wasi_nn_registry_key, wasi_nn_ctx); diff --git a/core/iwasm/common/wasm_runtime_common.c b/core/iwasm/common/wasm_runtime_common.c index ccd569cdc..61b8e7796 100644 --- a/core/iwasm/common/wasm_runtime_common.c +++ b/core/iwasm/common/wasm_runtime_common.c @@ -1804,9 +1804,11 @@ wasm_runtime_wasi_nn_graph_registry_args_set_defaults(WASINNRegistry *args) bool wasm_runtime_wasi_nn_registry_set_args(WASINNRegistry *registry, - const char **model_names, const uint32_t **encoding, - const uint32_t **target, uint32_t n_graphs, - const char **graph_paths) + const char **model_names, + const uint32_t **encoding, + const uint32_t **target, + uint32_t n_graphs, + const char **graph_paths) { if (!registry || !model_names || !encoding || !target || !graph_paths) { return false; @@ -1818,11 +1820,16 @@ wasm_runtime_wasi_nn_registry_set_args(WASINNRegistry *registry, } registry->n_graphs = n_graphs; - registry->target = (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); - registry->encoding = (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); - registry->loaded = (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); - registry->model_names = (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); - registry->graph_paths = (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); + registry->target = + (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); + registry->encoding = + (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); + registry->loaded = + (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); + registry->model_names = + (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); + registry->graph_paths = + (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); memset(registry->target, 0, sizeof(uint32_t *) * n_graphs); memset(registry->encoding, 0, sizeof(uint32_t *) * n_graphs); memset(registry->loaded, 0, sizeof(uint32_t *) * n_graphs); @@ -1858,15 +1865,15 @@ wasm_runtime_wasi_nn_registry_destroy(WASINNRegistry *registry) for (uint32_t i = 0; i < registry->n_graphs; i++) if (registry->graph_paths[i]) { wasm_runtime_free(registry->graph_paths[i]); - if (registry->model_names[i]) - wasm_runtime_free(registry->model_names[i]); + if (registry->model_names[i]) + wasm_runtime_free(registry->model_names[i]); } - if (registry->encoding) - wasm_runtime_free(registry->encoding); - if (registry->target) - wasm_runtime_free(registry->target); - if (registry->loaded) - wasm_runtime_free(registry->loaded); + if (registry->encoding) + wasm_runtime_free(registry->encoding); + if (registry->target) + wasm_runtime_free(registry->target); + if (registry->loaded) + wasm_runtime_free(registry->loaded); wasm_runtime_free(registry); } } diff --git a/core/iwasm/common/wasm_runtime_common.h b/core/iwasm/common/wasm_runtime_common.h index 23aa45126..bc6cbbbf1 100644 --- a/core/iwasm/common/wasm_runtime_common.h +++ b/core/iwasm/common/wasm_runtime_common.h @@ -805,9 +805,11 @@ wasm_runtime_instantiation_args_set_wasi_nn_registry( WASM_RUNTIME_API_EXTERN bool wasm_runtime_wasi_nn_registry_set_args(WASINNRegistry *registry, - const char **model_names, const uint32_t **encoding, - const uint32_t **target, uint32_t n_graphs, - const char **graph_paths); + const char **model_names, + const uint32_t **encoding, + const uint32_t **target, + uint32_t n_graphs, + const char **graph_paths); #endif /* See wasm_export.h for description */ @@ -1465,14 +1467,14 @@ wasm_runtime_check_and_update_last_used_shared_heap( #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 WASM_RUNTIME_API_EXTERN void wasm_runtime_set_wasi_nn_registry(WASMModuleInstanceCommon *module_inst, - WASINNRegistry *wasi_ctx); + WASINNRegistry *wasi_ctx); WASM_RUNTIME_API_EXTERN WASINNRegistry * wasm_runtime_get_wasi_nn_registry(WASMModuleInstanceCommon *module_inst_comm); WASM_RUNTIME_API_EXTERN void wasm_runtime_set_wasi_nn_registry(WASMModuleInstanceCommon *module_inst_comm, - WASINNRegistry *wasi_nn_ctx); + WASINNRegistry *wasi_nn_ctx); #endif #ifdef __cplusplus diff --git a/core/iwasm/include/wasm_export.h b/core/iwasm/include/wasm_export.h index 8f4a61cdc..885e160c0 100644 --- a/core/iwasm/include/wasm_export.h +++ b/core/iwasm/include/wasm_export.h @@ -802,7 +802,7 @@ wasm_runtime_get_wasi_nn_registry(const wasm_module_inst_t module_inst); WASM_RUNTIME_API_EXTERN void wasm_runtime_set_wasi_nn_registry(wasm_module_inst_t module_inst, - struct WASINNRegistry *wasi_ctx); + struct WASINNRegistry *wasi_ctx); WASM_RUNTIME_API_EXTERN int wasm_runtime_wasi_nn_registry_create(struct WASINNRegistry **registryp); diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index d0bfe0f2e..8effa8fd3 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -578,8 +578,7 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, bool is_loaded = false; uint32 model_idx = 0; - uint32_t global_n_graphs = - wasi_nn_registry->n_graphs; + uint32_t global_n_graphs = wasi_nn_registry->n_graphs; for (model_idx = 0; model_idx < global_n_graphs; model_idx++) { char *model_name = wasi_nn_registry->model_names[model_idx]; @@ -590,8 +589,10 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, is_loaded = wasi_nn_registry->loaded[model_idx]; char *global_model_path_i = wasi_nn_registry->graph_paths[model_idx]; - graph_encoding encoding = (graph_encoding)(wasi_nn_registry->encoding[model_idx]); - execution_target target = (execution_target)(wasi_nn_registry->target[model_idx]); + graph_encoding encoding = + (graph_encoding)(wasi_nn_registry->encoding[model_idx]); + execution_target target = + (execution_target)(wasi_nn_registry->target[model_idx]); // res = ensure_backend(instance, autodetect, wasi_nn_ctx); res = ensure_backend(instance, encoding, wasi_nn_ctx); diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h b/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h index 6a75c3309..6e2e5a4a9 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h @@ -18,7 +18,7 @@ load(void *ctx, graph_builder_array *builder, graph_encoding encoding, __attribute__((visibility("default"))) wasi_nn_error load_by_name(void *tflite_ctx, const char *name, uint32_t namelen, - execution_target target, graph *g); + execution_target target, graph *g); __attribute__((visibility("default"))) wasi_nn_error load_by_name_with_config(void *ctx, const char *name, uint32_t namelen, diff --git a/core/iwasm/libraries/wasi-nn/test/utils.c b/core/iwasm/libraries/wasi-nn/test/utils.c index 64247e8d3..5bff60d6e 100644 --- a/core/iwasm/libraries/wasi-nn/test/utils.c +++ b/core/iwasm/libraries/wasi-nn/test/utils.c @@ -151,7 +151,7 @@ run_inference(float *input, uint32_t *input_size, uint32_t *output_size, WASI_NN_NAME(graph) graph; wasi_nn_error_t res = wasm_load_by_name(model_name, &graph); - + if (res == WASI_NN_ERROR_NAME(not_found)) { NN_INFO_PRINTF("Model %s is not loaded, you should pass its path " "through --wasi-nn-graph", diff --git a/core/iwasm/libraries/wasi-nn/test/utils.h b/core/iwasm/libraries/wasi-nn/test/utils.h index ac3acd347..ff14b209f 100644 --- a/core/iwasm/libraries/wasi-nn/test/utils.h +++ b/core/iwasm/libraries/wasi-nn/test/utils.h @@ -35,27 +35,27 @@ typedef wasi_nn_error wasi_nn_error_t; /* wasi-nn wrappers */ wasi_nn_error_t -wasm_load(char *model_name, WASI_NN_NAME(graph) *g, WASI_NN_NAME(execution_target) target); +wasm_load(char *model_name, WASI_NN_NAME(graph) * g, + WASI_NN_NAME(execution_target) target); + +wasi_nn_error_t wasm_init_execution_context( + WASI_NN_NAME(graph) g, WASI_NN_NAME(graph_execution_context) * ctx); wasi_nn_error_t -wasm_init_execution_context(WASI_NN_NAME(graph) g, WASI_NN_NAME(graph_execution_context) *ctx); +wasm_set_input(WASI_NN_NAME(graph_execution_context) ctx, float *input_tensor, + uint32_t *dim); + +wasi_nn_error_t wasm_compute(WASI_NN_NAME(graph_execution_context) ctx); wasi_nn_error_t -wasm_set_input(WASI_NN_NAME(graph_execution_context) ctx, float *input_tensor, uint32_t *dim); - -wasi_nn_error_t -wasm_compute(WASI_NN_NAME(graph_execution_context) ctx); - -wasi_nn_error_t -wasm_get_output(WASI_NN_NAME(graph_execution_context) ctx, uint32_t index, float *out_tensor, - uint32_t *out_size); +wasm_get_output(WASI_NN_NAME(graph_execution_context) ctx, uint32_t index, + float *out_tensor, uint32_t *out_size); /* Utils */ float * -run_inference(float *input, uint32_t *input_size, - uint32_t *output_size, char *model_name, - uint32_t num_output_tensors); +run_inference(float *input, uint32_t *input_size, uint32_t *output_size, + char *model_name, uint32_t num_output_tensors); input_info create_input(int *dims); diff --git a/product-mini/platforms/common/libc_wasi.c b/product-mini/platforms/common/libc_wasi.c index b7fa53bbc..137521fbd 100644 --- a/product-mini/platforms/common/libc_wasi.c +++ b/product-mini/platforms/common/libc_wasi.c @@ -285,13 +285,13 @@ wasi_nn_parse(char **argv, wasi_nn_parse_context_t *ctx) ctx->target[ctx->n_graphs] = (uint32_t)str2target(tokens[2]); ctx->graph_paths[ctx->n_graphs] = tokens[3]; - if ((!ctx->model_names[ctx->n_graphs]) || - (ctx->encoding[ctx->n_graphs] == wasi_nn_unknown_backend) || - (ctx->target[ctx->n_graphs] == wasi_nn_unsupported_target)) { - ret = LIBC_WASI_PARSE_RESULT_NEED_HELP; - printf("Invalid arguments for wasi-nn.\n"); - goto fail; - } + if ((!ctx->model_names[ctx->n_graphs]) + || (ctx->encoding[ctx->n_graphs] == wasi_nn_unknown_backend) + || (ctx->target[ctx->n_graphs] == wasi_nn_unsupported_target)) { + ret = LIBC_WASI_PARSE_RESULT_NEED_HELP; + printf("Invalid arguments for wasi-nn.\n"); + goto fail; + } ctx->n_graphs++; fail: @@ -305,8 +305,8 @@ wasi_nn_set_init_args(struct InstantiationArgs2 *args, wasi_nn_parse_context_t *ctx) { wasm_runtime_wasi_nn_registry_set_args(nn_registry, ctx->model_names, - ctx->encoding, ctx->target, ctx->n_graphs, - ctx->graph_paths); + ctx->encoding, ctx->target, + ctx->n_graphs, ctx->graph_paths); wasm_runtime_instantiation_args_set_wasi_nn_registry(args, nn_registry); } #endif \ No newline at end of file