mirror of
https://github.com/bytecodealliance/wasm-micro-runtime.git
synced 2025-09-06 01:41:35 +00:00
wasi-nn: add minimum serialization on WASINNContext (#4387)
currently this is not necessary because context (WASINNContext) is local to instance. (wasm_module_instance_t) i plan to make a context shared among instances in a cluster when fixing https://github.com/bytecodealliance/wasm-micro-runtime/issues/4313. this is a preparation for that direction. an obvious alternative is to tweak the module instance context APIs to allow declaring some kind of contexts instance-local. but i feel, in this particular case, it's more natural to make "wasi-nn handles" shared among threads within a "process". note that, spec-wise, how wasi-nn behaves wrt threads is not defined at all because wasi officially doesn't have threads yet. i suppose, at this point, that how wasi-nn interacts with wasi-threads is something we need to define by ourselves, especially when we are using an outdated wasi-nn version. with this change, if a thread attempts to access a context while another thread is using it, we simply make the operation fail with the "busy" error. this is intended for the mimimum serialization to avoid problems like crashes/leaks/etc. this is not intended to allow parallelism or such. no functional changes are intended at this point yet. cf. https://github.com/bytecodealliance/wasm-micro-runtime/issues/4313 https://github.com/bytecodealliance/wasm-micro-runtime/issues/2430
This commit is contained in:
parent
71c07f3e4e
commit
ea408ab6c0
|
@ -102,6 +102,8 @@ wasi_nn_ctx_destroy(WASINNContext *wasi_nn_ctx)
|
|||
NN_DBG_PRINTF("-> is_model_loaded: %d", wasi_nn_ctx->is_model_loaded);
|
||||
NN_DBG_PRINTF("-> current_encoding: %d", wasi_nn_ctx->backend);
|
||||
|
||||
bh_assert(!wasi_nn_ctx->busy);
|
||||
|
||||
/* deinit() the backend */
|
||||
if (wasi_nn_ctx->is_backend_ctx_initialized) {
|
||||
wasi_nn_error res;
|
||||
|
@ -109,6 +111,7 @@ wasi_nn_ctx_destroy(WASINNContext *wasi_nn_ctx)
|
|||
wasi_nn_ctx->backend_ctx);
|
||||
}
|
||||
|
||||
os_mutex_destroy(&wasi_nn_ctx->lock);
|
||||
wasm_runtime_free(wasi_nn_ctx);
|
||||
}
|
||||
|
||||
|
@ -154,6 +157,11 @@ wasi_nn_initialize_context()
|
|||
}
|
||||
|
||||
memset(wasi_nn_ctx, 0, sizeof(WASINNContext));
|
||||
if (os_mutex_init(&wasi_nn_ctx->lock)) {
|
||||
NN_ERR_PRINTF("Error when initializing a lock for WASI-NN context");
|
||||
wasm_runtime_free(wasi_nn_ctx);
|
||||
return NULL;
|
||||
}
|
||||
return wasi_nn_ctx;
|
||||
}
|
||||
|
||||
|
@ -180,6 +188,35 @@ wasm_runtime_get_wasi_nn_ctx(wasm_module_inst_t instance)
|
|||
return wasi_nn_ctx;
|
||||
}
|
||||
|
||||
static WASINNContext *
|
||||
lock_ctx(wasm_module_inst_t instance)
|
||||
{
|
||||
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
||||
if (wasi_nn_ctx == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
os_mutex_lock(&wasi_nn_ctx->lock);
|
||||
if (wasi_nn_ctx->busy) {
|
||||
os_mutex_unlock(&wasi_nn_ctx->lock);
|
||||
return NULL;
|
||||
}
|
||||
wasi_nn_ctx->busy = true;
|
||||
os_mutex_unlock(&wasi_nn_ctx->lock);
|
||||
return wasi_nn_ctx;
|
||||
}
|
||||
|
||||
static void
|
||||
unlock_ctx(WASINNContext *wasi_nn_ctx)
|
||||
{
|
||||
if (wasi_nn_ctx == NULL) {
|
||||
return;
|
||||
}
|
||||
os_mutex_lock(&wasi_nn_ctx->lock);
|
||||
bh_assert(wasi_nn_ctx->busy);
|
||||
wasi_nn_ctx->busy = false;
|
||||
os_mutex_unlock(&wasi_nn_ctx->lock);
|
||||
}
|
||||
|
||||
void
|
||||
wasi_nn_destroy()
|
||||
{
|
||||
|
@ -405,7 +442,7 @@ detect_and_load_backend(graph_encoding backend_hint,
|
|||
|
||||
static wasi_nn_error
|
||||
ensure_backend(wasm_module_inst_t instance, graph_encoding encoding,
|
||||
WASINNContext **wasi_nn_ctx_ptr)
|
||||
WASINNContext *wasi_nn_ctx)
|
||||
{
|
||||
wasi_nn_error res;
|
||||
|
||||
|
@ -416,7 +453,6 @@ ensure_backend(wasm_module_inst_t instance, graph_encoding encoding,
|
|||
goto fail;
|
||||
}
|
||||
|
||||
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
||||
if (wasi_nn_ctx->is_backend_ctx_initialized) {
|
||||
if (wasi_nn_ctx->backend != loaded_backend) {
|
||||
res = unsupported_operation;
|
||||
|
@ -434,7 +470,6 @@ ensure_backend(wasm_module_inst_t instance, graph_encoding encoding,
|
|||
|
||||
wasi_nn_ctx->is_backend_ctx_initialized = true;
|
||||
}
|
||||
*wasi_nn_ctx_ptr = wasi_nn_ctx;
|
||||
return success;
|
||||
fail:
|
||||
return res;
|
||||
|
@ -462,17 +497,23 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
|
|||
if (!instance)
|
||||
return runtime_error;
|
||||
|
||||
WASINNContext *wasi_nn_ctx = lock_ctx(instance);
|
||||
if (wasi_nn_ctx == NULL) {
|
||||
res = busy;
|
||||
goto fail;
|
||||
}
|
||||
|
||||
graph_builder_array builder_native = { 0 };
|
||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
if (success
|
||||
!= (res = graph_builder_array_app_native(
|
||||
instance, builder, builder_wasm_size, &builder_native)))
|
||||
return res;
|
||||
goto fail;
|
||||
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
|
||||
if (success
|
||||
!= (res = graph_builder_array_app_native(instance, builder,
|
||||
&builder_native)))
|
||||
return res;
|
||||
goto fail;
|
||||
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
|
||||
|
||||
if (!wasm_runtime_validate_native_addr(instance, g,
|
||||
|
@ -482,8 +523,7 @@ wasi_nn_load(wasm_exec_env_t exec_env, graph_builder_array_wasm *builder,
|
|||
goto fail;
|
||||
}
|
||||
|
||||
WASINNContext *wasi_nn_ctx;
|
||||
res = ensure_backend(instance, encoding, &wasi_nn_ctx);
|
||||
res = ensure_backend(instance, encoding, wasi_nn_ctx);
|
||||
if (res != success)
|
||||
goto fail;
|
||||
|
||||
|
@ -498,6 +538,7 @@ fail:
|
|||
// XXX: Free intermediate structure pointers
|
||||
if (builder_native.buf)
|
||||
wasm_runtime_free(builder_native.buf);
|
||||
unlock_ctx(wasi_nn_ctx);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
@ -531,18 +572,26 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
|
|||
|
||||
NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME %s...", name);
|
||||
|
||||
WASINNContext *wasi_nn_ctx;
|
||||
res = ensure_backend(instance, autodetect, &wasi_nn_ctx);
|
||||
WASINNContext *wasi_nn_ctx = lock_ctx(instance);
|
||||
if (wasi_nn_ctx == NULL) {
|
||||
res = busy;
|
||||
goto fail;
|
||||
}
|
||||
|
||||
res = ensure_backend(instance, autodetect, wasi_nn_ctx);
|
||||
if (res != success)
|
||||
return res;
|
||||
goto fail;
|
||||
|
||||
call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res,
|
||||
wasi_nn_ctx->backend_ctx, name, name_len, g);
|
||||
if (res != success)
|
||||
return res;
|
||||
goto fail;
|
||||
|
||||
wasi_nn_ctx->is_model_loaded = true;
|
||||
return success;
|
||||
res = success;
|
||||
fail:
|
||||
unlock_ctx(wasi_nn_ctx);
|
||||
return res;
|
||||
}
|
||||
|
||||
wasi_nn_error
|
||||
|
@ -580,19 +629,28 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name,
|
|||
|
||||
NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME_WITH_CONFIG %s %s...", name, config);
|
||||
|
||||
WASINNContext *wasi_nn_ctx;
|
||||
res = ensure_backend(instance, autodetect, &wasi_nn_ctx);
|
||||
WASINNContext *wasi_nn_ctx = lock_ctx(instance);
|
||||
if (wasi_nn_ctx == NULL) {
|
||||
res = busy;
|
||||
goto fail;
|
||||
}
|
||||
|
||||
res = ensure_backend(instance, autodetect, wasi_nn_ctx);
|
||||
if (res != success)
|
||||
return res;
|
||||
goto fail;
|
||||
;
|
||||
|
||||
call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name_with_config, res,
|
||||
wasi_nn_ctx->backend_ctx, name, name_len, config,
|
||||
config_len, g);
|
||||
if (res != success)
|
||||
return res;
|
||||
goto fail;
|
||||
|
||||
wasi_nn_ctx->is_model_loaded = true;
|
||||
return success;
|
||||
res = success;
|
||||
fail:
|
||||
unlock_ctx(wasi_nn_ctx);
|
||||
return res;
|
||||
}
|
||||
|
||||
wasi_nn_error
|
||||
|
@ -606,20 +664,27 @@ wasi_nn_init_execution_context(wasm_exec_env_t exec_env, graph g,
|
|||
return runtime_error;
|
||||
}
|
||||
|
||||
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
||||
|
||||
wasi_nn_error res;
|
||||
WASINNContext *wasi_nn_ctx = lock_ctx(instance);
|
||||
if (wasi_nn_ctx == NULL) {
|
||||
res = busy;
|
||||
goto fail;
|
||||
}
|
||||
|
||||
if (success != (res = is_model_initialized(wasi_nn_ctx)))
|
||||
return res;
|
||||
goto fail;
|
||||
|
||||
if (!wasm_runtime_validate_native_addr(
|
||||
instance, ctx, (uint64)sizeof(graph_execution_context))) {
|
||||
NN_ERR_PRINTF("ctx is invalid");
|
||||
return invalid_argument;
|
||||
res = invalid_argument;
|
||||
goto fail;
|
||||
}
|
||||
|
||||
call_wasi_nn_func(wasi_nn_ctx->backend, init_execution_context, res,
|
||||
wasi_nn_ctx->backend_ctx, g, ctx);
|
||||
fail:
|
||||
unlock_ctx(wasi_nn_ctx);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
@ -634,17 +699,21 @@ wasi_nn_set_input(wasm_exec_env_t exec_env, graph_execution_context ctx,
|
|||
return runtime_error;
|
||||
}
|
||||
|
||||
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
||||
|
||||
wasi_nn_error res;
|
||||
WASINNContext *wasi_nn_ctx = lock_ctx(instance);
|
||||
if (wasi_nn_ctx == NULL) {
|
||||
res = busy;
|
||||
goto fail;
|
||||
}
|
||||
|
||||
if (success != (res = is_model_initialized(wasi_nn_ctx)))
|
||||
return res;
|
||||
goto fail;
|
||||
|
||||
tensor input_tensor_native = { 0 };
|
||||
if (success
|
||||
!= (res = tensor_app_native(instance, input_tensor,
|
||||
&input_tensor_native)))
|
||||
return res;
|
||||
goto fail;
|
||||
|
||||
call_wasi_nn_func(wasi_nn_ctx->backend, set_input, res,
|
||||
wasi_nn_ctx->backend_ctx, ctx, index,
|
||||
|
@ -652,7 +721,8 @@ wasi_nn_set_input(wasm_exec_env_t exec_env, graph_execution_context ctx,
|
|||
// XXX: Free intermediate structure pointers
|
||||
if (input_tensor_native.dimensions)
|
||||
wasm_runtime_free(input_tensor_native.dimensions);
|
||||
|
||||
fail:
|
||||
unlock_ctx(wasi_nn_ctx);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
@ -666,14 +736,20 @@ wasi_nn_compute(wasm_exec_env_t exec_env, graph_execution_context ctx)
|
|||
return runtime_error;
|
||||
}
|
||||
|
||||
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
||||
|
||||
wasi_nn_error res;
|
||||
WASINNContext *wasi_nn_ctx = lock_ctx(instance);
|
||||
if (wasi_nn_ctx == NULL) {
|
||||
res = busy;
|
||||
goto fail;
|
||||
}
|
||||
|
||||
if (success != (res = is_model_initialized(wasi_nn_ctx)))
|
||||
return res;
|
||||
goto fail;
|
||||
|
||||
call_wasi_nn_func(wasi_nn_ctx->backend, compute, res,
|
||||
wasi_nn_ctx->backend_ctx, ctx);
|
||||
fail:
|
||||
unlock_ctx(wasi_nn_ctx);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
@ -696,16 +772,21 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
|
|||
return runtime_error;
|
||||
}
|
||||
|
||||
WASINNContext *wasi_nn_ctx = wasm_runtime_get_wasi_nn_ctx(instance);
|
||||
|
||||
wasi_nn_error res;
|
||||
WASINNContext *wasi_nn_ctx = lock_ctx(instance);
|
||||
if (wasi_nn_ctx == NULL) {
|
||||
res = busy;
|
||||
goto fail;
|
||||
}
|
||||
|
||||
if (success != (res = is_model_initialized(wasi_nn_ctx)))
|
||||
return res;
|
||||
goto fail;
|
||||
|
||||
if (!wasm_runtime_validate_native_addr(instance, output_tensor_size,
|
||||
(uint64)sizeof(uint32_t))) {
|
||||
NN_ERR_PRINTF("output_tensor_size is invalid");
|
||||
return invalid_argument;
|
||||
res = invalid_argument;
|
||||
goto fail;
|
||||
}
|
||||
|
||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
|
@ -718,6 +799,8 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
|
|||
wasi_nn_ctx->backend_ctx, ctx, index, output_tensor,
|
||||
output_tensor_size);
|
||||
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
|
||||
fail:
|
||||
unlock_ctx(wasi_nn_ctx);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
|
|
@ -9,7 +9,11 @@
|
|||
#include "wasi_nn_types.h"
|
||||
#include "wasm_export.h"
|
||||
|
||||
#include "bh_platform.h"
|
||||
|
||||
typedef struct {
|
||||
korp_mutex lock;
|
||||
bool busy;
|
||||
bool is_backend_ctx_initialized;
|
||||
bool is_model_loaded;
|
||||
graph_encoding backend;
|
||||
|
|
Loading…
Reference in New Issue
Block a user