mirror of
https://github.com/bytecodealliance/wasm-micro-runtime.git
synced 2026-01-20 08:16:41 +00:00
Add the way to set the target evenif we use load_by_name
This commit is contained in:
parent
2063ac1688
commit
96cdfa63ad
|
|
@ -25,6 +25,10 @@ static NativeSymbolsList g_native_symbols_list = NULL;
|
|||
static void *g_wasi_context_key;
|
||||
#endif /* WASM_ENABLE_LIBC_WASI */
|
||||
|
||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
static void *g_wasi_nn_context_key;
|
||||
#endif
|
||||
|
||||
uint32
|
||||
get_libc_builtin_export_apis(NativeSymbol **p_libc_builtin_apis);
|
||||
|
||||
|
|
@ -473,6 +477,31 @@ wasi_context_dtor(WASMModuleInstanceCommon *inst, void *ctx)
|
|||
}
|
||||
#endif /* end of WASM_ENABLE_LIBC_WASI */
|
||||
|
||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
WASINNGlobalContext *
|
||||
wasm_runtime_get_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm)
|
||||
{
|
||||
return wasm_native_get_context(module_inst_comm, g_wasi_nn_context_key);
|
||||
}
|
||||
|
||||
void
|
||||
wasm_runtime_set_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm,
|
||||
WASINNGlobalContext *wasi_nn_ctx)
|
||||
{
|
||||
wasm_native_set_context(module_inst_comm, g_wasi_nn_context_key, wasi_nn_ctx);
|
||||
}
|
||||
|
||||
static void
|
||||
wasi_nn_context_dtor(WASMModuleInstanceCommon *inst, void *ctx)
|
||||
{
|
||||
if (ctx == NULL) {
|
||||
return;
|
||||
}
|
||||
|
||||
wasm_runtime_destroy_wasi_nn_global_ctx(inst);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if WASM_ENABLE_QUICK_AOT_ENTRY != 0
|
||||
static bool
|
||||
quick_aot_entry_init(void);
|
||||
|
|
@ -582,6 +611,11 @@ wasm_native_init()
|
|||
#endif /* WASM_ENABLE_LIB_RATS */
|
||||
|
||||
#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
g_wasi_nn_context_key = wasm_native_create_context_key(wasi_nn_context_dtor);
|
||||
if (g_wasi_nn_context_key == NULL) {
|
||||
goto fail;
|
||||
}
|
||||
|
||||
if (!wasi_nn_initialize())
|
||||
goto fail;
|
||||
|
||||
|
|
@ -648,6 +682,10 @@ wasm_native_destroy()
|
|||
#endif
|
||||
|
||||
#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
if (g_wasi_nn_context_key != NULL) {
|
||||
wasm_native_destroy_context_key(g_wasi_nn_context_key);
|
||||
g_wasi_nn_context_key = NULL;
|
||||
}
|
||||
wasi_nn_destroy();
|
||||
#endif
|
||||
|
||||
|
|
|
|||
|
|
@ -1696,6 +1696,67 @@ wasm_runtime_instantiation_args_destroy(struct InstantiationArgs2 *p)
|
|||
wasm_runtime_free(p);
|
||||
}
|
||||
|
||||
#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0)
|
||||
struct wasi_nn_graph_registry;
|
||||
|
||||
void
|
||||
wasm_runtime_wasi_nn_graph_registry_args_set_defaults(struct wasi_nn_graph_registry *args)
|
||||
{
|
||||
memset(args, 0, sizeof(*args));
|
||||
}
|
||||
|
||||
bool
|
||||
wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, const char* encoding,
|
||||
const char* target, uint32_t n_graphs,
|
||||
const char** graph_paths)
|
||||
{
|
||||
if (!registry || !encoding || !target || !graph_paths)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
registry->encoding = strdup(encoding);
|
||||
registry->target = strdup(target);
|
||||
registry->n_graphs = n_graphs;
|
||||
registry->graph_paths = (uint32_t**)malloc(sizeof(uint32_t*) * n_graphs);
|
||||
memset(registry->graph_paths, 0, sizeof(uint32_t*) * n_graphs);
|
||||
for (uint32_t i = 0; i < registry->n_graphs; i++)
|
||||
registry->graph_paths[i] = strdup(graph_paths[i]);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int
|
||||
wasi_nn_graph_registry_create(struct wasi_nn_graph_registry **registryp)
|
||||
{
|
||||
struct wasi_nn_graph_registry *args = wasm_runtime_malloc(sizeof(*args));
|
||||
if (args == NULL) {
|
||||
return false;
|
||||
}
|
||||
wasm_runtime_wasi_nn_graph_registry_args_set_defaults(args);
|
||||
*registryp = args;
|
||||
return 0;
|
||||
}
|
||||
|
||||
void
|
||||
wasi_nn_graph_registry_destroy(struct wasi_nn_graph_registry *registry)
|
||||
{
|
||||
if (registry)
|
||||
{
|
||||
for (uint32_t i = 0; i < registry->n_graphs; i++)
|
||||
if (registry->graph_paths[i])
|
||||
{
|
||||
// wasi_nn_graph_registry_unregister_graph(registry, registry->name[i]);
|
||||
free(registry->graph_paths[i]);
|
||||
}
|
||||
if (registry->encoding)
|
||||
free(registry->encoding);
|
||||
if (registry->target)
|
||||
free(registry->target);
|
||||
free(registry);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void
|
||||
wasm_runtime_instantiation_args_set_default_stack_size(
|
||||
struct InstantiationArgs2 *p, uint32 v)
|
||||
|
|
@ -1794,6 +1855,14 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool(
|
|||
wasi_args->set_by_user = true;
|
||||
}
|
||||
#endif /* WASM_ENABLE_LIBC_WASI != 0 */
|
||||
#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0)
|
||||
void
|
||||
wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(
|
||||
struct InstantiationArgs2 *p, struct wasi_nn_graph_registry *registry)
|
||||
{
|
||||
p->nn_registry = *registry;
|
||||
}
|
||||
#endif
|
||||
|
||||
WASMModuleInstanceCommon *
|
||||
wasm_runtime_instantiate_ex2(WASMModuleCommon *module,
|
||||
|
|
@ -8080,3 +8149,114 @@ wasm_runtime_check_and_update_last_used_shared_heap(
|
|||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
bool
|
||||
wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst,
|
||||
const char* encoding, const char* target,
|
||||
const uint32_t n_graphs, char* graph_paths[],
|
||||
char *error_buf, uint32_t error_buf_size)
|
||||
{
|
||||
WASINNGlobalContext *ctx;
|
||||
bool ret = false;
|
||||
|
||||
ctx = runtime_malloc(sizeof(*ctx), module_inst, error_buf, error_buf_size);
|
||||
if (!ctx)
|
||||
return false;
|
||||
|
||||
ctx->encoding = strdup(encoding);
|
||||
ctx->target = strdup(target);
|
||||
ctx->n_graphs = n_graphs;
|
||||
ctx->loaded = (uint32_t*)malloc(sizeof(uint32_t) * n_graphs);
|
||||
memset(ctx->loaded, 0, sizeof(uint32_t) * n_graphs);
|
||||
|
||||
ctx->graph_paths = (uint32_t**)malloc(sizeof(uint32_t*) * n_graphs);
|
||||
memset(ctx->graph_paths, 0, sizeof(uint32_t*) * n_graphs);
|
||||
for (uint32_t i = 0; i < n_graphs; i++)
|
||||
{
|
||||
ctx->graph_paths[i] = strdup(graph_paths[i]);
|
||||
}
|
||||
|
||||
wasm_runtime_set_wasi_nn_global_ctx(module_inst, ctx);
|
||||
|
||||
ret = true;
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
void
|
||||
wasm_runtime_destroy_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst)
|
||||
{
|
||||
WASINNGlobalContext *wasi_nn_global_ctx = wasm_runtime_get_wasi_nn_global_ctx(module_inst);
|
||||
|
||||
for (uint32 i = 0; i < wasi_nn_global_ctx->n_graphs; i++)
|
||||
{
|
||||
// All graphs will be unregistered in deinit()
|
||||
if (wasi_nn_global_ctx->graph_paths[i])
|
||||
free(wasi_nn_global_ctx->graph_paths[i]);
|
||||
}
|
||||
free(wasi_nn_global_ctx->encoding);
|
||||
free(wasi_nn_global_ctx->target);
|
||||
free(wasi_nn_global_ctx->loaded);
|
||||
free(wasi_nn_global_ctx->graph_paths);
|
||||
|
||||
if (wasi_nn_global_ctx) {
|
||||
wasm_runtime_free(wasi_nn_global_ctx);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t
|
||||
wasm_runtime_get_wasi_nn_global_ctx_ngraphs(WASINNGlobalContext *wasi_nn_global_ctx)
|
||||
{
|
||||
if (wasi_nn_global_ctx)
|
||||
return wasi_nn_global_ctx->n_graphs;
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
char *
|
||||
wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx)
|
||||
{
|
||||
if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs))
|
||||
return wasi_nn_global_ctx->graph_paths[idx];
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
uint32_t
|
||||
wasm_runtime_get_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx)
|
||||
{
|
||||
if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs))
|
||||
return wasi_nn_global_ctx->loaded[idx];
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
uint32_t
|
||||
wasm_runtime_set_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value)
|
||||
{
|
||||
if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs))
|
||||
wasi_nn_global_ctx->loaded[idx] = value;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
char*
|
||||
wasm_runtime_get_wasi_nn_global_ctx_encoding(WASINNGlobalContext *wasi_nn_global_ctx)
|
||||
{
|
||||
if (wasi_nn_global_ctx)
|
||||
return wasi_nn_global_ctx->encoding;
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
char*
|
||||
wasm_runtime_get_wasi_nn_global_ctx_target(WASINNGlobalContext *wasi_nn_global_ctx)
|
||||
{
|
||||
if (wasi_nn_global_ctx)
|
||||
return wasi_nn_global_ctx->target;
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -545,6 +545,17 @@ typedef struct WASMModuleInstMemConsumption {
|
|||
uint32 exports_size;
|
||||
} WASMModuleInstMemConsumption;
|
||||
|
||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
typedef struct WASINNGlobalContext {
|
||||
char* encoding;
|
||||
char* target;
|
||||
|
||||
uint32_t n_graphs;
|
||||
uint32_t *loaded;
|
||||
char** graph_paths;
|
||||
} WASINNGlobalContext;
|
||||
#endif
|
||||
|
||||
#if WASM_ENABLE_LIBC_WASI != 0
|
||||
#if WASM_ENABLE_UVWASI == 0
|
||||
typedef struct WASIContext {
|
||||
|
|
@ -612,11 +623,30 @@ WASMExecEnv *
|
|||
wasm_runtime_get_exec_env_tls(void);
|
||||
#endif
|
||||
|
||||
#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0)
|
||||
struct wasi_nn_graph_registry {
|
||||
char* encoding;
|
||||
char* target;
|
||||
|
||||
char** graph_paths;
|
||||
uint32_t n_graphs;
|
||||
};
|
||||
|
||||
WASM_RUNTIME_API_EXTERN int
|
||||
wasi_nn_graph_registry_create(struct wasi_nn_graph_registry **registryp);
|
||||
|
||||
WASM_RUNTIME_API_EXTERN void
|
||||
wasi_nn_graph_registry_destroy(struct wasi_nn_graph_registry *registry);
|
||||
#endif
|
||||
|
||||
struct InstantiationArgs2 {
|
||||
InstantiationArgs v1;
|
||||
#if WASM_ENABLE_LIBC_WASI != 0
|
||||
WASIArguments wasi;
|
||||
#endif
|
||||
#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0)
|
||||
struct wasi_nn_graph_registry nn_registry;
|
||||
#endif
|
||||
};
|
||||
|
||||
void
|
||||
|
|
@ -775,6 +805,17 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool(
|
|||
struct InstantiationArgs2 *p, const char *ns_lookup_pool[],
|
||||
uint32 ns_lookup_pool_size);
|
||||
|
||||
#if (WASM_ENABLE_WASI_EPHEMERAL_NN != 0)
|
||||
WASM_RUNTIME_API_EXTERN void
|
||||
wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(
|
||||
struct InstantiationArgs2 *p, struct wasi_nn_graph_registry *registry);
|
||||
|
||||
WASM_RUNTIME_API_EXTERN bool
|
||||
wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, const char* encoding,
|
||||
const char* target, uint32_t n_graphs,
|
||||
const char** graph_paths);
|
||||
#endif
|
||||
|
||||
/* See wasm_export.h for description */
|
||||
WASM_RUNTIME_API_EXTERN WASMModuleInstanceCommon *
|
||||
wasm_runtime_instantiate_ex2(WASMModuleCommon *module,
|
||||
|
|
@ -1427,6 +1468,39 @@ wasm_runtime_check_and_update_last_used_shared_heap(
|
|||
uint8 **shared_heap_base_addr_adj_p, bool is_memory64);
|
||||
#endif
|
||||
|
||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
WASM_RUNTIME_API_EXTERN bool
|
||||
wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst,
|
||||
const char* encoding, const char* target,
|
||||
const uint32_t n_graphs, char* graph_paths[],
|
||||
char *error_buf, uint32_t error_buf_size);
|
||||
|
||||
WASM_RUNTIME_API_EXTERN void
|
||||
wasm_runtime_destroy_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst);
|
||||
|
||||
WASM_RUNTIME_API_EXTERN void
|
||||
wasm_runtime_set_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst,
|
||||
WASINNGlobalContext *wasi_ctx);
|
||||
|
||||
WASM_RUNTIME_API_EXTERN uint32_t
|
||||
wasm_runtime_get_wasi_nn_global_ctx_ngraphs(WASINNGlobalContext *wasi_nn_global_ctx);
|
||||
|
||||
WASM_RUNTIME_API_EXTERN char *
|
||||
wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx);
|
||||
|
||||
WASM_RUNTIME_API_EXTERN uint32_t
|
||||
wasm_runtime_get_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx);
|
||||
|
||||
WASM_RUNTIME_API_EXTERN uint32_t
|
||||
wasm_runtime_set_wasi_nn_global_ctx_loaded_i(WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value);
|
||||
|
||||
WASM_RUNTIME_API_EXTERN char*
|
||||
wasm_runtime_get_wasi_nn_global_ctx_encoding(WASINNGlobalContext *wasi_nn_global_ctx);
|
||||
|
||||
WASM_RUNTIME_API_EXTERN char*
|
||||
wasm_runtime_get_wasi_nn_global_ctx_target(WASINNGlobalContext *wasi_nn_global_ctx);
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -290,6 +290,8 @@ typedef struct InstantiationArgs {
|
|||
#endif /* INSTANTIATION_ARGS_OPTION_DEFINED */
|
||||
|
||||
struct InstantiationArgs2;
|
||||
struct WASINNGlobalContext;
|
||||
typedef struct WASINNGlobalContext *wasi_nn_global_context;
|
||||
|
||||
#ifndef WASM_VALKIND_T_DEFINED
|
||||
#define WASM_VALKIND_T_DEFINED
|
||||
|
|
@ -796,6 +798,55 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool(
|
|||
struct InstantiationArgs2 *p, const char *ns_lookup_pool[],
|
||||
uint32_t ns_lookup_pool_size);
|
||||
|
||||
// WASM_RUNTIME_API_EXTERN int
|
||||
// wasi_nn_graph_registry_create(struct wasi_nn_graph_registry **registryp);
|
||||
|
||||
// WASM_RUNTIME_API_EXTERN void
|
||||
// wasi_nn_graph_registry_destroy(struct wasi_nn_graph_registry *registry);
|
||||
|
||||
// WASM_RUNTIME_API_EXTERN void
|
||||
// wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(
|
||||
// struct InstantiationArgs2 *p, struct wasi_nn_graph_registry *registry);
|
||||
|
||||
// WASM_RUNTIME_API_EXTERN bool
|
||||
// wasi_nn_graph_registry_set_args(struct wasi_nn_graph_registry *registry, const char* encoding,
|
||||
// const char* target, uint32_t n_graphs,
|
||||
// const char** graph_paths);
|
||||
|
||||
WASM_RUNTIME_API_EXTERN bool
|
||||
wasm_runtime_init_wasi_nn_global_ctx(wasm_module_inst_t module_inst,
|
||||
const char* encoding, const char* target,
|
||||
const uint32_t n_graphs, char* graph_paths[],
|
||||
char *error_buf, uint32_t error_buf_size);
|
||||
|
||||
WASM_RUNTIME_API_EXTERN void
|
||||
wasm_runtime_destroy_wasi_nn_global_ctx(wasm_module_inst_t module_inst);
|
||||
|
||||
WASM_RUNTIME_API_EXTERN void
|
||||
wasm_runtime_set_wasi_nn_global_ctx(wasm_module_inst_t module_inst,
|
||||
wasi_nn_global_context wasi_ctx);
|
||||
|
||||
WASM_RUNTIME_API_EXTERN wasi_nn_global_context
|
||||
wasm_runtime_get_wasi_nn_global_ctx(const wasm_module_inst_t module_inst);
|
||||
|
||||
WASM_RUNTIME_API_EXTERN uint32_t
|
||||
wasm_runtime_get_wasi_nn_global_ctx_ngraphs(wasi_nn_global_context wasi_nn_global_ctx);
|
||||
|
||||
WASM_RUNTIME_API_EXTERN char *
|
||||
wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(wasi_nn_global_context wasi_nn_global_ctx, uint32_t idx);
|
||||
|
||||
WASM_RUNTIME_API_EXTERN uint32_t
|
||||
wasm_runtime_get_wasi_nn_global_ctx_loaded_i(wasi_nn_global_context wasi_nn_global_ctx, uint32_t idx);
|
||||
|
||||
WASM_RUNTIME_API_EXTERN uint32_t
|
||||
wasm_runtime_set_wasi_nn_global_ctx_loaded_i(wasi_nn_global_context wasi_nn_global_ctx, uint32_t idx, uint32_t value);
|
||||
|
||||
WASM_RUNTIME_API_EXTERN char*
|
||||
wasm_runtime_get_wasi_nn_global_ctx_encoding(wasi_nn_global_context wasi_nn_global_ctx);
|
||||
|
||||
WASM_RUNTIME_API_EXTERN char*
|
||||
wasm_runtime_get_wasi_nn_global_ctx_target(wasi_nn_global_context wasi_nn_global_ctx);
|
||||
|
||||
/**
|
||||
* Instantiate a WASM module, with specified instantiation arguments
|
||||
*
|
||||
|
|
|
|||
|
|
@ -3300,6 +3300,18 @@ wasm_instantiate(WASMModule *module, WASMModuleInstance *parent,
|
|||
}
|
||||
#endif
|
||||
|
||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
/* Store graphs' path into ctx. Graphs will be loaded until user app calls load_by_name */
|
||||
// Do not consider load() for now
|
||||
struct wasi_nn_graph_registry *nn_registry = &args->nn_registry;
|
||||
if (!wasm_runtime_init_wasi_nn_global_ctx(
|
||||
(WASMModuleInstanceCommon *)module_inst, nn_registry->encoding,
|
||||
nn_registry->target, nn_registry->n_graphs, nn_registry->graph_paths,
|
||||
error_buf, error_buf_size)) {
|
||||
goto fail;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if WASM_ENABLE_DEBUG_INTERP != 0
|
||||
if (!is_sub_inst) {
|
||||
/* Add module instance into module's instance list */
|
||||
|
|
|
|||
|
|
@ -8,5 +8,5 @@
|
|||
|
||||
#include "wasi_nn.h"
|
||||
|
||||
#undef WASM_ENABLE_WASI_EPHEMERAL_NN
|
||||
#undef WASI_NN_NAME
|
||||
// #undef WASM_ENABLE_WASI_EPHEMERAL_NN
|
||||
// #undef WASI_NN_NAME
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@
|
|||
#else
|
||||
#define WASI_NN_IMPORT(name) \
|
||||
__attribute__((import_module("wasi_nn"), import_name(name)))
|
||||
#warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It's deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.)
|
||||
#warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It is deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.)
|
||||
#endif
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ extern "C" {
|
|||
#define WASI_NN_TYPE_NAME(name) WASI_NN_NAME(type_##name)
|
||||
#define WASI_NN_ENCODING_NAME(name) WASI_NN_NAME(encoding_##name)
|
||||
#define WASI_NN_TARGET_NAME(name) WASI_NN_NAME(target_##name)
|
||||
#define WASI_NN_ERROR_TYPE WASI_NN_NAME(error);
|
||||
#define WASI_NN_ERROR_TYPE WASI_NN_NAME(error)
|
||||
#endif
|
||||
|
||||
/**
|
||||
|
|
@ -169,6 +169,7 @@ typedef enum WASI_NN_NAME(execution_target) {
|
|||
WASI_NN_TARGET_NAME(cpu) = 0,
|
||||
WASI_NN_TARGET_NAME(gpu),
|
||||
WASI_NN_TARGET_NAME(tpu),
|
||||
WASI_NN_TARGET_NAME(unsupported_target),
|
||||
} WASI_NN_NAME(execution_target);
|
||||
|
||||
// Bind a `graph` to the input and output tensors for an inference.
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@
|
|||
#include "wasm_export.h"
|
||||
|
||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN == 0
|
||||
#warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It's deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.)
|
||||
#warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It is deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.)
|
||||
#endif
|
||||
|
||||
#define HASHMAP_INITIAL_SIZE 20
|
||||
|
|
@ -35,6 +35,8 @@
|
|||
#define LLAMACPP_BACKEND_LIB "libwasi_nn_llamacpp" LIB_EXTENTION
|
||||
#define ONNX_BACKEND_LIB "libwasi_nn_onnx" LIB_EXTENTION
|
||||
|
||||
#define MAX_GLOBAL_GRAPHS_PER_INST 4 // ONNX only allows 4 graphs per instances
|
||||
|
||||
/* Global variables */
|
||||
static korp_mutex wasi_nn_lock;
|
||||
/*
|
||||
|
|
@ -208,6 +210,44 @@ wasi_nn_destroy()
|
|||
* - model file format
|
||||
* - on device ML framework
|
||||
*/
|
||||
static graph_encoding str2encoding(char* str_encoding)
|
||||
{
|
||||
if (!str_encoding) {
|
||||
NN_ERR_PRINTF("Got empty string encoding");
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!strcmp(str_encoding, "openvino"))
|
||||
return openvino;
|
||||
else if (!strcmp(str_encoding, "tensorflowlite"))
|
||||
return tensorflowlite;
|
||||
else if (!strcmp(str_encoding, "ggml"))
|
||||
return ggml;
|
||||
else if (!strcmp(str_encoding, "onnx"))
|
||||
return onnx;
|
||||
else
|
||||
return unknown_backend;
|
||||
// return autodetect;
|
||||
}
|
||||
|
||||
static execution_target str2target(char* str_target)
|
||||
{
|
||||
if (!str_target) {
|
||||
NN_ERR_PRINTF("Got empty string target");
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!strcmp(str_target, "cpu"))
|
||||
return cpu;
|
||||
else if (!strcmp(str_target, "gpu"))
|
||||
return gpu;
|
||||
else if (!strcmp(str_target, "tpu"))
|
||||
return tpu;
|
||||
else
|
||||
return unsupported_target;
|
||||
// return autodetect;
|
||||
}
|
||||
|
||||
static graph_encoding
|
||||
choose_a_backend()
|
||||
{
|
||||
|
|
@ -565,17 +605,82 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
|
|||
goto fail;
|
||||
}
|
||||
|
||||
res = ensure_backend(instance, autodetect, wasi_nn_ctx);
|
||||
wasi_nn_global_context wasi_nn_global_ctx = wasm_runtime_get_wasi_nn_global_ctx(instance);
|
||||
if (!wasi_nn_global_ctx) {
|
||||
NN_ERR_PRINTF("global context is invalid");
|
||||
res = not_found;
|
||||
goto fail;
|
||||
}
|
||||
graph_encoding encoding = str2encoding(wasm_runtime_get_wasi_nn_global_ctx_encoding(wasi_nn_global_ctx));
|
||||
execution_target target = str2target(wasm_runtime_get_wasi_nn_global_ctx_target(wasi_nn_global_ctx));
|
||||
|
||||
// res = ensure_backend(instance, autodetect, wasi_nn_ctx);
|
||||
res = ensure_backend(instance, encoding, wasi_nn_ctx);
|
||||
if (res != success)
|
||||
goto fail;
|
||||
|
||||
bool is_loaded = false;
|
||||
uint32 model_idx = 0;
|
||||
char *global_model_path_i;
|
||||
// Assume filename got from user wasm app : max; sum; average; ...
|
||||
// Assume file path got from user cmd opt: /your/path1/max.tflite; /your/path2/sum.tflite; ......
|
||||
for (model_idx = 0; model_idx < wasm_runtime_get_wasi_nn_global_ctx_ngraphs(wasi_nn_global_ctx); model_idx++)
|
||||
{
|
||||
// Extract filename from file path
|
||||
global_model_path_i = wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(wasi_nn_global_ctx, model_idx);
|
||||
char *model_file_name;
|
||||
const char *slash = strrchr(global_model_path_i, '/');
|
||||
if (slash != NULL) {
|
||||
model_file_name = (char*)(slash + 1);
|
||||
}
|
||||
else
|
||||
model_file_name = global_model_path_i;
|
||||
|
||||
call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res,
|
||||
wasi_nn_ctx->backend_ctx, nul_terminated_name, name_len,
|
||||
g);
|
||||
if (res != success)
|
||||
// Extract modelname from filename
|
||||
char* model_name = NULL;
|
||||
size_t model_name_len = 0;
|
||||
char* dot = strrchr(model_file_name, '.');
|
||||
if (dot)
|
||||
{
|
||||
model_name_len = dot - model_file_name;
|
||||
model_name = malloc(model_name_len + 1);
|
||||
strncpy(model_name, model_file_name, model_name_len);
|
||||
model_name[model_name_len] = '\0';
|
||||
}
|
||||
|
||||
if (model_name && strcmp(nul_terminated_name, model_name) == 0) {
|
||||
is_loaded = wasm_runtime_get_wasi_nn_global_ctx_loaded_i(wasi_nn_global_ctx, model_idx);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_loaded && (model_idx < MAX_GLOBAL_GRAPHS_PER_INST))
|
||||
{
|
||||
NN_DBG_PRINTF("Model is not yet loaded, will add to global context");
|
||||
call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res,
|
||||
wasi_nn_ctx->backend_ctx, global_model_path_i, strlen(global_model_path_i),
|
||||
encoding, target, g);
|
||||
if (res != success)
|
||||
goto fail;
|
||||
|
||||
wasm_runtime_set_wasi_nn_global_ctx_loaded_i(wasi_nn_global_ctx, model_idx, 1);
|
||||
res = success;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (is_loaded)
|
||||
{
|
||||
NN_DBG_PRINTF("Model is already loaded");
|
||||
res = success;
|
||||
}
|
||||
else if (model_idx >= MAX_GLOBAL_GRAPHS_PER_INST)
|
||||
{
|
||||
// No enlarge for now
|
||||
NN_ERR_PRINTF("No enough space for new model");
|
||||
res = too_large;
|
||||
}
|
||||
goto fail;
|
||||
|
||||
res = success;
|
||||
}
|
||||
fail:
|
||||
if (nul_terminated_name != NULL) {
|
||||
wasm_runtime_free(nul_terminated_name);
|
||||
|
|
|
|||
|
|
@ -17,7 +17,8 @@ load(void *ctx, graph_builder_array *builder, graph_encoding encoding,
|
|||
execution_target target, graph *g);
|
||||
|
||||
__attribute__((visibility("default"))) wasi_nn_error
|
||||
load_by_name(void *tflite_ctx, const char *name, uint32_t namelen, graph *g);
|
||||
load_by_name(void *tflite_ctx, const char *name, uint32_t namelen,
|
||||
graph_encoding encoding, execution_target target, graph *g);
|
||||
|
||||
__attribute__((visibility("default"))) wasi_nn_error
|
||||
load_by_name_with_config(void *ctx, const char *name, uint32_t namelen,
|
||||
|
|
|
|||
|
|
@ -338,7 +338,8 @@ __load_by_name_with_configuration(void *ctx, const char *filename, graph *g)
|
|||
}
|
||||
|
||||
__attribute__((visibility("default"))) wasi_nn_error
|
||||
load_by_name(void *ctx, const char *filename, uint32_t filename_len, graph *g)
|
||||
load_by_name(void *ctx, const char *filename, uint32_t filename_len,
|
||||
graph_encoding encoding, execution_target target, graph *g)
|
||||
{
|
||||
struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
|
||||
|
||||
|
|
|
|||
|
|
@ -334,7 +334,8 @@ load(void *onnx_ctx, graph_builder_array *builder, graph_encoding encoding,
|
|||
}
|
||||
|
||||
__attribute__((visibility("default"))) wasi_nn_error
|
||||
load_by_name(void *onnx_ctx, const char *name, uint32_t filename_len, graph *g)
|
||||
load_by_name(void *onnx_ctx, const char *name, uint32_t filename_len,
|
||||
graph_encoding encoding, execution_target target, graph *g)
|
||||
{
|
||||
if (!onnx_ctx) {
|
||||
return runtime_error;
|
||||
|
|
|
|||
|
|
@ -306,7 +306,8 @@ fail:
|
|||
}
|
||||
|
||||
__attribute__((visibility("default"))) wasi_nn_error
|
||||
load_by_name(void *ctx, const char *filename, uint32_t filename_len, graph *g)
|
||||
load_by_name(void *ctx, const char *filename, uint32_t filename_len,
|
||||
graph_encoding encoding, execution_target target, graph *g)
|
||||
{
|
||||
OpenVINOContext *ov_ctx = (OpenVINOContext *)ctx;
|
||||
struct OpenVINOGraph *graph;
|
||||
|
|
|
|||
|
|
@ -21,7 +21,8 @@ typedef struct {
|
|||
|
||||
typedef wasi_nn_error (*LOAD)(void *, graph_builder_array *, graph_encoding,
|
||||
execution_target, graph *);
|
||||
typedef wasi_nn_error (*LOAD_BY_NAME)(void *, const char *, uint32_t, graph *);
|
||||
typedef wasi_nn_error (*LOAD_BY_NAME)(void *, const char *, uint32_t, graph_encoding,
|
||||
execution_target, graph *);
|
||||
typedef wasi_nn_error (*LOAD_BY_NAME_WITH_CONFIG)(void *, const char *,
|
||||
uint32_t, void *, uint32_t,
|
||||
graph *);
|
||||
|
|
|
|||
|
|
@ -164,8 +164,8 @@ load(void *tflite_ctx, graph_builder_array *builder, graph_encoding encoding,
|
|||
}
|
||||
|
||||
__attribute__((visibility("default"))) wasi_nn_error
|
||||
load_by_name(void *tflite_ctx, const char *filename, uint32_t filename_len,
|
||||
graph *g)
|
||||
load_by_name(void *tflite_ctx, const char *filename, uint32_t filename_len,
|
||||
graph_encoding encoding, execution_target target,graph *g)
|
||||
{
|
||||
TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx;
|
||||
|
||||
|
|
@ -183,7 +183,7 @@ load_by_name(void *tflite_ctx, const char *filename, uint32_t filename_len,
|
|||
}
|
||||
|
||||
// Use CPU as default
|
||||
tfl_ctx->models[*g].target = cpu;
|
||||
tfl_ctx->models[*g].target = target;
|
||||
return success;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
tensorflow==2.12.1
|
||||
tensorflow==2.14.0
|
||||
numpy==1.24.4
|
||||
|
|
|
|||
|
|
@ -13,16 +13,16 @@
|
|||
#include "logger.h"
|
||||
|
||||
void
|
||||
test_sum(execution_target target)
|
||||
test_sum()
|
||||
{
|
||||
int dims[] = { 1, 5, 5, 1 };
|
||||
input_info input = create_input(dims);
|
||||
|
||||
uint32_t output_size = 0;
|
||||
float *output = run_inference(target, input.input_tensor, input.dim,
|
||||
&output_size, "./models/sum.tflite", 1);
|
||||
float *output = run_inference(input.input_tensor, input.dim,
|
||||
&output_size, "sum", 1);
|
||||
|
||||
assert(output_size == 1);
|
||||
assert((output_size / sizeof(float)) == 1);
|
||||
assert(fabs(output[0] - 300.0) < EPSILON);
|
||||
|
||||
free(input.dim);
|
||||
|
|
@ -31,16 +31,16 @@ test_sum(execution_target target)
|
|||
}
|
||||
|
||||
void
|
||||
test_max(execution_target target)
|
||||
test_max()
|
||||
{
|
||||
int dims[] = { 1, 5, 5, 1 };
|
||||
input_info input = create_input(dims);
|
||||
|
||||
uint32_t output_size = 0;
|
||||
float *output = run_inference(target, input.input_tensor, input.dim,
|
||||
&output_size, "./models/max.tflite", 1);
|
||||
float *output = run_inference(input.input_tensor, input.dim,
|
||||
&output_size, "max", 1);
|
||||
|
||||
assert(output_size == 1);
|
||||
assert((output_size / sizeof(float)) == 1);
|
||||
assert(fabs(output[0] - 24.0) < EPSILON);
|
||||
NN_INFO_PRINTF("Result: max is %f", output[0]);
|
||||
|
||||
|
|
@ -50,16 +50,16 @@ test_max(execution_target target)
|
|||
}
|
||||
|
||||
void
|
||||
test_average(execution_target target)
|
||||
test_average()
|
||||
{
|
||||
int dims[] = { 1, 5, 5, 1 };
|
||||
input_info input = create_input(dims);
|
||||
|
||||
uint32_t output_size = 0;
|
||||
float *output = run_inference(target, input.input_tensor, input.dim,
|
||||
&output_size, "./models/average.tflite", 1);
|
||||
float *output = run_inference(input.input_tensor, input.dim,
|
||||
&output_size, "average", 1);
|
||||
|
||||
assert(output_size == 1);
|
||||
assert((output_size / sizeof(float)) == 1);
|
||||
assert(fabs(output[0] - 12.0) < EPSILON);
|
||||
NN_INFO_PRINTF("Result: average is %f", output[0]);
|
||||
|
||||
|
|
@ -69,16 +69,16 @@ test_average(execution_target target)
|
|||
}
|
||||
|
||||
void
|
||||
test_mult_dimensions(execution_target target)
|
||||
test_mult_dimensions()
|
||||
{
|
||||
int dims[] = { 1, 3, 3, 1 };
|
||||
input_info input = create_input(dims);
|
||||
|
||||
uint32_t output_size = 0;
|
||||
float *output = run_inference(target, input.input_tensor, input.dim,
|
||||
&output_size, "./models/mult_dim.tflite", 1);
|
||||
float *output = run_inference(input.input_tensor, input.dim,
|
||||
&output_size, "mult_dim", 1);
|
||||
|
||||
assert(output_size == 9);
|
||||
assert((output_size / sizeof(float)) == 9);
|
||||
for (int i = 0; i < 9; i++)
|
||||
assert(fabs(output[i] - i) < EPSILON);
|
||||
|
||||
|
|
@ -88,16 +88,16 @@ test_mult_dimensions(execution_target target)
|
|||
}
|
||||
|
||||
void
|
||||
test_mult_outputs(execution_target target)
|
||||
test_mult_outputs()
|
||||
{
|
||||
int dims[] = { 1, 4, 4, 1 };
|
||||
input_info input = create_input(dims);
|
||||
|
||||
uint32_t output_size = 0;
|
||||
float *output = run_inference(target, input.input_tensor, input.dim,
|
||||
&output_size, "./models/mult_out.tflite", 2);
|
||||
float *output = run_inference(input.input_tensor, input.dim,
|
||||
&output_size, "mult_out", 2);
|
||||
|
||||
assert(output_size == 8);
|
||||
assert((output_size / sizeof(float)) == 8);
|
||||
// first tensor check
|
||||
for (int i = 0; i < 4; i++)
|
||||
assert(fabs(output[i] - (i * 4 + 24)) < EPSILON);
|
||||
|
|
@ -113,30 +113,18 @@ test_mult_outputs(execution_target target)
|
|||
int
|
||||
main()
|
||||
{
|
||||
char *env = getenv("TARGET");
|
||||
if (env == NULL) {
|
||||
NN_INFO_PRINTF("Usage:\n--env=\"TARGET=[cpu|gpu]\"");
|
||||
return 1;
|
||||
}
|
||||
execution_target target;
|
||||
if (strcmp(env, "cpu") == 0)
|
||||
target = cpu;
|
||||
else if (strcmp(env, "gpu") == 0)
|
||||
target = gpu;
|
||||
else {
|
||||
NN_ERR_PRINTF("Wrong target!");
|
||||
return 1;
|
||||
}
|
||||
NN_INFO_PRINTF("Usage:\niwasm --native-lib=./libwasi_nn_tflite.so --wasi-nn-graph=encoding:target:model_path1:model_path2:...:model_pathn test_tensorflow.wasm\"");
|
||||
|
||||
NN_INFO_PRINTF("################### Testing sum...");
|
||||
test_sum(target);
|
||||
test_sum();
|
||||
NN_INFO_PRINTF("################### Testing max...");
|
||||
test_max(target);
|
||||
test_max();
|
||||
NN_INFO_PRINTF("################### Testing average...");
|
||||
test_average(target);
|
||||
test_average();
|
||||
NN_INFO_PRINTF("################### Testing multiple dimensions...");
|
||||
test_mult_dimensions(target);
|
||||
test_mult_dimensions();
|
||||
NN_INFO_PRINTF("################### Testing multiple outputs...");
|
||||
test_mult_outputs(target);
|
||||
test_mult_outputs();
|
||||
|
||||
NN_INFO_PRINTF("Tests: passed!");
|
||||
return 0;
|
||||
|
|
|
|||
|
|
@ -16,15 +16,15 @@
|
|||
#define EPSILON 1e-2
|
||||
|
||||
void
|
||||
test_average_quantized(execution_target target)
|
||||
test_average_quantized()
|
||||
{
|
||||
int dims[] = { 1, 5, 5, 1 };
|
||||
input_info input = create_input(dims);
|
||||
|
||||
uint32_t output_size = 0;
|
||||
float *output =
|
||||
run_inference(target, input.input_tensor, input.dim, &output_size,
|
||||
"./models/quantized_model.tflite", 1);
|
||||
run_inference(input.input_tensor, input.dim, &output_size,
|
||||
"quantized_model", 1);
|
||||
|
||||
NN_INFO_PRINTF("Output size: %d", output_size);
|
||||
NN_INFO_PRINTF("Result: average is %f", output[0]);
|
||||
|
|
@ -39,24 +39,10 @@ test_average_quantized(execution_target target)
|
|||
int
|
||||
main()
|
||||
{
|
||||
char *env = getenv("TARGET");
|
||||
if (env == NULL) {
|
||||
NN_INFO_PRINTF("Usage:\n--env=\"TARGET=[cpu|gpu|tpu]\"");
|
||||
return 1;
|
||||
}
|
||||
execution_target target;
|
||||
if (strcmp(env, "cpu") == 0)
|
||||
target = cpu;
|
||||
else if (strcmp(env, "gpu") == 0)
|
||||
target = gpu;
|
||||
else if (strcmp(env, "tpu") == 0)
|
||||
target = tpu;
|
||||
else {
|
||||
NN_ERR_PRINTF("Wrong target!");
|
||||
return 1;
|
||||
}
|
||||
NN_INFO_PRINTF("Usage:\niwasm --native-lib=./libwasi_nn_tflite.so --wasi-nn-graph=encoding:target:model_path1:model_path2:...:model_pathn test_tensorflow.wasm\"");
|
||||
|
||||
NN_INFO_PRINTF("################### Testing quantized model...");
|
||||
test_average_quantized(target);
|
||||
test_average_quantized();
|
||||
|
||||
NN_INFO_PRINTF("Tests: passed!");
|
||||
return 0;
|
||||
|
|
|
|||
|
|
@ -5,17 +5,15 @@
|
|||
|
||||
#include "utils.h"
|
||||
#include "logger.h"
|
||||
#include "wasi_nn.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
wasi_nn_error
|
||||
wasm_load(char *model_name, graph *g, execution_target target)
|
||||
WASI_NN_ERROR_TYPE
|
||||
wasm_load(char *model_name, WASI_NN_NAME(graph) *g, WASI_NN_NAME(execution_target) target)
|
||||
{
|
||||
FILE *pFile = fopen(model_name, "r");
|
||||
if (pFile == NULL)
|
||||
return invalid_argument;
|
||||
return WASI_NN_ERROR_NAME(invalid_argument);
|
||||
|
||||
uint8_t *buffer;
|
||||
size_t result;
|
||||
|
|
@ -24,20 +22,29 @@ wasm_load(char *model_name, graph *g, execution_target target)
|
|||
buffer = (uint8_t *)malloc(sizeof(uint8_t) * MAX_MODEL_SIZE);
|
||||
if (buffer == NULL) {
|
||||
fclose(pFile);
|
||||
return too_large;
|
||||
return WASI_NN_ERROR_NAME(too_large);
|
||||
}
|
||||
|
||||
result = fread(buffer, 1, MAX_MODEL_SIZE, pFile);
|
||||
if (result <= 0) {
|
||||
fclose(pFile);
|
||||
free(buffer);
|
||||
return too_large;
|
||||
return WASI_NN_ERROR_NAME(too_large);
|
||||
}
|
||||
|
||||
graph_builder_array arr;
|
||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
WASI_NN_NAME(graph_builder) arr;
|
||||
|
||||
arr.buf = buffer;
|
||||
arr.size = result;
|
||||
|
||||
WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load)(&arr, result, WASI_NN_ENCODING_NAME(tensorflowlite), target, g);
|
||||
// WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load)(&arr, 1, WASI_NN_ENCODING_NAME(tensorflowlite), target, g);
|
||||
#else
|
||||
WASI_NN_NAME(graph_builder_array) arr;
|
||||
|
||||
arr.size = 1;
|
||||
arr.buf = (graph_builder *)malloc(sizeof(graph_builder));
|
||||
arr.buf = (WASI_NN_NAME(graph_builder) *)malloc(sizeof(WASI_NN_NAME(graph_builder)));
|
||||
if (arr.buf == NULL) {
|
||||
fclose(pFile);
|
||||
free(buffer);
|
||||
|
|
@ -47,7 +54,8 @@ wasm_load(char *model_name, graph *g, execution_target target)
|
|||
arr.buf[0].size = result;
|
||||
arr.buf[0].buf = buffer;
|
||||
|
||||
wasi_nn_error res = load(&arr, tensorflowlite, target, g);
|
||||
WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load)(&arr, WASI_NN_ENCODING_NAME(tensorflowlite), target, g);
|
||||
#endif
|
||||
|
||||
fclose(pFile);
|
||||
free(buffer);
|
||||
|
|
@ -55,77 +63,97 @@ wasm_load(char *model_name, graph *g, execution_target target)
|
|||
return res;
|
||||
}
|
||||
|
||||
wasi_nn_error
|
||||
wasm_load_by_name(const char *model_name, graph *g)
|
||||
WASI_NN_ERROR_TYPE
|
||||
wasm_load_by_name(const char *model_name, WASI_NN_NAME(graph) *g)
|
||||
{
|
||||
wasi_nn_error res = load_by_name(model_name, strlen(model_name), g);
|
||||
WASI_NN_ERROR_TYPE res = WASI_NN_NAME(load_by_name)(model_name, strlen(model_name), g);
|
||||
return res;
|
||||
}
|
||||
|
||||
wasi_nn_error
|
||||
wasm_init_execution_context(graph g, graph_execution_context *ctx)
|
||||
WASI_NN_ERROR_TYPE
|
||||
wasm_init_execution_context(WASI_NN_NAME(graph) g, WASI_NN_NAME(graph_execution_context) *ctx)
|
||||
{
|
||||
return init_execution_context(g, ctx);
|
||||
return WASI_NN_NAME(init_execution_context)(g, ctx);
|
||||
}
|
||||
|
||||
wasi_nn_error
|
||||
wasm_set_input(graph_execution_context ctx, float *input_tensor, uint32_t *dim)
|
||||
WASI_NN_ERROR_TYPE
|
||||
wasm_set_input(WASI_NN_NAME(graph_execution_context) ctx, float *input_tensor, uint32_t *dim)
|
||||
{
|
||||
tensor_dimensions dims;
|
||||
WASI_NN_NAME(tensor_dimensions) dims;
|
||||
dims.size = INPUT_TENSOR_DIMS;
|
||||
dims.buf = (uint32_t *)malloc(dims.size * sizeof(uint32_t));
|
||||
if (dims.buf == NULL)
|
||||
return too_large;
|
||||
return WASI_NN_ERROR_NAME(too_large);
|
||||
|
||||
tensor tensor;
|
||||
WASI_NN_NAME(tensor) tensor;
|
||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
tensor.dimensions = dims;
|
||||
for (int i = 0; i < tensor.dimensions.size; ++i)
|
||||
tensor.dimensions.buf[i] = dim[i];
|
||||
tensor.type = WASI_NN_TYPE_NAME(fp32);
|
||||
tensor.data.buf = (uint8_t *)input_tensor;
|
||||
|
||||
uint32_t tmp_size = 1;
|
||||
if (dim)
|
||||
for (int i = 0; i < INPUT_TENSOR_DIMS; ++i)
|
||||
tmp_size *= dim[i];
|
||||
|
||||
tensor.data.size = (tmp_size * sizeof(float));
|
||||
#else
|
||||
tensor.dimensions = &dims;
|
||||
for (int i = 0; i < tensor.dimensions->size; ++i)
|
||||
tensor.dimensions->buf[i] = dim[i];
|
||||
tensor.type = fp32;
|
||||
tensor.type = WASI_NN_TYPE_NAME(fp32);
|
||||
tensor.data = (uint8_t *)input_tensor;
|
||||
wasi_nn_error err = set_input(ctx, 0, &tensor);
|
||||
#endif
|
||||
|
||||
WASI_NN_ERROR_TYPE err = WASI_NN_NAME(set_input)(ctx, 0, &tensor);
|
||||
|
||||
free(dims.buf);
|
||||
return err;
|
||||
}
|
||||
|
||||
wasi_nn_error
|
||||
wasm_compute(graph_execution_context ctx)
|
||||
WASI_NN_ERROR_TYPE
|
||||
wasm_compute(WASI_NN_NAME(graph_execution_context) ctx)
|
||||
{
|
||||
return compute(ctx);
|
||||
return WASI_NN_NAME(compute)(ctx);
|
||||
}
|
||||
|
||||
wasi_nn_error
|
||||
wasm_get_output(graph_execution_context ctx, uint32_t index, float *out_tensor,
|
||||
WASI_NN_ERROR_TYPE
|
||||
wasm_get_output(WASI_NN_NAME(graph_execution_context) ctx, uint32_t index, float *out_tensor,
|
||||
uint32_t *out_size)
|
||||
{
|
||||
return get_output(ctx, index, (uint8_t *)out_tensor, out_size);
|
||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
return WASI_NN_NAME(get_output)(ctx, index, (uint8_t *)out_tensor, MAX_OUTPUT_TENSOR_SIZE, out_size);
|
||||
#else
|
||||
return WASI_NN_NAME(get_output)(ctx, index, (uint8_t *)out_tensor, out_size);
|
||||
#endif
|
||||
}
|
||||
|
||||
float *
|
||||
run_inference(execution_target target, float *input, uint32_t *input_size,
|
||||
run_inference(float *input, uint32_t *input_size,
|
||||
uint32_t *output_size, char *model_name,
|
||||
uint32_t num_output_tensors)
|
||||
{
|
||||
graph graph;
|
||||
WASI_NN_NAME(graph) graph;
|
||||
|
||||
if (wasm_load_by_name(model_name, &graph) != success) {
|
||||
if (wasm_load_by_name(model_name, &graph) != WASI_NN_ERROR_NAME(success)) {
|
||||
NN_ERR_PRINTF("Error when loading model.");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
graph_execution_context ctx;
|
||||
if (wasm_init_execution_context(graph, &ctx) != success) {
|
||||
WASI_NN_NAME(graph_execution_context) ctx;
|
||||
if (wasm_init_execution_context(graph, &ctx) != WASI_NN_ERROR_NAME(success)) {
|
||||
NN_ERR_PRINTF("Error when initialixing execution context.");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
if (wasm_set_input(ctx, input, input_size) != success) {
|
||||
if (wasm_set_input(ctx, input, input_size) != WASI_NN_ERROR_NAME(success)) {
|
||||
NN_ERR_PRINTF("Error when setting input tensor.");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
if (wasm_compute(ctx) != success) {
|
||||
if (wasm_compute(ctx) != WASI_NN_ERROR_NAME(success)) {
|
||||
NN_ERR_PRINTF("Error when running inference.");
|
||||
exit(1);
|
||||
}
|
||||
|
|
@ -140,7 +168,7 @@ run_inference(execution_target target, float *input, uint32_t *input_size,
|
|||
for (int i = 0; i < num_output_tensors; ++i) {
|
||||
*output_size = MAX_OUTPUT_TENSOR_SIZE - *output_size;
|
||||
if (wasm_get_output(ctx, i, &out_tensor[offset], output_size)
|
||||
!= success) {
|
||||
!= WASI_NN_ERROR_NAME(success)) {
|
||||
NN_ERR_PRINTF("Error when getting index %d.", i);
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "wasi_ephemeral_nn.h"
|
||||
#include "wasi_nn_types.h"
|
||||
|
||||
#define MAX_MODEL_SIZE 85000000
|
||||
|
|
@ -23,26 +24,26 @@ typedef struct {
|
|||
|
||||
/* wasi-nn wrappers */
|
||||
|
||||
wasi_nn_error
|
||||
wasm_load(char *model_name, graph *g, execution_target target);
|
||||
WASI_NN_ERROR_TYPE
|
||||
wasm_load(char *model_name, WASI_NN_NAME(graph) *g, WASI_NN_NAME(execution_target) target);
|
||||
|
||||
wasi_nn_error
|
||||
wasm_init_execution_context(graph g, graph_execution_context *ctx);
|
||||
WASI_NN_ERROR_TYPE
|
||||
wasm_init_execution_context(WASI_NN_NAME(graph) g, WASI_NN_NAME(graph_execution_context) *ctx);
|
||||
|
||||
wasi_nn_error
|
||||
wasm_set_input(graph_execution_context ctx, float *input_tensor, uint32_t *dim);
|
||||
WASI_NN_ERROR_TYPE
|
||||
wasm_set_input(WASI_NN_NAME(graph_execution_context) ctx, float *input_tensor, uint32_t *dim);
|
||||
|
||||
wasi_nn_error
|
||||
wasm_compute(graph_execution_context ctx);
|
||||
WASI_NN_ERROR_TYPE
|
||||
wasm_compute(WASI_NN_NAME(graph_execution_context) ctx);
|
||||
|
||||
wasi_nn_error
|
||||
wasm_get_output(graph_execution_context ctx, uint32_t index, float *out_tensor,
|
||||
WASI_NN_ERROR_TYPE
|
||||
wasm_get_output(WASI_NN_NAME(graph_execution_context) ctx, uint32_t index, float *out_tensor,
|
||||
uint32_t *out_size);
|
||||
|
||||
/* Utils */
|
||||
|
||||
float *
|
||||
run_inference(execution_target target, float *input, uint32_t *input_size,
|
||||
run_inference(float *input, uint32_t *input_size,
|
||||
uint32_t *output_size, char *model_name,
|
||||
uint32_t num_output_tensors);
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,10 @@
|
|||
#include "../common/libc_wasi.c"
|
||||
#endif
|
||||
|
||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
#include "wasi_ephemeral_nn.h"
|
||||
#endif
|
||||
|
||||
#include "../common/wasm_proposal.c"
|
||||
|
||||
#if BH_HAS_DLFCN
|
||||
|
|
@ -115,6 +119,12 @@ print_help(void)
|
|||
#endif
|
||||
#if WASM_ENABLE_STATIC_PGO != 0
|
||||
printf(" --gen-prof-file=<path> Generate LLVM PGO (Profile-Guided Optimization) profile file\n");
|
||||
#endif
|
||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
printf(" --wasi-nn-graph=encoding:target:<model_path1>:<model_path2>:...:<model_pathn>\n");
|
||||
printf(" Set encoding, target and model_paths for wasi-nn. target can be\n");
|
||||
printf(" cpu|gpu|tpu, encoding can be tensorflowlite|openvino|llama|onnx|\n");
|
||||
printf(" tensorflow|pytorch|ggml|autodetect\n");
|
||||
#endif
|
||||
printf(" --version Show version information\n");
|
||||
return 1;
|
||||
|
|
@ -635,6 +645,13 @@ main(int argc, char *argv[])
|
|||
int timeout_ms = -1;
|
||||
#endif
|
||||
|
||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
struct wasi_nn_graph_registry *nn_registry;
|
||||
char *encoding, *target;
|
||||
uint32_t n_models = 0;
|
||||
char **model_paths;
|
||||
#endif
|
||||
|
||||
#if WASM_ENABLE_LIBC_WASI != 0
|
||||
memset(&wasi_parse_ctx, 0, sizeof(wasi_parse_ctx));
|
||||
#endif
|
||||
|
|
@ -825,6 +842,37 @@ main(int argc, char *argv[])
|
|||
wasm_proposal_print_status();
|
||||
return 0;
|
||||
}
|
||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
else if (!strncmp(argv[0], "--wasi-nn-graph=", 16)) {
|
||||
char *token;
|
||||
char *saveptr = NULL;
|
||||
int token_count = 0;
|
||||
char *tokens[12] = {0};
|
||||
|
||||
// encoding:tensorflowlite|openvino|llama target:cpu|gpu|tpu
|
||||
// --wasi-nn-graph=encoding:target:model_file_path1:model_file_path2:model_file_path3:......
|
||||
token = strtok_r(argv[0] + 16, ":", &saveptr);
|
||||
while (token) {
|
||||
tokens[token_count] = token;
|
||||
token_count++;
|
||||
token = strtok_r(NULL, ":", &saveptr);
|
||||
}
|
||||
|
||||
if (token_count < 2) {
|
||||
return print_help();
|
||||
}
|
||||
|
||||
n_models = token_count - 2;
|
||||
encoding = strdup(tokens[0]);
|
||||
target = strdup(tokens[1]);
|
||||
model_paths = malloc(n_models * sizeof(void*));
|
||||
for (int i = 0; i < n_models; i++) {
|
||||
model_paths[i] = strdup(tokens[i + 2]);
|
||||
}
|
||||
if (token)
|
||||
free(token);
|
||||
}
|
||||
#endif
|
||||
else {
|
||||
#if WASM_ENABLE_LIBC_WASI != 0
|
||||
libc_wasi_parse_result_t result =
|
||||
|
|
@ -974,6 +1022,11 @@ main(int argc, char *argv[])
|
|||
libc_wasi_set_init_args(inst_args, argc, argv, &wasi_parse_ctx);
|
||||
#endif
|
||||
|
||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
wasi_nn_graph_registry_create(&nn_registry);
|
||||
wasi_nn_graph_registry_set_args(nn_registry, encoding, target, n_models, model_paths);
|
||||
wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(inst_args, nn_registry);
|
||||
#endif
|
||||
/* instantiate the module */
|
||||
wasm_module_inst = wasm_runtime_instantiate_ex2(
|
||||
wasm_module, inst_args, error_buf, sizeof(error_buf));
|
||||
|
|
@ -1092,6 +1145,15 @@ fail5:
|
|||
#endif
|
||||
#if WASM_ENABLE_DEBUG_INTERP != 0
|
||||
fail4:
|
||||
#endif
|
||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
wasi_nn_graph_registry_destroy(nn_registry);
|
||||
for (uint32_t i = 0; i < n_models; i++)
|
||||
if (model_paths[i])
|
||||
free(model_paths[i]);
|
||||
free(model_paths);
|
||||
free(encoding);
|
||||
free(target);
|
||||
#endif
|
||||
/* destroy the module instance */
|
||||
wasm_runtime_deinstantiate(wasm_module_inst);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user