add load_by_name in wasi-nn (#4298)

This commit is contained in:
hongxia 2025-06-03 06:26:58 +08:00 committed by GitHub
parent 2a303861cc
commit aa1ff778b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 26 additions and 29 deletions

View File

@ -30,7 +30,7 @@ load(graph_builder_array *builder, graph_encoding encoding,
__attribute__((import_module("wasi_nn"))); __attribute__((import_module("wasi_nn")));
wasi_nn_error wasi_nn_error
load_by_name(const char *name, graph *g) load_by_name(const char *name, uint32_t name_len, graph *g)
__attribute__((import_module("wasi_nn"))); __attribute__((import_module("wasi_nn")));
/** /**

View File

@ -697,6 +697,7 @@ static NativeSymbol native_symbols_wasi_nn[] = {
REG_NATIVE_FUNC(get_output, "(ii*i*)i"), REG_NATIVE_FUNC(get_output, "(ii*i*)i"),
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */ #else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
REG_NATIVE_FUNC(load, "(*ii*)i"), REG_NATIVE_FUNC(load, "(*ii*)i"),
REG_NATIVE_FUNC(load_by_name, "(*i*)i"),
REG_NATIVE_FUNC(init_execution_context, "(i*)i"), REG_NATIVE_FUNC(init_execution_context, "(i*)i"),
REG_NATIVE_FUNC(set_input, "(ii*)i"), REG_NATIVE_FUNC(set_input, "(ii*)i"),
REG_NATIVE_FUNC(compute, "(i)i"), REG_NATIVE_FUNC(compute, "(i)i"),

View File

@ -85,12 +85,8 @@ is_valid_graph(TFLiteContext *tfl_ctx, graph g)
NN_ERR_PRINTF("Invalid graph: %d >= %d.", g, MAX_GRAPHS_PER_INST); NN_ERR_PRINTF("Invalid graph: %d >= %d.", g, MAX_GRAPHS_PER_INST);
return runtime_error; return runtime_error;
} }
if (tfl_ctx->models[g].model_pointer == NULL) {
NN_ERR_PRINTF("Context (model) non-initialized.");
return runtime_error;
}
if (tfl_ctx->models[g].model == NULL) { if (tfl_ctx->models[g].model == NULL) {
NN_ERR_PRINTF("Context (tflite model) non-initialized."); NN_ERR_PRINTF("Context (model) non-initialized.");
return runtime_error; return runtime_error;
} }
return success; return success;
@ -472,32 +468,31 @@ deinit_backend(void *tflite_ctx)
NN_DBG_PRINTF("Freeing memory."); NN_DBG_PRINTF("Freeing memory.");
for (int i = 0; i < MAX_GRAPHS_PER_INST; ++i) { for (int i = 0; i < MAX_GRAPHS_PER_INST; ++i) {
tfl_ctx->models[i].model.reset(); tfl_ctx->models[i].model.reset();
if (tfl_ctx->models[i].model_pointer) { if (tfl_ctx->delegate) {
if (tfl_ctx->delegate) { switch (tfl_ctx->models[i].target) {
switch (tfl_ctx->models[i].target) { case gpu:
case gpu: {
{
#if WASM_ENABLE_WASI_NN_GPU != 0 #if WASM_ENABLE_WASI_NN_GPU != 0
TfLiteGpuDelegateV2Delete(tfl_ctx->delegate); TfLiteGpuDelegateV2Delete(tfl_ctx->delegate);
#else #else
NN_ERR_PRINTF("GPU delegate delete but not enabled."); NN_ERR_PRINTF("GPU delegate delete but not enabled.");
#endif #endif
break; break;
}
case tpu:
{
#if WASM_ENABLE_WASI_NN_EXTERNAL_DELEGATE != 0
TfLiteExternalDelegateDelete(tfl_ctx->delegate);
#else
NN_ERR_PRINTF(
"External delegate delete but not enabled.");
#endif
break;
}
default:
break;
} }
case tpu:
{
#if WASM_ENABLE_WASI_NN_EXTERNAL_DELEGATE != 0
TfLiteExternalDelegateDelete(tfl_ctx->delegate);
#else
NN_ERR_PRINTF("External delegate delete but not enabled.");
#endif
break;
}
default:
break;
} }
}
if (tfl_ctx->models[i].model_pointer) {
wasm_runtime_free(tfl_ctx->models[i].model_pointer); wasm_runtime_free(tfl_ctx->models[i].model_pointer);
} }
tfl_ctx->models[i].model_pointer = NULL; tfl_ctx->models[i].model_pointer = NULL;

View File

@ -58,7 +58,7 @@ wasm_load(char *model_name, graph *g, execution_target target)
wasi_nn_error wasi_nn_error
wasm_load_by_name(const char *model_name, graph *g) wasm_load_by_name(const char *model_name, graph *g)
{ {
wasi_nn_error res = load_by_name(model_name, g); wasi_nn_error res = load_by_name(model_name, strlen(model_name), g);
return res; return res;
} }
@ -108,7 +108,8 @@ run_inference(execution_target target, float *input, uint32_t *input_size,
uint32_t num_output_tensors) uint32_t num_output_tensors)
{ {
graph graph; graph graph;
if (wasm_load(model_name, &graph, target) != success) {
if (wasm_load_by_name(model_name, &graph) != success) {
NN_ERR_PRINTF("Error when loading model."); NN_ERR_PRINTF("Error when loading model.");
exit(1); exit(1);
} }