Make wasi_nn_load_by_name and wasi_nn_load_by_name_with_config share a common logic.

This commit is contained in:
zhanheng1 2026-02-06 14:48:36 +08:00
parent 4747d61912
commit 9e89828a34

View File

@ -535,6 +535,89 @@ copyin_and_nul_terminate(wasm_module_inst_t inst, char *name, uint32_t name_len,
return success;
}
static wasi_nn_error
load_by_name_with_optional_config(WASINNContext *wasi_nn_ctx,
wasm_module_inst_t instance, bool use_config,
graph *g, const char *model_name,
const char *config, int32_t config_len)
{
wasi_nn_error res = success;
WASINNRegistry *wasi_nn_registry =
wasm_runtime_get_wasi_nn_registry(instance);
if (!wasi_nn_registry) {
NN_ERR_PRINTF("global context is invalid");
res = not_found;
goto fail;
}
bool is_loaded = false;
uint32 model_idx = 0;
uint32_t global_n_graphs = wasi_nn_registry->n_graphs;
for (model_idx = 0; model_idx < global_n_graphs; model_idx++) {
char *model_name_i = wasi_nn_registry->model_names[model_idx];
if (strcmp(model_name, model_name_i) != 0) {
continue;
}
is_loaded = wasi_nn_registry->loaded[model_idx];
char *global_model_path_i = wasi_nn_registry->graph_paths[model_idx];
graph_encoding encoding =
(graph_encoding)(wasi_nn_registry->encoding[model_idx]);
execution_target target =
(execution_target)(wasi_nn_registry->target[model_idx]);
res = ensure_backend(instance, encoding, wasi_nn_ctx);
if (res != success)
goto fail;
if (!is_loaded && (model_idx < MAX_GLOBAL_GRAPHS_PER_INST)
&& (model_idx < global_n_graphs)) {
NN_DBG_PRINTF(
"Model is not yet loaded, will add to global context");
if (use_config && config && config_len > 0) {
call_wasi_nn_func(
wasi_nn_ctx->backend, load_by_name_with_config, res,
wasi_nn_ctx->backend_ctx, global_model_path_i,
strlen(global_model_path_i), config, config_len, g);
}
else {
call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res,
wasi_nn_ctx->backend_ctx, global_model_path_i,
strlen(global_model_path_i), target, g);
}
if (res != success)
goto fail;
wasi_nn_registry->loaded[model_idx] = 1;
res = success;
break;
}
}
if (is_loaded) {
NN_DBG_PRINTF("Model is already loaded");
res = success;
}
else if (model_idx >= MAX_GLOBAL_GRAPHS_PER_INST) {
// No enlarge for now
NN_ERR_PRINTF("No enough space for new model");
res = too_large;
}
else if (model_idx >= global_n_graphs) {
NN_ERR_PRINTF("Model %s is not loaded, you should pass its path "
"through --wasi-nn-graph",
model_name);
res = not_found;
}
fail:
return res;
}
wasi_nn_error
wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
graph *g)
@ -568,68 +651,9 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
goto fail;
}
WASINNRegistry *wasi_nn_registry =
wasm_runtime_get_wasi_nn_registry(instance);
if (!wasi_nn_registry) {
NN_ERR_PRINTF("global context is invalid");
res = not_found;
goto fail;
}
res = load_by_name_with_optional_config(wasi_nn_ctx, instance, false, g,
nul_terminated_name, NULL, 0);
bool is_loaded = false;
uint32 model_idx = 0;
uint32_t global_n_graphs = wasi_nn_registry->n_graphs;
for (model_idx = 0; model_idx < global_n_graphs; model_idx++) {
char *model_name = wasi_nn_registry->model_names[model_idx];
if (strcmp(nul_terminated_name, model_name) != 0) {
continue;
}
is_loaded = wasi_nn_registry->loaded[model_idx];
char *global_model_path_i = wasi_nn_registry->graph_paths[model_idx];
graph_encoding encoding =
(graph_encoding)(wasi_nn_registry->encoding[model_idx]);
execution_target target =
(execution_target)(wasi_nn_registry->target[model_idx]);
// res = ensure_backend(instance, autodetect, wasi_nn_ctx);
res = ensure_backend(instance, encoding, wasi_nn_ctx);
if (res != success)
goto fail;
if (!is_loaded && (model_idx < MAX_GLOBAL_GRAPHS_PER_INST)
&& (model_idx < global_n_graphs)) {
NN_DBG_PRINTF(
"Model is not yet loaded, will add to global context");
call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res,
wasi_nn_ctx->backend_ctx, global_model_path_i,
strlen(global_model_path_i), target, g);
if (res != success)
goto fail;
wasi_nn_registry->loaded[model_idx] = 1;
res = success;
break;
}
}
if (is_loaded) {
NN_DBG_PRINTF("Model is already loaded");
res = success;
}
else if (model_idx >= MAX_GLOBAL_GRAPHS_PER_INST) {
// No enlarge for now
NN_ERR_PRINTF("No enough space for new model");
res = too_large;
}
else if (model_idx >= global_n_graphs) {
NN_ERR_PRINTF("Model %s is not loaded, you should pass its path "
"through --wasi-nn-graph",
nul_terminated_name);
res = not_found;
}
fail:
if (nul_terminated_name != NULL) {
wasm_runtime_free(nul_terminated_name);
@ -686,9 +710,9 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
goto fail;
;
call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name_with_config, res,
wasi_nn_ctx->backend_ctx, nul_terminated_name, name_len,
nul_terminated_config, config_len, g);
res = load_by_name_with_optional_config(wasi_nn_ctx, instance, true, g,
nul_terminated_name,
nul_terminated_config, config_len);
if (res != success)
goto fail;