diff --git a/core/iwasm/common/wasm_native.c b/core/iwasm/common/wasm_native.c index b8430520a..2ba4a5778 100644 --- a/core/iwasm/common/wasm_native.c +++ b/core/iwasm/common/wasm_native.c @@ -26,7 +26,7 @@ static void *g_wasi_context_key; #endif /* WASM_ENABLE_LIBC_WASI */ #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 -static void *g_wasi_nn_context_key; +static void *g_wasi_nn_registry_key; #endif uint32 @@ -478,17 +478,17 @@ wasi_context_dtor(WASMModuleInstanceCommon *inst, void *ctx) #endif /* end of WASM_ENABLE_LIBC_WASI */ #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 -WASINNGlobalContext * -wasm_runtime_get_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm) +WASINNRegistry * +wasm_runtime_get_wasi_nn_registry(WASMModuleInstanceCommon *module_inst_comm) { - return wasm_native_get_context(module_inst_comm, g_wasi_nn_context_key); + return wasm_native_get_context(module_inst_comm, g_wasi_nn_registry_key); } void -wasm_runtime_set_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm, - WASINNGlobalContext *wasi_nn_ctx) +wasm_runtime_set_wasi_nn_registry(WASMModuleInstanceCommon *module_inst_comm, + WASINNRegistry *wasi_nn_ctx) { - wasm_native_set_context(module_inst_comm, g_wasi_nn_context_key, + wasm_native_set_context(module_inst_comm, g_wasi_nn_registry_key, wasi_nn_ctx); } @@ -499,7 +499,7 @@ wasi_nn_context_dtor(WASMModuleInstanceCommon *inst, void *ctx) return; } - wasm_runtime_destroy_wasi_nn_global_ctx(inst); + wasm_runtime_wasi_nn_registry_destroy(ctx); } #endif @@ -612,9 +612,9 @@ wasm_native_init() #endif /* WASM_ENABLE_LIB_RATS */ #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - g_wasi_nn_context_key = + g_wasi_nn_registry_key = wasm_native_create_context_key(wasi_nn_context_dtor); - if (g_wasi_nn_context_key == NULL) { + if (g_wasi_nn_registry_key == NULL) { goto fail; } @@ -684,9 +684,9 @@ wasm_native_destroy() #endif #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - if (g_wasi_nn_context_key != NULL) { - wasm_native_destroy_context_key(g_wasi_nn_context_key); - g_wasi_nn_context_key = NULL; + if (g_wasi_nn_registry_key != NULL) { + wasm_native_destroy_context_key(g_wasi_nn_registry_key); + g_wasi_nn_registry_key = NULL; } wasi_nn_destroy(); #endif diff --git a/core/iwasm/common/wasm_runtime_common.c b/core/iwasm/common/wasm_runtime_common.c index 536261355..2d011d79b 100644 --- a/core/iwasm/common/wasm_runtime_common.c +++ b/core/iwasm/common/wasm_runtime_common.c @@ -1796,48 +1796,48 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool( #endif /* WASM_ENABLE_LIBC_WASI != 0 */ #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 -typedef struct WASINNArguments WASINNArguments; - void -wasm_runtime_wasi_nn_graph_registry_args_set_defaults(WASINNArguments *args) +wasm_runtime_wasi_nn_graph_registry_args_set_defaults(WASINNRegistry *args) { memset(args, 0, sizeof(*args)); } bool -wasi_nn_graph_registry_set_args(WASINNArguments *registry, - const char **model_names, const char **encoding, - const char **target, uint32_t n_graphs, - const char **graph_paths) +wasm_runtime_wasi_nn_registry_set_args(WASINNRegistry *registry, + const char **model_names, const char **encoding, + const char **target, uint32_t n_graphs, + const char **graph_paths) { if (!registry || !model_names || !encoding || !target || !graph_paths) { return false; } registry->n_graphs = n_graphs; - registry->target = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs); - registry->encoding = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs); - registry->model_names = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs); - registry->graph_paths = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs); + registry->target = (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); + registry->encoding = (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); + registry->loaded = (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); + registry->model_names = (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); + registry->graph_paths = (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); memset(registry->target, 0, sizeof(uint32_t *) * n_graphs); memset(registry->encoding, 0, sizeof(uint32_t *) * n_graphs); + memset(registry->loaded, 0, sizeof(uint32_t *) * n_graphs); memset(registry->model_names, 0, sizeof(uint32_t *) * n_graphs); memset(registry->graph_paths, 0, sizeof(uint32_t *) * n_graphs); for (uint32_t i = 0; i < registry->n_graphs; i++) { - registry->graph_paths[i] = strdup(graph_paths[i]); - registry->model_names[i] = strdup(model_names[i]); - registry->encoding[i] = strdup(encoding[i]); - registry->target[i] = strdup(target[i]); + registry->graph_paths[i] = bh_strdup(graph_paths[i]); + registry->model_names[i] = bh_strdup(model_names[i]); + registry->encoding[i] = bh_strdup(encoding[i]); + registry->target[i] = bh_strdup(target[i]); } return true; } int -wasi_nn_graph_registry_create(WASINNArguments **registryp) +wasm_runtime_wasi_nn_registry_create(WASINNRegistry **registryp) { - WASINNArguments *args = wasm_runtime_malloc(sizeof(*args)); + WASINNRegistry *args = wasm_runtime_malloc(sizeof(*args)); if (args == NULL) { return -1; } @@ -1847,28 +1847,45 @@ wasi_nn_graph_registry_create(WASINNArguments **registryp) } void -wasi_nn_graph_registry_destroy(WASINNArguments *registry) +wasm_runtime_wasi_nn_registry_destroy(WASINNRegistry *registry) { if (registry) { for (uint32_t i = 0; i < registry->n_graphs; i++) if (registry->graph_paths[i]) { - free(registry->graph_paths[i]); - if (registry->model_names[i]) - free(registry->model_names[i]); - if (registry->encoding[i]) - free(registry->encoding[i]); - if (registry->target[i]) - free(registry->target[i]); + wasm_runtime_free(registry->graph_paths[i]); + if (registry->model_names[i]) + wasm_runtime_free(registry->model_names[i]); + if (registry->encoding[i]) + wasm_runtime_free(registry->encoding[i]); + if (registry->target[i]) + wasm_runtime_free(registry->target[i]); } - free(registry); + if (registry->loaded) + wasm_runtime_free(registry->loaded); + wasm_runtime_free(registry); } } void -wasm_runtime_instantiation_args_set_wasi_nn_graph_registry( - struct InstantiationArgs2 *p, WASINNArguments *registry) +wasm_runtime_instantiation_args_set_wasi_nn_registry( + struct InstantiationArgs2 *p, WASINNRegistry *registry) { - p->nn_registry = *registry; + if (!registry) + return; + WASINNRegistry *wasi_nn_registry = &p->nn_registry; + + wasi_nn_registry->n_graphs = registry->n_graphs; + + if (registry->model_names) + wasi_nn_registry->model_names = bh_strdup(registry->model_names); + if (registry->encoding) + wasi_nn_registry->encoding = bh_strdup(registry->encoding); + if (registry->target) + wasi_nn_registry->target = bh_strdup(registry->target); + if (registry->loaded) + wasi_nn_registry->loaded = bh_strdup(registry->loaded); + if (registry->graph_paths) + wasi_nn_registry->graph_paths = bh_strdup(registry->graph_paths); } #endif @@ -8159,142 +8176,73 @@ wasm_runtime_check_and_update_last_used_shared_heap( #endif #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 -bool -wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, - const char **model_names, - const char **encoding, const char **target, - const uint32_t n_graphs, - char *graph_paths[], char *error_buf, - uint32_t error_buf_size) -{ - WASINNGlobalContext *ctx; - bool ret = false; - - ctx = runtime_malloc(sizeof(*ctx), module_inst, error_buf, error_buf_size); - if (!ctx) - return false; - - ctx->n_graphs = n_graphs; - - ctx->encoding = (uint32_t *)malloc(sizeof(uint32_t) * n_graphs); - memset(ctx->encoding, 0, sizeof(uint32_t) * n_graphs); - ctx->target = (uint32_t *)malloc(sizeof(uint32_t) * n_graphs); - memset(ctx->target, 0, sizeof(uint32_t) * n_graphs); - ctx->loaded = (uint32_t *)malloc(sizeof(uint32_t) * n_graphs); - memset(ctx->loaded, 0, sizeof(uint32_t) * n_graphs); - ctx->model_names = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs); - memset(ctx->model_names, 0, sizeof(uint32_t *) * n_graphs); - ctx->graph_paths = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs); - memset(ctx->graph_paths, 0, sizeof(uint32_t *) * n_graphs); - - for (uint32_t i = 0; i < n_graphs; i++) { - ctx->graph_paths[i] = strdup(graph_paths[i]); - ctx->model_names[i] = strdup(model_names[i]); - ctx->target[i] = strdup(target[i]); - ctx->encoding[i] = strdup(encoding[i]); - } - - wasm_runtime_set_wasi_nn_global_ctx(module_inst, ctx); - - ret = true; - - return ret; -} - -void -wasm_runtime_destroy_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst) -{ - WASINNGlobalContext *wasi_nn_global_ctx = - wasm_runtime_get_wasi_nn_global_ctx(module_inst); - - for (uint32 i = 0; i < wasi_nn_global_ctx->n_graphs; i++) { - // All graphs will be unregistered in deinit() - if (wasi_nn_global_ctx->graph_paths[i]) - free(wasi_nn_global_ctx->graph_paths[i]); - if (wasi_nn_global_ctx->model_names[i]) - free(wasi_nn_global_ctx->model_names[i]); - if (wasi_nn_global_ctx->encoding[i]) - free(wasi_nn_global_ctx->encoding[i]); - if (wasi_nn_global_ctx->target[i]) - free(wasi_nn_global_ctx->target[i]); - } - free(wasi_nn_global_ctx->encoding); - free(wasi_nn_global_ctx->target); - free(wasi_nn_global_ctx->loaded); - free(wasi_nn_global_ctx->model_names); - free(wasi_nn_global_ctx->graph_paths); - - if (wasi_nn_global_ctx) { - wasm_runtime_free(wasi_nn_global_ctx); - } -} uint32_t -wasm_runtime_get_wasi_nn_global_ctx_ngraphs( - WASINNGlobalContext *wasi_nn_global_ctx) +wasm_runtime_get_wasi_nn_registry_ngraphs( + WASINNRegistry *wasi_nn_registry) { - if (wasi_nn_global_ctx) - return wasi_nn_global_ctx->n_graphs; + if (wasi_nn_registry) + return wasi_nn_registry->n_graphs; return -1; } char * -wasm_runtime_get_wasi_nn_global_ctx_model_names_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx) +wasm_runtime_get_wasi_nn_registry_model_names_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx) { - if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs)) - return wasi_nn_global_ctx->model_names[idx]; + if (wasi_nn_registry && (idx < wasi_nn_registry->n_graphs)) + return wasi_nn_registry->model_names[idx]; return NULL; } char * -wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx) +wasm_runtime_get_wasi_nn_registry_graph_paths_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx) { - if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs)) - return wasi_nn_global_ctx->graph_paths[idx]; + if (wasi_nn_registry && (idx < wasi_nn_registry->n_graphs)) + return wasi_nn_registry->graph_paths[idx]; return NULL; } uint32_t -wasm_runtime_get_wasi_nn_global_ctx_loaded_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx) +wasm_runtime_get_wasi_nn_registry_loaded_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx) { - if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs)) - return wasi_nn_global_ctx->loaded[idx]; + if (wasi_nn_registry && (idx < wasi_nn_registry->n_graphs)) + return wasi_nn_registry->loaded[idx]; return -1; } uint32_t -wasm_runtime_set_wasi_nn_global_ctx_loaded_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value) +wasm_runtime_set_wasi_nn_registry_loaded_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx, uint32_t value) { - if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs)) - wasi_nn_global_ctx->loaded[idx] = value; + if (wasi_nn_registry && (idx < wasi_nn_registry->n_graphs)) + wasi_nn_registry->loaded[idx] = value; return 0; } char * -wasm_runtime_get_wasi_nn_global_ctx_encoding_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx) +wasm_runtime_get_wasi_nn_registry_encoding_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx) { - if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs)) - return wasi_nn_global_ctx->encoding[idx]; + if (wasi_nn_registry && (idx < wasi_nn_registry->n_graphs)) + return wasi_nn_registry->encoding[idx]; return NULL; } char * -wasm_runtime_get_wasi_nn_global_ctx_target_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx) +wasm_runtime_get_wasi_nn_registry_target_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx) { - if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs)) - return wasi_nn_global_ctx->target[idx]; + if (wasi_nn_registry && (idx < wasi_nn_registry->n_graphs)) + return wasi_nn_registry->target[idx]; return NULL; } diff --git a/core/iwasm/common/wasm_runtime_common.h b/core/iwasm/common/wasm_runtime_common.h index 869ac1eeb..0f7b1fbce 100644 --- a/core/iwasm/common/wasm_runtime_common.h +++ b/core/iwasm/common/wasm_runtime_common.h @@ -546,7 +546,7 @@ typedef struct WASMModuleInstMemConsumption { } WASMModuleInstMemConsumption; #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 -typedef struct WASINNGlobalContext { +typedef struct WASINNRegistry { char **model_names; char **encoding; char **target; @@ -554,7 +554,7 @@ typedef struct WASINNGlobalContext { uint32_t n_graphs; uint32_t *loaded; char **graph_paths; -} WASINNGlobalContext; +} WASINNRegistry; #endif #if WASM_ENABLE_LIBC_WASI != 0 @@ -625,20 +625,11 @@ wasm_runtime_get_exec_env_tls(void); #endif #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 -typedef struct WASINNArguments { - char **model_names; - char **encoding; - char **target; - - char **graph_paths; - uint32_t n_graphs; -} WASINNArguments; - WASM_RUNTIME_API_EXTERN int -wasi_nn_graph_registry_create(WASINNArguments **registryp); +wasm_runtime_wasi_nn_registry_create(WASINNRegistry **registryp); WASM_RUNTIME_API_EXTERN void -wasi_nn_graph_registry_destroy(WASINNArguments *registry); +wasm_runtime_wasi_nn_registry_destroy(WASINNRegistry *registry); #endif struct InstantiationArgs2 { @@ -647,7 +638,7 @@ struct InstantiationArgs2 { WASIArguments wasi; #endif #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - WASINNArguments nn_registry; + WASINNRegistry nn_registry; #endif }; @@ -809,11 +800,11 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool( #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 WASM_RUNTIME_API_EXTERN void -wasm_runtime_instantiation_args_set_wasi_nn_graph_registry( - struct InstantiationArgs2 *p, WASINNArguments *registry); +wasm_runtime_instantiation_args_set_wasi_nn_registry( + struct InstantiationArgs2 *p, WASINNRegistry *registry); WASM_RUNTIME_API_EXTERN bool -wasi_nn_graph_registry_set_args(WASINNArguments *registry, +wasm_runtime_wasi_nn_registry_set_args(WASINNRegistry *registry, const char **model_names, const char **encoding, const char **target, uint32_t n_graphs, const char **graph_paths); @@ -1472,55 +1463,44 @@ wasm_runtime_check_and_update_last_used_shared_heap( #endif #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 -WASM_RUNTIME_API_EXTERN bool -wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, - const char **model_names, - const char **encoding, const char **target, - const uint32_t n_graphs, - char *graph_paths[], char *error_buf, - uint32_t error_buf_size); +WASM_RUNTIME_API_EXTERN void +wasm_runtime_set_wasi_nn_registry(WASMModuleInstanceCommon *module_inst, + WASINNRegistry *wasi_ctx); + +WASM_RUNTIME_API_EXTERN WASINNRegistry * +wasm_runtime_get_wasi_nn_registry(WASMModuleInstanceCommon *module_inst_comm); WASM_RUNTIME_API_EXTERN void -wasm_runtime_destroy_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst); - -WASM_RUNTIME_API_EXTERN void -wasm_runtime_set_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, - WASINNGlobalContext *wasi_ctx); - -WASM_RUNTIME_API_EXTERN WASINNGlobalContext * -wasm_runtime_get_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm); - -WASM_RUNTIME_API_EXTERN void -wasm_runtime_set_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm, - WASINNGlobalContext *wasi_nn_ctx); +wasm_runtime_set_wasi_nn_registry(WASMModuleInstanceCommon *module_inst_comm, + WASINNRegistry *wasi_nn_ctx); WASM_RUNTIME_API_EXTERN uint32_t -wasm_runtime_get_wasi_nn_global_ctx_ngraphs( - WASINNGlobalContext *wasi_nn_global_ctx); +wasm_runtime_get_wasi_nn_registry_ngraphs( + WASINNRegistry *wasi_nn_registry); WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_global_ctx_model_names_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); +wasm_runtime_get_wasi_nn_registry_model_names_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx); WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); +wasm_runtime_get_wasi_nn_registry_graph_paths_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx); WASM_RUNTIME_API_EXTERN uint32_t -wasm_runtime_get_wasi_nn_global_ctx_loaded_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); +wasm_runtime_get_wasi_nn_registry_loaded_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx); WASM_RUNTIME_API_EXTERN uint32_t -wasm_runtime_set_wasi_nn_global_ctx_loaded_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value); +wasm_runtime_set_wasi_nn_registry_loaded_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx, uint32_t value); WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_global_ctx_encoding_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); +wasm_runtime_get_wasi_nn_registry_encoding_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx); WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_global_ctx_target_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); +wasm_runtime_get_wasi_nn_registry_target_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx); #endif #ifdef __cplusplus diff --git a/core/iwasm/include/wasm_export.h b/core/iwasm/include/wasm_export.h index 0c3659528..37cceaef1 100644 --- a/core/iwasm/include/wasm_export.h +++ b/core/iwasm/include/wasm_export.h @@ -290,8 +290,7 @@ typedef struct InstantiationArgs { #endif /* INSTANTIATION_ARGS_OPTION_DEFINED */ struct InstantiationArgs2; -struct WASINNGlobalContext; -typedef struct WASINNGlobalContext WASINNGlobalContext; +typedef struct WASINNRegistry WASINNRegistry; #ifndef WASM_VALKIND_T_DEFINED #define WASM_VALKIND_T_DEFINED @@ -798,36 +797,36 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool( struct InstantiationArgs2 *p, const char *ns_lookup_pool[], uint32_t ns_lookup_pool_size); -WASM_RUNTIME_API_EXTERN WASINNGlobalContext * -wasm_runtime_get_wasi_nn_global_ctx(const wasm_module_inst_t module_inst); +WASM_RUNTIME_API_EXTERN WASINNRegistry * +wasm_runtime_get_wasi_nn_registry(const wasm_module_inst_t module_inst); WASM_RUNTIME_API_EXTERN uint32_t -wasm_runtime_get_wasi_nn_global_ctx_ngraphs( - WASINNGlobalContext *wasi_nn_global_ctx); +wasm_runtime_get_wasi_nn_registry_ngraphs( + WASINNRegistry *wasi_nn_registry); WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_global_ctx_model_names_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); +wasm_runtime_get_wasi_nn_registry_model_names_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx); WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); +wasm_runtime_get_wasi_nn_registry_graph_paths_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx); WASM_RUNTIME_API_EXTERN uint32_t -wasm_runtime_get_wasi_nn_global_ctx_loaded_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); +wasm_runtime_get_wasi_nn_registry_loaded_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx); WASM_RUNTIME_API_EXTERN uint32_t -wasm_runtime_set_wasi_nn_global_ctx_loaded_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value); +wasm_runtime_set_wasi_nn_registry_loaded_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx, uint32_t value); WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_global_ctx_encoding_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); +wasm_runtime_get_wasi_nn_registry_encoding_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx); WASM_RUNTIME_API_EXTERN char * -wasm_runtime_get_wasi_nn_global_ctx_target_i( - WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); +wasm_runtime_get_wasi_nn_registry_target_i( + WASINNRegistry *wasi_nn_registry, uint32_t idx); /** * Instantiate a WASM module, with specified instantiation arguments diff --git a/core/iwasm/interpreter/wasm_runtime.c b/core/iwasm/interpreter/wasm_runtime.c index cfe3840d9..a59bc9257 100644 --- a/core/iwasm/interpreter/wasm_runtime.c +++ b/core/iwasm/interpreter/wasm_runtime.c @@ -3300,18 +3300,6 @@ wasm_instantiate(WASMModule *module, WASMModuleInstance *parent, } #endif -#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - /* Store graphs' path into ctx. Graphs will be loaded until user app calls - * load_by_name */ - WASINNArguments *nn_registry = &args->nn_registry; - if (!wasm_runtime_init_wasi_nn_global_ctx( - (WASMModuleInstanceCommon *)module_inst, nn_registry->model_names, - nn_registry->encoding, nn_registry->target, nn_registry->n_graphs, - nn_registry->graph_paths, error_buf, error_buf_size)) { - goto fail; - } -#endif - #if WASM_ENABLE_DEBUG_INTERP != 0 if (!is_sub_inst) { /* Add module instance into module's instance list */ diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index c4e43ede8..33ef0e090 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -607,9 +607,9 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, goto fail; } - WASINNGlobalContext *wasi_nn_global_ctx = - wasm_runtime_get_wasi_nn_global_ctx(instance); - if (!wasi_nn_global_ctx) { + WASINNRegistry *wasi_nn_registry = + wasm_runtime_get_wasi_nn_registry(instance); + if (!wasi_nn_registry) { NN_ERR_PRINTF("global context is invalid"); res = not_found; goto fail; @@ -618,27 +618,27 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, bool is_loaded = false; uint32 model_idx = 0; uint32_t global_n_graphs = - wasm_runtime_get_wasi_nn_global_ctx_ngraphs(wasi_nn_global_ctx); + wasm_runtime_get_wasi_nn_registry_ngraphs(wasi_nn_registry); for (model_idx = 0; model_idx < global_n_graphs; model_idx++) { - char *model_name = wasm_runtime_get_wasi_nn_global_ctx_model_names_i( - wasi_nn_global_ctx, model_idx); + char *model_name = wasm_runtime_get_wasi_nn_registry_model_names_i( + wasi_nn_registry, model_idx); if (model_name && strcmp(nul_terminated_name, model_name) != 0) { continue; } - is_loaded = wasm_runtime_get_wasi_nn_global_ctx_loaded_i( - wasi_nn_global_ctx, model_idx); + is_loaded = wasm_runtime_get_wasi_nn_registry_loaded_i( + wasi_nn_registry, model_idx); char *global_model_path_i = - wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i( - wasi_nn_global_ctx, model_idx); + wasm_runtime_get_wasi_nn_registry_graph_paths_i( + wasi_nn_registry, model_idx); graph_encoding encoding = - str2encoding(wasm_runtime_get_wasi_nn_global_ctx_encoding_i( - wasi_nn_global_ctx, model_idx)); + str2encoding(wasm_runtime_get_wasi_nn_registry_encoding_i( + wasi_nn_registry, model_idx)); execution_target target = - str2target(wasm_runtime_get_wasi_nn_global_ctx_target_i( - wasi_nn_global_ctx, model_idx)); + str2target(wasm_runtime_get_wasi_nn_registry_target_i( + wasi_nn_registry, model_idx)); // res = ensure_backend(instance, autodetect, wasi_nn_ctx); res = ensure_backend(instance, encoding, wasi_nn_ctx); @@ -655,7 +655,7 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, if (res != success) goto fail; - wasm_runtime_set_wasi_nn_global_ctx_loaded_i(wasi_nn_global_ctx, + wasm_runtime_set_wasi_nn_registry_loaded_i(wasi_nn_registry, model_idx, 1); res = success; break; diff --git a/product-mini/platforms/common/libc_wasi.c b/product-mini/platforms/common/libc_wasi.c index ab22c47e8..bbe475119 100644 --- a/product-mini/platforms/common/libc_wasi.c +++ b/product-mini/platforms/common/libc_wasi.c @@ -236,14 +236,13 @@ fail: static void wasi_nn_set_init_args(struct InstantiationArgs2 *args, - struct WASINNArguments *nn_registry, + struct WASINNRegistry *nn_registry, wasi_nn_parse_context_t *ctx) { - wasi_nn_graph_registry_set_args(nn_registry, ctx->model_names, + wasm_runtime_wasi_nn_registry_set_args(nn_registry, ctx->model_names, ctx->encoding, ctx->target, ctx->n_graphs, ctx->graph_paths); - wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(args, - nn_registry); + wasm_runtime_instantiation_args_set_wasi_nn_registry(args, nn_registry); for (uint32_t i = 0; i < ctx->n_graphs; i++) { if (ctx->model_names[i]) diff --git a/product-mini/platforms/posix/main.c b/product-mini/platforms/posix/main.c index d26565f68..58ec567c2 100644 --- a/product-mini/platforms/posix/main.c +++ b/product-mini/platforms/posix/main.c @@ -650,7 +650,7 @@ main(int argc, char *argv[]) #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 wasi_nn_parse_context_t wasi_nn_parse_ctx; - struct WASINNArguments *nn_registry; + struct WASINNRegistry *nn_registry; memset(&wasi_nn_parse_ctx, 0, sizeof(wasi_nn_parse_ctx)); #endif @@ -1009,19 +1009,21 @@ main(int argc, char *argv[]) libc_wasi_set_init_args(inst_args, argc, argv, &wasi_parse_ctx); #endif -#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - wasi_nn_graph_registry_create(&nn_registry); - wasi_nn_set_init_args(inst_args, nn_registry, &wasi_nn_parse_ctx); -#endif /* instantiate the module */ wasm_module_inst = wasm_runtime_instantiate_ex2( wasm_module, inst_args, error_buf, sizeof(error_buf)); - wasm_runtime_instantiation_args_destroy(inst_args); if (!wasm_module_inst) { printf("%s\n", error_buf); goto fail3; } +#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + wasm_runtime_wasi_nn_registry_create(&nn_registry); + wasi_nn_set_init_args(inst_args, nn_registry, &wasi_nn_parse_ctx); + wasm_runtime_set_wasi_nn_registry(wasm_module_inst, nn_registry); +#endif + wasm_runtime_instantiation_args_destroy(inst_args); + #if WASM_CONFIGURABLE_BOUNDS_CHECKS != 0 if (disable_bounds_checks) { wasm_runtime_set_bounds_checks(wasm_module_inst, false); @@ -1131,9 +1133,6 @@ fail5: #endif #if WASM_ENABLE_DEBUG_INTERP != 0 fail4: -#endif -#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - wasi_nn_graph_registry_destroy(nn_registry); #endif /* destroy the module instance */ wasm_runtime_deinstantiate(wasm_module_inst);