mirror of
https://github.com/bytecodealliance/wasm-micro-runtime.git
synced 2026-04-18 18:18:44 +00:00
Move the error checks to an earlier stage.
This commit is contained in:
parent
ccee1941c2
commit
5357fb5f21
|
|
@ -1804,8 +1804,8 @@ wasm_runtime_wasi_nn_graph_registry_args_set_defaults(WASINNRegistry *args)
|
|||
|
||||
bool
|
||||
wasm_runtime_wasi_nn_registry_set_args(WASINNRegistry *registry,
|
||||
const char **model_names, const char **encoding,
|
||||
const char **target, uint32_t n_graphs,
|
||||
const char **model_names, const uint32_t **encoding,
|
||||
const uint32_t **target, uint32_t n_graphs,
|
||||
const char **graph_paths)
|
||||
{
|
||||
if (!registry || !model_names || !encoding || !target || !graph_paths) {
|
||||
|
|
@ -1832,8 +1832,8 @@ wasm_runtime_wasi_nn_registry_set_args(WASINNRegistry *registry,
|
|||
for (uint32_t i = 0; i < registry->n_graphs; i++) {
|
||||
registry->graph_paths[i] = bh_strdup(graph_paths[i]);
|
||||
registry->model_names[i] = bh_strdup(model_names[i]);
|
||||
registry->encoding[i] = bh_strdup(encoding[i]);
|
||||
registry->target[i] = bh_strdup(target[i]);
|
||||
registry->encoding[i] = encoding[i];
|
||||
registry->target[i] = target[i];
|
||||
}
|
||||
|
||||
return true;
|
||||
|
|
@ -1860,13 +1860,13 @@ wasm_runtime_wasi_nn_registry_destroy(WASINNRegistry *registry)
|
|||
wasm_runtime_free(registry->graph_paths[i]);
|
||||
if (registry->model_names[i])
|
||||
wasm_runtime_free(registry->model_names[i]);
|
||||
if (registry->encoding[i])
|
||||
wasm_runtime_free(registry->encoding[i]);
|
||||
if (registry->target[i])
|
||||
wasm_runtime_free(registry->target[i]);
|
||||
}
|
||||
if (registry->loaded)
|
||||
wasm_runtime_free(registry->loaded);
|
||||
if (registry->encoding)
|
||||
wasm_runtime_free(registry->encoding);
|
||||
if (registry->target)
|
||||
wasm_runtime_free(registry->target);
|
||||
if (registry->loaded)
|
||||
wasm_runtime_free(registry->loaded);
|
||||
wasm_runtime_free(registry);
|
||||
}
|
||||
}
|
||||
|
|
@ -1881,16 +1881,13 @@ wasm_runtime_instantiation_args_set_wasi_nn_registry(
|
|||
|
||||
wasi_nn_registry->n_graphs = registry->n_graphs;
|
||||
|
||||
if (registry->model_names)
|
||||
wasi_nn_registry->model_names = bh_strdup(registry->model_names);
|
||||
if (registry->encoding)
|
||||
wasi_nn_registry->encoding = bh_strdup(registry->encoding);
|
||||
if (registry->target)
|
||||
wasi_nn_registry->target = bh_strdup(registry->target);
|
||||
if (registry->loaded)
|
||||
wasi_nn_registry->loaded = bh_strdup(registry->loaded);
|
||||
if (registry->graph_paths)
|
||||
wasi_nn_registry->graph_paths = bh_strdup(registry->graph_paths);
|
||||
for (uint32_t i = 0; i < registry->n_graphs; i++) {
|
||||
registry->graph_paths[i] = bh_strdup(registry->graph_paths[i]);
|
||||
registry->model_names[i] = bh_strdup(registry->model_names[i]);
|
||||
registry->encoding[i] = registry->encoding[i];
|
||||
registry->target[i] = registry->target[i];
|
||||
wasi_nn_registry->loaded = registry->loaded;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
|
|||
|
|
@ -548,11 +548,11 @@ typedef struct WASMModuleInstMemConsumption {
|
|||
#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
typedef struct WASINNRegistry {
|
||||
char **model_names;
|
||||
char **encoding;
|
||||
char **target;
|
||||
uint32_t **encoding;
|
||||
uint32_t **target;
|
||||
|
||||
uint32_t n_graphs;
|
||||
uint32_t *loaded;
|
||||
uint32_t **loaded;
|
||||
char **graph_paths;
|
||||
} WASINNRegistry;
|
||||
#endif
|
||||
|
|
@ -805,8 +805,8 @@ wasm_runtime_instantiation_args_set_wasi_nn_registry(
|
|||
|
||||
WASM_RUNTIME_API_EXTERN bool
|
||||
wasm_runtime_wasi_nn_registry_set_args(WASINNRegistry *registry,
|
||||
const char **model_names, const char **encoding,
|
||||
const char **target, uint32_t n_graphs,
|
||||
const char **model_names, const uint32_t **encoding,
|
||||
const uint32_t **target, uint32_t n_graphs,
|
||||
const char **graph_paths);
|
||||
#endif
|
||||
|
||||
|
|
|
|||
|
|
@ -211,46 +211,6 @@ wasi_nn_destroy()
|
|||
* - model file format
|
||||
* - on device ML framework
|
||||
*/
|
||||
static graph_encoding
|
||||
str2encoding(char *str_encoding)
|
||||
{
|
||||
if (!str_encoding) {
|
||||
NN_ERR_PRINTF("Got empty string encoding");
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!strcmp(str_encoding, "openvino"))
|
||||
return openvino;
|
||||
else if (!strcmp(str_encoding, "tensorflowlite"))
|
||||
return tensorflowlite;
|
||||
else if (!strcmp(str_encoding, "ggml"))
|
||||
return ggml;
|
||||
else if (!strcmp(str_encoding, "onnx"))
|
||||
return onnx;
|
||||
else
|
||||
return unknown_backend;
|
||||
// return autodetect;
|
||||
}
|
||||
|
||||
static execution_target
|
||||
str2target(char *str_target)
|
||||
{
|
||||
if (!str_target) {
|
||||
NN_ERR_PRINTF("Got empty string target");
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!strcmp(str_target, "cpu"))
|
||||
return cpu;
|
||||
else if (!strcmp(str_target, "gpu"))
|
||||
return gpu;
|
||||
else if (!strcmp(str_target, "tpu"))
|
||||
return tpu;
|
||||
else
|
||||
return unsupported_target;
|
||||
// return autodetect;
|
||||
}
|
||||
|
||||
static graph_encoding
|
||||
choose_a_backend()
|
||||
{
|
||||
|
|
@ -630,10 +590,8 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
|
|||
is_loaded = wasi_nn_registry->loaded[model_idx];
|
||||
char *global_model_path_i = wasi_nn_registry->graph_paths[model_idx];
|
||||
|
||||
graph_encoding encoding =
|
||||
str2encoding(wasi_nn_registry->encoding[model_idx]);
|
||||
execution_target target =
|
||||
str2target(wasi_nn_registry->target[model_idx]);
|
||||
graph_encoding encoding = (graph_encoding)(wasi_nn_registry->encoding[model_idx]);
|
||||
execution_target target = (execution_target)(wasi_nn_registry->target[model_idx]);
|
||||
|
||||
// res = ensure_backend(instance, autodetect, wasi_nn_ctx);
|
||||
res = ensure_backend(instance, encoding, wasi_nn_ctx);
|
||||
|
|
|
|||
|
|
@ -21,20 +21,78 @@ typedef struct {
|
|||
uint32 ns_lookup_pool_size;
|
||||
} libc_wasi_parse_context_t;
|
||||
|
||||
typedef struct {
|
||||
const char *model_names[10];
|
||||
const char *encoding[10];
|
||||
const char *target[10];
|
||||
const char *graph_paths[10];
|
||||
uint32 n_graphs;
|
||||
} wasi_nn_parse_context_t;
|
||||
|
||||
typedef enum {
|
||||
LIBC_WASI_PARSE_RESULT_OK = 0,
|
||||
LIBC_WASI_PARSE_RESULT_NEED_HELP,
|
||||
LIBC_WASI_PARSE_RESULT_BAD_PARAM
|
||||
} libc_wasi_parse_result_t;
|
||||
|
||||
typedef struct {
|
||||
const char *model_names[10];
|
||||
const uint32_t *encoding[10];
|
||||
const uint32_t *target[10];
|
||||
const char *graph_paths[10];
|
||||
uint32 n_graphs;
|
||||
} wasi_nn_parse_context_t;
|
||||
|
||||
typedef enum {
|
||||
wasi_nn_openvino = 0,
|
||||
wasi_nn_onnx,
|
||||
wasi_nn_tensorflow,
|
||||
wasi_nn_pytorch,
|
||||
wasi_nn_tensorflowlite,
|
||||
wasi_nn_ggml,
|
||||
wasi_nn_autodetect,
|
||||
wasi_nn_unknown_backend,
|
||||
} wasi_nn_encoding;
|
||||
|
||||
typedef enum wasi_nn_target {
|
||||
wasi_nn_cpu = 0,
|
||||
wasi_nn_gpu,
|
||||
wasi_nn_tpu,
|
||||
wasi_nn_unsupported_target,
|
||||
} wasi_nn_target;
|
||||
|
||||
static wasi_nn_encoding
|
||||
str2encoding(char *str_encoding)
|
||||
{
|
||||
if (!str_encoding) {
|
||||
printf("Got empty string encoding");
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!strcmp(str_encoding, "openvino"))
|
||||
return wasi_nn_openvino;
|
||||
else if (!strcmp(str_encoding, "tensorflowlite"))
|
||||
return wasi_nn_tensorflowlite;
|
||||
else if (!strcmp(str_encoding, "ggml"))
|
||||
return wasi_nn_ggml;
|
||||
else if (!strcmp(str_encoding, "onnx"))
|
||||
return wasi_nn_onnx;
|
||||
else
|
||||
return wasi_nn_unknown_backend;
|
||||
// return autodetect;
|
||||
}
|
||||
|
||||
static wasi_nn_target
|
||||
str2target(char *str_target)
|
||||
{
|
||||
if (!str_target) {
|
||||
printf("Got empty string target");
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!strcmp(str_target, "cpu"))
|
||||
return wasi_nn_cpu;
|
||||
else if (!strcmp(str_target, "gpu"))
|
||||
return wasi_nn_gpu;
|
||||
else if (!strcmp(str_target, "tpu"))
|
||||
return wasi_nn_tpu;
|
||||
else
|
||||
return wasi_nn_unsupported_target;
|
||||
// return autodetect;
|
||||
}
|
||||
|
||||
static void
|
||||
libc_wasi_print_help(void)
|
||||
{
|
||||
|
|
@ -223,10 +281,19 @@ wasi_nn_parse(char **argv, wasi_nn_parse_context_t *ctx)
|
|||
}
|
||||
|
||||
ctx->model_names[ctx->n_graphs] = tokens[0];
|
||||
ctx->encoding[ctx->n_graphs] = tokens[1];
|
||||
ctx->target[ctx->n_graphs] = tokens[2];
|
||||
ctx->graph_paths[ctx->n_graphs++] = tokens[3];
|
||||
ctx->encoding[ctx->n_graphs] = (uint32_t)str2encoding(tokens[1]);
|
||||
ctx->target[ctx->n_graphs] = (uint32_t)str2target(tokens[2]);
|
||||
ctx->graph_paths[ctx->n_graphs] = tokens[3];
|
||||
|
||||
if ((!ctx->model_names[ctx->n_graphs]) ||
|
||||
(ctx->encoding[ctx->n_graphs] == wasi_nn_unknown_backend) ||
|
||||
(ctx->target[ctx->n_graphs] == wasi_nn_unsupported_target)) {
|
||||
ret = LIBC_WASI_PARSE_RESULT_NEED_HELP;
|
||||
printf("Invalid arguments for wasi-nn.\n");
|
||||
goto fail;
|
||||
}
|
||||
|
||||
ctx->n_graphs++;
|
||||
fail:
|
||||
|
||||
return ret;
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user