From 96cdfa63adef2036af42f450ef70de9a6f56ccf5 Mon Sep 17 00:00:00 2001 From: QiuYuan Han Date: Wed, 10 Dec 2025 13:52:50 +0800 Subject: [PATCH] Add the way to set the target evenif we use load_by_name --- core/iwasm/common/wasm_native.c | 38 ++++ core/iwasm/common/wasm_runtime_common.c | 180 ++++++++++++++++++ core/iwasm/common/wasm_runtime_common.h | 74 +++++++ core/iwasm/include/wasm_export.h | 51 +++++ core/iwasm/interpreter/wasm_runtime.c | 12 ++ .../wasi-nn/include/wasi_ephemeral_nn.h | 4 +- .../iwasm/libraries/wasi-nn/include/wasi_nn.h | 2 +- .../libraries/wasi-nn/include/wasi_nn_types.h | 3 +- core/iwasm/libraries/wasi-nn/src/wasi_nn.c | 121 +++++++++++- .../libraries/wasi-nn/src/wasi_nn_backend.h | 3 +- .../libraries/wasi-nn/src/wasi_nn_llamacpp.c | 3 +- .../libraries/wasi-nn/src/wasi_nn_onnx.cpp | 3 +- .../libraries/wasi-nn/src/wasi_nn_openvino.c | 3 +- .../libraries/wasi-nn/src/wasi_nn_private.h | 3 +- .../wasi-nn/src/wasi_nn_tensorflowlite.cpp | 6 +- .../libraries/wasi-nn/test/requirements.txt | 2 +- .../libraries/wasi-nn/test/test_tensorflow.c | 66 +++---- .../wasi-nn/test/test_tensorflow_quantized.c | 26 +-- core/iwasm/libraries/wasi-nn/test/utils.c | 102 ++++++---- core/iwasm/libraries/wasi-nn/test/utils.h | 23 +-- product-mini/platforms/posix/main.c | 62 ++++++ 21 files changed, 659 insertions(+), 128 deletions(-) diff --git a/core/iwasm/common/wasm_native.c b/core/iwasm/common/wasm_native.c index 42aa55db2..8938524db 100644 --- a/core/iwasm/common/wasm_native.c +++ b/core/iwasm/common/wasm_native.c @@ -25,6 +25,10 @@ static NativeSymbolsList g_native_symbols_list = NULL; static void *g_wasi_context_key; #endif /* WASM_ENABLE_LIBC_WASI */ +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +static void *g_wasi_nn_context_key; +#endif + uint32 get_libc_builtin_export_apis(NativeSymbol **p_libc_builtin_apis); @@ -473,6 +477,31 @@ wasi_context_dtor(WASMModuleInstanceCommon *inst, void *ctx) } #endif /* end of WASM_ENABLE_LIBC_WASI */ +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +WASINNGlobalContext * +wasm_runtime_get_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm) +{ + return wasm_native_get_context(module_inst_comm, g_wasi_nn_context_key); +} + +void +wasm_runtime_set_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm, + WASINNGlobalContext *wasi_nn_ctx) +{ + wasm_native_set_context(module_inst_comm, g_wasi_nn_context_key, wasi_nn_ctx); +} + +static void +wasi_nn_context_dtor(WASMModuleInstanceCommon *inst, void *ctx) +{ + if (ctx == NULL) { + return; + } + + wasm_runtime_destroy_wasi_nn_global_ctx(inst); +} +#endif + #if WASM_ENABLE_QUICK_AOT_ENTRY != 0 static bool quick_aot_entry_init(void); @@ -582,6 +611,11 @@ wasm_native_init() #endif /* WASM_ENABLE_LIB_RATS */ #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + g_wasi_nn_context_key = wasm_native_create_context_key(wasi_nn_context_dtor); + if (g_wasi_nn_context_key == NULL) { + goto fail; + } + if (!wasi_nn_initialize()) goto fail; @@ -648,6 +682,10 @@ wasm_native_destroy() #endif #if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + if (g_wasi_nn_context_key != NULL) { + wasm_native_destroy_context_key(g_wasi_nn_context_key); + g_wasi_nn_context_key = NULL; + } wasi_nn_destroy(); #endif diff --git a/core/iwasm/common/wasm_runtime_common.c b/core/iwasm/common/wasm_runtime_common.c index 259816e0b..312c4b9c7 100644 --- a/core/iwasm/common/wasm_runtime_common.c +++ b/core/iwasm/common/wasm_runtime_common.c @@ -1696,6 +1696,67 @@ wasm_runtime_instantiation_args_destroy(struct InstantiationArgs2 *p) wasm_runtime_free(p); } +#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0) +struct wasi_nn_graph_registry; + +void +wasm_runtime_wasi_nn_graph_registry_args_set_defaults(struct wasi_nn_graph_registry *args) +{ + memset(args, 0, sizeof(*args)); +} + +bool +wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, const char* encoding, + const char* target, uint32_t n_graphs, + const char** graph_paths) +{ + if (!registry || !encoding || !target || !graph_paths) + { + return false; + } + registry->encoding = strdup(encoding); + registry->target = strdup(target); + registry->n_graphs = n_graphs; + registry->graph_paths = (uint32_t**)malloc(sizeof(uint32_t*) * n_graphs); + memset(registry->graph_paths, 0, sizeof(uint32_t*) * n_graphs); + for (uint32_t i = 0; i < registry->n_graphs; i++) + registry->graph_paths[i] = strdup(graph_paths[i]); + + return true; +} + +int +wasi_nn_graph_registry_create(struct wasi_nn_graph_registry **registryp) +{ + struct wasi_nn_graph_registry *args = wasm_runtime_malloc(sizeof(*args)); + if (args == NULL) { + return false; + } + wasm_runtime_wasi_nn_graph_registry_args_set_defaults(args); + *registryp = args; + return 0; +} + +void +wasi_nn_graph_registry_destroy(struct wasi_nn_graph_registry *registry) +{ + if (registry) + { + for (uint32_t i = 0; i < registry->n_graphs; i++) + if (registry->graph_paths[i]) + { + // wasi_nn_graph_registry_unregister_graph(registry, registry->name[i]); + free(registry->graph_paths[i]); + } + if (registry->encoding) + free(registry->encoding); + if (registry->target) + free(registry->target); + free(registry); + } +} +#endif + void wasm_runtime_instantiation_args_set_default_stack_size( struct InstantiationArgs2 *p, uint32 v) @@ -1794,6 +1855,14 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool( wasi_args->set_by_user = true; } #endif /* WASM_ENABLE_LIBC_WASI != 0 */ +#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0) +void +wasm_runtime_instantiation_args_set_wasi_nn_graph_registry( + struct InstantiationArgs2 *p, struct wasi_nn_graph_registry *registry) +{ + p->nn_registry = *registry; +} +#endif WASMModuleInstanceCommon * wasm_runtime_instantiate_ex2(WASMModuleCommon *module, @@ -8080,3 +8149,114 @@ wasm_runtime_check_and_update_last_used_shared_heap( return false; } #endif + +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +bool +wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, + const char* encoding, const char* target, + const uint32_t n_graphs, char* graph_paths[], + char *error_buf, uint32_t error_buf_size) +{ + WASINNGlobalContext *ctx; + bool ret = false; + + ctx = runtime_malloc(sizeof(*ctx), module_inst, error_buf, error_buf_size); + if (!ctx) + return false; + + ctx->encoding = strdup(encoding); + ctx->target = strdup(target); + ctx->n_graphs = n_graphs; + ctx->loaded = (uint32_t*)malloc(sizeof(uint32_t) * n_graphs); + memset(ctx->loaded, 0, sizeof(uint32_t) * n_graphs); + + ctx->graph_paths = (uint32_t**)malloc(sizeof(uint32_t*) * n_graphs); + memset(ctx->graph_paths, 0, sizeof(uint32_t*) * n_graphs); + for (uint32_t i = 0; i < n_graphs; i++) + { + ctx->graph_paths[i] = strdup(graph_paths[i]); + } + + wasm_runtime_set_wasi_nn_global_ctx(module_inst, ctx); + + ret = true; + + return ret; +} + +void +wasm_runtime_destroy_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst) +{ + WASINNGlobalContext *wasi_nn_global_ctx = wasm_runtime_get_wasi_nn_global_ctx(module_inst); + + for (uint32 i = 0; i < wasi_nn_global_ctx->n_graphs; i++) + { + // All graphs will be unregistered in deinit() + if (wasi_nn_global_ctx->graph_paths[i]) + free(wasi_nn_global_ctx->graph_paths[i]); + } + free(wasi_nn_global_ctx->encoding); + free(wasi_nn_global_ctx->target); + free(wasi_nn_global_ctx->loaded); + free(wasi_nn_global_ctx->graph_paths); + + if (wasi_nn_global_ctx) { + wasm_runtime_free(wasi_nn_global_ctx); + } +} + +uint32_t +wasm_runtime_get_wasi_nn_global_ctx_ngraphs(WASINNGlobalContext *wasi_nn_global_ctx) +{ + if (wasi_nn_global_ctx) + return wasi_nn_global_ctx->n_graphs; + + return -1; +} + +char * +wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx) +{ + if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs)) + return wasi_nn_global_ctx->graph_paths[idx]; + + return NULL; +} + +uint32_t +wasm_runtime_get_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx) +{ + if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs)) + return wasi_nn_global_ctx->loaded[idx]; + + return -1; +} + +uint32_t +wasm_runtime_set_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value) +{ + if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs)) + wasi_nn_global_ctx->loaded[idx] = value; + + return 0; +} + +char* +wasm_runtime_get_wasi_nn_global_ctx_encoding(WASINNGlobalContext *wasi_nn_global_ctx) +{ + if (wasi_nn_global_ctx) + return wasi_nn_global_ctx->encoding; + + return NULL; +} + +char* +wasm_runtime_get_wasi_nn_global_ctx_target(WASINNGlobalContext *wasi_nn_global_ctx) +{ + if (wasi_nn_global_ctx) + return wasi_nn_global_ctx->target; + + return NULL; +} + +#endif diff --git a/core/iwasm/common/wasm_runtime_common.h b/core/iwasm/common/wasm_runtime_common.h index 88f23485e..8d002bedc 100644 --- a/core/iwasm/common/wasm_runtime_common.h +++ b/core/iwasm/common/wasm_runtime_common.h @@ -545,6 +545,17 @@ typedef struct WASMModuleInstMemConsumption { uint32 exports_size; } WASMModuleInstMemConsumption; +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +typedef struct WASINNGlobalContext { + char* encoding; + char* target; + + uint32_t n_graphs; + uint32_t *loaded; + char** graph_paths; +} WASINNGlobalContext; +#endif + #if WASM_ENABLE_LIBC_WASI != 0 #if WASM_ENABLE_UVWASI == 0 typedef struct WASIContext { @@ -612,11 +623,30 @@ WASMExecEnv * wasm_runtime_get_exec_env_tls(void); #endif +#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0) +struct wasi_nn_graph_registry { + char* encoding; + char* target; + + char** graph_paths; + uint32_t n_graphs; +}; + +WASM_RUNTIME_API_EXTERN int +wasi_nn_graph_registry_create(struct wasi_nn_graph_registry **registryp); + +WASM_RUNTIME_API_EXTERN void +wasi_nn_graph_registry_destroy(struct wasi_nn_graph_registry *registry); +#endif + struct InstantiationArgs2 { InstantiationArgs v1; #if WASM_ENABLE_LIBC_WASI != 0 WASIArguments wasi; #endif +#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0) + struct wasi_nn_graph_registry nn_registry; +#endif }; void @@ -775,6 +805,17 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool( struct InstantiationArgs2 *p, const char *ns_lookup_pool[], uint32 ns_lookup_pool_size); +#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0) +WASM_RUNTIME_API_EXTERN void +wasm_runtime_instantiation_args_set_wasi_nn_graph_registry( + struct InstantiationArgs2 *p, struct wasi_nn_graph_registry *registry); + +WASM_RUNTIME_API_EXTERN bool +wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, const char* encoding, + const char* target, uint32_t n_graphs, + const char** graph_paths); +#endif + /* See wasm_export.h for description */ WASM_RUNTIME_API_EXTERN WASMModuleInstanceCommon * wasm_runtime_instantiate_ex2(WASMModuleCommon *module, @@ -1427,6 +1468,39 @@ wasm_runtime_check_and_update_last_used_shared_heap( uint8 **shared_heap_base_addr_adj_p, bool is_memory64); #endif +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +WASM_RUNTIME_API_EXTERN bool +wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, + const char* encoding, const char* target, + const uint32_t n_graphs, char* graph_paths[], + char *error_buf, uint32_t error_buf_size); + +WASM_RUNTIME_API_EXTERN void +wasm_runtime_destroy_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst); + +WASM_RUNTIME_API_EXTERN void +wasm_runtime_set_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst, + WASINNGlobalContext *wasi_ctx); + +WASM_RUNTIME_API_EXTERN uint32_t +wasm_runtime_get_wasi_nn_global_ctx_ngraphs(WASINNGlobalContext *wasi_nn_global_ctx); + +WASM_RUNTIME_API_EXTERN char * +wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); + +WASM_RUNTIME_API_EXTERN uint32_t +wasm_runtime_get_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx); + +WASM_RUNTIME_API_EXTERN uint32_t +wasm_runtime_set_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value); + +WASM_RUNTIME_API_EXTERN char* +wasm_runtime_get_wasi_nn_global_ctx_encoding(WASINNGlobalContext *wasi_nn_global_ctx); + +WASM_RUNTIME_API_EXTERN char* +wasm_runtime_get_wasi_nn_global_ctx_target(WASINNGlobalContext *wasi_nn_global_ctx); +#endif + #ifdef __cplusplus } #endif diff --git a/core/iwasm/include/wasm_export.h b/core/iwasm/include/wasm_export.h index 44a45dedf..50263f182 100644 --- a/core/iwasm/include/wasm_export.h +++ b/core/iwasm/include/wasm_export.h @@ -290,6 +290,8 @@ typedef struct InstantiationArgs { #endif /* INSTANTIATION_ARGS_OPTION_DEFINED */ struct InstantiationArgs2; +struct WASINNGlobalContext; +typedef struct WASINNGlobalContext *wasi_nn_global_context; #ifndef WASM_VALKIND_T_DEFINED #define WASM_VALKIND_T_DEFINED @@ -796,6 +798,55 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool( struct InstantiationArgs2 *p, const char *ns_lookup_pool[], uint32_t ns_lookup_pool_size); +// WASM_RUNTIME_API_EXTERN int +// wasi_nn_graph_registry_create(struct wasi_nn_graph_registry **registryp); + +// WASM_RUNTIME_API_EXTERN void +// wasi_nn_graph_registry_destroy(struct wasi_nn_graph_registry *registry); + +// WASM_RUNTIME_API_EXTERN void +// wasm_runtime_instantiation_args_set_wasi_nn_graph_registry( +// struct InstantiationArgs2 *p, struct wasi_nn_graph_registry *registry); + +// WASM_RUNTIME_API_EXTERN bool +// wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, const char* encoding, +// const char* target, uint32_t n_graphs, +// const char** graph_paths); + +WASM_RUNTIME_API_EXTERN bool +wasm_runtime_init_wasi_nn_global_ctx(wasm_module_inst_t module_inst, + const char* encoding, const char* target, + const uint32_t n_graphs, char* graph_paths[], + char *error_buf, uint32_t error_buf_size); + +WASM_RUNTIME_API_EXTERN void +wasm_runtime_destroy_wasi_nn_global_ctx(wasm_module_inst_t module_inst); + +WASM_RUNTIME_API_EXTERN void +wasm_runtime_set_wasi_nn_global_ctx(wasm_module_inst_t module_inst, + wasi_nn_global_context wasi_ctx); + +WASM_RUNTIME_API_EXTERN wasi_nn_global_context +wasm_runtime_get_wasi_nn_global_ctx(const wasm_module_inst_t module_inst); + +WASM_RUNTIME_API_EXTERN uint32_t +wasm_runtime_get_wasi_nn_global_ctx_ngraphs(wasi_nn_global_context wasi_nn_global_ctx); + +WASM_RUNTIME_API_EXTERN char * +wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(wasi_nn_global_context wasi_nn_global_ctx, uint32_t idx); + +WASM_RUNTIME_API_EXTERN uint32_t +wasm_runtime_get_wasi_nn_global_ctx_loaded_i(wasi_nn_global_context wasi_nn_global_ctx, uint32_t idx); + +WASM_RUNTIME_API_EXTERN uint32_t +wasm_runtime_set_wasi_nn_global_ctx_loaded_i(wasi_nn_global_context wasi_nn_global_ctx, uint32_t idx, uint32_t value); + +WASM_RUNTIME_API_EXTERN char* +wasm_runtime_get_wasi_nn_global_ctx_encoding(wasi_nn_global_context wasi_nn_global_ctx); + +WASM_RUNTIME_API_EXTERN char* +wasm_runtime_get_wasi_nn_global_ctx_target(wasi_nn_global_context wasi_nn_global_ctx); + /** * Instantiate a WASM module, with specified instantiation arguments * diff --git a/core/iwasm/interpreter/wasm_runtime.c b/core/iwasm/interpreter/wasm_runtime.c index a59bc9257..79d4c73c2 100644 --- a/core/iwasm/interpreter/wasm_runtime.c +++ b/core/iwasm/interpreter/wasm_runtime.c @@ -3300,6 +3300,18 @@ wasm_instantiate(WASMModule *module, WASMModuleInstance *parent, } #endif +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + /* Store graphs' path into ctx. Graphs will be loaded until user app calls load_by_name */ + // Do not consider load() for now + struct wasi_nn_graph_registry *nn_registry = &args->nn_registry; + if (!wasm_runtime_init_wasi_nn_global_ctx( + (WASMModuleInstanceCommon *)module_inst, nn_registry->encoding, + nn_registry->target, nn_registry->n_graphs, nn_registry->graph_paths, + error_buf, error_buf_size)) { + goto fail; + } +#endif + #if WASM_ENABLE_DEBUG_INTERP != 0 if (!is_sub_inst) { /* Add module instance into module's instance list */ diff --git a/core/iwasm/libraries/wasi-nn/include/wasi_ephemeral_nn.h b/core/iwasm/libraries/wasi-nn/include/wasi_ephemeral_nn.h index f76295a1e..83beba98f 100644 --- a/core/iwasm/libraries/wasi-nn/include/wasi_ephemeral_nn.h +++ b/core/iwasm/libraries/wasi-nn/include/wasi_ephemeral_nn.h @@ -8,5 +8,5 @@ #include "wasi_nn.h" -#undef WASM_ENABLE_WASI_EPHEMERAL_NN -#undef WASI_NN_NAME +// #undef WASM_ENABLE_WASI_EPHEMERAL_NN +// #undef WASI_NN_NAME diff --git a/core/iwasm/libraries/wasi-nn/include/wasi_nn.h b/core/iwasm/libraries/wasi-nn/include/wasi_nn.h index cda26324e..d76de3ffc 100644 --- a/core/iwasm/libraries/wasi-nn/include/wasi_nn.h +++ b/core/iwasm/libraries/wasi-nn/include/wasi_nn.h @@ -21,7 +21,7 @@ #else #define WASI_NN_IMPORT(name) \ __attribute__((import_module("wasi_nn"), import_name(name))) -#warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It's deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.) +#warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It is deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.) #endif /** diff --git a/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h b/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h index 952fb65e2..d77fe9a6c 100644 --- a/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h +++ b/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h @@ -27,7 +27,7 @@ extern "C" { #define WASI_NN_TYPE_NAME(name) WASI_NN_NAME(type_##name) #define WASI_NN_ENCODING_NAME(name) WASI_NN_NAME(encoding_##name) #define WASI_NN_TARGET_NAME(name) WASI_NN_NAME(target_##name) -#define WASI_NN_ERROR_TYPE WASI_NN_NAME(error); +#define WASI_NN_ERROR_TYPE WASI_NN_NAME(error) #endif /** @@ -169,6 +169,7 @@ typedef enum WASI_NN_NAME(execution_target) { WASI_NN_TARGET_NAME(cpu) = 0, WASI_NN_TARGET_NAME(gpu), WASI_NN_TARGET_NAME(tpu), + WASI_NN_TARGET_NAME(unsupported_target), } WASI_NN_NAME(execution_target); // Bind a `graph` to the input and output tensors for an inference. diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index 2282534b0..9e3e741b6 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -21,7 +21,7 @@ #include "wasm_export.h" #if WASM_ENABLE_WASI_EPHEMERAL_NN == 0 -#warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It's deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.) +#warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It is deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.) #endif #define HASHMAP_INITIAL_SIZE 20 @@ -35,6 +35,8 @@ #define LLAMACPP_BACKEND_LIB "libwasi_nn_llamacpp" LIB_EXTENTION #define ONNX_BACKEND_LIB "libwasi_nn_onnx" LIB_EXTENTION +#define MAX_GLOBAL_GRAPHS_PER_INST 4 // ONNX only allows 4 graphs per instances + /* Global variables */ static korp_mutex wasi_nn_lock; /* @@ -208,6 +210,44 @@ wasi_nn_destroy() * - model file format * - on device ML framework */ +static graph_encoding str2encoding(char* str_encoding) +{ + if (!str_encoding) { + NN_ERR_PRINTF("Got empty string encoding"); + return -1; + } + + if (!strcmp(str_encoding, "openvino")) + return openvino; + else if (!strcmp(str_encoding, "tensorflowlite")) + return tensorflowlite; + else if (!strcmp(str_encoding, "ggml")) + return ggml; + else if (!strcmp(str_encoding, "onnx")) + return onnx; + else + return unknown_backend; + // return autodetect; +} + +static execution_target str2target(char* str_target) +{ + if (!str_target) { + NN_ERR_PRINTF("Got empty string target"); + return -1; + } + + if (!strcmp(str_target, "cpu")) + return cpu; + else if (!strcmp(str_target, "gpu")) + return gpu; + else if (!strcmp(str_target, "tpu")) + return tpu; + else + return unsupported_target; + // return autodetect; +} + static graph_encoding choose_a_backend() { @@ -565,17 +605,82 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, goto fail; } - res = ensure_backend(instance, autodetect, wasi_nn_ctx); + wasi_nn_global_context wasi_nn_global_ctx = wasm_runtime_get_wasi_nn_global_ctx(instance); + if (!wasi_nn_global_ctx) { + NN_ERR_PRINTF("global context is invalid"); + res = not_found; + goto fail; + } + graph_encoding encoding = str2encoding(wasm_runtime_get_wasi_nn_global_ctx_encoding(wasi_nn_global_ctx)); + execution_target target = str2target(wasm_runtime_get_wasi_nn_global_ctx_target(wasi_nn_global_ctx)); + + // res = ensure_backend(instance, autodetect, wasi_nn_ctx); + res = ensure_backend(instance, encoding, wasi_nn_ctx); if (res != success) goto fail; + + bool is_loaded = false; + uint32 model_idx = 0; + char *global_model_path_i; + // Assume filename got from user wasm app : max; sum; average; ... + // Assume file path got from user cmd opt: /your/path1/max.tflite; /your/path2/sum.tflite; ...... + for (model_idx = 0; model_idx < wasm_runtime_get_wasi_nn_global_ctx_ngraphs(wasi_nn_global_ctx); model_idx++) + { + // Extract filename from file path + global_model_path_i = wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(wasi_nn_global_ctx, model_idx); + char *model_file_name; + const char *slash = strrchr(global_model_path_i, '/'); + if (slash != NULL) { + model_file_name = (char*)(slash + 1); + } + else + model_file_name = global_model_path_i; - call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res, - wasi_nn_ctx->backend_ctx, nul_terminated_name, name_len, - g); - if (res != success) + // Extract modelname from filename + char* model_name = NULL; + size_t model_name_len = 0; + char* dot = strrchr(model_file_name, '.'); + if (dot) + { + model_name_len = dot - model_file_name; + model_name = malloc(model_name_len + 1); + strncpy(model_name, model_file_name, model_name_len); + model_name[model_name_len] = '\0'; + } + + if (model_name && strcmp(nul_terminated_name, model_name) == 0) { + is_loaded = wasm_runtime_get_wasi_nn_global_ctx_loaded_i(wasi_nn_global_ctx, model_idx); + break; + } + } + + if (!is_loaded && (model_idx < MAX_GLOBAL_GRAPHS_PER_INST)) + { + NN_DBG_PRINTF("Model is not yet loaded, will add to global context"); + call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res, + wasi_nn_ctx->backend_ctx, global_model_path_i, strlen(global_model_path_i), + encoding, target, g); + if (res != success) + goto fail; + + wasm_runtime_set_wasi_nn_global_ctx_loaded_i(wasi_nn_global_ctx, model_idx, 1); + res = success; + } + else + { + if (is_loaded) + { + NN_DBG_PRINTF("Model is already loaded"); + res = success; + } + else if (model_idx >= MAX_GLOBAL_GRAPHS_PER_INST) + { + // No enlarge for now + NN_ERR_PRINTF("No enough space for new model"); + res = too_large; + } goto fail; - - res = success; + } fail: if (nul_terminated_name != NULL) { wasm_runtime_free(nul_terminated_name); diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h b/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h index 8cd03f121..3108f2eef 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_backend.h @@ -17,7 +17,8 @@ load(void *ctx, graph_builder_array *builder, graph_encoding encoding, execution_target target, graph *g); __attribute__((visibility("default"))) wasi_nn_error -load_by_name(void *tflite_ctx, const char *name, uint32_t namelen, graph *g); +load_by_name(void *tflite_ctx, const char *name, uint32_t namelen, + graph_encoding encoding, execution_target target, graph *g); __attribute__((visibility("default"))) wasi_nn_error load_by_name_with_config(void *ctx, const char *name, uint32_t namelen, diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c index 2e1e64936..fd09c2be0 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c @@ -338,7 +338,8 @@ __load_by_name_with_configuration(void *ctx, const char *filename, graph *g) } __attribute__((visibility("default"))) wasi_nn_error -load_by_name(void *ctx, const char *filename, uint32_t filename_len, graph *g) +load_by_name(void *ctx, const char *filename, uint32_t filename_len, + graph_encoding encoding, execution_target target, graph *g) { struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx; diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp index 88587f68b..e2283df0f 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp @@ -334,7 +334,8 @@ load(void *onnx_ctx, graph_builder_array *builder, graph_encoding encoding, } __attribute__((visibility("default"))) wasi_nn_error -load_by_name(void *onnx_ctx, const char *name, uint32_t filename_len, graph *g) +load_by_name(void *onnx_ctx, const char *name, uint32_t filename_len, + graph_encoding encoding, execution_target target, graph *g) { if (!onnx_ctx) { return runtime_error; diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_openvino.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn_openvino.c index 899e06ee3..eec4f8190 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_openvino.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_openvino.c @@ -306,7 +306,8 @@ fail: } __attribute__((visibility("default"))) wasi_nn_error -load_by_name(void *ctx, const char *filename, uint32_t filename_len, graph *g) +load_by_name(void *ctx, const char *filename, uint32_t filename_len, + graph_encoding encoding, execution_target target, graph *g) { OpenVINOContext *ov_ctx = (OpenVINOContext *)ctx; struct OpenVINOGraph *graph; diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h b/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h index 1bff2c514..5dcb173f4 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_private.h @@ -21,7 +21,8 @@ typedef struct { typedef wasi_nn_error (*LOAD)(void *, graph_builder_array *, graph_encoding, execution_target, graph *); -typedef wasi_nn_error (*LOAD_BY_NAME)(void *, const char *, uint32_t, graph *); +typedef wasi_nn_error (*LOAD_BY_NAME)(void *, const char *, uint32_t, graph_encoding, + execution_target, graph *); typedef wasi_nn_error (*LOAD_BY_NAME_WITH_CONFIG)(void *, const char *, uint32_t, void *, uint32_t, graph *); diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp b/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp index 9ac54e664..eb56a42f2 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp @@ -164,8 +164,8 @@ load(void *tflite_ctx, graph_builder_array *builder, graph_encoding encoding, } __attribute__((visibility("default"))) wasi_nn_error -load_by_name(void *tflite_ctx, const char *filename, uint32_t filename_len, - graph *g) +load_by_name(void *tflite_ctx, const char *filename, uint32_t filename_len, + graph_encoding encoding, execution_target target,graph *g) { TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx; @@ -183,7 +183,7 @@ load_by_name(void *tflite_ctx, const char *filename, uint32_t filename_len, } // Use CPU as default - tfl_ctx->models[*g].target = cpu; + tfl_ctx->models[*g].target = target; return success; } diff --git a/core/iwasm/libraries/wasi-nn/test/requirements.txt b/core/iwasm/libraries/wasi-nn/test/requirements.txt index 1643b91b0..0c80fd6b1 100644 --- a/core/iwasm/libraries/wasi-nn/test/requirements.txt +++ b/core/iwasm/libraries/wasi-nn/test/requirements.txt @@ -1,2 +1,2 @@ -tensorflow==2.12.1 +tensorflow==2.14.0 numpy==1.24.4 diff --git a/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c b/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c index 6a9e20702..b3d6ba803 100644 --- a/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c +++ b/core/iwasm/libraries/wasi-nn/test/test_tensorflow.c @@ -13,16 +13,16 @@ #include "logger.h" void -test_sum(execution_target target) +test_sum() { int dims[] = { 1, 5, 5, 1 }; input_info input = create_input(dims); uint32_t output_size = 0; - float *output = run_inference(target, input.input_tensor, input.dim, - &output_size, "./models/sum.tflite", 1); + float *output = run_inference(input.input_tensor, input.dim, + &output_size, "sum", 1); - assert(output_size == 1); + assert((output_size / sizeof(float)) == 1); assert(fabs(output[0] - 300.0) < EPSILON); free(input.dim); @@ -31,16 +31,16 @@ test_sum(execution_target target) } void -test_max(execution_target target) +test_max() { int dims[] = { 1, 5, 5, 1 }; input_info input = create_input(dims); uint32_t output_size = 0; - float *output = run_inference(target, input.input_tensor, input.dim, - &output_size, "./models/max.tflite", 1); + float *output = run_inference(input.input_tensor, input.dim, + &output_size, "max", 1); - assert(output_size == 1); + assert((output_size / sizeof(float)) == 1); assert(fabs(output[0] - 24.0) < EPSILON); NN_INFO_PRINTF("Result: max is %f", output[0]); @@ -50,16 +50,16 @@ test_max(execution_target target) } void -test_average(execution_target target) +test_average() { int dims[] = { 1, 5, 5, 1 }; input_info input = create_input(dims); uint32_t output_size = 0; - float *output = run_inference(target, input.input_tensor, input.dim, - &output_size, "./models/average.tflite", 1); + float *output = run_inference(input.input_tensor, input.dim, + &output_size, "average", 1); - assert(output_size == 1); + assert((output_size / sizeof(float)) == 1); assert(fabs(output[0] - 12.0) < EPSILON); NN_INFO_PRINTF("Result: average is %f", output[0]); @@ -69,16 +69,16 @@ test_average(execution_target target) } void -test_mult_dimensions(execution_target target) +test_mult_dimensions() { int dims[] = { 1, 3, 3, 1 }; input_info input = create_input(dims); uint32_t output_size = 0; - float *output = run_inference(target, input.input_tensor, input.dim, - &output_size, "./models/mult_dim.tflite", 1); + float *output = run_inference(input.input_tensor, input.dim, + &output_size, "mult_dim", 1); - assert(output_size == 9); + assert((output_size / sizeof(float)) == 9); for (int i = 0; i < 9; i++) assert(fabs(output[i] - i) < EPSILON); @@ -88,16 +88,16 @@ test_mult_dimensions(execution_target target) } void -test_mult_outputs(execution_target target) +test_mult_outputs() { int dims[] = { 1, 4, 4, 1 }; input_info input = create_input(dims); uint32_t output_size = 0; - float *output = run_inference(target, input.input_tensor, input.dim, - &output_size, "./models/mult_out.tflite", 2); + float *output = run_inference(input.input_tensor, input.dim, + &output_size, "mult_out", 2); - assert(output_size == 8); + assert((output_size / sizeof(float)) == 8); // first tensor check for (int i = 0; i < 4; i++) assert(fabs(output[i] - (i * 4 + 24)) < EPSILON); @@ -113,30 +113,18 @@ test_mult_outputs(execution_target target) int main() { - char *env = getenv("TARGET"); - if (env == NULL) { - NN_INFO_PRINTF("Usage:\n--env=\"TARGET=[cpu|gpu]\""); - return 1; - } - execution_target target; - if (strcmp(env, "cpu") == 0) - target = cpu; - else if (strcmp(env, "gpu") == 0) - target = gpu; - else { - NN_ERR_PRINTF("Wrong target!"); - return 1; - } + NN_INFO_PRINTF("Usage:\niwasm --native-lib=./libwasi_nn_tflite.so --wasi-nn-graph=encoding:target:model_path1:model_path2:...:model_pathn test_tensorflow.wasm\""); + NN_INFO_PRINTF("################### Testing sum..."); - test_sum(target); + test_sum(); NN_INFO_PRINTF("################### Testing max..."); - test_max(target); + test_max(); NN_INFO_PRINTF("################### Testing average..."); - test_average(target); + test_average(); NN_INFO_PRINTF("################### Testing multiple dimensions..."); - test_mult_dimensions(target); + test_mult_dimensions(); NN_INFO_PRINTF("################### Testing multiple outputs..."); - test_mult_outputs(target); + test_mult_outputs(); NN_INFO_PRINTF("Tests: passed!"); return 0; diff --git a/core/iwasm/libraries/wasi-nn/test/test_tensorflow_quantized.c b/core/iwasm/libraries/wasi-nn/test/test_tensorflow_quantized.c index 3ed7c751e..0898c7ae2 100644 --- a/core/iwasm/libraries/wasi-nn/test/test_tensorflow_quantized.c +++ b/core/iwasm/libraries/wasi-nn/test/test_tensorflow_quantized.c @@ -16,15 +16,15 @@ #define EPSILON 1e-2 void -test_average_quantized(execution_target target) +test_average_quantized() { int dims[] = { 1, 5, 5, 1 }; input_info input = create_input(dims); uint32_t output_size = 0; float *output = - run_inference(target, input.input_tensor, input.dim, &output_size, - "./models/quantized_model.tflite", 1); + run_inference(input.input_tensor, input.dim, &output_size, + "quantized_model", 1); NN_INFO_PRINTF("Output size: %d", output_size); NN_INFO_PRINTF("Result: average is %f", output[0]); @@ -39,24 +39,10 @@ test_average_quantized(execution_target target) int main() { - char *env = getenv("TARGET"); - if (env == NULL) { - NN_INFO_PRINTF("Usage:\n--env=\"TARGET=[cpu|gpu|tpu]\""); - return 1; - } - execution_target target; - if (strcmp(env, "cpu") == 0) - target = cpu; - else if (strcmp(env, "gpu") == 0) - target = gpu; - else if (strcmp(env, "tpu") == 0) - target = tpu; - else { - NN_ERR_PRINTF("Wrong target!"); - return 1; - } + NN_INFO_PRINTF("Usage:\niwasm --native-lib=./libwasi_nn_tflite.so --wasi-nn-graph=encoding:target:model_path1:model_path2:...:model_pathn test_tensorflow.wasm\""); + NN_INFO_PRINTF("################### Testing quantized model..."); - test_average_quantized(target); + test_average_quantized(); NN_INFO_PRINTF("Tests: passed!"); return 0; diff --git a/core/iwasm/libraries/wasi-nn/test/utils.c b/core/iwasm/libraries/wasi-nn/test/utils.c index 690c37f0e..97ed08378 100644 --- a/core/iwasm/libraries/wasi-nn/test/utils.c +++ b/core/iwasm/libraries/wasi-nn/test/utils.c @@ -5,17 +5,15 @@ #include "utils.h" #include "logger.h" -#include "wasi_nn.h" - #include #include -wasi_nn_error -wasm_load(char *model_name, graph *g, execution_target target) +WASI_NN_ERROR_TYPE +wasm_load(char *model_name, WASI_NN_NAME(graph) *g, WASI_NN_NAME(execution_target) target) { FILE *pFile = fopen(model_name, "r"); if (pFile == NULL) - return invalid_argument; + return WASI_NN_ERROR_NAME(invalid_argument); uint8_t *buffer; size_t result; @@ -24,20 +22,29 @@ wasm_load(char *model_name, graph *g, execution_target target) buffer = (uint8_t *)malloc(sizeof(uint8_t) * MAX_MODEL_SIZE); if (buffer == NULL) { fclose(pFile); - return too_large; + return WASI_NN_ERROR_NAME(too_large); } result = fread(buffer, 1, MAX_MODEL_SIZE, pFile); if (result <= 0) { fclose(pFile); free(buffer); - return too_large; + return WASI_NN_ERROR_NAME(too_large); } - graph_builder_array arr; +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + WASI_NN_NAME(graph_builder) arr; + + arr.buf = buffer; + arr.size = result; + + WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load)(&arr, result, WASI_NN_ENCODING_NAME(tensorflowlite), target, g); + // WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load)(&arr, 1, WASI_NN_ENCODING_NAME(tensorflowlite), target, g); +#else + WASI_NN_NAME(graph_builder_array) arr; arr.size = 1; - arr.buf = (graph_builder *)malloc(sizeof(graph_builder)); + arr.buf = (WASI_NN_NAME(graph_builder) *)malloc(sizeof(WASI_NN_NAME(graph_builder))); if (arr.buf == NULL) { fclose(pFile); free(buffer); @@ -47,7 +54,8 @@ wasm_load(char *model_name, graph *g, execution_target target) arr.buf[0].size = result; arr.buf[0].buf = buffer; - wasi_nn_error res = load(&arr, tensorflowlite, target, g); + WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load)(&arr, WASI_NN_ENCODING_NAME(tensorflowlite), target, g); +#endif fclose(pFile); free(buffer); @@ -55,77 +63,97 @@ wasm_load(char *model_name, graph *g, execution_target target) return res; } -wasi_nn_error -wasm_load_by_name(const char *model_name, graph *g) +WASI_NN_ERROR_TYPE +wasm_load_by_name(const char *model_name, WASI_NN_NAME(graph) *g) { - wasi_nn_error res = load_by_name(model_name, strlen(model_name), g); + WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load_by_name)(model_name, strlen(model_name), g); return res; } -wasi_nn_error -wasm_init_execution_context(graph g, graph_execution_context *ctx) +WASI_NN_ERROR_TYPE +wasm_init_execution_context(WASI_NN_NAME(graph) g, WASI_NN_NAME(graph_execution_context) *ctx) { - return init_execution_context(g, ctx); + return WASI_NN_NAME(init_execution_context)(g, ctx); } -wasi_nn_error -wasm_set_input(graph_execution_context ctx, float *input_tensor, uint32_t *dim) +WASI_NN_ERROR_TYPE +wasm_set_input(WASI_NN_NAME(graph_execution_context) ctx, float *input_tensor, uint32_t *dim) { - tensor_dimensions dims; + WASI_NN_NAME(tensor_dimensions) dims; dims.size = INPUT_TENSOR_DIMS; dims.buf = (uint32_t *)malloc(dims.size * sizeof(uint32_t)); if (dims.buf == NULL) - return too_large; + return WASI_NN_ERROR_NAME(too_large); - tensor tensor; + WASI_NN_NAME(tensor) tensor; +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + tensor.dimensions = dims; + for (int i = 0; i < tensor.dimensions.size; ++i) + tensor.dimensions.buf[i] = dim[i]; + tensor.type = WASI_NN_TYPE_NAME(fp32); + tensor.data.buf = (uint8_t *)input_tensor; + + uint32_t tmp_size = 1; + if (dim) + for (int i = 0; i < INPUT_TENSOR_DIMS; ++i) + tmp_size *= dim[i]; + + tensor.data.size = (tmp_size * sizeof(float)); +#else tensor.dimensions = &dims; for (int i = 0; i < tensor.dimensions->size; ++i) tensor.dimensions->buf[i] = dim[i]; - tensor.type = fp32; + tensor.type = WASI_NN_TYPE_NAME(fp32); tensor.data = (uint8_t *)input_tensor; - wasi_nn_error err = set_input(ctx, 0, &tensor); +#endif + + WASI_NN_ERROR_TYPE err = WASI_NN_NAME(set_input)(ctx, 0, &tensor); free(dims.buf); return err; } -wasi_nn_error -wasm_compute(graph_execution_context ctx) +WASI_NN_ERROR_TYPE +wasm_compute(WASI_NN_NAME(graph_execution_context) ctx) { - return compute(ctx); + return WASI_NN_NAME(compute)(ctx); } -wasi_nn_error -wasm_get_output(graph_execution_context ctx, uint32_t index, float *out_tensor, +WASI_NN_ERROR_TYPE +wasm_get_output(WASI_NN_NAME(graph_execution_context) ctx, uint32_t index, float *out_tensor, uint32_t *out_size) { - return get_output(ctx, index, (uint8_t *)out_tensor, out_size); +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + return WASI_NN_NAME(get_output)(ctx, index, (uint8_t *)out_tensor, MAX_OUTPUT_TENSOR_SIZE, out_size); +#else + return WASI_NN_NAME(get_output)(ctx, index, (uint8_t *)out_tensor, out_size); +#endif } float * -run_inference(execution_target target, float *input, uint32_t *input_size, +run_inference(float *input, uint32_t *input_size, uint32_t *output_size, char *model_name, uint32_t num_output_tensors) { - graph graph; + WASI_NN_NAME(graph) graph; - if (wasm_load_by_name(model_name, &graph) != success) { + if (wasm_load_by_name(model_name, &graph) != WASI_NN_ERROR_NAME(success)) { NN_ERR_PRINTF("Error when loading model."); exit(1); } - graph_execution_context ctx; - if (wasm_init_execution_context(graph, &ctx) != success) { + WASI_NN_NAME(graph_execution_context) ctx; + if (wasm_init_execution_context(graph, &ctx) != WASI_NN_ERROR_NAME(success)) { NN_ERR_PRINTF("Error when initialixing execution context."); exit(1); } - if (wasm_set_input(ctx, input, input_size) != success) { + if (wasm_set_input(ctx, input, input_size) != WASI_NN_ERROR_NAME(success)) { NN_ERR_PRINTF("Error when setting input tensor."); exit(1); } - if (wasm_compute(ctx) != success) { + if (wasm_compute(ctx) != WASI_NN_ERROR_NAME(success)) { NN_ERR_PRINTF("Error when running inference."); exit(1); } @@ -140,7 +168,7 @@ run_inference(execution_target target, float *input, uint32_t *input_size, for (int i = 0; i < num_output_tensors; ++i) { *output_size = MAX_OUTPUT_TENSOR_SIZE - *output_size; if (wasm_get_output(ctx, i, &out_tensor[offset], output_size) - != success) { + != WASI_NN_ERROR_NAME(success)) { NN_ERR_PRINTF("Error when getting index %d.", i); break; } diff --git a/core/iwasm/libraries/wasi-nn/test/utils.h b/core/iwasm/libraries/wasi-nn/test/utils.h index e0d241772..45ba156a0 100644 --- a/core/iwasm/libraries/wasi-nn/test/utils.h +++ b/core/iwasm/libraries/wasi-nn/test/utils.h @@ -8,6 +8,7 @@ #include +#include "wasi_ephemeral_nn.h" #include "wasi_nn_types.h" #define MAX_MODEL_SIZE 85000000 @@ -23,26 +24,26 @@ typedef struct { /* wasi-nn wrappers */ -wasi_nn_error -wasm_load(char *model_name, graph *g, execution_target target); +WASI_NN_ERROR_TYPE +wasm_load(char *model_name, WASI_NN_NAME(graph) *g, WASI_NN_NAME(execution_target) target); -wasi_nn_error -wasm_init_execution_context(graph g, graph_execution_context *ctx); +WASI_NN_ERROR_TYPE +wasm_init_execution_context(WASI_NN_NAME(graph) g, WASI_NN_NAME(graph_execution_context) *ctx); -wasi_nn_error -wasm_set_input(graph_execution_context ctx, float *input_tensor, uint32_t *dim); +WASI_NN_ERROR_TYPE +wasm_set_input(WASI_NN_NAME(graph_execution_context) ctx, float *input_tensor, uint32_t *dim); -wasi_nn_error -wasm_compute(graph_execution_context ctx); +WASI_NN_ERROR_TYPE +wasm_compute(WASI_NN_NAME(graph_execution_context) ctx); -wasi_nn_error -wasm_get_output(graph_execution_context ctx, uint32_t index, float *out_tensor, +WASI_NN_ERROR_TYPE +wasm_get_output(WASI_NN_NAME(graph_execution_context) ctx, uint32_t index, float *out_tensor, uint32_t *out_size); /* Utils */ float * -run_inference(execution_target target, float *input, uint32_t *input_size, +run_inference(float *input, uint32_t *input_size, uint32_t *output_size, char *model_name, uint32_t num_output_tensors); diff --git a/product-mini/platforms/posix/main.c b/product-mini/platforms/posix/main.c index 2d7d3afeb..ef99f2a84 100644 --- a/product-mini/platforms/posix/main.c +++ b/product-mini/platforms/posix/main.c @@ -18,6 +18,10 @@ #include "../common/libc_wasi.c" #endif +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 +#include "wasi_ephemeral_nn.h" +#endif + #include "../common/wasm_proposal.c" #if BH_HAS_DLFCN @@ -115,6 +119,12 @@ print_help(void) #endif #if WASM_ENABLE_STATIC_PGO != 0 printf(" --gen-prof-file= Generate LLVM PGO (Profile-Guided Optimization) profile file\n"); +#endif +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + printf(" --wasi-nn-graph=encoding:target:::...:\n"); + printf(" Set encoding, target and model_paths for wasi-nn. target can be\n"); + printf(" cpu|gpu|tpu, encoding can be tensorflowlite|openvino|llama|onnx|\n"); + printf(" tensorflow|pytorch|ggml|autodetect\n"); #endif printf(" --version Show version information\n"); return 1; @@ -635,6 +645,13 @@ main(int argc, char *argv[]) int timeout_ms = -1; #endif +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + struct wasi_nn_graph_registry *nn_registry; + char *encoding, *target; + uint32_t n_models = 0; + char **model_paths; +#endif + #if WASM_ENABLE_LIBC_WASI != 0 memset(&wasi_parse_ctx, 0, sizeof(wasi_parse_ctx)); #endif @@ -825,6 +842,37 @@ main(int argc, char *argv[]) wasm_proposal_print_status(); return 0; } +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + else if (!strncmp(argv[0], "--wasi-nn-graph=", 16)) { + char *token; + char *saveptr = NULL; + int token_count = 0; + char *tokens[12] = {0}; + + // encoding:tensorflowlite|openvino|llama target:cpu|gpu|tpu + // --wasi-nn-graph=encoding:target:model_file_path1:model_file_path2:model_file_path3:...... + token = strtok_r(argv[0] + 16, ":", &saveptr); + while (token) { + tokens[token_count] = token; + token_count++; + token = strtok_r(NULL, ":", &saveptr); + } + + if (token_count < 2) { + return print_help(); + } + + n_models = token_count - 2; + encoding = strdup(tokens[0]); + target = strdup(tokens[1]); + model_paths = malloc(n_models * sizeof(void*)); + for (int i = 0; i < n_models; i++) { + model_paths[i] = strdup(tokens[i + 2]); + } + if (token) + free(token); + } +#endif else { #if WASM_ENABLE_LIBC_WASI != 0 libc_wasi_parse_result_t result = @@ -974,6 +1022,11 @@ main(int argc, char *argv[]) libc_wasi_set_init_args(inst_args, argc, argv, &wasi_parse_ctx); #endif +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + wasi_nn_graph_registry_create(&nn_registry); + wasi_nn_graph_registry_set_args(nn_registry, encoding, target, n_models, model_paths); + wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(inst_args, nn_registry); +#endif /* instantiate the module */ wasm_module_inst = wasm_runtime_instantiate_ex2( wasm_module, inst_args, error_buf, sizeof(error_buf)); @@ -1092,6 +1145,15 @@ fail5: #endif #if WASM_ENABLE_DEBUG_INTERP != 0 fail4: +#endif +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + wasi_nn_graph_registry_destroy(nn_registry); + for (uint32_t i = 0; i < n_models; i++) + if (model_paths[i]) + free(model_paths[i]); + free(model_paths); + free(encoding); + free(target); #endif /* destroy the module instance */ wasm_runtime_deinstantiate(wasm_module_inst);