Move the error checks to an earlier stage.

This commit is contained in:
zhanheng1 2026-02-06 11:13:12 +08:00
parent ccee1941c2
commit 5357fb5f21
4 changed files with 102 additions and 80 deletions

View File

@ -1804,8 +1804,8 @@ wasm_runtime_wasi_nn_graph_registry_args_set_defaults(WASINNRegistry *args)
bool
wasm_runtime_wasi_nn_registry_set_args(WASINNRegistry *registry,
const char **model_names, const char **encoding,
const char **target, uint32_t n_graphs,
const char **model_names, const uint32_t **encoding,
const uint32_t **target, uint32_t n_graphs,
const char **graph_paths)
{
if (!registry || !model_names || !encoding || !target || !graph_paths) {
@ -1832,8 +1832,8 @@ wasm_runtime_wasi_nn_registry_set_args(WASINNRegistry *registry,
for (uint32_t i = 0; i < registry->n_graphs; i++) {
registry->graph_paths[i] = bh_strdup(graph_paths[i]);
registry->model_names[i] = bh_strdup(model_names[i]);
registry->encoding[i] = bh_strdup(encoding[i]);
registry->target[i] = bh_strdup(target[i]);
registry->encoding[i] = encoding[i];
registry->target[i] = target[i];
}
return true;
@ -1860,13 +1860,13 @@ wasm_runtime_wasi_nn_registry_destroy(WASINNRegistry *registry)
wasm_runtime_free(registry->graph_paths[i]);
if (registry->model_names[i])
wasm_runtime_free(registry->model_names[i]);
if (registry->encoding[i])
wasm_runtime_free(registry->encoding[i]);
if (registry->target[i])
wasm_runtime_free(registry->target[i]);
}
if (registry->loaded)
wasm_runtime_free(registry->loaded);
if (registry->encoding)
wasm_runtime_free(registry->encoding);
if (registry->target)
wasm_runtime_free(registry->target);
if (registry->loaded)
wasm_runtime_free(registry->loaded);
wasm_runtime_free(registry);
}
}
@ -1881,16 +1881,13 @@ wasm_runtime_instantiation_args_set_wasi_nn_registry(
wasi_nn_registry->n_graphs = registry->n_graphs;
if (registry->model_names)
wasi_nn_registry->model_names = bh_strdup(registry->model_names);
if (registry->encoding)
wasi_nn_registry->encoding = bh_strdup(registry->encoding);
if (registry->target)
wasi_nn_registry->target = bh_strdup(registry->target);
if (registry->loaded)
wasi_nn_registry->loaded = bh_strdup(registry->loaded);
if (registry->graph_paths)
wasi_nn_registry->graph_paths = bh_strdup(registry->graph_paths);
for (uint32_t i = 0; i < registry->n_graphs; i++) {
registry->graph_paths[i] = bh_strdup(registry->graph_paths[i]);
registry->model_names[i] = bh_strdup(registry->model_names[i]);
registry->encoding[i] = registry->encoding[i];
registry->target[i] = registry->target[i];
wasi_nn_registry->loaded = registry->loaded;
}
}
#endif

View File

