mirror of
https://github.com/bytecodealliance/wasm-micro-runtime.git
synced 2025-09-06 09:51:27 +00:00
wasi_nn.h: make this compatible with wasi_ephemeral_nn (#4330)
- wasi_nn.h: make this compatible with wasi_ephemeral_nn cf. https://github.com/bytecodealliance/wasm-micro-runtime/issues/4323 - fix WASM_ENABLE_WASI_EPHEMERAL_NN build this structure is used by host logic as well. ideally definitions for wasm and host should be separated. until it happens, check __wasm__ to avoid the breakage.
This commit is contained in:
parent
99c75b53db
commit
4d6b8dcd5d
|
@ -15,21 +15,33 @@
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include "wasi_nn_types.h"
|
#include "wasi_nn_types.h"
|
||||||
|
|
||||||
|
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||||
|
#define WASI_NN_IMPORT(name) \
|
||||||
|
__attribute__((import_module("wasi_ephemeral_nn"), import_name(name)))
|
||||||
|
#else
|
||||||
#define WASI_NN_IMPORT(name) \
|
#define WASI_NN_IMPORT(name) \
|
||||||
__attribute__((import_module("wasi_nn"), import_name(name)))
|
__attribute__((import_module("wasi_nn"), import_name(name)))
|
||||||
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Load an opaque sequence of bytes to use for inference.
|
* @brief Load an opaque sequence of bytes to use for inference.
|
||||||
*
|
*
|
||||||
* @param builder Model builder.
|
* @param builder Model builder.
|
||||||
|
* @param builder_len The size of model builder.
|
||||||
* @param encoding Model encoding.
|
* @param encoding Model encoding.
|
||||||
* @param target Execution target.
|
* @param target Execution target.
|
||||||
* @param g Graph.
|
* @param g Graph.
|
||||||
* @return wasi_nn_error Execution status.
|
* @return wasi_nn_error Execution status.
|
||||||
*/
|
*/
|
||||||
|
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||||
|
wasi_nn_error
|
||||||
|
load(graph_builder *builder, uint32_t builder_len, graph_encoding encoding,
|
||||||
|
execution_target target, graph *g) WASI_NN_IMPORT("load");
|
||||||
|
#else
|
||||||
wasi_nn_error
|
wasi_nn_error
|
||||||
load(graph_builder_array *builder, graph_encoding encoding,
|
load(graph_builder_array *builder, graph_encoding encoding,
|
||||||
execution_target target, graph *g) WASI_NN_IMPORT("load");
|
execution_target target, graph *g) WASI_NN_IMPORT("load");
|
||||||
|
#endif
|
||||||
|
|
||||||
wasi_nn_error
|
wasi_nn_error
|
||||||
load_by_name(const char *name, uint32_t name_len, graph *g)
|
load_by_name(const char *name, uint32_t name_len, graph *g)
|
||||||
|
@ -84,9 +96,16 @@ compute(graph_execution_context ctx) WASI_NN_IMPORT("compute");
|
||||||
* copied number of bytes.
|
* copied number of bytes.
|
||||||
* @return wasi_nn_error Execution status.
|
* @return wasi_nn_error Execution status.
|
||||||
*/
|
*/
|
||||||
|
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||||
|
wasi_nn_error
|
||||||
|
get_output(graph_execution_context ctx, uint32_t index,
|
||||||
|
tensor_data output_tensor, uint32_t output_tensor_max_size,
|
||||||
|
uint32_t *output_tensor_size) WASI_NN_IMPORT("get_output");
|
||||||
|
#else
|
||||||
wasi_nn_error
|
wasi_nn_error
|
||||||
get_output(graph_execution_context ctx, uint32_t index,
|
get_output(graph_execution_context ctx, uint32_t index,
|
||||||
tensor_data output_tensor, uint32_t *output_tensor_size)
|
tensor_data output_tensor, uint32_t *output_tensor_size)
|
||||||
WASI_NN_IMPORT("get_output");
|
WASI_NN_IMPORT("get_output");
|
||||||
|
#endif
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -77,7 +77,11 @@ typedef struct {
|
||||||
// Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To
|
// Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To
|
||||||
// represent a tensor containing a single value, use `[1]` for the tensor
|
// represent a tensor containing a single value, use `[1]` for the tensor
|
||||||
// dimensions.
|
// dimensions.
|
||||||
|
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 && defined(__wasm__)
|
||||||
|
tensor_dimensions dimensions;
|
||||||
|
#else
|
||||||
tensor_dimensions *dimensions;
|
tensor_dimensions *dimensions;
|
||||||
|
#endif
|
||||||
// Describe the type of element in the tensor (e.g., f32).
|
// Describe the type of element in the tensor (e.g., f32).
|
||||||
uint8_t type;
|
uint8_t type;
|
||||||
uint8_t _pad[3];
|
uint8_t _pad[3];
|
||||||
|
|
Loading…
Reference in New Issue
Block a user