From 5357fb5f21dc957f43c4f44aef48b96a2a37a5f9 Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Fri, 6 Feb 2026 11:13:12 +0800 Subject: [PATCH] Move the error checks to an earlier stage. --- core/iwasm/common/wasm_runtime_common.c | 37 +++++---- core/iwasm/common/wasm_runtime_common.h | 10 +-- core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 46 +---------- product-mini/platforms/common/libc_wasi.c | 89 +++++++++++++++++++--- 4 files changed, 102 insertions(+), 80 deletions(-) diff --git a/core/iwasm/common/wasm_runtime_common.c b/core/iwasm/common/wasm_runtime_common.c index d99467e7f..ccd569cdc 100644 --- a/core/iwasm/common/wasm_runtime_common.c +++ b/core/iwasm/common/wasm_runtime_common.c @@ -1804,8 +1804,8 @@ wasm_runtime_wasi_nn_graph_registry_args_set_defaults(WASINNRegistry *args) bool wasm_runtime_wasi_nn_registry_set_args(WASINNRegistry *registry, - const char **model_names, const char **encoding, - const char **target, uint32_t n_graphs, + const char **model_names, const uint32_t **encoding, + const uint32_t **target, uint32_t n_graphs, const char **graph_paths) { if (!registry || !model_names || !encoding || !target || !graph_paths) { @@ -1832,8 +1832,8 @@ wasm_runtime_wasi_nn_registry_set_args(WASINNRegistry *registry, for (uint32_t i = 0; i < registry->n_graphs; i++) { registry->graph_paths[i] = bh_strdup(graph_paths[i]); registry->model_names[i] = bh_strdup(model_names[i]); - registry->encoding[i] = bh_strdup(encoding[i]); - registry->target[i] = bh_strdup(target[i]); + registry->encoding[i] = encoding[i]; + registry->target[i] = target[i]; } return true; @@ -1860,13 +1860,13 @@ wasm_runtime_wasi_nn_registry_destroy(WASINNRegistry *registry) wasm_runtime_free(registry->graph_paths[i]); if (registry->model_names[i]) wasm_runtime_free(registry->model_names[i]); - if (registry->encoding[i]) - wasm_runtime_free(registry->encoding[i]); - if (registry->target[i]) - wasm_runtime_free(registry->target[i]); } - if (registry->loaded) - wasm_runtime_free(registry->loaded); + if (registry->encoding) + wasm_runtime_free(registry->encoding); + if (registry->target) + wasm_runtime_free(registry->target); + if (registry->loaded) + wasm_runtime_free(registry->loaded); wasm_runtime_free(registry); } } @@ -1881,16 +1881,13 @@ wasm_runtime_instantiation_args_set_wasi_nn_registry( wasi_nn_registry->n_graphs = registry->n_graphs; - if (registry->model_names) - wasi_nn_registry->model_names = bh_strdup(registry->model_names); - if (registry->encoding) - wasi_nn_registry->encoding = bh_strdup(registry->encoding); - if (registry->target) - wasi_nn_registry->target = bh_strdup(registry->target); - if (registry->loaded) - wasi_nn_registry->loaded = bh_strdup(registry->loaded); - if (registry->graph_paths) - wasi_nn_registry->graph_paths = bh_strdup(registry->graph_paths); + for (uint32_t i = 0; i < registry->n_graphs; i++) { + registry->graph_paths[i] = bh_strdup(registry->graph_paths[i]); + registry->model_names[i] = bh_strdup(registry->model_names[i]); + registry->encoding[i] = registry->encoding[i]; + registry->target[i] = registry->target[i]; + wasi_nn_registry->loaded = registry->loaded; + } } #endif diff --git a/core/iwasm/common/wasm_runtime_common.h b/core/iwasm/common/wasm_runtime_common.h index 80b0ea05f..23aa45126 100644 --- a/core/iwasm/common/wasm_runtime_common.h +++ b/core/iwasm/common/wasm_runtime_common.h @@ -548,11 +548,11 @@ typedef struct WASMModuleInstMemConsumption { #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 typedef struct WASINNRegistry { char **model_names; - char **encoding; - char **target; + uint32_t **encoding; + uint32_t **target; uint32_t n_graphs; - uint32_t *loaded; + uint32_t **loaded; char **graph_paths; } WASINNRegistry; #endif @@ -805,8 +805,8 @@ wasm_runtime_instantiation_args_set_wasi_nn_registry( WASM_RUNTIME_API_EXTERN bool wasm_runtime_wasi_nn_registry_set_args(WASINNRegistry *registry, - const char **model_names, const char **encoding, - const char **target, uint32_t n_graphs, + const char **model_names, const uint32_t **encoding, + const uint32_t **target, uint32_t n_graphs, const char **graph_paths); #endif diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 79465a8da..d0bfe0f2e 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -211,46 +211,6 @@ wasi_nn_destroy() * - model file format * - on device ML framework */ -static graph_encoding -str2encoding(char *str_encoding) -{ - if (!str_encoding) { - NN_ERR_PRINTF("Got empty string encoding"); - return -1; - } - - if (!strcmp(str_encoding, "openvino")) - return openvino; - else if (!strcmp(str_encoding, "tensorflowlite")) - return tensorflowlite; - else if (!strcmp(str_encoding, "ggml")) - return ggml; - else if (!strcmp(str_encoding, "onnx")) - return onnx; - else - return unknown_backend; - // return autodetect; -} - -static execution_target -str2target(char *str_target) -{ - if (!str_target) { - NN_ERR_PRINTF("Got empty string target"); - return -1; - } - - if (!strcmp(str_target, "cpu")) - return cpu; - else if (!strcmp(str_target, "gpu")) - return gpu; - else if (!strcmp(str_target, "tpu")) - return tpu; - else - return unsupported_target; - // return autodetect; -} - static graph_encoding choose_a_backend() { @@ -630,10 +590,8 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, is_loaded = wasi_nn_registry->loaded[model_idx]; char *global_model_path_i = wasi_nn_registry->graph_paths[model_idx]; - graph_encoding encoding = - str2encoding(wasi_nn_registry->encoding[model_idx]); - execution_target target = - str2target(wasi_nn_registry->target[model_idx]); + graph_encoding encoding = (graph_encoding)(wasi_nn_registry->encoding[model_idx]); + execution_target target = (execution_target)(wasi_nn_registry->target[model_idx]); // res = ensure_backend(instance, autodetect, wasi_nn_ctx); res = ensure_backend(instance, encoding, wasi_nn_ctx); diff --git a/product-mini/platforms/common/libc_wasi.c b/product-mini/platforms/common/libc_wasi.c index 3559bd1f6..b7fa53bbc 100644 --- a/product-mini/platforms/common/libc_wasi.c +++ b/product-mini/platforms/common/libc_wasi.c @@ -21,20 +21,78 @@ typedef struct { uint32 ns_lookup_pool_size; } libc_wasi_parse_context_t; -typedef struct { - const char *model_names[10]; - const char *encoding[10]; - const char *target[10]; - const char *graph_paths[10]; - uint32 n_graphs; -} wasi_nn_parse_context_t; - typedef enum { LIBC_WASI_PARSE_RESULT_OK = 0, LIBC_WASI_PARSE_RESULT_NEED_HELP, LIBC_WASI_PARSE_RESULT_BAD_PARAM } libc_wasi_parse_result_t; +typedef struct { + const char *model_names[10]; + const uint32_t *encoding[10]; + const uint32_t *target[10]; + const char *graph_paths[10]; + uint32 n_graphs; +} wasi_nn_parse_context_t; + +typedef enum { + wasi_nn_openvino = 0, + wasi_nn_onnx, + wasi_nn_tensorflow, + wasi_nn_pytorch, + wasi_nn_tensorflowlite, + wasi_nn_ggml, + wasi_nn_autodetect, + wasi_nn_unknown_backend, +} wasi_nn_encoding; + +typedef enum wasi_nn_target { + wasi_nn_cpu = 0, + wasi_nn_gpu, + wasi_nn_tpu, + wasi_nn_unsupported_target, +} wasi_nn_target; + +static wasi_nn_encoding +str2encoding(char *str_encoding) +{ + if (!str_encoding) { + printf("Got empty string encoding"); + return -1; + } + + if (!strcmp(str_encoding, "openvino")) + return wasi_nn_openvino; + else if (!strcmp(str_encoding, "tensorflowlite")) + return wasi_nn_tensorflowlite; + else if (!strcmp(str_encoding, "ggml")) + return wasi_nn_ggml; + else if (!strcmp(str_encoding, "onnx")) + return wasi_nn_onnx; + else + return wasi_nn_unknown_backend; + // return autodetect; +} + +static wasi_nn_target +str2target(char *str_target) +{ + if (!str_target) { + printf("Got empty string target"); + return -1; + } + + if (!strcmp(str_target, "cpu")) + return wasi_nn_cpu; + else if (!strcmp(str_target, "gpu")) + return wasi_nn_gpu; + else if (!strcmp(str_target, "tpu")) + return wasi_nn_tpu; + else + return wasi_nn_unsupported_target; + // return autodetect; +} + static void libc_wasi_print_help(void) { @@ -223,10 +281,19 @@ wasi_nn_parse(char **argv, wasi_nn_parse_context_t *ctx) } ctx->model_names[ctx->n_graphs] = tokens[0]; - ctx->encoding[ctx->n_graphs] = tokens[1]; - ctx->target[ctx->n_graphs] = tokens[2]; - ctx->graph_paths[ctx->n_graphs++] = tokens[3]; + ctx->encoding[ctx->n_graphs] = (uint32_t)str2encoding(tokens[1]); + ctx->target[ctx->n_graphs] = (uint32_t)str2target(tokens[2]); + ctx->graph_paths[ctx->n_graphs] = tokens[3]; + if ((!ctx->model_names[ctx->n_graphs]) || + (ctx->encoding[ctx->n_graphs] == wasi_nn_unknown_backend) || + (ctx->target[ctx->n_graphs] == wasi_nn_unsupported_target)) { + ret = LIBC_WASI_PARSE_RESULT_NEED_HELP; + printf("Invalid arguments for wasi-nn.\n"); + goto fail; + } + + ctx->n_graphs++; fail: return ret;