Add the way to set the target evenif we use load_by_name

This commit is contained in:
QiuYuan Han 2025-12-10 13:52:50 +08:00 committed by qinzh
parent 2063ac1688
commit 96cdfa63ad
21 changed files with 659 additions and 128 deletions

View File

@ -25,6 +25,10 @@ static NativeSymbolsList g_native_symbols_list = NULL;
static void *g_wasi_context_key;
#endif /* WASM_ENABLE_LIBC_WASI */
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
static void *g_wasi_nn_context_key;
#endif
uint32
get_libc_builtin_export_apis(NativeSymbol **p_libc_builtin_apis);
@ -473,6 +477,31 @@ wasi_context_dtor(WASMModuleInstanceCommon *inst, void *ctx)
}
#endif /* end of WASM_ENABLE_LIBC_WASI */
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
WASINNGlobalContext *
wasm_runtime_get_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm)
{
return wasm_native_get_context(module_inst_comm, g_wasi_nn_context_key);
}
void
wasm_runtime_set_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm,
WASINNGlobalContext *wasi_nn_ctx)
{
wasm_native_set_context(module_inst_comm, g_wasi_nn_context_key, wasi_nn_ctx);
}
static void
wasi_nn_context_dtor(WASMModuleInstanceCommon *inst, void *ctx)
{
if (ctx == NULL) {
return;
}
wasm_runtime_destroy_wasi_nn_global_ctx(inst);
}
#endif
#if WASM_ENABLE_QUICK_AOT_ENTRY != 0
static bool
quick_aot_entry_init(void);
@ -582,6 +611,11 @@ wasm_native_init()
#endif /* WASM_ENABLE_LIB_RATS */
#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
g_wasi_nn_context_key = wasm_native_create_context_key(wasi_nn_context_dtor);
if (g_wasi_nn_context_key == NULL) {
goto fail;
}
if (!wasi_nn_initialize())
goto fail;
@ -648,6 +682,10 @@ wasm_native_destroy()
#endif
#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
if (g_wasi_nn_context_key != NULL) {
wasm_native_destroy_context_key(g_wasi_nn_context_key);
g_wasi_nn_context_key = NULL;
}
wasi_nn_destroy();
#endif

View File

@ -1696,6 +1696,67 @@ wasm_runtime_instantiation_args_destroy(struct InstantiationArgs2 *p)
wasm_runtime_free(p);
}
#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0)
struct wasi_nn_graph_registry;
void
wasm_runtime_wasi_nn_graph_registry_args_set_defaults(struct wasi_nn_graph_registry *args)
{
memset(args, 0, sizeof(*args));
}
bool
wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, const char* encoding,
const char* target, uint32_t n_graphs,
const char** graph_paths)
{
if (!registry || !encoding || !target || !graph_paths)
{
return false;
}
registry->encoding = strdup(encoding);
registry->target = strdup(target);
registry->n_graphs = n_graphs;
registry->graph_paths = (uint32_t**)malloc(sizeof(uint32_t*) * n_graphs);
memset(registry->graph_paths, 0, sizeof(uint32_t*) * n_graphs);
for (uint32_t i = 0; i < registry->n_graphs; i++)
registry->graph_paths[i] = strdup(graph_paths[i]);
return true;
}
int
wasi_nn_graph_registry_create(struct wasi_nn_graph_registry **registryp)
{
struct wasi_nn_graph_registry *args = wasm_runtime_malloc(sizeof(*args));
if (args == NULL) {
return false;
}
wasm_runtime_wasi_nn_graph_registry_args_set_defaults(args);
*registryp = args;
return 0;
}
void
wasi_nn_graph_registry_destroy(struct wasi_nn_graph_registry *registry)
{
if (registry)
{
for (uint32_t i = 0; i < registry->n_graphs; i++)
if (registry->graph_paths[i])
{
// wasi_nn_graph_registry_unregister_graph(registry, registry->name[i]);
free(registry->graph_paths[i]);
}
if (registry->encoding)
free(registry->encoding);
if (registry->target)
free(registry->target);
free(registry);
}
}
#endif
void
wasm_runtime_instantiation_args_set_default_stack_size(
struct InstantiationArgs2 *p, uint32 v)
@ -1794,6 +1855,14 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool(
wasi_args->set_by_user = true;
}
#endif /* WASM_ENABLE_LIBC_WASI != 0 */
#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0)
void
wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(
struct InstantiationArgs2 *p, struct wasi_nn_graph_registry *registry)
{
p->nn_registry = *registry;
}
#endif
WASMModuleInstanceCommon *
wasm_runtime_instantiate_ex2(WASMModuleCommon *module,
@ -8080,3 +8149,114 @@ wasm_runtime_check_and_update_last_used_shared_heap(
return false;
}
#endif
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
bool
wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst,
const char* encoding, const char* target,
const uint32_t n_graphs, char* graph_paths[],
char *error_buf, uint32_t error_buf_size)
{
WASINNGlobalContext *ctx;
bool ret = false;
ctx = runtime_malloc(sizeof(*ctx), module_inst, error_buf, error_buf_size);
if (!ctx)
return false;
ctx->encoding = strdup(encoding);
ctx->target = strdup(target);
ctx->n_graphs = n_graphs;
ctx->loaded = (uint32_t*)malloc(sizeof(uint32_t) * n_graphs);
memset(ctx->loaded, 0, sizeof(uint32_t) * n_graphs);
ctx->graph_paths = (uint32_t**)malloc(sizeof(uint32_t*) * n_graphs);
memset(ctx->graph_paths, 0, sizeof(uint32_t*) * n_graphs);
for (uint32_t i = 0; i < n_graphs; i++)
{
ctx->graph_paths[i] = strdup(graph_paths[i]);
}
wasm_runtime_set_wasi_nn_global_ctx(module_inst, ctx);
ret = true;
return ret;
}
void
wasm_runtime_destroy_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst)
{
WASINNGlobalContext *wasi_nn_global_ctx = wasm_runtime_get_wasi_nn_global_ctx(module_inst);
for (uint32 i = 0; i < wasi_nn_global_ctx->n_graphs; i++)
{
// All graphs will be unregistered in deinit()
if (wasi_nn_global_ctx->graph_paths[i])
free(wasi_nn_global_ctx->graph_paths[i]);
}
free(wasi_nn_global_ctx->encoding);
free(wasi_nn_global_ctx->target);
free(wasi_nn_global_ctx->loaded);
free(wasi_nn_global_ctx->graph_paths);
if (wasi_nn_global_ctx) {
wasm_runtime_free(wasi_nn_global_ctx);
}
}
uint32_t
wasm_runtime_get_wasi_nn_global_ctx_ngraphs(WASINNGlobalContext *wasi_nn_global_ctx)
{
if (wasi_nn_global_ctx)
return wasi_nn_global_ctx->n_graphs;
return -1;
}
char *
wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx)
{
if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs))
return wasi_nn_global_ctx->graph_paths[idx];
return NULL;
}
uint32_t
wasm_runtime_get_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx)
{
if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs))
return wasi_nn_global_ctx->loaded[idx];
return -1;
}
uint32_t
wasm_runtime_set_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value)
{
if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs))
wasi_nn_global_ctx->loaded[idx] = value;
return 0;
}
char*
wasm_runtime_get_wasi_nn_global_ctx_encoding(WASINNGlobalContext *wasi_nn_global_ctx)
{
if (wasi_nn_global_ctx)
return wasi_nn_global_ctx->encoding;
return NULL;
}
char*
wasm_runtime_get_wasi_nn_global_ctx_target(WASINNGlobalContext *wasi_nn_global_ctx)
{
if (wasi_nn_global_ctx)
return wasi_nn_global_ctx->target;
return NULL;
}
#endif

