wasi_nn.h: make this compatible with wasi_ephemeral_nn (#4330)

- wasi_nn.h: make this compatible with wasi_ephemeral_nn
cf. https://github.com/bytecodealliance/wasm-micro-runtime/issues/4323

- fix WASM_ENABLE_WASI_EPHEMERAL_NN build
this structure is used by host logic as well.
ideally definitions for wasm and host should be separated.
until it happens, check __wasm__ to avoid the breakage.
This commit is contained in:
YAMAMOTO Takashi 2025-06-09 12:36:05 +09:00 committed by GitHub
parent 99c75b53db
commit 4d6b8dcd5d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 0 deletions

View File

@ -15,21 +15,33 @@
#include <stdint.h>
#include "wasi_nn_types.h"
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
#define WASI_NN_IMPORT(name) \
__attribute__((import_module("wasi_ephemeral_nn"), import_name(name)))
#else
#define WASI_NN_IMPORT(name) \
__attribute__((import_module("wasi_nn"), import_name(name)))
#endif
/**
* @brief Load an opaque sequence of bytes to use for inference.
*
* @param builder Model builder.
* @param builder_len The size of model builder.
* @param encoding Model encoding.
* @param target Execution target.
* @param g Graph.
* @return wasi_nn_error Execution status.
*/
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
wasi_nn_error
load(graph_builder *builder, uint32_t builder_len, graph_encoding encoding,
execution_target target, graph *g) WASI_NN_IMPORT("load");
#else
wasi_nn_error
load(graph_builder_array *builder, graph_encoding encoding,
execution_target target, graph *g) WASI_NN_IMPORT("load");
#endif
wasi_nn_error
load_by_name(const char *name, uint32_t name_len, graph *g)
@ -84,9 +96,16 @@ compute(graph_execution_context ctx) WASI_NN_IMPORT("compute");
* copied number of bytes.
* @return wasi_nn_error Execution status.
*/
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
wasi_nn_error
get_output(graph_execution_context ctx, uint32_t index,
tensor_data output_tensor, uint32_t output_tensor_max_size,
uint32_t *output_tensor_size) WASI_NN_IMPORT("get_output");
#else
wasi_nn_error
get_output(graph_execution_context ctx, uint32_t index,
tensor_data output_tensor, uint32_t *output_tensor_size)
WASI_NN_IMPORT("get_output");
#endif
#endif

View File

@ -77,7 +77,11 @@ typedef struct {
// Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To
// represent a tensor containing a single value, use `[1]` for the tensor
// dimensions.
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 && defined(__wasm__)
tensor_dimensions dimensions;
#else
tensor_dimensions *dimensions;
#endif
// Describe the type of element in the tensor (e.g., f32).
uint8_t type;
uint8_t _pad[3];