wasi-nn: make the host use the wasi_ephemeral_nn version of tensor_data (#4411)

the motivations:

* make the actual input size available to the backends.
  (currently the backends have to make a guess from shape/type.)

* make the host logic look a bit similar to wasi_ephemeral_nn.

this is a backend api/abi change.
This commit is contained in:
YAMAMOTO Takashi 2025-06-27 08:41:42 +09:00 committed by GitHub
parent 23799a2cb6
commit 2372a472aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 43 additions and 37 deletions

View File

@ -99,7 +99,7 @@ typedef enum {
// 4-byte f32 elements would have a data array of length 16). Naturally, this
// representation requires some knowledge of how to lay out data in
// memory--e.g., using row-major ordering--and could perhaps be improved.
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 && defined(__wasm__)
#if !defined(__wasm__) || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
typedef struct {
uint8_t *buf;
uint32_t size;

View File

@ -99,7 +99,8 @@ graph_builder_array_app_native(wasm_module_inst_t instance,
static wasi_nn_error
tensor_data_app_native(wasm_module_inst_t instance, uint32_t total_elements,
tensor_wasm *input_tensor_wasm, tensor_data *data)
tensor_wasm *input_tensor_wasm, void **data,
uint32_t *size)
{
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
#define data_size input_tensor_wasm->data_size
@ -113,8 +114,9 @@ tensor_data_app_native(wasm_module_inst_t instance, uint32_t total_elements,
NN_ERR_PRINTF("input_tensor_wasm->data_offset is invalid");
return invalid_argument;
}
*data = (tensor_data)wasm_runtime_addr_app_to_native(
*data = wasm_runtime_addr_app_to_native(
instance, (uint64)input_tensor_wasm->data_offset);
*size = data_size;
return success;
#undef data_size
}
@ -188,16 +190,19 @@ tensor_app_native(wasm_module_inst_t instance, tensor_wasm *input_tensor_wasm,
NN_DBG_PRINTF("Tensor type: %d", input_tensor_wasm->type);
NN_DBG_PRINTF("Total number of elements: %d", total_elements);
tensor_data data = NULL;
void *data = NULL;
uint32_t datasize;
if (success
!= (res = tensor_data_app_native(instance, total_elements,
input_tensor_wasm, &data))) {
!= (res =
tensor_data_app_native(instance, total_elements,
input_tensor_wasm, &data, &datasize))) {
wasm_runtime_free(dimensions);
return res;
}
input_tensor->type = input_tensor_wasm->type;
input_tensor->dimensions = dimensions;
input_tensor->data = data;
input_tensor->data.buf = data;
input_tensor->data.size = datasize;
return success;
}

View File

@ -720,12 +720,12 @@ fail:
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
wasi_nn_error
wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
uint32_t index, tensor_data output_tensor,
uint32_t index, void *output_tensor,
uint32_t output_tensor_len, uint32_t *output_tensor_size)
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
wasi_nn_error
wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
uint32_t index, tensor_data output_tensor,
uint32_t index, void *output_tensor,
uint32_t *output_tensor_size)
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
{
@ -753,16 +753,17 @@ wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context ctx,
goto fail;
}
tensor_data tensor = {
.buf = output_tensor,
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
.size = output_tensor_len,
#else
.size = *output_tensor_size,
#endif
};
call_wasi_nn_func(wasi_nn_ctx->backend, get_output, res,
wasi_nn_ctx->backend_ctx, ctx, index, output_tensor,
&output_tensor_len);
*output_tensor_size = output_tensor_len;
#else /* WASM_ENABLE_WASI_EPHEMERAL_NN == 0 */
call_wasi_nn_func(wasi_nn_ctx->backend, get_output, res,
wasi_nn_ctx->backend_ctx, ctx, index, output_tensor,
wasi_nn_ctx->backend_ctx, ctx, index, &tensor,
output_tensor_size);
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
fail:
unlock_ctx(wasi_nn_ctx);
return res;

View File

@ -385,7 +385,7 @@ set_input(void *ctx, graph_execution_context exec_ctx, uint32_t index,
{
struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
// tensor->data is the prompt string. ends with \0
char *prompt_text = (char *)wasi_nn_tensor->data;
char *prompt_text = (char *)wasi_nn_tensor->data.buf;
#ifndef NDEBUG
NN_DBG_PRINTF("--------------------------------------------------");
@ -552,7 +552,7 @@ fail:
__attribute__((visibility("default"))) wasi_nn_error
get_output(void *ctx, graph_execution_context exec_ctx, uint32_t index,
tensor_data output_tensor, uint32_t *output_tensor_size)
tensor_data *output_tensor, uint32_t *output_tensor_size)
{
struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
@ -571,7 +571,7 @@ get_output(void *ctx, graph_execution_context exec_ctx, uint32_t index,
printf("%s\n", output_metadata);
}
memcpy(output_tensor, output_metadata, strlen(output_metadata));
memcpy(output_tensor->buf, output_metadata, strlen(output_metadata));
*output_tensor_size = strlen(output_metadata);
return success;
}
@ -591,7 +591,7 @@ get_output(void *ctx, graph_execution_context exec_ctx, uint32_t index,
printf("%s", buf);
}
memcpy(output_tensor + end_pos, buf, strlen(buf));
memcpy(output_tensor->buf + end_pos, buf, strlen(buf));
end_pos += strlen(buf);
}

View File

@ -402,7 +402,7 @@ set_input(void *ctx, graph_execution_context exec_ctx, uint32_t index,
shape_info);
CHECK_OV_STATUS(ov_tensor_create_from_host_ptr(input_type, input_shape,
wasi_nn_tensor->data,
wasi_nn_tensor->data.buf,
&input_tensor),
ret);
}
@ -441,7 +441,7 @@ fail:
__attribute__((visibility("default"))) wasi_nn_error
get_output(void *ctx, graph_execution_context exec_ctx, uint32_t index,
tensor_data output_tensor, uint32_t *output_tensor_size)
tensor_data *output_tensor, uint32_t *output_tensor_size)
{
OpenVINOContext *ov_ctx = (OpenVINOContext *)ctx;
struct OpenVINOExecutionContext *exec;
@ -460,14 +460,14 @@ get_output(void *ctx, graph_execution_context exec_ctx, uint32_t index,
CHECK_OV_STATUS(ov_tensor_get_byte_size(ov_tensor, &byte_size), ret);
if (byte_size > *output_tensor_size) {
if (byte_size > output_tensor->size) {
ret = too_large;
goto fail;
}
CHECK_OV_STATUS(ov_tensor_data(ov_tensor, &data), ret);
memcpy(output_tensor, data, byte_size);
memcpy(output_tensor->buf, data, byte_size);
*output_tensor_size = (uint32_t)byte_size;

View File

@ -24,7 +24,7 @@ compute(void *ctx, graph_execution_context exec_ctx);
__attribute__((visibility("default"))) wasi_nn_error
get_output(void *ctx, graph_execution_context exec_ctx, uint32_t index,
tensor_data output_tensor, uint32_t *output_tensor_size);
tensor_data *output_tensor, uint32_t *output_tensor_size);
__attribute__((visibility("default"))) wasi_nn_error
init_backend(void **ctx);
@ -32,4 +32,4 @@ init_backend(void **ctx);
__attribute__((visibility("default"))) wasi_nn_error
deinit_backend(void *ctx);
#endif /* WASI_NN_OPENVINO_HPP */
#endif /* WASI_NN_OPENVINO_HPP */

View File

@ -32,7 +32,7 @@ typedef wasi_nn_error (*SET_INPUT)(void *, graph_execution_context, uint32_t,
tensor *);
typedef wasi_nn_error (*COMPUTE)(void *, graph_execution_context);
typedef wasi_nn_error (*GET_OUTPUT)(void *, graph_execution_context, uint32_t,
tensor_data, uint32_t *);
tensor_data *, uint32_t *);
/* wasi-nn general APIs */
typedef wasi_nn_error (*BACKEND_INITIALIZE)(void **);
typedef wasi_nn_error (*BACKEND_DEINITIALIZE)(void *);

View File

@ -324,7 +324,7 @@ set_input(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
index);
int size = model_tensor_size * sizeof(float);
bh_memcpy_s(it, size, input_tensor->data, size);
bh_memcpy_s(it, size, input_tensor->data.buf, size);
}
else { // TODO: Assuming uint8 quantized networks.
TfLiteAffineQuantization *quant_info =
@ -342,7 +342,7 @@ set_input(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
NN_DBG_PRINTF("input tensor: (scale, offset) = (%f, %f)", scale,
zero_point);
float *input_tensor_f = (float *)input_tensor->data;
float *input_tensor_f = (float *)input_tensor->data.buf;
for (uint32_t i = 0; i < model_tensor_size; ++i) {
it[i] = (uint8_t)(input_tensor_f[i] / scale + zero_point);
}
@ -366,7 +366,7 @@ compute(void *tflite_ctx, graph_execution_context ctx)
__attribute__((visibility("default"))) wasi_nn_error
get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
tensor_data output_tensor, uint32_t *output_tensor_size)
tensor_data *output_tensor, uint32_t *output_tensor_size)
{
TFLiteContext *tfl_ctx = (TFLiteContext *)tflite_ctx;
@ -392,7 +392,7 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
if (tensor->quantization.type == kTfLiteNoQuantization) {
NN_DBG_PRINTF("No quantization information");
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
if (*output_tensor_size < tensor->bytes) {
if (output_tensor->size < tensor->bytes) {
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
return too_large;
}
@ -401,12 +401,12 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
* for now, maintain the bug-to-bug compatibility with the old abi,
* where the size here is the number of fp32, not bytes.
*/
if (*output_tensor_size < tensor->bytes / sizeof(float)) {
if (output_tensor->size < tensor->bytes / sizeof(float)) {
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
return too_large;
}
#endif
bh_memcpy_s(output_tensor, *output_tensor_size, tensor->data.data,
bh_memcpy_s(output_tensor->buf, output_tensor->size, tensor->data.data,
tensor->bytes);
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
*output_tensor_size = tensor->bytes;
@ -431,7 +431,7 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
model_tensor_size *= (uint32_t)tensor->dims->data[i];
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
if (*output_tensor_size / sizeof(float) < model_tensor_size) {
if (output_tensor->size / sizeof(float) < model_tensor_size) {
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
return too_large;
}
@ -440,7 +440,7 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
* for now, maintain the bug-to-bug compatibility with the old abi,
* where the size here is the number of fp32, not bytes.
*/
if (*output_tensor_size < model_tensor_size) {
if (output_tensor->size < model_tensor_size) {
NN_ERR_PRINTF("Insufficient memory to copy tensor %d", index);
return too_large;
}
@ -454,7 +454,7 @@ get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
NN_DBG_PRINTF("output tensor: (scale, offset) = (%f, %f)", scale,
zero_point);
float *output_tensor_f = (float *)output_tensor;
float *output_tensor_f = (float *)output_tensor->buf;
for (uint32_t i = 0; i < model_tensor_size; ++i) {
output_tensor_f[i] = (ot[i] - zero_point) * scale;
}

View File

@ -32,7 +32,7 @@ compute(void *tflite_ctx, graph_execution_context ctx);
__attribute__((visibility("default"))) wasi_nn_error
get_output(void *tflite_ctx, graph_execution_context ctx, uint32_t index,
tensor_data output_tensor, uint32_t *output_tensor_size);
tensor_data *output_tensor, uint32_t *output_tensor_size);
__attribute__((visibility("default"))) wasi_nn_error
init_backend(void **tflite_ctx);