View File

@ -545,6 +545,17 @@ typedef struct WASMModuleInstMemConsumption {
uint32 exports_size;
} WASMModuleInstMemConsumption;
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
typedef struct WASINNGlobalContext {
char* encoding;
char* target;
uint32_t n_graphs;
uint32_t *loaded;
char** graph_paths;
} WASINNGlobalContext;
#endif
#if WASM_ENABLE_LIBC_WASI != 0
#if WASM_ENABLE_UVWASI == 0
typedef struct WASIContext {
@ -612,11 +623,30 @@ WASMExecEnv *
wasm_runtime_get_exec_env_tls(void);
#endif
#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0)
struct wasi_nn_graph_registry {
char* encoding;
char* target;
char** graph_paths;
uint32_t n_graphs;
};
WASM_RUNTIME_API_EXTERN int
wasi_nn_graph_registry_create(struct wasi_nn_graph_registry **registryp);
WASM_RUNTIME_API_EXTERN void
wasi_nn_graph_registry_destroy(struct wasi_nn_graph_registry *registry);
#endif
struct InstantiationArgs2 {
InstantiationArgs v1;
#if WASM_ENABLE_LIBC_WASI != 0
WASIArguments wasi;
#endif
#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0)
struct wasi_nn_graph_registry nn_registry;
#endif
};
void
@ -775,6 +805,17 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool(
struct InstantiationArgs2 *p, const char *ns_lookup_pool[],
uint32 ns_lookup_pool_size);
#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0)
WASM_RUNTIME_API_EXTERN void
wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(
struct InstantiationArgs2 *p, struct wasi_nn_graph_registry *registry);
WASM_RUNTIME_API_EXTERN bool
wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, const char* encoding,
const char* target, uint32_t n_graphs,
const char** graph_paths);
#endif
/* See wasm_export.h for description */
WASM_RUNTIME_API_EXTERN WASMModuleInstanceCommon *
wasm_runtime_instantiate_ex2(WASMModuleCommon *module,
@ -1427,6 +1468,39 @@ wasm_runtime_check_and_update_last_used_shared_heap(
uint8 **shared_heap_base_addr_adj_p, bool is_memory64);
#endif
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
WASM_RUNTIME_API_EXTERN bool
wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst,
const char* encoding, const char* target,
const uint32_t n_graphs, char* graph_paths[],
char *error_buf, uint32_t error_buf_size);
WASM_RUNTIME_API_EXTERN void
wasm_runtime_destroy_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst);
WASM_RUNTIME_API_EXTERN void
wasm_runtime_set_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst,
WASINNGlobalContext *wasi_ctx);
WASM_RUNTIME_API_EXTERN uint32_t
wasm_runtime_get_wasi_nn_global_ctx_ngraphs(WASINNGlobalContext *wasi_nn_global_ctx);
WASM_RUNTIME_API_EXTERN char *
wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx);
WASM_RUNTIME_API_EXTERN uint32_t
wasm_runtime_get_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx);
WASM_RUNTIME_API_EXTERN uint32_t
wasm_runtime_set_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value);
WASM_RUNTIME_API_EXTERN char*
wasm_runtime_get_wasi_nn_global_ctx_encoding(WASINNGlobalContext *wasi_nn_global_ctx);
WASM_RUNTIME_API_EXTERN char*
wasm_runtime_get_wasi_nn_global_ctx_target(WASINNGlobalContext *wasi_nn_global_ctx);
#endif
#ifdef __cplusplus
}
#endif

