mirror of
https://github.com/bytecodealliance/wasm-micro-runtime.git
synced 2026-04-18 18:18:44 +00:00
Add model_name option for --wasi-nn-graphs to make it more flexible and simpler
This commit is contained in:
parent
c024c94150
commit
736f357a58
|
|
@ -1806,23 +1806,27 @@ wasm_runtime_wasi_nn_graph_registry_args_set_defaults(WASINNArguments *args)
|
|||
|
||||
bool
|
||||
wasi_nn_graph_registry_set_args(WASINNArguments *registry,
|
||||
const char **encoding, const char **target,
|
||||
uint32_t n_graphs, const char **graph_paths)
|
||||
const char **model_names, const char **encoding,
|
||||
const char **target, uint32_t n_graphs,
|
||||
const char **graph_paths)
|
||||
{
|
||||
if (!registry || !encoding || !target || !graph_paths) {
|
||||
if (!registry || !model_names || !encoding || !target || !graph_paths) {
|
||||
return false;
|
||||
}
|
||||
|
||||
registry->n_graphs = n_graphs;
|
||||
registry->target = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs);
|
||||
registry->encoding = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs);
|
||||
registry->model_names = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs);
|
||||
registry->graph_paths = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs);
|
||||
memset(registry->target, 0, sizeof(uint32_t *) * n_graphs);
|
||||
memset(registry->encoding, 0, sizeof(uint32_t *) * n_graphs);
|
||||
memset(registry->model_names, 0, sizeof(uint32_t *) * n_graphs);
|
||||
memset(registry->graph_paths, 0, sizeof(uint32_t *) * n_graphs);
|
||||
|
||||
for (uint32_t i = 0; i < registry->n_graphs; i++) {
|
||||
registry->graph_paths[i] = strdup(graph_paths[i]);
|
||||
registry->model_names[i] = strdup(model_names[i]);
|
||||
registry->encoding[i] = strdup(encoding[i]);
|
||||
registry->target[i] = strdup(target[i]);
|
||||
}
|
||||
|
|
@ -1849,6 +1853,8 @@ wasi_nn_graph_registry_destroy(WASINNArguments *registry)
|
|||
for (uint32_t i = 0; i < registry->n_graphs; i++)
|
||||
if (registry->graph_paths[i]) {
|
||||
free(registry->graph_paths[i]);
|
||||
if (registry->model_names[i])
|
||||
free(registry->model_names[i]);
|
||||
if (registry->encoding[i])
|
||||
free(registry->encoding[i]);
|
||||
if (registry->target[i])
|
||||
|
|
@ -8155,6 +8161,7 @@ wasm_runtime_check_and_update_last_used_shared_heap(
|
|||
#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
bool
|
||||
wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst,
|
||||
const char **model_names,
|
||||
const char **encoding, const char **target,
|
||||
const uint32_t n_graphs,
|
||||
char *graph_paths[], char *error_buf,
|
||||
|
|
@ -8175,11 +8182,14 @@ wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst,
|
|||
memset(ctx->target, 0, sizeof(uint32_t) * n_graphs);
|
||||
ctx->loaded = (uint32_t *)malloc(sizeof(uint32_t) * n_graphs);
|
||||
memset(ctx->loaded, 0, sizeof(uint32_t) * n_graphs);
|
||||
ctx->model_names = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs);
|
||||
memset(ctx->model_names, 0, sizeof(uint32_t *) * n_graphs);
|
||||
ctx->graph_paths = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs);
|
||||
memset(ctx->graph_paths, 0, sizeof(uint32_t *) * n_graphs);
|
||||
|
||||
for (uint32_t i = 0; i < n_graphs; i++) {
|
||||
ctx->graph_paths[i] = strdup(graph_paths[i]);
|
||||
ctx->model_names[i] = strdup(model_names[i]);
|
||||
ctx->target[i] = strdup(target[i]);
|
||||
ctx->encoding[i] = strdup(encoding[i]);
|
||||
}
|
||||
|
|
@ -8201,14 +8211,17 @@ wasm_runtime_destroy_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst)
|
|||
// All graphs will be unregistered in deinit()
|
||||
if (wasi_nn_global_ctx->graph_paths[i])
|
||||
free(wasi_nn_global_ctx->graph_paths[i]);
|
||||
if (wasi_nn_global_ctx->model_names[i])
|
||||
free(wasi_nn_global_ctx->model_names[i]);
|
||||
if (wasi_nn_global_ctx->encoding[i])
|
||||
free(wasi_nn_global_ctx->encoding[i]);
|
||||
if (wasi_nn_global_ctx->encoding[i])
|
||||
if (wasi_nn_global_ctx->target[i])
|
||||
free(wasi_nn_global_ctx->target[i]);
|
||||
}
|
||||
free(wasi_nn_global_ctx->encoding);
|
||||
free(wasi_nn_global_ctx->target);
|
||||
free(wasi_nn_global_ctx->loaded);
|
||||
free(wasi_nn_global_ctx->model_names);
|
||||
free(wasi_nn_global_ctx->graph_paths);
|
||||
|
||||
if (wasi_nn_global_ctx) {
|
||||
|
|
@ -8226,6 +8239,16 @@ wasm_runtime_get_wasi_nn_global_ctx_ngraphs(
|
|||
return -1;
|
||||
}
|
||||
|
||||
char *
|
||||
wasm_runtime_get_wasi_nn_global_ctx_model_names_i(
|
||||
WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx)
|
||||
{
|
||||
if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs))
|
||||
return wasi_nn_global_ctx->model_names[idx];
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
char *
|
||||
wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(
|
||||
WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx)
|
||||
|
|
|
|||
|
|
@ -547,6 +547,7 @@ typedef struct WASMModuleInstMemConsumption {
|
|||
|
||||
#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
typedef struct WASINNGlobalContext {
|
||||
char **model_names;
|
||||
char **encoding;
|
||||
char **target;
|
||||
|
||||
|
|
@ -625,6 +626,7 @@ wasm_runtime_get_exec_env_tls(void);
|
|||
|
||||
#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
typedef struct WASINNArguments {
|
||||
char **model_names;
|
||||
char **encoding;
|
||||
char **target;
|
||||
|
||||
|
|
@ -812,8 +814,9 @@ wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(
|
|||
|
||||
WASM_RUNTIME_API_EXTERN bool
|
||||
wasi_nn_graph_registry_set_args(WASINNArguments *registry,
|
||||
const char **encoding, const char **target,
|
||||
uint32_t n_graphs, const char **graph_paths);
|
||||
const char **model_names, const char **encoding,
|
||||
const char **target, uint32_t n_graphs,
|
||||
const char **graph_paths);
|
||||
#endif
|
||||
|
||||
/* See wasm_export.h for description */
|
||||
|
|
@ -1471,6 +1474,7 @@ wasm_runtime_check_and_update_last_used_shared_heap(
|
|||
#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
WASM_RUNTIME_API_EXTERN bool
|
||||
wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst,
|
||||
const char **model_names,
|
||||
const char **encoding, const char **target,
|
||||
const uint32_t n_graphs,
|
||||
char *graph_paths[], char *error_buf,
|
||||
|
|
@ -1494,6 +1498,10 @@ WASM_RUNTIME_API_EXTERN uint32_t
|
|||
wasm_runtime_get_wasi_nn_global_ctx_ngraphs(
|
||||
WASINNGlobalContext *wasi_nn_global_ctx);
|
||||
|
||||
WASM_RUNTIME_API_EXTERN char *
|
||||
wasm_runtime_get_wasi_nn_global_ctx_model_names_i(
|
||||
WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx);
|
||||
|
||||
WASM_RUNTIME_API_EXTERN char *
|
||||
wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(
|
||||
WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx);
|
||||
|
|
|
|||
|
|
@ -805,6 +805,10 @@ WASM_RUNTIME_API_EXTERN uint32_t
|
|||
wasm_runtime_get_wasi_nn_global_ctx_ngraphs(
|
||||
WASINNGlobalContext *wasi_nn_global_ctx);
|
||||
|
||||
WASM_RUNTIME_API_EXTERN char *
|
||||
wasm_runtime_get_wasi_nn_global_ctx_model_names_i(
|
||||
WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx);
|
||||
|
||||
WASM_RUNTIME_API_EXTERN char *
|
||||
wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(
|
||||
WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx);
|
||||
|
|
|
|||
|
|
@ -3305,8 +3305,8 @@ wasm_instantiate(WASMModule *module, WASMModuleInstance *parent,
|
|||
* load_by_name */
|
||||
WASINNArguments *nn_registry = &args->nn_registry;
|
||||
if (!wasm_runtime_init_wasi_nn_global_ctx(
|
||||
(WASMModuleInstanceCommon *)module_inst, nn_registry->encoding,
|
||||
nn_registry->target, nn_registry->n_graphs,
|
||||
(WASMModuleInstanceCommon *)module_inst, nn_registry->model_names,
|
||||
nn_registry->encoding, nn_registry->target, nn_registry->n_graphs,
|
||||
nn_registry->graph_paths, error_buf, error_buf_size)) {
|
||||
goto fail;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -617,42 +617,21 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len,
|
|||
|
||||
bool is_loaded = false;
|
||||
uint32 model_idx = 0;
|
||||
char *global_model_path_i;
|
||||
uint32_t global_n_graphs =
|
||||
wasm_runtime_get_wasi_nn_global_ctx_ngraphs(wasi_nn_global_ctx);
|
||||
// Model got from user wasm app : modelA; modelB...
|
||||
// Filelist got from user cmd opt: /path1/modelA.tflite;
|
||||
// /path/modelB.tflite; ......
|
||||
for (model_idx = 0; model_idx < global_n_graphs; model_idx++) {
|
||||
// Extract filename from file path
|
||||
global_model_path_i = wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(
|
||||
char *model_name = wasm_runtime_get_wasi_nn_global_ctx_model_names_i(
|
||||
wasi_nn_global_ctx, model_idx);
|
||||
char *model_file_name;
|
||||
const char *slash = strrchr(global_model_path_i, '/');
|
||||
if (slash != NULL) {
|
||||
model_file_name = (char *)(slash + 1);
|
||||
}
|
||||
else
|
||||
model_file_name = global_model_path_i;
|
||||
|
||||
// Extract modelname from filename
|
||||
char *model_name = NULL;
|
||||
size_t model_name_len = 0;
|
||||
char *dot = strrchr(model_file_name, '.');
|
||||
if (dot) {
|
||||
model_name_len = dot - model_file_name;
|
||||
model_name = malloc(model_name_len + 1);
|
||||
strncpy(model_name, model_file_name, model_name_len);
|
||||
model_name[model_name_len] = '\0';
|
||||
}
|
||||
|
||||
if (model_name && strcmp(nul_terminated_name, model_name) != 0) {
|
||||
free(model_name);
|
||||
continue;
|
||||
}
|
||||
|
||||
is_loaded = wasm_runtime_get_wasi_nn_global_ctx_loaded_i(
|
||||
wasi_nn_global_ctx, model_idx);
|
||||
free(model_name);
|
||||
char *global_model_path_i =
|
||||
wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(
|
||||
wasi_nn_global_ctx, model_idx);
|
||||
|
||||
graph_encoding encoding =
|
||||
str2encoding(wasm_runtime_get_wasi_nn_global_ctx_encoding_i(
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ typedef struct {
|
|||
} libc_wasi_parse_context_t;
|
||||
|
||||
typedef struct {
|
||||
const char *model_names[10];
|
||||
const char *encoding[10];
|
||||
const char *target[10];
|
||||
const char *graph_paths[10];
|
||||
|
|
@ -208,19 +209,23 @@ wasi_nn_parse(char **argv, wasi_nn_parse_context_t *ctx)
|
|||
// --wasi-nn-graph=encoding2:target2:model_file_path2 ...
|
||||
token = strtok_r(argv[0] + 16, ":", &saveptr);
|
||||
while (token) {
|
||||
tokens[token_count] = token;
|
||||
token_count++;
|
||||
token = strtok_r(NULL, ":", &saveptr);
|
||||
if (strlen(token) > 0) {
|
||||
tokens[token_count] = token;
|
||||
token_count++;
|
||||
token = strtok_r(NULL, ":", &saveptr);
|
||||
}
|
||||
}
|
||||
|
||||
if (token_count != 3) {
|
||||
if (token_count != 4) {
|
||||
ret = LIBC_WASI_PARSE_RESULT_NEED_HELP;
|
||||
printf("4 arguments are needed for wasi-nn.\n");
|
||||
goto fail;
|
||||
}
|
||||
|
||||
ctx->encoding[ctx->n_graphs] = strdup(tokens[0]);
|
||||
ctx->target[ctx->n_graphs] = strdup(tokens[1]);
|
||||
ctx->graph_paths[ctx->n_graphs++] = strdup(tokens[2]);
|
||||
ctx->model_names[ctx->n_graphs] = strdup(tokens[0]);
|
||||
ctx->encoding[ctx->n_graphs] = strdup(tokens[1]);
|
||||
ctx->target[ctx->n_graphs] = strdup(tokens[2]);
|
||||
ctx->graph_paths[ctx->n_graphs++] = strdup(tokens[3]);
|
||||
|
||||
fail:
|
||||
if (token)
|
||||
|
|
@ -234,12 +239,15 @@ wasi_nn_set_init_args(struct InstantiationArgs2 *args,
|
|||
struct WASINNArguments *nn_registry,
|
||||
wasi_nn_parse_context_t *ctx)
|
||||
{
|
||||
wasi_nn_graph_registry_set_args(nn_registry, ctx->encoding, ctx->target,
|
||||
ctx->n_graphs, ctx->graph_paths);
|
||||
wasi_nn_graph_registry_set_args(nn_registry, ctx->model_names,
|
||||
ctx->encoding, ctx->target, ctx->n_graphs,
|
||||
ctx->graph_paths);
|
||||
wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(args,
|
||||
nn_registry);
|
||||
|
||||
for (uint32_t i = 0; i < ctx->n_graphs; i++) {
|
||||
if (ctx->model_names[i])
|
||||
free(ctx->model_names[i]);
|
||||
if (ctx->graph_paths[i])
|
||||
free(ctx->graph_paths[i]);
|
||||
if (ctx->encoding[i])
|
||||
|
|
|
|||
|
|
@ -123,8 +123,8 @@ print_help(void)
|
|||
printf(" --gen-prof-file=<path> Generate LLVM PGO (Profile-Guided Optimization) profile file\n");
|
||||
#endif
|
||||
#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
printf(" --wasi-nn-graph=encodingA:targetB:<modelA_path>\n");
|
||||
printf(" --wasi-nn-graph=encodingA:targetB:<modelB_path>...\n");
|
||||
printf(" --wasi-nn-graph=modelA_name:encodingA:targetA:<modelA_path>\n");
|
||||
printf(" --wasi-nn-graph=modelB_name:encodingB:targetB:<modelB_path>...\n");
|
||||
printf(" Set encoding, target and model_paths for wasi-nn. target can be\n");
|
||||
printf(" cpu|gpu|tpu, encoding can be tensorflowlite|openvino|llama|onnx|\n");
|
||||
printf(" tensorflow|pytorch|ggml|autodetect\n");
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user