mirror of
https://github.com/bytecodealliance/wasm-micro-runtime.git
synced 2025-06-18 02:59:21 +00:00
wasi-nn: fix backend leak on multiple loads (#4366)
cf. https://github.com/bytecodealliance/wasm-micro-runtime/issues/4340
This commit is contained in:
parent
8e60feb181
commit
0d001c4c38
|
@ -397,6 +397,43 @@ detect_and_load_backend(graph_encoding backend_hint,
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static wasi_nn_error
|
||||||
|
ensure_backend(wasm_module_inst_t instance, graph_encoding encoding,
|
||||||
|
WASINNContext **wasi_nn_ctx_ptr)
|
||||||
|
{
|
||||||
|
wasi_nn_error res;
|
||||||
|
|
||||||
|
graph_encoding loaded_backend = autodetect;
|
||||||
|
if (!detect_and_load_backend(encoding, &loaded_backend)) {
|
||||||
|
res = invalid_encoding;
|
||||||
|
NN_ERR_PRINTF("load backend failed");
|
||||||
|
goto fail;
|
||||||
|
}
|
||||||
|
|
||||||
|
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
||||||
|
if (wasi_nn_ctx->is_backend_ctx_initialized) {
|
||||||
|
if (wasi_nn_ctx->backend != loaded_backend) {
|
||||||
|
res = unsupported_operation;
|
||||||
|
goto fail;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
wasi_nn_ctx->backend = loaded_backend;
|
||||||
|
|
||||||
|
/* init() the backend */
|
||||||
|
call_wasi_nn_func(wasi_nn_ctx->backend, init, res,
|
||||||
|
&wasi_nn_ctx->backend_ctx);
|
||||||
|
if (res != success)
|
||||||
|
goto fail;
|
||||||
|
|
||||||
|
wasi_nn_ctx->is_backend_ctx_initialized = true;
|
||||||
|
}
|
||||||
|
*wasi_nn_ctx_ptr = wasi_nn_ctx;
|
||||||
|
return success;
|
||||||
|
fail:
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
/* WASI-NN implementation */
|
/* WASI-NN implementation */
|
||||||
|
|
||||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||||
|
@ -410,6 +447,8 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
|
||||||
graph_encoding encoding, execution_target target, graph *g)
|
graph_encoding encoding, execution_target target, graph *g)
|
||||||
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
|
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
|
||||||
{
|
{
|
||||||
|
wasi_nn_error res;
|
||||||
|
|
||||||
NN_DBG_PRINTF("[WASI NN] LOAD [encoding=%d, target=%d]...", encoding,
|
NN_DBG_PRINTF("[WASI NN] LOAD [encoding=%d, target=%d]...", encoding,
|
||||||
target);
|
target);
|
||||||
|
|
||||||
|
@ -417,7 +456,6 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
|
||||||
if (!instance)
|
if (!instance)
|
||||||
return runtime_error;
|
return runtime_error;
|
||||||
|
|
||||||
wasi_nn_error res;
|
|
||||||
graph_builder_array builder_native = { 0 };
|
graph_builder_array builder_native = { 0 };
|
||||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||||
if (success
|
if (success
|
||||||
|
@ -438,19 +476,8 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
|
||||||
goto fail;
|
goto fail;
|
||||||
}
|
}
|
||||||
|
|
||||||
graph_encoding loaded_backend = autodetect;
|
WASINNContext *wasi_nn_ctx;
|
||||||
if (!detect_and_load_backend(encoding, &loaded_backend)) {
|
res = ensure_backend(instance, encoding, &wasi_nn_ctx);
|
||||||
res = invalid_encoding;
|
|
||||||
NN_ERR_PRINTF("load backend failed");
|
|
||||||
goto fail;
|
|
||||||
}
|
|
||||||
|
|
||||||
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
|
||||||
wasi_nn_ctx->backend = loaded_backend;
|
|
||||||
|
|
||||||
/* init() the backend */
|
|
||||||
call_wasi_nn_func(wasi_nn_ctx->backend, init, res,
|
|
||||||
&wasi_nn_ctx->backend_ctx);
|
|
||||||
if (res != success)
|
if (res != success)
|
||||||
goto fail;
|
goto fail;
|
||||||
|
|
||||||
|
@ -473,6 +500,8 @@ wasi_nn_error
|
||||||
wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
|
wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
|
||||||
graph *g)
|
graph *g)
|
||||||
{
|
{
|
||||||
|
wasi_nn_error res;
|
||||||
|
|
||||||
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
|
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
|
||||||
if (!instance) {
|
if (!instance) {
|
||||||
return runtime_error;
|
return runtime_error;
|
||||||
|
@ -496,19 +525,8 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
|
||||||
|
|
||||||
NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME %s...", name);
|
NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME %s...", name);
|
||||||
|
|
||||||
graph_encoding loaded_backend = autodetect;
|
WASINNContext *wasi_nn_ctx;
|
||||||
if (!detect_and_load_backend(autodetect, &loaded_backend)) {
|
res = ensure_backend(instance, autodetect, &wasi_nn_ctx);
|
||||||
NN_ERR_PRINTF("load backend failed");
|
|
||||||
return invalid_encoding;
|
|
||||||
}
|
|
||||||
|
|
||||||
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
|
||||||
wasi_nn_ctx->backend = loaded_backend;
|
|
||||||
|
|
||||||
wasi_nn_error res;
|
|
||||||
/* init() the backend */
|
|
||||||
call_wasi_nn_func(wasi_nn_ctx->backend, init, res,
|
|
||||||
&wasi_nn_ctx->backend_ctx);
|
|
||||||
if (res != success)
|
if (res != success)
|
||||||
return res;
|
return res;
|
||||||
|
|
||||||
|
@ -526,6 +544,8 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
|
||||||
int32_t name_len, char *config,
|
int32_t name_len, char *config,
|
||||||
int32_t config_len, graph *g)
|
int32_t config_len, graph *g)
|
||||||
{
|
{
|
||||||
|
wasi_nn_error res;
|
||||||
|
|
||||||
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
|
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
|
||||||
if (!instance) {
|
if (!instance) {
|
||||||
return runtime_error;
|
return runtime_error;
|
||||||
|
@ -554,19 +574,8 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
|
||||||
|
|
||||||
NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME_WITH_CONFIG %s %s...", name, config);
|
NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME_WITH_CONFIG %s %s...", name, config);
|
||||||
|
|
||||||
graph_encoding loaded_backend = autodetect;
|
WASINNContext *wasi_nn_ctx;
|
||||||
if (!detect_and_load_backend(autodetect, &loaded_backend)) {
|
res = ensure_backend(instance, autodetect, &wasi_nn_ctx);
|
||||||
NN_ERR_PRINTF("load backend failed");
|
|
||||||
return invalid_encoding;
|
|
||||||
}
|
|
||||||
|
|
||||||
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
|
||||||
wasi_nn_ctx->backend = loaded_backend;
|
|
||||||
|
|
||||||
wasi_nn_error res;
|
|
||||||
/* init() the backend */
|
|
||||||
call_wasi_nn_func(wasi_nn_ctx->backend, init, res,
|
|
||||||
&wasi_nn_ctx->backend_ctx);
|
|
||||||
if (res != success)
|
if (res != success)
|
||||||
return res;
|
return res;
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
#include "wasm_export.h"
|
#include "wasm_export.h"
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
|
bool is_backend_ctx_initialized;
|
||||||
bool is_model_loaded;
|
bool is_model_loaded;
|
||||||
graph_encoding backend;
|
graph_encoding backend;
|
||||||
void *backend_ctx;
|
void *backend_ctx;
|
||||||
|
|
Loading…
Reference in New Issue
Block a user