View File

@ -290,6 +290,8 @@ typedef struct InstantiationArgs {
#endif /* INSTANTIATION_ARGS_OPTION_DEFINED */
struct InstantiationArgs2;
struct WASINNGlobalContext;
typedef struct WASINNGlobalContext *wasi_nn_global_context;
#ifndef WASM_VALKIND_T_DEFINED
#define WASM_VALKIND_T_DEFINED
@ -796,6 +798,55 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool(
struct InstantiationArgs2 *p, const char *ns_lookup_pool[],
uint32_t ns_lookup_pool_size);
// WASM_RUNTIME_API_EXTERN int
// wasi_nn_graph_registry_create(struct wasi_nn_graph_registry **registryp);
// WASM_RUNTIME_API_EXTERN void
// wasi_nn_graph_registry_destroy(struct wasi_nn_graph_registry *registry);
// WASM_RUNTIME_API_EXTERN void
// wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(
// struct InstantiationArgs2 *p, struct wasi_nn_graph_registry *registry);
// WASM_RUNTIME_API_EXTERN bool
// wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, const char* encoding,
// const char* target, uint32_t n_graphs,
// const char** graph_paths);
WASM_RUNTIME_API_EXTERN bool
wasm_runtime_init_wasi_nn_global_ctx(wasm_module_inst_t module_inst,
const char* encoding, const char* target,
const uint32_t n_graphs, char* graph_paths[],
char *error_buf, uint32_t error_buf_size);
WASM_RUNTIME_API_EXTERN void
wasm_runtime_destroy_wasi_nn_global_ctx(wasm_module_inst_t module_inst);
WASM_RUNTIME_API_EXTERN void
wasm_runtime_set_wasi_nn_global_ctx(wasm_module_inst_t module_inst,
wasi_nn_global_context wasi_ctx);
WASM_RUNTIME_API_EXTERN wasi_nn_global_context
wasm_runtime_get_wasi_nn_global_ctx(const wasm_module_inst_t module_inst);
WASM_RUNTIME_API_EXTERN uint32_t
wasm_runtime_get_wasi_nn_global_ctx_ngraphs(wasi_nn_global_context wasi_nn_global_ctx);
WASM_RUNTIME_API_EXTERN char *
wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(wasi_nn_global_context wasi_nn_global_ctx, uint32_t idx);
WASM_RUNTIME_API_EXTERN uint32_t
wasm_runtime_get_wasi_nn_global_ctx_loaded_i(wasi_nn_global_context wasi_nn_global_ctx, uint32_t idx);
WASM_RUNTIME_API_EXTERN uint32_t
wasm_runtime_set_wasi_nn_global_ctx_loaded_i(wasi_nn_global_context wasi_nn_global_ctx, uint32_t idx, uint32_t value);
WASM_RUNTIME_API_EXTERN char*
wasm_runtime_get_wasi_nn_global_ctx_encoding(wasi_nn_global_context wasi_nn_global_ctx);
WASM_RUNTIME_API_EXTERN char*
wasm_runtime_get_wasi_nn_global_ctx_target(wasi_nn_global_context wasi_nn_global_ctx);
/**
* Instantiate a WASM module, with specified instantiation arguments
*

View File

@ -3300,6 +3300,18 @@ wasm_instantiate(WASMModule *module, WASMModuleInstance *parent,
}
#endif
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
/* Store graphs' path into ctx. Graphs will be loaded until user app calls load_by_name */
// Do not consider load() for now
struct wasi_nn_graph_registry *nn_registry = &args->nn_registry;
if (!wasm_runtime_init_wasi_nn_global_ctx(
(WASMModuleInstanceCommon *)module_inst, nn_registry->encoding,
nn_registry->target, nn_registry->n_graphs, nn_registry->graph_paths,
error_buf, error_buf_size)) {
goto fail;
}
#endif
#if WASM_ENABLE_DEBUG_INTERP != 0
if (!is_sub_inst) {
/* Add module instance into module's instance list */

View File

@ -8,5 +8,5 @@
#include "wasi_nn.h"
#undef WASM_ENABLE_WASI_EPHEMERAL_NN
#undef WASI_NN_NAME
// #undef WASM_ENABLE_WASI_EPHEMERAL_NN
// #undef WASI_NN_NAME

View File

@ -21,7 +21,7 @@
#else
#define WASI_NN_IMPORT(name) \
__attribute__((import_module("wasi_nn"), import_name(name)))
#warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It's deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.)
#warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It is deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.)
#endif
/**

View File

@ -27,7 +27,7 @@ extern "C" {
#define WASI_NN_TYPE_NAME(name) WASI_NN_NAME(type_##name)
#define WASI_NN_ENCODING_NAME(name) WASI_NN_NAME(encoding_##name)
#define WASI_NN_TARGET_NAME(name) WASI_NN_NAME(target_##name)
#define WASI_NN_ERROR_TYPE WASI_NN_NAME(error);
#define WASI_NN_ERROR_TYPE WASI_NN_NAME(error)
#endif
/**
@ -169,6 +169,7 @@ typedef enum WASI_NN_NAME(execution_target) {
WASI_NN_TARGET_NAME(cpu) = 0,
WASI_NN_TARGET_NAME(gpu),
WASI_NN_TARGET_NAME(tpu),
WASI_NN_TARGET_NAME(unsupported_target),
} WASI_NN_NAME(execution_target);
// Bind a `graph` to the input and output tensors for an inference.

View File

@ -21,7 +21,7 @@
#include "wasm_export.h"
#if WASM_ENABLE_WASI_EPHEMERAL_NN == 0
#warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It's deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.)
#warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It is deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.)
#endif
#define HASHMAP_INITIAL_SIZE 20
@ -35,6 +35,8 @@
#define LLAMACPP_BACKEND_LIB "libwasi_nn_llamacpp" LIB_EXTENTION
#define ONNX_BACKEND_LIB "libwasi_nn_onnx" LIB_EXTENTION
#define MAX_GLOBAL_GRAPHS_PER_INST 4 // ONNX only allows 4 graphs per instances
/* Global variables */
static korp_mutex wasi_nn_lock;
/*
@ -208,6 +210,44 @@ wasi_nn_destroy()
* - model file format
* - on device ML framework
*/
static graph_encoding str2encoding(char* str_encoding)
{
if (!str_encoding) {
NN_ERR_PRINTF("Got empty string encoding");
return -1;
}
if (!strcmp(str_encoding, "openvino"))
return openvino;
else if (!strcmp(str_encoding, "tensorflowlite"))
return tensorflowlite;
else if (!strcmp(str_encoding, "ggml"))
return ggml;
else if (!strcmp(str_encoding, "onnx"))
return onnx;
else
return unknown_backend;
// return autodetect;
}
static execution_target str2target(char* str_target)
{
if (!str_target) {
NN_ERR_PRINTF("Got empty string target");
return -1;
}
if (!strcmp(str_target, "cpu"))
return cpu;
else if (!strcmp(str_target, "gpu"))
return gpu;
else if (!strcmp(str_target, "tpu"))
return tpu;
else
return unsupported_target;
// return autodetect;
}
static graph_encoding
choose_a_backend()
{
@ -565,17 +605,82 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
goto fail;
}
res = ensure_backend(instance, autodetect, wasi_nn_ctx);
wasi_nn_global_context wasi_nn_global_ctx = wasm_runtime_get_wasi_nn_global_ctx(instance);
if (!wasi_nn_global_ctx) {
NN_ERR_PRINTF("global context is invalid");
res = not_found;
goto fail;
}
graph_encoding encoding = str2encoding(wasm_runtime_get_wasi_nn_global_ctx_encoding(wasi_nn_global_ctx));
execution_target target = str2target(wasm_runtime_get_wasi_nn_global_ctx_target(wasi_nn_global_ctx));
// res = ensure_backend(instance, autodetect, wasi_nn_ctx);
res = ensure_backend(instance, encoding, wasi_nn_ctx);
if (res != success)
goto fail;
bool is_loaded = false;
uint32 model_idx = 0;
char *global_model_path_i;
// Assume filename got from user wasm app : max; sum; average; ...
// Assume file path got from user cmd opt: /your/path1/max.tflite; /your/path2/sum.tflite; ......
for (model_idx = 0; model_idx < wasm_runtime_get_wasi_nn_global_ctx_ngraphs(wasi_nn_global_ctx); model_idx++)
{
// Extract filename from file path
global_model_path_i = wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(wasi_nn_global_ctx, model_idx);
char *model_file_name;
const char *slash = strrchr(global_model_path_i, '/');
if (slash != NULL) {
model_file_name = (char*)(slash + 1);
}
else
model_file_name = global_model_path_i;
call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res,
wasi_nn_ctx->backend_ctx, nul_terminated_name, name_len,
g);
if (res != success)
// Extract modelname from filename
char* model_name = NULL;
size_t model_name_len = 0;
char* dot = strrchr(model_file_name, '.');
if (dot)
{
model_name_len = dot - model_file_name;
model_name = malloc(model_name_len + 1);
strncpy(model_name, model_file_name, model_name_len);
model_name[model_name_len] = '\0';
}
if (model_name && strcmp(nul_terminated_name, model_name) == 0) {
is_loaded = wasm_runtime_get_wasi_nn_global_ctx_loaded_i(wasi_nn_global_ctx, model_idx);
break;
}
}
if (!is_loaded && (model_idx < MAX_GLOBAL_GRAPHS_PER_INST))
{
NN_DBG_PRINTF("Model is not yet loaded, will add to global context");
call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res,
wasi_nn_ctx->backend_ctx, global_model_path_i, strlen(global_model_path_i),
encoding, target, g);
if (res != success)
goto fail;
wasm_runtime_set_wasi_nn_global_ctx_loaded_i(wasi_nn_global_ctx, model_idx, 1);
res = success;
}
else
{
if (is_loaded)
{
NN_DBG_PRINTF("Model is already loaded");
res = success;
}
else if (model_idx >= MAX_GLOBAL_GRAPHS_PER_INST)
{
// No enlarge for now
NN_ERR_PRINTF("No enough space for new model");
res = too_large;
}
goto fail;
res = success;
}
fail:
if (nul_terminated_name != NULL) {
wasm_runtime_free(nul_terminated_name);

View File

@ -17,7 +17,8 @@ load(void *ctx, graph_builder_array *builder, graph_encoding encoding,
execution_target target, graph *g);
__attribute__((visibility("default"))) wasi_nn_error
load_by_name(void *tflite_ctx, const char *name, uint32_t namelen, graph *g);
load_by_name(void *tflite_ctx, const char *name, uint32_t namelen,
graph_encoding encoding, execution_target target, graph *g);
__attribute__((visibility("default"))) wasi_nn_error
load_by_name_with_config(void *ctx, const char *name, uint32_t namelen,

View File

@ -338,7 +338,8 @@ __load_by_name_with_configuration(void *ctx, const char *filename, graph *g)
}
__attribute__((visibility("default"))) wasi_nn_error
load_by_name(void *ctx, const char *filename, uint32_t filename_len, graph *g)
load_by_name(void *ctx, const char *filename, uint32_t filename_len,
graph_encoding encoding, execution_target target, graph *g)
{
struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;

View File

@ -334,7 +334,8 @@ load(void *onnx_ctx, graph_builder_array *builder, graph_encoding encoding,
}
__attribute__((visibility("default"))) wasi_nn_error
load_by_name(void *onnx_ctx, const char *name, uint32_t filename_len, graph *g)
load_by_name(void *onnx_ctx, const char *name, uint32_t filename_len,
graph_encoding encoding, execution_target target, graph *g)
{
if (!onnx_ctx) {
return runtime_error;

View File

@ -306,7 +306,8 @@ fail:
}
__attribute__((visibility("default"))) wasi_nn_error
load_by_name(void *ctx, const char *filename, uint32_t filename_len, graph *g)
load_by_name(void *ctx, const char *filename, uint32_t filename_len,
graph_encoding encoding, execution_target target, graph *g)
{
OpenVINOContext *ov_ctx = (OpenVINOContext *)ctx;
struct OpenVINOGraph *graph;

View File

@ -21,7 +21,8 @@ typedef struct {
typedef wasi_nn_error (*LOAD)(void *, graph_builder_array *, graph_encoding,
execution_target, graph *);
typedef wasi_nn_error (*LOAD_BY_NAME)(void *, const char *, uint32_t, graph *);
typedef wasi_nn_error (*LOAD_BY_NAME)(void *, const char *, uint32_t, graph_encoding,
execution_target, graph *);
typedef wasi_nn_error (*LOAD_BY_NAME_WITH_CONFIG)(void *, const char *,
uint32_t, void *, uint32_t,
graph *);

View File

@ -164,8 +164,8 @@ load(void *tflite_ctx, graph_builder_array *builder, graph_encoding encoding,
}
__attribute__((visibility("default"))) wasi_nn_error
load_by_name(void *tflite_ctx, const char *filename, uint32_t filename_len,
graph *g)
load_by_name(void *tflite_ctx, const char *filename, uint32_t filename_len,
graph_encoding encoding, execution_target target,graph *g)
{
TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx;
@ -183,7 +183,7 @@ load_by_name(void *tflite_ctx, const char *filename, uint32_t filename_len,
}
// Use CPU as default
tfl_ctx->models[*g].target = cpu;
tfl_ctx->models[*g].target = target;
return success;
}

View File

@ -1,2 +1,2 @@
tensorflow==2.12.1
tensorflow==2.14.0
numpy==1.24.4

View File

@ -13,16 +13,16 @@
#include "logger.h"
void
test_sum(execution_target target)
test_sum()
{
int dims[] = { 1, 5, 5, 1 };
input_info input = create_input(dims);
uint32_t output_size = 0;
float *output = run_inference(target, input.input_tensor, input.dim,
&output_size, "./models/sum.tflite", 1);
float *output = run_inference(input.input_tensor, input.dim,
&output_size, "sum", 1);
assert(output_size == 1);
assert((output_size / sizeof(float)) == 1);
assert(fabs(output[0] - 300.0) < EPSILON);
free(input.dim);
@ -31,16 +31,16 @@ test_sum(execution_target target)
}
void
test_max(execution_target target)
test_max()
{
int dims[] = { 1, 5, 5, 1 };
input_info input = create_input(dims);
uint32_t output_size = 0;
float *output = run_inference(target, input.input_tensor, input.dim,
&output_size, "./models/max.tflite", 1);
float *output = run_inference(input.input_tensor, input.dim,
&output_size, "max", 1);
assert(output_size == 1);
assert((output_size / sizeof(float)) == 1);
assert(fabs(output[0] - 24.0) < EPSILON);
NN_INFO_PRINTF("Result: max is %f", output[0]);
@ -50,16 +50,16 @@ test_max(execution_target target)
}
void
test_average(execution_target target)
test_average()
{
int dims[] = { 1, 5, 5, 1 };
input_info input = create_input(dims);
uint32_t output_size = 0;
float *output = run_inference(target, input.input_tensor, input.dim,
&output_size, "./models/average.tflite", 1);
float *output = run_inference(input.input_tensor, input.dim,
&output_size, "average", 1);
assert(output_size == 1);
assert((output_size / sizeof(float)) == 1);
assert(fabs(output[0] - 12.0) < EPSILON);
NN_INFO_PRINTF("Result: average is %f", output[0]);
@ -69,16 +69,16 @@ test_average(execution_target target)
}
void
test_mult_dimensions(execution_target target)
test_mult_dimensions()
{
int dims[] = { 1, 3, 3, 1 };
input_info input = create_input(dims);
uint32_t output_size = 0;
float *output = run_inference(target, input.input_tensor, input.dim,
&output_size, "./models/mult_dim.tflite", 1);
float *output = run_inference(input.input_tensor, input.dim,
&output_size, "mult_dim", 1);
assert(output_size == 9);
assert((output_size / sizeof(float)) == 9);
for (int i = 0; i < 9; i++)
assert(fabs(output[i] - i) < EPSILON);
@ -88,16 +88,16 @@ test_mult_dimensions(execution_target target)
}
void
test_mult_outputs(execution_target target)
test_mult_outputs()
{
int dims[] = { 1, 4, 4, 1 };
input_info input = create_input(dims);
uint32_t output_size = 0;
float *output = run_inference(target, input.input_tensor, input.dim,
&output_size, "./models/mult_out.tflite", 2);
float *output = run_inference(input.input_tensor, input.dim,
&output_size, "mult_out", 2);
assert(output_size == 8);
assert((output_size / sizeof(float)) == 8);
// first tensor check
for (int i = 0; i < 4; i++)
assert(fabs(output[i] - (i * 4 + 24)) < EPSILON);
@ -113,30 +113,18 @@ test_mult_outputs(execution_target target)
int
main()
{
char *env = getenv("TARGET");
if (env == NULL) {
NN_INFO_PRINTF("Usage:\n--env=\"TARGET=[cpu|gpu]\"");
return 1;
}
execution_target target;
if (strcmp(env, "cpu") == 0)
target = cpu;
else if (strcmp(env, "gpu") == 0)
target = gpu;
else {
NN_ERR_PRINTF("Wrong target!");
return 1;
}
NN_INFO_PRINTF("Usage:\niwasm --native-lib=./libwasi_nn_tflite.so --wasi-nn-graph=encoding:target:model_path1:model_path2:...:model_pathn test_tensorflow.wasm\"");
NN_INFO_PRINTF("################### Testing sum...");
test_sum(target);
test_sum();
NN_INFO_PRINTF("################### Testing max...");
test_max(target);
test_max();
NN_INFO_PRINTF("################### Testing average...");
test_average(target);
test_average();
NN_INFO_PRINTF("################### Testing multiple dimensions...");
test_mult_dimensions(target);
test_mult_dimensions();
NN_INFO_PRINTF("################### Testing multiple outputs...");
test_mult_outputs(target);
test_mult_outputs();
NN_INFO_PRINTF("Tests: passed!");
return 0;

View File

@ -16,15 +16,15 @@
#define EPSILON 1e-2
void
test_average_quantized(execution_target target)
test_average_quantized()
{
int dims[] = { 1, 5, 5, 1 };
input_info input = create_input(dims);
uint32_t output_size = 0;
float *output =
run_inference(target, input.input_tensor, input.dim, &output_size,
"./models/quantized_model.tflite", 1);
run_inference(input.input_tensor, input.dim, &output_size,
"quantized_model", 1);
NN_INFO_PRINTF("Output size: %d", output_size);
NN_INFO_PRINTF("Result: average is %f", output[0]);
@ -39,24 +39,10 @@ test_average_quantized(execution_target target)
int
main()
{
char *env = getenv("TARGET");
if (env == NULL) {
NN_INFO_PRINTF("Usage:\n--env=\"TARGET=[cpu|gpu|tpu]\"");
return 1;
}
execution_target target;
if (strcmp(env, "cpu") == 0)
target = cpu;
else if (strcmp(env, "gpu") == 0)
target = gpu;
else if (strcmp(env, "tpu") == 0)
target = tpu;
else {
NN_ERR_PRINTF("Wrong target!");
return 1;
}
NN_INFO_PRINTF("Usage:\niwasm --native-lib=./libwasi_nn_tflite.so --wasi-nn-graph=encoding:target:model_path1:model_path2:...:model_pathn test_tensorflow.wasm\"");
NN_INFO_PRINTF("################### Testing quantized model...");
test_average_quantized(target);
test_average_quantized();
NN_INFO_PRINTF("Tests: passed!");
return 0;

View File

@ -5,17 +5,15 @@
#include "utils.h"
#include "logger.h"
#include "wasi_nn.h"
#include <stdio.h>
#include <stdlib.h>
wasi_nn_error
wasm_load(char *model_name, graph *g, execution_target target)
WASI_NN_ERROR_TYPE
wasm_load(char *model_name, WASI_NN_NAME(graph) *g, WASI_NN_NAME(execution_target) target)
{
FILE *pFile = fopen(model_name, "r");
if (pFile == NULL)
return invalid_argument;
return WASI_NN_ERROR_NAME(invalid_argument);
uint8_t *buffer;
size_t result;
@ -24,20 +22,29 @@ wasm_load(char *model_name, graph *g, execution_target target)
buffer = (uint8_t *)malloc(sizeof(uint8_t) * MAX_MODEL_SIZE);
if (buffer == NULL) {
fclose(pFile);
return too_large;
return WASI_NN_ERROR_NAME(too_large);
}
result = fread(buffer, 1, MAX_MODEL_SIZE, pFile);
if (result <= 0) {
fclose(pFile);
free(buffer);
return too_large;
return WASI_NN_ERROR_NAME(too_large);
}
graph_builder_array arr;
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
WASI_NN_NAME(graph_builder) arr;
arr.buf = buffer;
arr.size = result;
WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load)(&arr, result, WASI_NN_ENCODING_NAME(tensorflowlite), target, g);
// WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load)(&arr, 1, WASI_NN_ENCODING_NAME(tensorflowlite), target, g);
#else
WASI_NN_NAME(graph_builder_array) arr;
arr.size = 1;
arr.buf = (graph_builder *)malloc(sizeof(graph_builder));
arr.buf = (WASI_NN_NAME(graph_builder) *)malloc(sizeof(WASI_NN_NAME(graph_builder)));
if (arr.buf == NULL) {
fclose(pFile);
free(buffer);
@ -47,7 +54,8 @@ wasm_load(char *model_name, graph *g, execution_target target)
arr.buf[0].size = result;
arr.buf[0].buf = buffer;
wasi_nn_error res = load(&arr, tensorflowlite, target, g);
WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load)(&arr, WASI_NN_ENCODING_NAME(tensorflowlite), target, g);
#endif
fclose(pFile);
free(buffer);
@ -55,77 +63,97 @@ wasm_load(char *model_name, graph *g, execution_target target)
return res;
}
wasi_nn_error
wasm_load_by_name(const char *model_name, graph *g)
WASI_NN_ERROR_TYPE
wasm_load_by_name(const char *model_name, WASI_NN_NAME(graph) *g)
{
wasi_nn_error res = load_by_name(model_name, strlen(model_name), g);
WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load_by_name)(model_name, strlen(model_name), g);
return res;
}
wasi_nn_error
wasm_init_execution_context(graph g, graph_execution_context *ctx)
WASI_NN_ERROR_TYPE
wasm_init_execution_context(WASI_NN_NAME(graph) g, WASI_NN_NAME(graph_execution_context) *ctx)
{
return init_execution_context(g, ctx);
return WASI_NN_NAME(init_execution_context)(g, ctx);
}
wasi_nn_error
wasm_set_input(graph_execution_context ctx, float *input_tensor, uint32_t *dim)
WASI_NN_ERROR_TYPE
wasm_set_input(WASI_NN_NAME(graph_execution_context) ctx, float *input_tensor, uint32_t *dim)
{
tensor_dimensions dims;
WASI_NN_NAME(tensor_dimensions) dims;
dims.size = INPUT_TENSOR_DIMS;
dims.buf = (uint32_t *)malloc(dims.size * sizeof(uint32_t));
if (dims.buf == NULL)
return too_large;
return WASI_NN_ERROR_NAME(too_large);
tensor tensor;
WASI_NN_NAME(tensor) tensor;
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
tensor.dimensions = dims;
for (int i = 0; i < tensor.dimensions.size; ++i)
tensor.dimensions.buf[i] = dim[i];
tensor.type = WASI_NN_TYPE_NAME(fp32);
tensor.data.buf = (uint8_t *)input_tensor;
uint32_t tmp_size = 1;
if (dim)
for (int i = 0; i < INPUT_TENSOR_DIMS; ++i)
tmp_size *= dim[i];
tensor.data.size = (tmp_size * sizeof(float));
#else
tensor.dimensions = &dims;
for (int i = 0; i < tensor.dimensions->size; ++i)
tensor.dimensions->buf[i] = dim[i];
tensor.type = fp32;
tensor.type = WASI_NN_TYPE_NAME(fp32);
tensor.data = (uint8_t *)input_tensor;
wasi_nn_error err = set_input(ctx, 0, &tensor);
#endif
WASI_NN_ERROR_TYPE err = WASI_NN_NAME(set_input)(ctx, 0, &tensor);
free(dims.buf);
return err;
}
wasi_nn_error
wasm_compute(graph_execution_context ctx)
WASI_NN_ERROR_TYPE
wasm_compute(WASI_NN_NAME(graph_execution_context) ctx)
{
return compute(ctx);
return WASI_NN_NAME(compute)(ctx);
}
wasi_nn_error
wasm_get_output(graph_execution_context ctx, uint32_t index, float *out_tensor,
WASI_NN_ERROR_TYPE
wasm_get_output(WASI_NN_NAME(graph_execution_context) ctx, uint32_t index, float *out_tensor,
uint32_t *out_size)
{
return get_output(ctx, index, (uint8_t *)out_tensor, out_size);
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
return WASI_NN_NAME(get_output)(ctx, index, (uint8_t *)out_tensor, MAX_OUTPUT_TENSOR_SIZE, out_size);
#else
return WASI_NN_NAME(get_output)(ctx, index, (uint8_t *)out_tensor, out_size);
#endif
}
float *
run_inference(execution_target target, float *input, uint32_t *input_size,
run_inference(float *input, uint32_t *input_size,
uint32_t *output_size, char *model_name,
uint32_t num_output_tensors)
{
graph graph;
WASI_NN_NAME(graph) graph;
if (wasm_load_by_name(model_name, &graph) != success) {
if (wasm_load_by_name(model_name, &graph) != WASI_NN_ERROR_NAME(success)) {
NN_ERR_PRINTF("Error when loading model.");
exit(1);
}
graph_execution_context ctx;
if (wasm_init_execution_context(graph, &ctx) != success) {
WASI_NN_NAME(graph_execution_context) ctx;
if (wasm_init_execution_context(graph, &ctx) != WASI_NN_ERROR_NAME(success)) {
NN_ERR_PRINTF("Error when initialixing execution context.");
exit(1);
}
if (wasm_set_input(ctx, input, input_size) != success) {
if (wasm_set_input(ctx, input, input_size) != WASI_NN_ERROR_NAME(success)) {
NN_ERR_PRINTF("Error when setting input tensor.");
exit(1);
}
if (wasm_compute(ctx) != success) {
if (wasm_compute(ctx) != WASI_NN_ERROR_NAME(success)) {
NN_ERR_PRINTF("Error when running inference.");
exit(1);
}
@ -140,7 +168,7 @@ run_inference(execution_target target, float *input, uint32_t *input_size,
for (int i = 0; i < num_output_tensors; ++i) {
*output_size = MAX_OUTPUT_TENSOR_SIZE - *output_size;
if (wasm_get_output(ctx, i, &out_tensor[offset], output_size)
!= success) {
!= WASI_NN_ERROR_NAME(success)) {
NN_ERR_PRINTF("Error when getting index %d.", i);
break;
}

View File

@ -8,6 +8,7 @@
#include <stdint.h>
#include "wasi_ephemeral_nn.h"
#include "wasi_nn_types.h"
#define MAX_MODEL_SIZE 85000000
@ -23,26 +24,26 @@ typedef struct {
/* wasi-nn wrappers */
wasi_nn_error
wasm_load(char *model_name, graph *g, execution_target target);
WASI_NN_ERROR_TYPE
wasm_load(char *model_name, WASI_NN_NAME(graph) *g, WASI_NN_NAME(execution_target) target);
wasi_nn_error
wasm_init_execution_context(graph g, graph_execution_context *ctx);
WASI_NN_ERROR_TYPE
wasm_init_execution_context(WASI_NN_NAME(graph) g, WASI_NN_NAME(graph_execution_context) *ctx);
wasi_nn_error
wasm_set_input(graph_execution_context ctx, float *input_tensor, uint32_t *dim);
WASI_NN_ERROR_TYPE
wasm_set_input(WASI_NN_NAME(graph_execution_context) ctx, float *input_tensor, uint32_t *dim);
wasi_nn_error
wasm_compute(graph_execution_context ctx);
WASI_NN_ERROR_TYPE
wasm_compute(WASI_NN_NAME(graph_execution_context) ctx);
wasi_nn_error
wasm_get_output(graph_execution_context ctx, uint32_t index, float *out_tensor,
WASI_NN_ERROR_TYPE
wasm_get_output(WASI_NN_NAME(graph_execution_context) ctx, uint32_t index, float *out_tensor,
uint32_t *out_size);
/* Utils */
float *
run_inference(execution_target target, float *input, uint32_t *input_size,
run_inference(float *input, uint32_t *input_size,
uint32_t *output_size, char *model_name,
uint32_t num_output_tensors);

View File

@ -18,6 +18,10 @@
#include "../common/libc_wasi.c"
#endif
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
#include "wasi_ephemeral_nn.h"
#endif
#include "../common/wasm_proposal.c"
#if BH_HAS_DLFCN
@ -115,6 +119,12 @@ print_help(void)
#endif
#if WASM_ENABLE_STATIC_PGO != 0
printf(" --gen-prof-file=<path> Generate LLVM PGO (Profile-Guided Optimization) profile file\n");
#endif
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
printf(" --wasi-nn-graph=encoding:target:<model_path1>:<model_path2>:...:<model_pathn>\n");
printf(" Set encoding, target and model_paths for wasi-nn. target can be\n");
printf(" cpu|gpu|tpu, encoding can be tensorflowlite|openvino|llama|onnx|\n");
printf(" tensorflow|pytorch|ggml|autodetect\n");
#endif
printf(" --version Show version information\n");
return 1;
@ -635,6 +645,13 @@ main(int argc, char *argv[])
int timeout_ms = -1;
#endif
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
struct wasi_nn_graph_registry *nn_registry;
char *encoding, *target;
uint32_t n_models = 0;
char **model_paths;
#endif
#if WASM_ENABLE_LIBC_WASI != 0
memset(&wasi_parse_ctx, 0, sizeof(wasi_parse_ctx));
#endif
@ -825,6 +842,37 @@ main(int argc, char *argv[])
wasm_proposal_print_status();
return 0;
}
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
else if (!strncmp(argv[0], "--wasi-nn-graph=", 16)) {
char *token;
char *saveptr = NULL;
int token_count = 0;
char *tokens[12] = {0};
// encoding:tensorflowlite|openvino|llama target:cpu|gpu|tpu
// --wasi-nn-graph=encoding:target:model_file_path1:model_file_path2:model_file_path3:......
token = strtok_r(argv[0] + 16, ":", &saveptr);
while (token) {
tokens[token_count] = token;
token_count++;
token = strtok_r(NULL, ":", &saveptr);
}
if (token_count < 2) {
return print_help();
}
n_models = token_count - 2;
encoding = strdup(tokens[0]);
target = strdup(tokens[1]);
model_paths = malloc(n_models * sizeof(void*));
for (int i = 0; i < n_models; i++) {
model_paths[i] = strdup(tokens[i + 2]);
}
if (token)
free(token);
}
#endif
else {
#if WASM_ENABLE_LIBC_WASI != 0
libc_wasi_parse_result_t result =
@ -974,6 +1022,11 @@ main(int argc, char *argv[])
libc_wasi_set_init_args(inst_args, argc, argv, &wasi_parse_ctx);
#endif
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
wasi_nn_graph_registry_create(&nn_registry);
wasi_nn_graph_registry_set_args(nn_registry, encoding, target, n_models, model_paths);
wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(inst_args, nn_registry);
#endif
/* instantiate the module */
wasm_module_inst = wasm_runtime_instantiate_ex2(
wasm_module, inst_args, error_buf, sizeof(error_buf));
@ -1092,6 +1145,15 @@ fail5:
#endif
#if WASM_ENABLE_DEBUG_INTERP != 0
fail4:
#endif
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
wasi_nn_graph_registry_destroy(nn_registry);
for (uint32_t i = 0; i < n_models; i++)
if (model_paths[i])
free(model_paths[i]);
free(model_paths);
free(encoding);
free(target);
#endif
/* destroy the module instance */
wasm_runtime_deinstantiate(wasm_module_inst);