wasi_nn_tensorflowlite.cpp: make this compatible with wasmedge (#4517)

for wasi_ephemeral_nn,

* implement u8 input

* stop dealing with quantization.

  * wasi-nn doesn't have a concept of quantization or pre/post-processing.
    i can't think of any ways to make the backend perform zero-point/scale
    processing without risking to break other applications.

  * there seems to be applications which just use u8 inputs/outputs for
    a quantized model. (see [1] for an example.)
    for certain kinds of inputs/outputs, it usually just works.

this commit keeps the legacy wasi_nn logic intact for now.

tested with [1] with [2] applied.

WAMR with this patch:
```
Read graph weights, size in bytes: 3561598
[wasi_nn.c:297 WARNING] load_by_name_with_config() not found
[wasi_nn_tensorflowlite.cpp:272 WARNING] Default encoding is CPU.
Loaded graph into wasi-nn with ID: Graph#0
Read input tensor, size in bytes: 150528
   1.) [166](198)Aix galericulata
   2.) [34](1)Gallus gallus domesticus
   3.) [158](1)Coccothraustes coccothraustes
   4.) [778](1)Sitta europaea
   5.) [819](1)Anas platyrhynchos
```

wasmedge:
```
Read graph weights, size in bytes: 3561598
Loaded graph into wasi-nn with ID: Graph#0
Read input tensor, size in bytes: 150528
   1.) [166](198)Aix galericulata
   2.) [34](1)Gallus gallus domesticus
   3.) [158](1)Coccothraustes coccothraustes
   4.) [778](1)Sitta europaea
   5.) [819](1)Anas platyrhynchos
```

and "Aix galericulata" seems like a reasonable classification
of the image to my eyes.

[1] 67f174bab5/tflite-birds_v1-image

[2] https://github.com/second-state/WasmEdge-WASINN-examples/pull/204

Related:
https://github.com/bytecodealliance/wasm-micro-runtime/issues/3555
https://github.com/bytecodealliance/wasm-micro-runtime/issues/2611
This commit is contained in:
YAMAMOTO Takashi 2025-08-01 15:31:02 +09:00 committed by GitHub
parent 272a41dc80
commit 29d465b44e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -9,6 +9,7 @@
#include "wasi_nn_backend.h"
#include "wasm_export.h"
#include <tensorflow/lite/c/c_api.h>
#include <tensorflow/lite/interpreter.h>
#include <tensorflow/lite/kernels/register.h>
#include <tensorflow/lite/model.h>
@ -279,29 +280,53 @@ set_input(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
tensor *input_tensor)
{
TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx;
TfLiteType tfl_type;
if (input_tensor->type != fp32) {
NN_ERR_PRINTF("unsupported input tensor type %u", input_tensor->type);
return runtime_error;
switch (input_tensor->type) {
case fp32:
tfl_type = TfLiteType::kTfLiteFloat32;
break;
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
case u8:
tfl_type = TfLiteType::kTfLiteUInt8;
break;
#endif
default:
NN_ERR_PRINTF("unsupported input tensor type %u",
input_tensor->type);
return runtime_error;
}
wasi_nn_error res;
if (success != (res = is_valid_graph_execution_context(tfl_ctx, ctx)))
return res;
uint32_t num_tensors =
tfl_ctx->interpreters[ctx].interpreter->inputs().size();
auto interpreter = tfl_ctx->interpreters[ctx].interpreter.get();
uint32_t num_tensors = interpreter->inputs().size();
NN_DBG_PRINTF("Number of tensors (%d)", num_tensors);
if (index + 1 > num_tensors) {
return runtime_error;
}
auto tensor = tfl_ctx->interpreters[ctx].interpreter->input_tensor(index);
auto tensor = interpreter->input_tensor(index);
if (tensor == NULL) {
NN_ERR_PRINTF("Missing memory");
return too_large;
}
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
if (TfLiteTensorType(tensor) != tfl_type) {
NN_ERR_PRINTF("Type mismatch");
return runtime_error;
}
if (TfLiteTensorCopyFromBuffer(tensor, input_tensor->data.buf,
input_tensor->data.size)
!= kTfLiteOk) {
return runtime_error;
}
#else
uint32_t model_tensor_size = 1;
for (int i = 0; i < tensor->dims->size; ++i)
model_tensor_size *= (uint32_t)tensor->dims->data[i];
@ -346,6 +371,7 @@ set_input(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
it[i] = (uint8_t)(input_tensor_f[i] / scale + zero_point);
}
}
#endif
return success;
}
@ -388,14 +414,19 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
return too_large;
}
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
size_t sz = TfLiteTensorByteSize(tensor);
if (output_tensor->size < sz) {
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
return too_large;
}
if (TfLiteTensorCopyToBuffer(tensor, output_tensor->buf, sz) != kTfLiteOk) {
return runtime_error;
}
*output_tensor_size = sz;
#else
if (tensor->quantization.type == kTfLiteNoQuantization) {
NN_DBG_PRINTF("No quantization information");
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
if (output_tensor->size < tensor->bytes) {
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
return too_large;
}
#else
/*
* for now, maintain the bug-to-bug compatibility with the old abi,
* where the size here is the number of fp32, not bytes.
@ -404,18 +435,13 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
return too_large;
}
#endif
bh_memcpy_s(output_tensor->buf, output_tensor->size, tensor->data.data,
tensor->bytes);
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
*output_tensor_size = tensor->bytes;
#else
/*
* for now, maintain the bug-to-bug compatibility with the old abi,
* where the size here is the number of fp32, not bytes.
*/
*output_tensor_size = tensor->bytes / sizeof(float);
#endif
}
else { // TODO: Assuming uint8 quantized networks.
TfLiteAffineQuantization *quant_info =
@ -429,12 +455,6 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
for (int i = 0; i < (int)tensor->dims->size; ++i)
model_tensor_size *= (uint32_t)tensor->dims->data[i];
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
if (output_tensor->size / sizeof(float) < model_tensor_size) {
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
return too_large;
}
#else
/*
* for now, maintain the bug-to-bug compatibility with the old abi,
* where the size here is the number of fp32, not bytes.
@ -443,7 +463,6 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
return too_large;
}
#endif
uint8_t *ot = tfl_ctx->interpreters[ctx]
.interpreter->typed_output_tensor<uint8_t>(index);
@ -458,16 +477,13 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
output_tensor_f[i] = (ot[i] - zero_point) * scale;
}
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
*output_tensor_size = model_tensor_size * sizeof(float);
#else
/*
* for now, maintain the bug-to-bug compatibility with the old abi,
* where the size here is the number of fp32, not bytes.
*/
*output_tensor_size = model_tensor_size;
#endif
}
#endif
return success;
}