mirror of
https://github.com/bytecodealliance/wasm-micro-runtime.git
synced 2025-09-09 11:19:47 +00:00
wasi_nn_tensorflowlite.cpp: make this compatible with wasmedge (#4517)
for wasi_ephemeral_nn,
* implement u8 input
* stop dealing with quantization.
* wasi-nn doesn't have a concept of quantization or pre/post-processing.
i can't think of any ways to make the backend perform zero-point/scale
processing without risking to break other applications.
* there seems to be applications which just use u8 inputs/outputs for
a quantized model. (see [1] for an example.)
for certain kinds of inputs/outputs, it usually just works.
this commit keeps the legacy wasi_nn logic intact for now.
tested with [1] with [2] applied.
WAMR with this patch:
```
Read graph weights, size in bytes: 3561598
[wasi_nn.c:297 WARNING] load_by_name_with_config() not found
[wasi_nn_tensorflowlite.cpp:272 WARNING] Default encoding is CPU.
Loaded graph into wasi-nn with ID: Graph#0
Read input tensor, size in bytes: 150528
1.) [166](198)Aix galericulata
2.) [34](1)Gallus gallus domesticus
3.) [158](1)Coccothraustes coccothraustes
4.) [778](1)Sitta europaea
5.) [819](1)Anas platyrhynchos
```
wasmedge:
```
Read graph weights, size in bytes: 3561598
Loaded graph into wasi-nn with ID: Graph#0
Read input tensor, size in bytes: 150528
1.) [166](198)Aix galericulata
2.) [34](1)Gallus gallus domesticus
3.) [158](1)Coccothraustes coccothraustes
4.) [778](1)Sitta europaea
5.) [819](1)Anas platyrhynchos
```
and "Aix galericulata" seems like a reasonable classification
of the image to my eyes.
[1] 67f174bab5/tflite-birds_v1-image
[2] https://github.com/second-state/WasmEdge-WASINN-examples/pull/204
Related:
https://github.com/bytecodealliance/wasm-micro-runtime/issues/3555
https://github.com/bytecodealliance/wasm-micro-runtime/issues/2611
This commit is contained in:
parent
272a41dc80
commit
29d465b44e
|
@ -9,6 +9,7 @@
|
||||||
#include "wasi_nn_backend.h"
|
#include "wasi_nn_backend.h"
|
||||||
#include "wasm_export.h"
|
#include "wasm_export.h"
|
||||||
|
|
||||||
|
#include <tensorflow/lite/c/c_api.h>
|
||||||
#include <tensorflow/lite/interpreter.h>
|
#include <tensorflow/lite/interpreter.h>
|
||||||
#include <tensorflow/lite/kernels/register.h>
|
#include <tensorflow/lite/kernels/register.h>
|
||||||
#include <tensorflow/lite/model.h>
|
#include <tensorflow/lite/model.h>
|
||||||
|
@ -279,29 +280,53 @@ set_input(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
|
||||||
tensor *input_tensor)
|
tensor *input_tensor)
|
||||||
{
|
{
|
||||||
TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx;
|
TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx;
|
||||||
|
TfLiteType tfl_type;
|
||||||
|
|
||||||
if (input_tensor->type != fp32) {
|
switch (input_tensor->type) {
|
||||||
NN_ERR_PRINTF("unsupported input tensor type %u", input_tensor->type);
|
case fp32:
|
||||||
return runtime_error;
|
tfl_type = TfLiteType::kTfLiteFloat32;
|
||||||
|
break;
|
||||||
|
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||||
|
case u8:
|
||||||
|
tfl_type = TfLiteType::kTfLiteUInt8;
|
||||||
|
break;
|
||||||
|
#endif
|
||||||
|
default:
|
||||||
|
NN_ERR_PRINTF("unsupported input tensor type %u",
|
||||||
|
input_tensor->type);
|
||||||
|
return runtime_error;
|
||||||
}
|
}
|
||||||
|
|
||||||
wasi_nn_error res;
|
wasi_nn_error res;
|
||||||
if (success != (res = is_valid_graph_execution_context(tfl_ctx, ctx)))
|
if (success != (res = is_valid_graph_execution_context(tfl_ctx, ctx)))
|
||||||
return res;
|
return res;
|
||||||
|
|
||||||
uint32_t num_tensors =
|
auto interpreter = tfl_ctx->interpreters[ctx].interpreter.get();
|
||||||
tfl_ctx->interpreters[ctx].interpreter->inputs().size();
|
|
||||||
|
uint32_t num_tensors = interpreter->inputs().size();
|
||||||
NN_DBG_PRINTF("Number of tensors (%d)", num_tensors);
|
NN_DBG_PRINTF("Number of tensors (%d)", num_tensors);
|
||||||
if (index + 1 > num_tensors) {
|
if (index + 1 > num_tensors) {
|
||||||
return runtime_error;
|
return runtime_error;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto tensor = tfl_ctx->interpreters[ctx].interpreter->input_tensor(index);
|
auto tensor = interpreter->input_tensor(index);
|
||||||
if (tensor == NULL) {
|
if (tensor == NULL) {
|
||||||
NN_ERR_PRINTF("Missing memory");
|
NN_ERR_PRINTF("Missing memory");
|
||||||
return too_large;
|
return too_large;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||||
|
if (TfLiteTensorType(tensor) != tfl_type) {
|
||||||
|
NN_ERR_PRINTF("Type mismatch");
|
||||||
|
return runtime_error;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (TfLiteTensorCopyFromBuffer(tensor, input_tensor->data.buf,
|
||||||
|
input_tensor->data.size)
|
||||||
|
!= kTfLiteOk) {
|
||||||
|
return runtime_error;
|
||||||
|
}
|
||||||
|
#else
|
||||||
uint32_t model_tensor_size = 1;
|
uint32_t model_tensor_size = 1;
|
||||||
for (int i = 0; i < tensor->dims->size; ++i)
|
for (int i = 0; i < tensor->dims->size; ++i)
|
||||||
model_tensor_size *= (uint32_t)tensor->dims->data[i];
|
model_tensor_size *= (uint32_t)tensor->dims->data[i];
|
||||||
|
@ -346,6 +371,7 @@ set_input(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
|
||||||
it[i] = (uint8_t)(input_tensor_f[i] / scale + zero_point);
|
it[i] = (uint8_t)(input_tensor_f[i] / scale + zero_point);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
return success;
|
return success;
|
||||||
}
|
}
|
||||||
|
@ -388,14 +414,19 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
|
||||||
return too_large;
|
return too_large;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||||
|
size_t sz = TfLiteTensorByteSize(tensor);
|
||||||
|
if (output_tensor->size < sz) {
|
||||||
|
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
|
||||||
|
return too_large;
|
||||||
|
}
|
||||||
|
if (TfLiteTensorCopyToBuffer(tensor, output_tensor->buf, sz) != kTfLiteOk) {
|
||||||
|
return runtime_error;
|
||||||
|
}
|
||||||
|
*output_tensor_size = sz;
|
||||||
|
#else
|
||||||
if (tensor->quantization.type == kTfLiteNoQuantization) {
|
if (tensor->quantization.type == kTfLiteNoQuantization) {
|
||||||
NN_DBG_PRINTF("No quantization information");
|
NN_DBG_PRINTF("No quantization information");
|
||||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
|
||||||
if (output_tensor->size < tensor->bytes) {
|
|
||||||
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
|
|
||||||
return too_large;
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
/*
|
/*
|
||||||
* for now, maintain the bug-to-bug compatibility with the old abi,
|
* for now, maintain the bug-to-bug compatibility with the old abi,
|
||||||
* where the size here is the number of fp32, not bytes.
|
* where the size here is the number of fp32, not bytes.
|
||||||
|
@ -404,18 +435,13 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
|
||||||
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
|
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
|
||||||
return too_large;
|
return too_large;
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
bh_memcpy_s(output_tensor->buf, output_tensor->size, tensor->data.data,
|
bh_memcpy_s(output_tensor->buf, output_tensor->size, tensor->data.data,
|
||||||
tensor->bytes);
|
tensor->bytes);
|
||||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
|
||||||
*output_tensor_size = tensor->bytes;
|
|
||||||
#else
|
|
||||||
/*
|
/*
|
||||||
* for now, maintain the bug-to-bug compatibility with the old abi,
|
* for now, maintain the bug-to-bug compatibility with the old abi,
|
||||||
* where the size here is the number of fp32, not bytes.
|
* where the size here is the number of fp32, not bytes.
|
||||||
*/
|
*/
|
||||||
*output_tensor_size = tensor->bytes / sizeof(float);
|
*output_tensor_size = tensor->bytes / sizeof(float);
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
else { // TODO: Assuming uint8 quantized networks.
|
else { // TODO: Assuming uint8 quantized networks.
|
||||||
TfLiteAffineQuantization *quant_info =
|
TfLiteAffineQuantization *quant_info =
|
||||||
|
@ -429,12 +455,6 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
|
||||||
for (int i = 0; i < (int)tensor->dims->size; ++i)
|
for (int i = 0; i < (int)tensor->dims->size; ++i)
|
||||||
model_tensor_size *= (uint32_t)tensor->dims->data[i];
|
model_tensor_size *= (uint32_t)tensor->dims->data[i];
|
||||||
|
|
||||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
|
||||||
if (output_tensor->size / sizeof(float) < model_tensor_size) {
|
|
||||||
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
|
|
||||||
return too_large;
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
/*
|
/*
|
||||||
* for now, maintain the bug-to-bug compatibility with the old abi,
|
* for now, maintain the bug-to-bug compatibility with the old abi,
|
||||||
* where the size here is the number of fp32, not bytes.
|
* where the size here is the number of fp32, not bytes.
|
||||||
|
@ -443,7 +463,6 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
|
||||||
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
|
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
|
||||||
return too_large;
|
return too_large;
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
uint8_t *ot = tfl_ctx->interpreters[ctx]
|
uint8_t *ot = tfl_ctx->interpreters[ctx]
|
||||||
.interpreter->typed_output_tensor<uint8_t>(index);
|
.interpreter->typed_output_tensor<uint8_t>(index);
|
||||||
|
@ -458,16 +477,13 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
|
||||||
output_tensor_f[i] = (ot[i] - zero_point) * scale;
|
output_tensor_f[i] = (ot[i] - zero_point) * scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
|
||||||
*output_tensor_size = model_tensor_size * sizeof(float);
|
|
||||||
#else
|
|
||||||
/*
|
/*
|
||||||
* for now, maintain the bug-to-bug compatibility with the old abi,
|
* for now, maintain the bug-to-bug compatibility with the old abi,
|
||||||
* where the size here is the number of fp32, not bytes.
|
* where the size here is the number of fp32, not bytes.
|
||||||
*/
|
*/
|
||||||
*output_tensor_size = model_tensor_size;
|
*output_tensor_size = model_tensor_size;
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
return success;
|
return success;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user