From 9e89828a3437227db317960ee31ce34cffa58037 Mon Sep 17 00:00:00 2001 From: zhanheng1 Date: Fri, 6 Feb 2026 14:48:36 +0800 Subject: [PATCH] Make wasi_nn_load_by_name and wasi_nn_load_by_name_with_config share a common logic. --- core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 152 ++++++++++++--------- 1 file changed, 88 insertions(+), 64 deletions(-) diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 8effa8fd3..4feff102b 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -535,6 +535,89 @@ copyin_and_nul_terminate(wasm_module_inst_t inst, char *name, uint32_t name_len, return success; } +static wasi_nn_error +load_by_name_with_optional_config(WASINNContext *wasi_nn_ctx, + wasm_module_inst_t instance, bool use_config, + graph *g, const char *model_name, + const char *config, int32_t config_len) +{ + wasi_nn_error res = success; + + WASINNRegistry *wasi_nn_registry = + wasm_runtime_get_wasi_nn_registry(instance); + if (!wasi_nn_registry) { + NN_ERR_PRINTF("global context is invalid"); + res = not_found; + goto fail; + } + + bool is_loaded = false; + uint32 model_idx = 0; + uint32_t global_n_graphs = wasi_nn_registry->n_graphs; + for (model_idx = 0; model_idx < global_n_graphs; model_idx++) { + char *model_name_i = wasi_nn_registry->model_names[model_idx]; + + if (strcmp(model_name, model_name_i) != 0) { + continue; + } + + is_loaded = wasi_nn_registry->loaded[model_idx]; + char *global_model_path_i = wasi_nn_registry->graph_paths[model_idx]; + + graph_encoding encoding = + (graph_encoding)(wasi_nn_registry->encoding[model_idx]); + execution_target target = + (execution_target)(wasi_nn_registry->target[model_idx]); + + res = ensure_backend(instance, encoding, wasi_nn_ctx); + if (res != success) + goto fail; + + if (!is_loaded && (model_idx < MAX_GLOBAL_GRAPHS_PER_INST) + && (model_idx < global_n_graphs)) { + NN_DBG_PRINTF( + "Model is not yet loaded, will add to global context"); + if (use_config && config && config_len > 0) { + call_wasi_nn_func( + wasi_nn_ctx->backend, load_by_name_with_config, res, + wasi_nn_ctx->backend_ctx, global_model_path_i, + strlen(global_model_path_i), config, config_len, g); + } + else { + call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res, + wasi_nn_ctx->backend_ctx, global_model_path_i, + strlen(global_model_path_i), target, g); + } + if (res != success) + goto fail; + + wasi_nn_registry->loaded[model_idx] = 1; + res = success; + break; + } + } + + if (is_loaded) { + NN_DBG_PRINTF("Model is already loaded"); + res = success; + } + else if (model_idx >= MAX_GLOBAL_GRAPHS_PER_INST) { + // No enlarge for now + NN_ERR_PRINTF("No enough space for new model"); + res = too_large; + } + else if (model_idx >= global_n_graphs) { + NN_ERR_PRINTF("Model %s is not loaded, you should pass its path " + "through --wasi-nn-graph", + model_name); + res = not_found; + } + +fail: + + return res; +} + wasi_nn_error wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, graph *g) @@ -568,68 +651,9 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, goto fail; } - WASINNRegistry *wasi_nn_registry = - wasm_runtime_get_wasi_nn_registry(instance); - if (!wasi_nn_registry) { - NN_ERR_PRINTF("global context is invalid"); - res = not_found; - goto fail; - } + res = load_by_name_with_optional_config(wasi_nn_ctx, instance, false, g, + nul_terminated_name, NULL, 0); - bool is_loaded = false; - uint32 model_idx = 0; - uint32_t global_n_graphs = wasi_nn_registry->n_graphs; - for (model_idx = 0; model_idx < global_n_graphs; model_idx++) { - char *model_name = wasi_nn_registry->model_names[model_idx]; - - if (strcmp(nul_terminated_name, model_name) != 0) { - continue; - } - - is_loaded = wasi_nn_registry->loaded[model_idx]; - char *global_model_path_i = wasi_nn_registry->graph_paths[model_idx]; - - graph_encoding encoding = - (graph_encoding)(wasi_nn_registry->encoding[model_idx]); - execution_target target = - (execution_target)(wasi_nn_registry->target[model_idx]); - - // res = ensure_backend(instance, autodetect, wasi_nn_ctx); - res = ensure_backend(instance, encoding, wasi_nn_ctx); - if (res != success) - goto fail; - - if (!is_loaded && (model_idx < MAX_GLOBAL_GRAPHS_PER_INST) - && (model_idx < global_n_graphs)) { - NN_DBG_PRINTF( - "Model is not yet loaded, will add to global context"); - call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res, - wasi_nn_ctx->backend_ctx, global_model_path_i, - strlen(global_model_path_i), target, g); - if (res != success) - goto fail; - - wasi_nn_registry->loaded[model_idx] = 1; - res = success; - break; - } - } - - if (is_loaded) { - NN_DBG_PRINTF("Model is already loaded"); - res = success; - } - else if (model_idx >= MAX_GLOBAL_GRAPHS_PER_INST) { - // No enlarge for now - NN_ERR_PRINTF("No enough space for new model"); - res = too_large; - } - else if (model_idx >= global_n_graphs) { - NN_ERR_PRINTF("Model %s is not loaded, you should pass its path " - "through --wasi-nn-graph", - nul_terminated_name); - res = not_found; - } fail: if (nul_terminated_name != NULL) { wasm_runtime_free(nul_terminated_name); @@ -686,9 +710,9 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name, goto fail; ; - call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name_with_config, res, - wasi_nn_ctx->backend_ctx, nul_terminated_name, name_len, - nul_terminated_config, config_len, g); + res = load_by_name_with_optional_config(wasi_nn_ctx, instance, true, g, + nul_terminated_name, + nul_terminated_config, config_len); if (res != success) goto fail;