From ea408ab6c0438acef61eccefa53fdd059ee8f9f7 Mon Sep 17 00:00:00 2001 From: YAMAMOTO Takashi Date: Fri, 20 Jun 2025 10:48:55 +0900 Subject: [PATCH] wasi-nn: add minimum serialization on WASINNContext (#4387) currently this is not necessary because context (WASINNContext) is local to instance. (wasm_module_instance_t) i plan to make a context shared among instances in a cluster when fixing https://github.com/bytecodealliance/wasm-micro-runtime/issues/4313. this is a preparation for that direction. an obvious alternative is to tweak the module instance context APIs to allow declaring some kind of contexts instance-local. but i feel, in this particular case, it's more natural to make "wasi-nn handles" shared among threads within a "process". note that, spec-wise, how wasi-nn behaves wrt threads is not defined at all because wasi officially doesn't have threads yet. i suppose, at this point, that how wasi-nn interacts with wasi-threads is something we need to define by ourselves, especially when we are using an outdated wasi-nn version. with this change, if a thread attempts to access a context while another thread is using it, we simply make the operation fail with the "busy" error. this is intended for the mimimum serialization to avoid problems like crashes/leaks/etc. this is not intended to allow parallelism or such. no functional changes are intended at this point yet. cf. https://github.com/bytecodealliance/wasm-micro-runtime/issues/4313 https://github.com/bytecodealliance/wasm-micro-runtime/issues/2430 --- core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 149 ++++++++++++++---- .../libraries/wasi-nn/src/wasi_nn_private.h | 4 + 2 files changed, 120 insertions(+), 33 deletions(-) diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 6fc0d07f8..5c916aa4a 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -102,6 +102,8 @@ wasi_nn_ctx_destroy(WASINNContext *wasi_nn_ctx) NN_DBG_PRINTF("-> is_model_loaded: %d", wasi_nn_ctx->is_model_loaded); NN_DBG_PRINTF("-> current_encoding: %d", wasi_nn_ctx->backend); + bh_assert(!wasi_nn_ctx->busy); + /* deinit() the backend */ if (wasi_nn_ctx->is_backend_ctx_initialized) { wasi_nn_error res; @@ -109,6 +111,7 @@ wasi_nn_ctx_destroy(WASINNContext *wasi_nn_ctx) wasi_nn_ctx->backend_ctx); } + os_mutex_destroy(&wasi_nn_ctx->lock); wasm_runtime_free(wasi_nn_ctx); } @@ -154,6 +157,11 @@ wasi_nn_initialize_context() } memset(wasi_nn_ctx, 0, sizeof(WASINNContext)); + if (os_mutex_init(&wasi_nn_ctx->lock)) { + NN_ERR_PRINTF("Error when initializing a lock for WASI-NN context"); + wasm_runtime_free(wasi_nn_ctx); + return NULL; + } return wasi_nn_ctx; } @@ -180,6 +188,35 @@ wasm_runtime_get_wasi_nn_ctx(wasm_module_inst_t instance) return wasi_nn_ctx; } +static WASINNContext * +lock_ctx(wasm_module_inst_t instance) +{ + WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance); + if (wasi_nn_ctx == NULL) { + return NULL; + } + os_mutex_lock(&wasi_nn_ctx->lock); + if (wasi_nn_ctx->busy) { + os_mutex_unlock(&wasi_nn_ctx->lock); + return NULL; + } + wasi_nn_ctx->busy = true; + os_mutex_unlock(&wasi_nn_ctx->lock); + return wasi_nn_ctx; +} + +static void +unlock_ctx(WASINNContext *wasi_nn_ctx) +{ + if (wasi_nn_ctx == NULL) { + return; + } + os_mutex_lock(&wasi_nn_ctx->lock); + bh_assert(wasi_nn_ctx->busy); + wasi_nn_ctx->busy = false; + os_mutex_unlock(&wasi_nn_ctx->lock); +} + void wasi_nn_destroy() { @@ -405,7 +442,7 @@ detect_and_load_backend(graph_encoding backend_hint, static wasi_nn_error ensure_backend(wasm_module_inst_t instance, graph_encoding encoding, - WASINNContext **wasi_nn_ctx_ptr) + WASINNContext *wasi_nn_ctx) { wasi_nn_error res; @@ -416,7 +453,6 @@ ensure_backend(wasm_module_inst_t instance, graph_encoding encoding, goto fail; } - WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance); if (wasi_nn_ctx->is_backend_ctx_initialized) { if (wasi_nn_ctx->backend != loaded_backend) { res = unsupported_operation; @@ -434,7 +470,6 @@ ensure_backend(wasm_module_inst_t instance, graph_encoding encoding, wasi_nn_ctx->is_backend_ctx_initialized = true; } - *wasi_nn_ctx_ptr = wasi_nn_ctx; return success; fail: return res; @@ -462,17 +497,23 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder, if (!instance) return runtime_error; + WASINNContext *wasi_nn_ctx = lock_ctx(instance); + if (wasi_nn_ctx == NULL) { + res = busy; + goto fail; + } + graph_builder_array builder_native = { 0 }; #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 if (success != (res = graph_builder_array_app_native( instance, builder, builder_wasm_size, &builder_native))) - return res; + goto fail; #else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */ if (success != (res = graph_builder_array_app_native(instance, builder, &builder_native))) - return res; + goto fail; #endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */ if (!wasm_runtime_validate_native_addr(instance, g, @@ -482,8 +523,7 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder, goto fail; } - WASINNContext *wasi_nn_ctx; - res = ensure_backend(instance, encoding, &wasi_nn_ctx); + res = ensure_backend(instance, encoding, wasi_nn_ctx); if (res != success) goto fail; @@ -498,6 +538,7 @@ fail: // XXX: Free intermediate structure pointers if (builder_native.buf) wasm_runtime_free(builder_native.buf); + unlock_ctx(wasi_nn_ctx); return res; } @@ -531,18 +572,26 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME %s...", name); - WASINNContext *wasi_nn_ctx; - res = ensure_backend(instance, autodetect, &wasi_nn_ctx); + WASINNContext *wasi_nn_ctx = lock_ctx(instance); + if (wasi_nn_ctx == NULL) { + res = busy; + goto fail; + } + + res = ensure_backend(instance, autodetect, wasi_nn_ctx); if (res != success) - return res; + goto fail; call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res, wasi_nn_ctx->backend_ctx, name, name_len, g); if (res != success) - return res; + goto fail; wasi_nn_ctx->is_model_loaded = true; - return success; + res = success; +fail: + unlock_ctx(wasi_nn_ctx); + return res; } wasi_nn_error @@ -580,19 +629,28 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name, NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME_WITH_CONFIG %s %s...", name, config); - WASINNContext *wasi_nn_ctx; - res = ensure_backend(instance, autodetect, &wasi_nn_ctx); + WASINNContext *wasi_nn_ctx = lock_ctx(instance); + if (wasi_nn_ctx == NULL) { + res = busy; + goto fail; + } + + res = ensure_backend(instance, autodetect, wasi_nn_ctx); if (res != success) - return res; + goto fail; + ; call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name_with_config, res, wasi_nn_ctx->backend_ctx, name, name_len, config, config_len, g); if (res != success) - return res; + goto fail; wasi_nn_ctx->is_model_loaded = true; - return success; + res = success; +fail: + unlock_ctx(wasi_nn_ctx); + return res; } wasi_nn_error @@ -606,20 +664,27 @@ wasi_nn_init_execution_context(wasm_exec_env_t exec_env, graph g, return runtime_error; } - WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance); - wasi_nn_error res; + WASINNContext *wasi_nn_ctx = lock_ctx(instance); + if (wasi_nn_ctx == NULL) { + res = busy; + goto fail; + } + if (success != (res = is_model_initialized(wasi_nn_ctx))) - return res; + goto fail; if (!wasm_runtime_validate_native_addr( instance, ctx, (uint64)sizeof(graph_execution_context))) { NN_ERR_PRINTF("ctx is invalid"); - return invalid_argument; + res = invalid_argument; + goto fail; } call_wasi_nn_func(wasi_nn_ctx->backend, init_execution_context, res, wasi_nn_ctx->backend_ctx, g, ctx); +fail: + unlock_ctx(wasi_nn_ctx); return res; } @@ -634,17 +699,21 @@ wasi_nn_set_input(wasm_exec_env_t exec_env, graph_execution_context ctx, return runtime_error; } - WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance); - wasi_nn_error res; + WASINNContext *wasi_nn_ctx = lock_ctx(instance); + if (wasi_nn_ctx == NULL) { + res = busy; + goto fail; + } + if (success != (res = is_model_initialized(wasi_nn_ctx))) - return res; + goto fail; tensor input_tensor_native = { 0 }; if (success != (res = tensor_app_native(instance, input_tensor, &input_tensor_native))) - return res; + goto fail; call_wasi_nn_func(wasi_nn_ctx->backend, set_input, res, wasi_nn_ctx->backend_ctx, ctx, index, @@ -652,7 +721,8 @@ wasi_nn_set_input(wasm_exec_env_t exec_env, graph_execution_context ctx, // XXX: Free intermediate structure pointers if (input_tensor_native.dimensions) wasm_runtime_free(input_tensor_native.dimensions); - +fail: + unlock_ctx(wasi_nn_ctx); return res; } @@ -666,14 +736,20 @@ wasi_nn_compute(wasm_exec_env_t exec_env, graph_execution_context ctx) return runtime_error; } - WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance); - wasi_nn_error res; + WASINNContext *wasi_nn_ctx = lock_ctx(instance); + if (wasi_nn_ctx == NULL) { + res = busy; + goto fail; + } + if (success != (res = is_model_initialized(wasi_nn_ctx))) - return res; + goto fail; call_wasi_nn_func(wasi_nn_ctx->backend, compute, res, wasi_nn_ctx->backend_ctx, ctx); +fail: + unlock_ctx(wasi_nn_ctx); return res; } @@ -696,16 +772,21 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx, return runtime_error; } - WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance); - wasi_nn_error res; + WASINNContext *wasi_nn_ctx = lock_ctx(instance); + if (wasi_nn_ctx == NULL) { + res = busy; + goto fail; + } + if (success != (res = is_model_initialized(wasi_nn_ctx))) - return res; + goto fail; if (!wasm_runtime_validate_native_addr(instance, output_tensor_size, (uint64)sizeof(uint32_t))) { NN_ERR_PRINTF("output_tensor_size is invalid"); - return invalid_argument; + res = invalid_argument; + goto fail; } #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 @@ -718,6 +799,8 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx, wasi_nn_ctx->backend_ctx, ctx, index, output_tensor, output_tensor_size); #endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */ +fail: + unlock_ctx(wasi_nn_ctx); return res; } diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h b/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h index fcca31023..a20ad1718 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h @@ -9,7 +9,11 @@ #include "wasi_nn_types.h" #include "wasm_export.h" +#include "bh_platform.h" + typedef struct { + korp_mutex lock; + bool busy; bool is_backend_ctx_initialized; bool is_model_loaded; graph_encoding backend;