diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp b/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp index 9c458179a..c9064a5ec 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_tensorflowlite.cpp @@ -389,23 +389,34 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index, return too_large; } - uint32_t model_tensor_size = 1; - for (int i = 0; i < (int)tensor->dims->size; ++i) - model_tensor_size *= (uint32_t)tensor->dims->data[i]; - - if (*output_tensor_size < model_tensor_size) { - NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index); - return too_large; - } - if (tensor->quantization.type == kTfLiteNoQuantization) { NN_DBG_PRINTF("No quantization information"); - float *ot = - tfl_ctx->interpreters[ctx].interpreter->typed_output_tensor( - index); - - int size = model_tensor_size * sizeof(float); - bh_memcpy_s(output_tensor, size, ot, size); +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + if (*output_tensor_size < tensor->bytes) { + NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index); + return too_large; + } +#else + /* + * for now, maintain the bug-to-bug compatibility with the old abi, + * where the size here is the number of fp32, not bytes. + */ + if (*output_tensor_size < tensor->bytes / sizeof(float)) { + NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index); + return too_large; + } +#endif + bh_memcpy_s(output_tensor, *output_tensor_size, tensor->data.data, + tensor->bytes); +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + *output_tensor_size = tensor->bytes; +#else + /* + * for now, maintain the bug-to-bug compatibility with the old abi, + * where the size here is the number of fp32, not bytes. + */ + *output_tensor_size = tensor->bytes / sizeof(float); +#endif } else { // TODO: Assuming uint8 quantized networks. TfLiteAffineQuantization *quant_info = @@ -414,6 +425,27 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index, NN_ERR_PRINTF("Quantization per channel is not supported"); return runtime_error; } + + uint32_t model_tensor_size = 1; + for (int i = 0; i < (int)tensor->dims->size; ++i) + model_tensor_size *= (uint32_t)tensor->dims->data[i]; + +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + if (*output_tensor_size / sizeof(float) < model_tensor_size) { + NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index); + return too_large; + } +#else + /* + * for now, maintain the bug-to-bug compatibility with the old abi, + * where the size here is the number of fp32, not bytes. + */ + if (*output_tensor_size < model_tensor_size) { + NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index); + return too_large; + } +#endif + uint8_t *ot = tfl_ctx->interpreters[ctx] .interpreter->typed_output_tensor(index); @@ -426,9 +458,18 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index, for (uint32_t i = 0; i < model_tensor_size; ++i) { output_tensor_f[i] = (ot[i] - zero_point) * scale; } + +#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 + *output_tensor_size = model_tensor_size * sizeof(float); +#else + /* + * for now, maintain the bug-to-bug compatibility with the old abi, + * where the size here is the number of fp32, not bytes. + */ + *output_tensor_size = model_tensor_size; +#endif } - *output_tensor_size = model_tensor_size; return success; }