mirror of
https://github.com/bytecodealliance/wasm-micro-runtime.git
synced 2025-09-06 01:41:35 +00:00
wasi_nn_tensorflowlite.cpp: fix get_output return size (#4390)
it should be byte size, not the number of (fp32) values. i'm ambivalent about how to deal with the compatibility for the legacy wamr-specific "wasi_nn". for now, i avoided changing it. (so that existing tests using the legacy abi, namely test_tensorflow.c and test_tensorflow_quantized.c, passes as they are.) if we have any users who still want to use the legacy abi, i suppose they consider the compatibility is more important than the consistency with other backends. cf. https://github.com/bytecodealliance/wasm-micro-runtime/issues/4376
This commit is contained in:
parent
70c39bae77
commit
8289452abb
|
@ -389,23 +389,34 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
|
|||
return too_large;
|
||||
}
|
||||
|
||||
uint32_t model_tensor_size = 1;
|
||||
for (int i = 0; i < (int)tensor->dims->size; ++i)
|
||||
model_tensor_size *= (uint32_t)tensor->dims->data[i];
|
||||
|
||||
if (*output_tensor_size < model_tensor_size) {
|
||||
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
|
||||
return too_large;
|
||||
}
|
||||
|
||||
if (tensor->quantization.type == kTfLiteNoQuantization) {
|
||||
NN_DBG_PRINTF("No quantization information");
|
||||
float *ot =
|
||||
tfl_ctx->interpreters[ctx].interpreter->typed_output_tensor<float>(
|
||||
index);
|
||||
|
||||
int size = model_tensor_size * sizeof(float);
|
||||
bh_memcpy_s(output_tensor, size, ot, size);
|
||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
if (*output_tensor_size < tensor->bytes) {
|
||||
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
|
||||
return too_large;
|
||||
}
|
||||
#else
|
||||
/*
|
||||
* for now, maintain the bug-to-bug compatibility with the old abi,
|
||||
* where the size here is the number of fp32, not bytes.
|
||||
*/
|
||||
if (*output_tensor_size < tensor->bytes / sizeof(float)) {
|
||||
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
|
||||
return too_large;
|
||||
}
|
||||
#endif
|
||||
bh_memcpy_s(output_tensor, *output_tensor_size, tensor->data.data,
|
||||
tensor->bytes);
|
||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
*output_tensor_size = tensor->bytes;
|
||||
#else
|
||||
/*
|
||||
* for now, maintain the bug-to-bug compatibility with the old abi,
|
||||
* where the size here is the number of fp32, not bytes.
|
||||
*/
|
||||
*output_tensor_size = tensor->bytes / sizeof(float);
|
||||
#endif
|
||||
}
|
||||
else { // TODO: Assuming uint8 quantized networks.
|
||||
TfLiteAffineQuantization *quant_info =
|
||||
|
@ -414,6 +425,27 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
|
|||
NN_ERR_PRINTF("Quantization per channel is not supported");
|
||||
return runtime_error;
|
||||
}
|
||||
|
||||
uint32_t model_tensor_size = 1;
|
||||
for (int i = 0; i < (int)tensor->dims->size; ++i)
|
||||
model_tensor_size *= (uint32_t)tensor->dims->data[i];
|
||||
|
||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
if (*output_tensor_size / sizeof(float) < model_tensor_size) {
|
||||
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
|
||||
return too_large;
|
||||
}
|
||||
#else
|
||||
/*
|
||||
* for now, maintain the bug-to-bug compatibility with the old abi,
|
||||
* where the size here is the number of fp32, not bytes.
|
||||
*/
|
||||
if (*output_tensor_size < model_tensor_size) {
|
||||
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
|
||||
return too_large;
|
||||
}
|
||||
#endif
|
||||
|
||||
uint8_t *ot = tfl_ctx->interpreters[ctx]
|
||||
.interpreter->typed_output_tensor<uint8_t>(index);
|
||||
|
||||
|
@ -426,9 +458,18 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
|
|||
for (uint32_t i = 0; i < model_tensor_size; ++i) {
|
||||
output_tensor_f[i] = (ot[i] - zero_point) * scale;
|
||||
}
|
||||
|
||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
*output_tensor_size = model_tensor_size * sizeof(float);
|
||||
#else
|
||||
/*
|
||||
* for now, maintain the bug-to-bug compatibility with the old abi,
|
||||
* where the size here is the number of fp32, not bytes.
|
||||
*/
|
||||
*output_tensor_size = model_tensor_size;
|
||||
#endif
|
||||
}
|
||||
|
||||
*output_tensor_size = model_tensor_size;
|
||||
return success;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user