diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 1a8ad03c6..76cdf1b83 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -397,6 +397,43 @@ detect_and_load_backend(graph_encoding backend_hint, return ret; } +static wasi_nn_error +ensure_backend(wasm_module_inst_t instance, graph_encoding encoding, + WASINNContext **wasi_nn_ctx_ptr) +{ + wasi_nn_error res; + + graph_encoding loaded_backend = autodetect; + if (!detect_and_load_backend(encoding, &loaded_backend)) { + res = invalid_encoding; + NN_ERR_PRINTF("load backend failed"); + goto fail; + } + + WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance); + if (wasi_nn_ctx->is_backend_ctx_initialized) { + if (wasi_nn_ctx->backend != loaded_backend) { + res = unsupported_operation; + goto fail; + } + } + else { + wasi_nn_ctx->backend = loaded_backend; + + /* init() the backend */ + call_wasi_nn_func(wasi_nn_ctx->backend, init, res, + &wasi_nn_ctx->backend_ctx); + if (res != success) + goto fail; + + wasi_nn_ctx->is_backend_ctx_initialized = true; + } + *wasi_nn_ctx_ptr = wasi_nn_ctx; + return success; +fail: + return res; +} + /* WASI-NN implementation */ #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 @@ -410,6 +447,8 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder, graph_encoding encoding, execution_target target, graph *g) #endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */ { + wasi_nn_error res; + NN_DBG_PRINTF("[WASI NN] LOAD [encoding=%d, target=%d]...", encoding, target); @@ -417,7 +456,6 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder, if (!instance) return runtime_error; - wasi_nn_error res; graph_builder_array builder_native = { 0 }; #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 if (success @@ -438,19 +476,8 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder, goto fail; } - graph_encoding loaded_backend = autodetect; - if (!detect_and_load_backend(encoding, &loaded_backend)) { - res = invalid_encoding; - NN_ERR_PRINTF("load backend failed"); - goto fail; - } - - WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance); - wasi_nn_ctx->backend = loaded_backend; - - /* init() the backend */ - call_wasi_nn_func(wasi_nn_ctx->backend, init, res, - &wasi_nn_ctx->backend_ctx); + WASINNContext *wasi_nn_ctx; + res = ensure_backend(instance, encoding, &wasi_nn_ctx); if (res != success) goto fail; @@ -473,6 +500,8 @@ wasi_nn_error wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, graph *g) { + wasi_nn_error res; + wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env); if (!instance) { return runtime_error; @@ -496,19 +525,8 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME %s...", name); - graph_encoding loaded_backend = autodetect; - if (!detect_and_load_backend(autodetect, &loaded_backend)) { - NN_ERR_PRINTF("load backend failed"); - return invalid_encoding; - } - - WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance); - wasi_nn_ctx->backend = loaded_backend; - - wasi_nn_error res; - /* init() the backend */ - call_wasi_nn_func(wasi_nn_ctx->backend, init, res, - &wasi_nn_ctx->backend_ctx); + WASINNContext *wasi_nn_ctx; + res = ensure_backend(instance, autodetect, &wasi_nn_ctx); if (res != success) return res; @@ -526,6 +544,8 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name, int32_t name_len, char *config, int32_t config_len, graph *g) { + wasi_nn_error res; + wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env); if (!instance) { return runtime_error; @@ -554,19 +574,8 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name, NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME_WITH_CONFIG %s %s...", name, config); - graph_encoding loaded_backend = autodetect; - if (!detect_and_load_backend(autodetect, &loaded_backend)) { - NN_ERR_PRINTF("load backend failed"); - return invalid_encoding; - } - - WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance); - wasi_nn_ctx->backend = loaded_backend; - - wasi_nn_error res; - /* init() the backend */ - call_wasi_nn_func(wasi_nn_ctx->backend, init, res, - &wasi_nn_ctx->backend_ctx); + WASINNContext *wasi_nn_ctx; + res = ensure_backend(instance, autodetect, &wasi_nn_ctx); if (res != success) return res; diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h b/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h index bb56f72fb..fcca31023 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h @@ -10,6 +10,7 @@ #include "wasm_export.h" typedef struct { + bool is_backend_ctx_initialized; bool is_model_loaded; graph_encoding backend; void *backend_ctx;