wasi-nn: add minimum serialization on WASINNContext (#4387)

currently this is not necessary because context (WASINNContext) is
local to instance. (wasm_module_instance_t)

i plan to make a context shared among instances in a cluster when
fixing https://github.com/bytecodealliance/wasm-micro-runtime/issues/4313.
this is a preparation for that direction.

an obvious alternative is to tweak the module instance context APIs
to allow declaring some kind of contexts instance-local. but i feel,
in this particular case, it's more natural to make "wasi-nn handles"
shared among threads within a "process".

note that, spec-wise, how wasi-nn behaves wrt threads is not defined
at all because wasi officially doesn't have threads yet. i suppose, at
this point, that how wasi-nn interacts with wasi-threads is something
we need to define by ourselves, especially when we are using an outdated
wasi-nn version.

with this change, if a thread attempts to access a context while
another thread is using it, we simply make the operation fail with
the "busy" error. this is intended for the mimimum serialization to
avoid problems like crashes/leaks/etc. this is not intended to allow
parallelism or such.

no functional changes are intended at this point yet.

cf.
https://github.com/bytecodealliance/wasm-micro-runtime/issues/4313
https://github.com/bytecodealliance/wasm-micro-runtime/issues/2430
This commit is contained in:
YAMAMOTO Takashi 2025-06-20 10:48:55 +09:00 committed by GitHub
parent 71c07f3e4e
commit ea408ab6c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 120 additions and 33 deletions

View File

@ -102,6 +102,8 @@ wasi_nn_ctx_destroy(WASINNContext *wasi_nn_ctx)
NN_DBG_PRINTF("-> is_model_loaded: %d", wasi_nn_ctx->is_model_loaded);
NN_DBG_PRINTF("-> current_encoding: %d", wasi_nn_ctx->backend);
bh_assert(!wasi_nn_ctx->busy);
/* deinit() the backend */
if (wasi_nn_ctx->is_backend_ctx_initialized) {
wasi_nn_error res;
@ -109,6 +111,7 @@ wasi_nn_ctx_destroy(WASINNContext *wasi_nn_ctx)
wasi_nn_ctx->backend_ctx);
}
os_mutex_destroy(&wasi_nn_ctx->lock);
wasm_runtime_free(wasi_nn_ctx);
}
@ -154,6 +157,11 @@ wasi_nn_initialize_context()
}
memset(wasi_nn_ctx, 0, sizeof(WASINNContext));
if (os_mutex_init(&wasi_nn_ctx->lock)) {
NN_ERR_PRINTF("Error when initializing a lock for WASI-NN context");
wasm_runtime_free(wasi_nn_ctx);
return NULL;
}
return wasi_nn_ctx;
}
@ -180,6 +188,35 @@ wasm_runtime_get_wasi_nn_ctx(wasm_module_inst_t instance)
return wasi_nn_ctx;
}
static WASINNContext *
lock_ctx(wasm_module_inst_t instance)
{
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
if (wasi_nn_ctx == NULL) {
return NULL;
}
os_mutex_lock(&wasi_nn_ctx->lock);
if (wasi_nn_ctx->busy) {
os_mutex_unlock(&wasi_nn_ctx->lock);
return NULL;
}
wasi_nn_ctx->busy = true;
os_mutex_unlock(&wasi_nn_ctx->lock);
return wasi_nn_ctx;
}
static void
unlock_ctx(WASINNContext *wasi_nn_ctx)
{
if (wasi_nn_ctx == NULL) {
return;
}
os_mutex_lock(&wasi_nn_ctx->lock);
bh_assert(wasi_nn_ctx->busy);
wasi_nn_ctx->busy = false;
os_mutex_unlock(&wasi_nn_ctx->lock);
}
void
wasi_nn_destroy()
{
@ -405,7 +442,7 @@ detect_and_load_backend(graph_encoding backend_hint,
static wasi_nn_error
ensure_backend(wasm_module_inst_t instance, graph_encoding encoding,
WASINNContext **wasi_nn_ctx_ptr)
WASINNContext *wasi_nn_ctx)
{
wasi_nn_error res;
@ -416,7 +453,6 @@ ensure_backend(wasm_module_inst_t instance, graph_encoding encoding,
goto fail;
}
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
if (wasi_nn_ctx->is_backend_ctx_initialized) {
if (wasi_nn_ctx->backend != loaded_backend) {
res = unsupported_operation;
@ -434,7 +470,6 @@ ensure_backend(wasm_module_inst_t instance, graph_encoding encoding,
wasi_nn_ctx->is_backend_ctx_initialized = true;
}
*wasi_nn_ctx_ptr = wasi_nn_ctx;
return success;
fail:
return res;
@ -462,17 +497,23 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
if (!instance)
return runtime_error;
WASINNContext *wasi_nn_ctx = lock_ctx(instance);
if (wasi_nn_ctx == NULL) {
res = busy;
goto fail;
}
graph_builder_array builder_native = { 0 };
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
if (success
!= (res = graph_builder_array_app_native(
instance, builder, builder_wasm_size, &builder_native)))
return res;
goto fail;
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
if (success
!= (res = graph_builder_array_app_native(instance, builder,
&builder_native)))
return res;
goto fail;
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
if (!wasm_runtime_validate_native_addr(instance, g,
@ -482,8 +523,7 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
goto fail;
}
WASINNContext *wasi_nn_ctx;
res = ensure_backend(instance, encoding, &wasi_nn_ctx);
res = ensure_backend(instance, encoding, wasi_nn_ctx);
if (res != success)
goto fail;
@ -498,6 +538,7 @@ fail:
// XXX: Free intermediate structure pointers
if (builder_native.buf)
wasm_runtime_free(builder_native.buf);
unlock_ctx(wasi_nn_ctx);
return res;
}
@ -531,18 +572,26 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME %s...", name);
WASINNContext *wasi_nn_ctx;
res = ensure_backend(instance, autodetect, &wasi_nn_ctx);
WASINNContext *wasi_nn_ctx = lock_ctx(instance);
if (wasi_nn_ctx == NULL) {
res = busy;
goto fail;
}
res = ensure_backend(instance, autodetect, wasi_nn_ctx);
if (res != success)
return res;
goto fail;
call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res,
wasi_nn_ctx->backend_ctx, name, name_len, g);
if (res != success)
return res;
goto fail;
wasi_nn_ctx->is_model_loaded = true;
return success;
res = success;
fail:
unlock_ctx(wasi_nn_ctx);
return res;
}
wasi_nn_error
@ -580,19 +629,28 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME_WITH_CONFIG %s %s...", name, config);
WASINNContext *wasi_nn_ctx;
res = ensure_backend(instance, autodetect, &wasi_nn_ctx);
WASINNContext *wasi_nn_ctx = lock_ctx(instance);
if (wasi_nn_ctx == NULL) {
res = busy;
goto fail;
}
res = ensure_backend(instance, autodetect, wasi_nn_ctx);
if (res != success)
return res;
goto fail;
;
call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name_with_config, res,
wasi_nn_ctx->backend_ctx, name, name_len, config,
config_len, g);
if (res != success)
return res;
goto fail;
wasi_nn_ctx->is_model_loaded = true;
return success;
res = success;
fail:
unlock_ctx(wasi_nn_ctx);
return res;
}
wasi_nn_error
@ -606,20 +664,27 @@ wasi_nn_init_execution_context(wasm_exec_env_t exec_env, graph g,
return runtime_error;
}
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
wasi_nn_error res;
WASINNContext *wasi_nn_ctx = lock_ctx(instance);
if (wasi_nn_ctx == NULL) {
res = busy;
goto fail;
}
if (success != (res = is_model_initialized(wasi_nn_ctx)))
return res;
goto fail;
if (!wasm_runtime_validate_native_addr(
instance, ctx, (uint64)sizeof(graph_execution_context))) {
NN_ERR_PRINTF("ctx is invalid");
return invalid_argument;
res = invalid_argument;
goto fail;
}
call_wasi_nn_func(wasi_nn_ctx->backend, init_execution_context, res,
wasi_nn_ctx->backend_ctx, g, ctx);
fail:
unlock_ctx(wasi_nn_ctx);
return res;
}
@ -634,17 +699,21 @@ wasi_nn_set_input(wasm_exec_env_t exec_env, graph_execution_context ctx,
return runtime_error;
}
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
wasi_nn_error res;
WASINNContext *wasi_nn_ctx = lock_ctx(instance);
if (wasi_nn_ctx == NULL) {
res = busy;
goto fail;
}
if (success != (res = is_model_initialized(wasi_nn_ctx)))
return res;
goto fail;
tensor input_tensor_native = { 0 };
if (success
!= (res = tensor_app_native(instance, input_tensor,
&input_tensor_native)))
return res;
goto fail;
call_wasi_nn_func(wasi_nn_ctx->backend, set_input, res,
wasi_nn_ctx->backend_ctx, ctx, index,
@ -652,7 +721,8 @@ wasi_nn_set_input(wasm_exec_env_t exec_env, graph_execution_context ctx,
// XXX: Free intermediate structure pointers
if (input_tensor_native.dimensions)
wasm_runtime_free(input_tensor_native.dimensions);
fail:
unlock_ctx(wasi_nn_ctx);
return res;
}
@ -666,14 +736,20 @@ wasi_nn_compute(wasm_exec_env_t exec_env, graph_execution_context ctx)
return runtime_error;
}
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
wasi_nn_error res;
WASINNContext *wasi_nn_ctx = lock_ctx(instance);
if (wasi_nn_ctx == NULL) {
res = busy;
goto fail;
}
if (success != (res = is_model_initialized(wasi_nn_ctx)))
return res;
goto fail;
call_wasi_nn_func(wasi_nn_ctx->backend, compute, res,
wasi_nn_ctx->backend_ctx, ctx);
fail:
unlock_ctx(wasi_nn_ctx);
return res;
}
@ -696,16 +772,21 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
return runtime_error;
}
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
wasi_nn_error res;
WASINNContext *wasi_nn_ctx = lock_ctx(instance);
if (wasi_nn_ctx == NULL) {
res = busy;
goto fail;
}
if (success != (res = is_model_initialized(wasi_nn_ctx)))
return res;
goto fail;
if (!wasm_runtime_validate_native_addr(instance, output_tensor_size,
(uint64)sizeof(uint32_t))) {
NN_ERR_PRINTF("output_tensor_size is invalid");
return invalid_argument;
res = invalid_argument;
goto fail;
}
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
@ -718,6 +799,8 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
wasi_nn_ctx->backend_ctx, ctx, index, output_tensor,
output_tensor_size);
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
fail:
unlock_ctx(wasi_nn_ctx);
return res;
}

View File

@ -9,7 +9,11 @@
#include "wasi_nn_types.h"
#include "wasm_export.h"
#include "bh_platform.h"
typedef struct {
korp_mutex lock;
bool busy;
bool is_backend_ctx_initialized;
bool is_model_loaded;
graph_encoding backend;