diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp b/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp index 819bd52af..9ac54e664 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp @@ -9,6 +9,7 @@ #include "wasi_nn_backend.h" #include "wasm_export.h" +#include #include #include #include @@ -279,29 +280,53 @@ set_input(void *tflite_ctx, graph_execution_context ctx, uint32_t index, tensor *input_tensor) { TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx; + TfLiteType tfl_type; - if (input_tensor->type != fp32) { - NN_ERR_PRINTF("unsupported input tensor type %u", input_tensor->type); - return runtime_error; + switch (input_tensor->type) { + case fp32: + tfl_type = TfLiteType::kTfLiteFloat32; + break; +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + case u8: + tfl_type = TfLiteType::kTfLiteUInt8; + break; +#endif + default: + NN_ERR_PRINTF("unsupported input tensor type %u", + input_tensor->type); + return runtime_error; } wasi_nn_error res; if (success != (res = is_valid_graph_execution_context(tfl_ctx, ctx))) return res; - uint32_t num_tensors = - tfl_ctx->interpreters[ctx].interpreter->inputs().size(); + auto interpreter = tfl_ctx->interpreters[ctx].interpreter.get(); + + uint32_t num_tensors = interpreter->inputs().size(); NN_DBG_PRINTF("Number of tensors (%d)", num_tensors); if (index + 1 > num_tensors) { return runtime_error; } - auto tensor = tfl_ctx->interpreters[ctx].interpreter->input_tensor(index); + auto tensor = interpreter->input_tensor(index); if (tensor == NULL) { NN_ERR_PRINTF("Missing memory"); return too_large; } +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + if (TfLiteTensorType(tensor) != tfl_type) { + NN_ERR_PRINTF("Type mismatch"); + return runtime_error; + } + + if (TfLiteTensorCopyFromBuffer(tensor, input_tensor->data.buf, + input_tensor->data.size) + != kTfLiteOk) { + return runtime_error; + } +#else uint32_t model_tensor_size = 1; for (int i = 0; i < tensor->dims->size; ++i) model_tensor_size *= (uint32_t)tensor->dims->data[i]; @@ -346,6 +371,7 @@ set_input(void *tflite_ctx, graph_execution_context ctx, uint32_t index, it[i] = (uint8_t)(input_tensor_f[i] / scale + zero_point); } } +#endif return success; } @@ -388,14 +414,19 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index, return too_large; } +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + size_t sz = TfLiteTensorByteSize(tensor); + if (output_tensor->size < sz) { + NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index); + return too_large; + } + if (TfLiteTensorCopyToBuffer(tensor, output_tensor->buf, sz) != kTfLiteOk) { + return runtime_error; + } + *output_tensor_size = sz; +#else if (tensor->quantization.type == kTfLiteNoQuantization) { NN_DBG_PRINTF("No quantization information"); -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - if (output_tensor->size < tensor->bytes) { - NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index); - return too_large; - } -#else /* * for now, maintain the bug-to-bug compatibility with the old abi, * where the size here is the number of fp32, not bytes. @@ -404,18 +435,13 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index, NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index); return too_large; } -#endif bh_memcpy_s(output_tensor->buf, output_tensor->size, tensor->data.data, tensor->bytes); -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - *output_tensor_size = tensor->bytes; -#else /* * for now, maintain the bug-to-bug compatibility with the old abi, * where the size here is the number of fp32, not bytes. */ *output_tensor_size = tensor->bytes / sizeof(float); -#endif } else { // TODO: Assuming uint8 quantized networks. TfLiteAffineQuantization *quant_info = @@ -429,12 +455,6 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index, for (int i = 0; i < (int)tensor->dims->size; ++i) model_tensor_size *= (uint32_t)tensor->dims->data[i]; -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - if (output_tensor->size / sizeof(float) < model_tensor_size) { - NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index); - return too_large; - } -#else /* * for now, maintain the bug-to-bug compatibility with the old abi, * where the size here is the number of fp32, not bytes. @@ -443,7 +463,6 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index, NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index); return too_large; } -#endif uint8_t *ot = tfl_ctx->interpreters[ctx] .interpreter->typed_output_tensor(index); @@ -458,16 +477,13 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index, output_tensor_f[i] = (ot[i] - zero_point) * scale; } -#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 - *output_tensor_size = model_tensor_size * sizeof(float); -#else /* * for now, maintain the bug-to-bug compatibility with the old abi, * where the size here is the number of fp32, not bytes. */ *output_tensor_size = model_tensor_size; -#endif } +#endif return success; }