@ -548,11 +548,11 @@ typedef struct WASMModuleInstMemConsumption {
#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
typedef struct WASINNRegistry {
char **model_names;
char **encoding;
char **target;
uint32_t **encoding;
uint32_t **target;
uint32_t n_graphs;
uint32_t *loaded;
uint32_t **loaded;
char **graph_paths;
} WASINNRegistry;
#endif
@ -805,8 +805,8 @@ wasm_runtime_instantiation_args_set_wasi_nn_registry(
WASM_RUNTIME_API_EXTERN bool
wasm_runtime_wasi_nn_registry_set_args(WASINNRegistry *registry,
const char **model_names, const char **encoding,
const char **target, uint32_t n_graphs,
const char **model_names, const uint32_t **encoding,
const uint32_t **target, uint32_t n_graphs,
const char **graph_paths);
#endif

View File

@ -211,46 +211,6 @@ wasi_nn_destroy()
* - model file format
* - on device ML framework
*/
static graph_encoding
str2encoding(char *str_encoding)
{
if (!str_encoding) {
NN_ERR_PRINTF("Got empty string encoding");
return -1;
}
if (!strcmp(str_encoding, "openvino"))
return openvino;
else if (!strcmp(str_encoding, "tensorflowlite"))
return tensorflowlite;
else if (!strcmp(str_encoding, "ggml"))
return ggml;
else if (!strcmp(str_encoding, "onnx"))
return onnx;
else
return unknown_backend;
// return autodetect;
}
static execution_target
str2target(char *str_target)
{
if (!str_target) {
NN_ERR_PRINTF("Got empty string target");
return -1;
}
if (!strcmp(str_target, "cpu"))
return cpu;
else if (!strcmp(str_target, "gpu"))
return gpu;
else if (!strcmp(str_target, "tpu"))
return tpu;
else
return unsupported_target;
// return autodetect;
}
static graph_encoding
choose_a_backend()
{
@ -630,10 +590,8 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
is_loaded = wasi_nn_registry->loaded[model_idx];
char *global_model_path_i = wasi_nn_registry->graph_paths[model_idx];
graph_encoding encoding =
str2encoding(wasi_nn_registry->encoding[model_idx]);
execution_target target =
str2target(wasi_nn_registry->target[model_idx]);
graph_encoding encoding = (graph_encoding)(wasi_nn_registry->encoding[model_idx]);
execution_target target = (execution_target)(wasi_nn_registry->target[model_idx]);
// res = ensure_backend(instance, autodetect, wasi_nn_ctx);
res = ensure_backend(instance, encoding, wasi_nn_ctx);

View File

@ -21,20 +21,78 @@ typedef struct {
uint32 ns_lookup_pool_size;
} libc_wasi_parse_context_t;
typedef struct {
const char *model_names[10];
const char *encoding[10];
const char *target[10];
const char *graph_paths[10];
uint32 n_graphs;
} wasi_nn_parse_context_t;
typedef enum {
LIBC_WASI_PARSE_RESULT_OK = 0,
LIBC_WASI_PARSE_RESULT_NEED_HELP,
LIBC_WASI_PARSE_RESULT_BAD_PARAM
} libc_wasi_parse_result_t;
typedef struct {
const char *model_names[10];
const uint32_t *encoding[10];
const uint32_t *target[10];
const char *graph_paths[10];
uint32 n_graphs;
} wasi_nn_parse_context_t;
typedef enum {
wasi_nn_openvino = 0,
wasi_nn_onnx,
wasi_nn_tensorflow,
wasi_nn_pytorch,
wasi_nn_tensorflowlite,
wasi_nn_ggml,
wasi_nn_autodetect,
wasi_nn_unknown_backend,
} wasi_nn_encoding;
typedef enum wasi_nn_target {
wasi_nn_cpu = 0,
wasi_nn_gpu,
wasi_nn_tpu,
wasi_nn_unsupported_target,
} wasi_nn_target;
static wasi_nn_encoding
str2encoding(char *str_encoding)
{
if (!str_encoding) {
printf("Got empty string encoding");
return -1;
}
if (!strcmp(str_encoding, "openvino"))
return wasi_nn_openvino;
else if (!strcmp(str_encoding, "tensorflowlite"))
return wasi_nn_tensorflowlite;
else if (!strcmp(str_encoding, "ggml"))
return wasi_nn_ggml;
else if (!strcmp(str_encoding, "onnx"))
return wasi_nn_onnx;
else
return wasi_nn_unknown_backend;
// return autodetect;
}
static wasi_nn_target
str2target(char *str_target)
{
if (!str_target) {
printf("Got empty string target");
return -1;
}
if (!strcmp(str_target, "cpu"))
return wasi_nn_cpu;
else if (!strcmp(str_target, "gpu"))
return wasi_nn_gpu;
else if (!strcmp(str_target, "tpu"))
return wasi_nn_tpu;
else
return wasi_nn_unsupported_target;
// return autodetect;
}
static void
libc_wasi_print_help(void)
{
@ -223,10 +281,19 @@ wasi_nn_parse(char **argv, wasi_nn_parse_context_t *ctx)
}
ctx->model_names[ctx->n_graphs] = tokens[0];
ctx->encoding[ctx->n_graphs] = tokens[1];
ctx->target[ctx->n_graphs] = tokens[2];
ctx->graph_paths[ctx->n_graphs++] = tokens[3];
ctx->encoding[ctx->n_graphs] = (uint32_t)str2encoding(tokens[1]);
ctx->target[ctx->n_graphs] = (uint32_t)str2target(tokens[2]);
ctx->graph_paths[ctx->n_graphs] = tokens[3];
if ((!ctx->model_names[ctx->n_graphs]) ||
(ctx->encoding[ctx->n_graphs] == wasi_nn_unknown_backend) ||
(ctx->target[ctx->n_graphs] == wasi_nn_unsupported_target)) {
ret = LIBC_WASI_PARSE_RESULT_NEED_HELP;
printf("Invalid arguments for wasi-nn.\n");
goto fail;
}
ctx->n_graphs++;
fail:
return ret;