wasi-nn: fix backend leak on multiple loads (#4366)

cf. https://github.com/bytecodealliance/wasm-micro-runtime/issues/4340
This commit is contained in:
YAMAMOTO Takashi 2025-06-17 12:01:07 +09:00 committed by GitHub
parent 8e60feb181
commit 0d001c4c38
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 50 additions and 40 deletions

View File

@ -397,6 +397,43 @@ detect_and_load_backend(graph_encoding backend_hint,
return ret;
}
static wasi_nn_error
ensure_backend(wasm_module_inst_t instance, graph_encoding encoding,
WASINNContext **wasi_nn_ctx_ptr)
{
wasi_nn_error res;
graph_encoding loaded_backend = autodetect;
if (!detect_and_load_backend(encoding, &loaded_backend)) {
res = invalid_encoding;
NN_ERR_PRINTF("load backend failed");
goto fail;
}
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
if (wasi_nn_ctx->is_backend_ctx_initialized) {
if (wasi_nn_ctx->backend != loaded_backend) {
res = unsupported_operation;
goto fail;
}
}
else {
wasi_nn_ctx->backend = loaded_backend;
/* init() the backend */
call_wasi_nn_func(wasi_nn_ctx->backend, init, res,
&wasi_nn_ctx->backend_ctx);
if (res != success)
goto fail;
wasi_nn_ctx->is_backend_ctx_initialized = true;
}
*wasi_nn_ctx_ptr = wasi_nn_ctx;
return success;
fail:
return res;
}
/* WASI-NN implementation */
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
@ -410,6 +447,8 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
graph_encoding encoding, execution_target target, graph *g)
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
{
wasi_nn_error res;
NN_DBG_PRINTF("[WASI NN] LOAD [encoding=%d, target=%d]...", encoding,
target);
@ -417,7 +456,6 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
if (!instance)
return runtime_error;
wasi_nn_error res;
graph_builder_array builder_native = { 0 };
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
if (success
@ -438,19 +476,8 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
goto fail;
}
graph_encoding loaded_backend = autodetect;
if (!detect_and_load_backend(encoding, &loaded_backend)) {
res = invalid_encoding;
NN_ERR_PRINTF("load backend failed");
goto fail;
}
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
wasi_nn_ctx->backend = loaded_backend;
/* init() the backend */
call_wasi_nn_func(wasi_nn_ctx->backend, init, res,
&wasi_nn_ctx->backend_ctx);
WASINNContext *wasi_nn_ctx;
res = ensure_backend(instance, encoding, &wasi_nn_ctx);
if (res != success)
goto fail;
@ -473,6 +500,8 @@ wasi_nn_error
wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
graph *g)
{
wasi_nn_error res;
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
if (!instance) {
return runtime_error;
@ -496,19 +525,8 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME %s...", name);
graph_encoding loaded_backend = autodetect;
if (!detect_and_load_backend(autodetect, &loaded_backend)) {
NN_ERR_PRINTF("load backend failed");
return invalid_encoding;
}
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
wasi_nn_ctx->backend = loaded_backend;
wasi_nn_error res;
/* init() the backend */
call_wasi_nn_func(wasi_nn_ctx->backend, init, res,
&wasi_nn_ctx->backend_ctx);
WASINNContext *wasi_nn_ctx;
res = ensure_backend(instance, autodetect, &wasi_nn_ctx);
if (res != success)
return res;
@ -526,6 +544,8 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
int32_t name_len, char *config,
int32_t config_len, graph *g)
{
wasi_nn_error res;
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
if (!instance) {
return runtime_error;
@ -554,19 +574,8 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME_WITH_CONFIG %s %s...", name, config);
graph_encoding loaded_backend = autodetect;
if (!detect_and_load_backend(autodetect, &loaded_backend)) {
NN_ERR_PRINTF("load backend failed");
return invalid_encoding;
}
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
wasi_nn_ctx->backend = loaded_backend;
wasi_nn_error res;
/* init() the backend */
call_wasi_nn_func(wasi_nn_ctx->backend, init, res,
&wasi_nn_ctx->backend_ctx);
WASINNContext *wasi_nn_ctx;
res = ensure_backend(instance, autodetect, &wasi_nn_ctx);
if (res != success)
return res;

View File

@ -10,6 +10,7 @@
#include "wasm_export.h"
typedef struct {
bool is_backend_ctx_initialized;
bool is_model_loaded;
graph_encoding backend;
void *backend_ctx;