mirror of
https://github.com/bytecodealliance/wasm-micro-runtime.git
synced 2025-06-07 13:49:18 +00:00
add load_by_name in wasi-nn (#4298)
This commit is contained in:
parent
2a303861cc
commit
aa1ff778b9
|
@ -30,7 +30,7 @@ load(graph_builder_array *builder, graph_encoding encoding,
|
||||||
__attribute__((import_module("wasi_nn")));
|
__attribute__((import_module("wasi_nn")));
|
||||||
|
|
||||||
wasi_nn_error
|
wasi_nn_error
|
||||||
load_by_name(const char *name, graph *g)
|
load_by_name(const char *name, uint32_t name_len, graph *g)
|
||||||
__attribute__((import_module("wasi_nn")));
|
__attribute__((import_module("wasi_nn")));
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -697,6 +697,7 @@ static NativeSymbol native_symbols_wasi_nn[] = {
|
||||||
REG_NATIVE_FUNC(get_output, "(ii*i*)i"),
|
REG_NATIVE_FUNC(get_output, "(ii*i*)i"),
|
||||||
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
|
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
|
||||||
REG_NATIVE_FUNC(load, "(*ii*)i"),
|
REG_NATIVE_FUNC(load, "(*ii*)i"),
|
||||||
|
REG_NATIVE_FUNC(load_by_name, "(*i*)i"),
|
||||||
REG_NATIVE_FUNC(init_execution_context, "(i*)i"),
|
REG_NATIVE_FUNC(init_execution_context, "(i*)i"),
|
||||||
REG_NATIVE_FUNC(set_input, "(ii*)i"),
|
REG_NATIVE_FUNC(set_input, "(ii*)i"),
|
||||||
REG_NATIVE_FUNC(compute, "(i)i"),
|
REG_NATIVE_FUNC(compute, "(i)i"),
|
||||||
|
|
|
@ -85,12 +85,8 @@ is_valid_graph(TFLiteContext *tfl_ctx, graph g)
|
||||||
NN_ERR_PRINTF("Invalid graph: %d >= %d.", g, MAX_GRAPHS_PER_INST);
|
NN_ERR_PRINTF("Invalid graph: %d >= %d.", g, MAX_GRAPHS_PER_INST);
|
||||||
return runtime_error;
|
return runtime_error;
|
||||||
}
|
}
|
||||||
if (tfl_ctx->models[g].model_pointer == NULL) {
|
|
||||||
NN_ERR_PRINTF("Context (model) non-initialized.");
|
|
||||||
return runtime_error;
|
|
||||||
}
|
|
||||||
if (tfl_ctx->models[g].model == NULL) {
|
if (tfl_ctx->models[g].model == NULL) {
|
||||||
NN_ERR_PRINTF("Context (tflite model) non-initialized.");
|
NN_ERR_PRINTF("Context (model) non-initialized.");
|
||||||
return runtime_error;
|
return runtime_error;
|
||||||
}
|
}
|
||||||
return success;
|
return success;
|
||||||
|
@ -472,32 +468,31 @@ deinit_backend(void *tflite_ctx)
|
||||||
NN_DBG_PRINTF("Freeing memory.");
|
NN_DBG_PRINTF("Freeing memory.");
|
||||||
for (int i = 0; i < MAX_GRAPHS_PER_INST; ++i) {
|
for (int i = 0; i < MAX_GRAPHS_PER_INST; ++i) {
|
||||||
tfl_ctx->models[i].model.reset();
|
tfl_ctx->models[i].model.reset();
|
||||||
if (tfl_ctx->models[i].model_pointer) {
|
if (tfl_ctx->delegate) {
|
||||||
if (tfl_ctx->delegate) {
|
switch (tfl_ctx->models[i].target) {
|
||||||
switch (tfl_ctx->models[i].target) {
|
case gpu:
|
||||||
case gpu:
|
{
|
||||||
{
|
|
||||||
#if WASM_ENABLE_WASI_NN_GPU != 0
|
#if WASM_ENABLE_WASI_NN_GPU != 0
|
||||||
TfLiteGpuDelegateV2Delete(tfl_ctx->delegate);
|
TfLiteGpuDelegateV2Delete(tfl_ctx->delegate);
|
||||||
#else
|
#else
|
||||||
NN_ERR_PRINTF("GPU delegate delete but not enabled.");
|
NN_ERR_PRINTF("GPU delegate delete but not enabled.");
|
||||||
#endif
|
#endif
|
||||||
break;
|
break;
|
||||||
}
|
|
||||||
case tpu:
|
|
||||||
{
|
|
||||||
#if WASM_ENABLE_WASI_NN_EXTERNAL_DELEGATE != 0
|
|
||||||
TfLiteExternalDelegateDelete(tfl_ctx->delegate);
|
|
||||||
#else
|
|
||||||
NN_ERR_PRINTF(
|
|
||||||
"External delegate delete but not enabled.");
|
|
||||||
#endif
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
|
case tpu:
|
||||||
|
{
|
||||||
|
#if WASM_ENABLE_WASI_NN_EXTERNAL_DELEGATE != 0
|
||||||
|
TfLiteExternalDelegateDelete(tfl_ctx->delegate);
|
||||||
|
#else
|
||||||
|
NN_ERR_PRINTF("External delegate delete but not enabled.");
|
||||||
|
#endif
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
if (tfl_ctx->models[i].model_pointer) {
|
||||||
wasm_runtime_free(tfl_ctx->models[i].model_pointer);
|
wasm_runtime_free(tfl_ctx->models[i].model_pointer);
|
||||||
}
|
}
|
||||||
tfl_ctx->models[i].model_pointer = NULL;
|
tfl_ctx->models[i].model_pointer = NULL;
|
||||||
|
|
|
@ -58,7 +58,7 @@ wasm_load(char *model_name, graph *g, execution_target target)
|
||||||
wasi_nn_error
|
wasi_nn_error
|
||||||
wasm_load_by_name(const char *model_name, graph *g)
|
wasm_load_by_name(const char *model_name, graph *g)
|
||||||
{
|
{
|
||||||
wasi_nn_error res = load_by_name(model_name, g);
|
wasi_nn_error res = load_by_name(model_name, strlen(model_name), g);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -108,7 +108,8 @@ run_inference(execution_target target, float *input, uint32_t *input_size,
|
||||||
uint32_t num_output_tensors)
|
uint32_t num_output_tensors)
|
||||||
{
|
{
|
||||||
graph graph;
|
graph graph;
|
||||||
if (wasm_load(model_name, &graph, target) != success) {
|
|
||||||
|
if (wasm_load_by_name(model_name, &graph) != success) {
|
||||||
NN_ERR_PRINTF("Error when loading model.");
|
NN_ERR_PRINTF("Error when loading model.");
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user