wasi_ephemeral_nn.h: prefix identfiers to avoid too generic names (#4358)

This commit is contained in:
YAMAMOTO Takashi 2025-06-17 12:15:01 +09:00 committed by GitHub
parent 965f2452c8
commit 20be1d33fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 133 additions and 82 deletions

View File

@ -4,4 +4,9 @@
*/ */
#define WASM_ENABLE_WASI_EPHEMERAL_NN 1 #define WASM_ENABLE_WASI_EPHEMERAL_NN 1
#define WASI_NN_NAME(name) wasi_ephemeral_nn_##name
#include "wasi_nn.h" #include "wasi_nn.h"
#undef WASM_ENABLE_WASI_EPHEMERAL_NN
#undef WASI_NN_NAME

View File

@ -34,17 +34,22 @@
* @return wasi_nn_error Execution status. * @return wasi_nn_error Execution status.
*/ */
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
wasi_nn_error WASI_NN_ERROR_TYPE
load(graph_builder *builder, uint32_t builder_len, graph_encoding encoding, WASI_NN_NAME(load)
execution_target target, graph *g) WASI_NN_IMPORT("load"); (WASI_NN_NAME(graph_builder) * builder, uint32_t builder_len,
WASI_NN_NAME(graph_encoding) encoding, WASI_NN_NAME(execution_target) target,
WASI_NN_NAME(graph) * g) WASI_NN_IMPORT("load");
#else #else
wasi_nn_error WASI_NN_ERROR_TYPE
load(graph_builder_array *builder, graph_encoding encoding, WASI_NN_NAME(load)
execution_target target, graph *g) WASI_NN_IMPORT("load"); (WASI_NN_NAME(graph_builder_array) * builder,
WASI_NN_NAME(graph_encoding) encoding, WASI_NN_NAME(execution_target) target,
WASI_NN_NAME(graph) * g) WASI_NN_IMPORT("load");
#endif #endif
wasi_nn_error WASI_NN_ERROR_TYPE
load_by_name(const char *name, uint32_t name_len, graph *g) WASI_NN_NAME(load_by_name)
(const char *name, uint32_t name_len, WASI_NN_NAME(graph) * g)
WASI_NN_IMPORT("load_by_name"); WASI_NN_IMPORT("load_by_name");
/** /**
@ -59,8 +64,9 @@ load_by_name(const char *name, uint32_t name_len, graph *g)
* @param ctx Execution context. * @param ctx Execution context.
* @return wasi_nn_error Execution status. * @return wasi_nn_error Execution status.
*/ */
wasi_nn_error WASI_NN_ERROR_TYPE
init_execution_context(graph g, graph_execution_context *ctx) WASI_NN_NAME(init_execution_context)
(WASI_NN_NAME(graph) g, WASI_NN_NAME(graph_execution_context) * ctx)
WASI_NN_IMPORT("init_execution_context"); WASI_NN_IMPORT("init_execution_context");
/** /**
@ -71,9 +77,10 @@ init_execution_context(graph g, graph_execution_context *ctx)
* @param tensor Input tensor. * @param tensor Input tensor.
* @return wasi_nn_error Execution status. * @return wasi_nn_error Execution status.
*/ */
wasi_nn_error WASI_NN_ERROR_TYPE
set_input(graph_execution_context ctx, uint32_t index, tensor *tensor) WASI_NN_NAME(set_input)
WASI_NN_IMPORT("set_input"); (WASI_NN_NAME(graph_execution_context) ctx, uint32_t index,
WASI_NN_NAME(tensor) * tensor) WASI_NN_IMPORT("set_input");
/** /**
* @brief Compute the inference on the given inputs. * @brief Compute the inference on the given inputs.
@ -81,8 +88,9 @@ set_input(graph_execution_context ctx, uint32_t index, tensor *tensor)
* @param ctx Execution context. * @param ctx Execution context.
* @return wasi_nn_error Execution status. * @return wasi_nn_error Execution status.
*/ */
wasi_nn_error WASI_NN_ERROR_TYPE
compute(graph_execution_context ctx) WASI_NN_IMPORT("compute"); WASI_NN_NAME(compute)
(WASI_NN_NAME(graph_execution_context) ctx) WASI_NN_IMPORT("compute");
/** /**
* @brief Extract the outputs after inference. * @brief Extract the outputs after inference.
@ -97,14 +105,16 @@ compute(graph_execution_context ctx) WASI_NN_IMPORT("compute");
* @return wasi_nn_error Execution status. * @return wasi_nn_error Execution status.
*/ */
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
wasi_nn_error WASI_NN_ERROR_TYPE
get_output(graph_execution_context ctx, uint32_t index, WASI_NN_NAME(get_output)
tensor_data output_tensor, uint32_t output_tensor_max_size, (WASI_NN_NAME(graph_execution_context) ctx, uint32_t index,
uint32_t *output_tensor_size) WASI_NN_IMPORT("get_output"); WASI_NN_NAME(tensor_data) output_tensor, uint32_t output_tensor_max_size,
uint32_t *output_tensor_size) WASI_NN_IMPORT("get_output");
#else #else
wasi_nn_error WASI_NN_ERROR_TYPE
get_output(graph_execution_context ctx, uint32_t index, WASI_NN_NAME(get_output)
tensor_data output_tensor, uint32_t *output_tensor_size) (graph_execution_context ctx, uint32_t index,
WASI_NN_NAME(tensor_data) output_tensor, uint32_t *output_tensor_size)
WASI_NN_IMPORT("get_output"); WASI_NN_IMPORT("get_output");
#endif #endif

View File

@ -13,6 +13,23 @@
extern "C" { extern "C" {
#endif #endif
/* our host logic doesn't use any prefix. neither legacy wasi_nn.h does. */
#if !defined(__wasm__) || !defined(WASI_NN_NAME)
#define WASI_NN_NAME(name) name
#define WASI_NN_ERROR_NAME(name) name
#define WASI_NN_TYPE_NAME(name) name
#define WASI_NN_ENCODING_NAME(name) name
#define WASI_NN_TARGET_NAME(name) name
#define WASI_NN_ERROR_TYPE wasi_nn_error
#else
#define WASI_NN_ERROR_NAME(name) WASI_NN_NAME(error_##name)
#define WASI_NN_TYPE_NAME(name) WASI_NN_NAME(type_##name)
#define WASI_NN_ENCODING_NAME(name) WASI_NN_NAME(encoding_##name)
#define WASI_NN_TARGET_NAME(name) WASI_NN_NAME(target_##name)
#define WASI_NN_ERROR_TYPE WASI_NN_NAME(error);
#endif
/** /**
* ERRORS * ERRORS
* *
@ -22,22 +39,22 @@ extern "C" {
// https://github.com/WebAssembly/wasi-nn/blob/71320d95b8c6d43f9af7f44e18b1839db85d89b4/wasi-nn.witx#L5-L17 // https://github.com/WebAssembly/wasi-nn/blob/71320d95b8c6d43f9af7f44e18b1839db85d89b4/wasi-nn.witx#L5-L17
// Error codes returned by functions in this API. // Error codes returned by functions in this API.
typedef enum { typedef enum {
success = 0, WASI_NN_ERROR_NAME(success) = 0,
invalid_argument, WASI_NN_ERROR_NAME(invalid_argument),
invalid_encoding, WASI_NN_ERROR_NAME(invalid_encoding),
missing_memory, WASI_NN_ERROR_NAME(missing_memory),
busy, WASI_NN_ERROR_NAME(busy),
runtime_error, WASI_NN_ERROR_NAME(runtime_error),
unsupported_operation, WASI_NN_ERROR_NAME(unsupported_operation),
too_large, WASI_NN_ERROR_NAME(too_large),
not_found, WASI_NN_ERROR_NAME(not_found),
// for WasmEdge-wasi-nn // for WasmEdge-wasi-nn
end_of_sequence = 100, // End of Sequence Found. WASI_NN_ERROR_NAME(end_of_sequence) = 100, // End of Sequence Found.
context_full = 101, // Context Full. WASI_NN_ERROR_NAME(context_full) = 101, // Context Full.
prompt_tool_long = 102, // Prompt Too Long. WASI_NN_ERROR_NAME(prompt_tool_long) = 102, // Prompt Too Long.
model_not_found = 103, // Model Not Found. WASI_NN_ERROR_NAME(model_not_found) = 103, // Model Not Found.
} wasi_nn_error; } WASI_NN_ERROR_TYPE;
/** /**
* TENSOR * TENSOR
@ -51,15 +68,27 @@ typedef enum {
typedef struct { typedef struct {
uint32_t *buf; uint32_t *buf;
uint32_t size; uint32_t size;
} tensor_dimensions; } WASI_NN_NAME(tensor_dimensions);
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
// sync up with // sync up with
// https://github.com/WebAssembly/wasi-nn/blob/71320d95b8c6d43f9af7f44e18b1839db85d89b4/wasi-nn.witx#L19-L28 // https://github.com/WebAssembly/wasi-nn/blob/71320d95b8c6d43f9af7f44e18b1839db85d89b4/wasi-nn.witx#L19-L28
// The type of the elements in a tensor. // The type of the elements in a tensor.
typedef enum { fp16 = 0, fp32, fp64, u8, i32, i64 } tensor_type; typedef enum {
WASI_NN_TYPE_NAME(fp16) = 0,
WASI_NN_TYPE_NAME(fp32),
WASI_NN_TYPE_NAME(fp64),
WASI_NN_TYPE_NAME(u8),
WASI_NN_TYPE_NAME(i32),
WASI_NN_TYPE_NAME(i64),
} WASI_NN_NAME(tensor_type);
#else #else
typedef enum { fp16 = 0, fp32, up8, ip32 } tensor_type; typedef enum {
WASI_NN_TYPE_NAME(fp16) = 0,
WASI_NN_TYPE_NAME(fp32),
WASI_NN_TYPE_NAME(up8),
WASI_NN_TYPE_NAME(ip32),
} WASI_NN_NAME(tensor_type);
#endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */ #endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */
// The tensor data. // The tensor data.
@ -70,7 +99,7 @@ typedef enum { fp16 = 0, fp32, up8, ip32 } tensor_type;
// 4-byte f32 elements would have a data array of length 16). Naturally, this // 4-byte f32 elements would have a data array of length 16). Naturally, this
// representation requires some knowledge of how to lay out data in // representation requires some knowledge of how to lay out data in
// memory--e.g., using row-major ordering--and could perhaps be improved. // memory--e.g., using row-major ordering--and could perhaps be improved.
typedef uint8_t *tensor_data; typedef uint8_t *WASI_NN_NAME(tensor_data);
// A tensor. // A tensor.
typedef struct { typedef struct {
@ -78,16 +107,16 @@ typedef struct {
// represent a tensor containing a single value, use `[1]` for the tensor // represent a tensor containing a single value, use `[1]` for the tensor
// dimensions. // dimensions.
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 && defined(__wasm__) #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 && defined(__wasm__)
tensor_dimensions dimensions; WASI_NN_NAME(tensor_dimensions) dimensions;
#else #else
tensor_dimensions *dimensions; WASI_NN_NAME(tensor_dimensions) * dimensions;
#endif #endif
// Describe the type of element in the tensor (e.g., f32). // Describe the type of element in the tensor (e.g., f32).
uint8_t type; uint8_t type;
uint8_t _pad[3]; uint8_t _pad[3];
// Contains the tensor data. // Contains the tensor data.
tensor_data data; WASI_NN_NAME(tensor_data) data;
} tensor; } WASI_NN_NAME(tensor);
/** /**
* GRAPH * GRAPH
@ -102,15 +131,15 @@ typedef struct {
typedef struct { typedef struct {
uint8_t *buf; uint8_t *buf;
uint32_t size; uint32_t size;
} graph_builder; } WASI_NN_NAME(graph_builder);
typedef struct { typedef struct {
graph_builder *buf; WASI_NN_NAME(graph_builder) * buf;
uint32_t size; uint32_t size;
} graph_builder_array; } WASI_NN_NAME(graph_builder_array);
// An execution graph for performing inference (i.e., a model). // An execution graph for performing inference (i.e., a model).
typedef uint32_t graph; typedef uint32_t WASI_NN_NAME(graph);
// sync up with // sync up with
// https://github.com/WebAssembly/wasi-nn/blob/main/wit/wasi-nn.wit#L75 // https://github.com/WebAssembly/wasi-nn/blob/main/wit/wasi-nn.wit#L75
@ -118,21 +147,25 @@ typedef uint32_t graph;
// various backends that encode (i.e., serialize) their graph IR with different // various backends that encode (i.e., serialize) their graph IR with different
// formats. // formats.
typedef enum { typedef enum {
openvino = 0, WASI_NN_ENCODING_NAME(openvino) = 0,
onnx, WASI_NN_ENCODING_NAME(onnx),
tensorflow, WASI_NN_ENCODING_NAME(tensorflow),
pytorch, WASI_NN_ENCODING_NAME(pytorch),
tensorflowlite, WASI_NN_ENCODING_NAME(tensorflowlite),
ggml, WASI_NN_ENCODING_NAME(ggml),
autodetect, WASI_NN_ENCODING_NAME(autodetect),
unknown_backend, WASI_NN_ENCODING_NAME(unknown_backend),
} graph_encoding; } WASI_NN_NAME(graph_encoding);
// Define where the graph should be executed. // Define where the graph should be executed.
typedef enum execution_target { cpu = 0, gpu, tpu } execution_target; typedef enum WASI_NN_NAME(execution_target) {
WASI_NN_TARGET_NAME(cpu) = 0,
WASI_NN_TARGET_NAME(gpu),
WASI_NN_TARGET_NAME(tpu),
} WASI_NN_NAME(execution_target);
// Bind a `graph` to the input and output tensors for an inference. // Bind a `graph` to the input and output tensors for an inference.
typedef uint32_t graph_execution_context; typedef uint32_t WASI_NN_NAME(graph_execution_context);
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@ -93,7 +93,7 @@ print_result(const float *result, size_t sz)
int int
main(int argc, char **argv) main(int argc, char **argv)
{ {
wasi_nn_error nnret; wasi_ephemeral_nn_error nnret;
int ret; int ret;
void *xml; void *xml;
size_t xmlsz; size_t xmlsz;
@ -112,25 +112,27 @@ main(int argc, char **argv)
exit(1); exit(1);
} }
/* note: openvino takes two buffers, namely IR and weights */ /* note: openvino takes two buffers, namely IR and weights */
graph_builder builders[2] = { { wasi_ephemeral_nn_graph_builder builders[2] = { {
.buf = xml, .buf = xml,
.size = xmlsz, .size = xmlsz,
}, },
{ {
.buf = weights, .buf = weights,
.size = weightssz, .size = weightssz,
} }; } };
graph g; wasi_ephemeral_nn_graph g;
nnret = load(builders, 2, openvino, cpu, &g); nnret =
wasi_ephemeral_nn_load(builders, 2, wasi_ephemeral_nn_encoding_openvino,
wasi_ephemeral_nn_target_cpu, &g);
unmap_file(xml, xmlsz); unmap_file(xml, xmlsz);
unmap_file(weights, weightssz); unmap_file(weights, weightssz);
if (nnret != success) { if (nnret != wasi_ephemeral_nn_error_success) {
fprintf(stderr, "load failed with %d\n", (int)nnret); fprintf(stderr, "load failed with %d\n", (int)nnret);
exit(1); exit(1);
} }
graph_execution_context ctx; wasi_ephemeral_nn_graph_execution_context ctx;
nnret = init_execution_context(g, &ctx); nnret = wasi_ephemeral_nn_init_execution_context(g, &ctx);
if (nnret != success) { if (nnret != wasi_ephemeral_nn_error_success) {
fprintf(stderr, "init_execution_context failed with %d\n", (int)nnret); fprintf(stderr, "init_execution_context failed with %d\n", (int)nnret);
exit(1); exit(1);
} }
@ -142,26 +144,27 @@ main(int argc, char **argv)
strerror(ret)); strerror(ret));
exit(1); exit(1);
} }
tensor tensor = { wasi_ephemeral_nn_tensor tensor = {
.dimensions = { .buf = (uint32_t[]){1, 3, 224, 224,}, .size = 4, }, .dimensions = { .buf = (uint32_t[]){1, 3, 224, 224,}, .size = 4, },
.type = fp32, .type = wasi_ephemeral_nn_type_fp32,
.data = tensordata, .data = tensordata,
}; };
nnret = set_input(ctx, 0, &tensor); nnret = wasi_ephemeral_nn_set_input(ctx, 0, &tensor);
unmap_file(tensordata, tensordatasz); unmap_file(tensordata, tensordatasz);
if (nnret != success) { if (nnret != wasi_ephemeral_nn_error_success) {
fprintf(stderr, "set_input failed with %d\n", (int)nnret); fprintf(stderr, "set_input failed with %d\n", (int)nnret);
exit(1); exit(1);
} }
nnret = compute(ctx); nnret = wasi_ephemeral_nn_compute(ctx);
if (nnret != success) { if (nnret != wasi_ephemeral_nn_error_success) {
fprintf(stderr, "compute failed with %d\n", (int)nnret); fprintf(stderr, "compute failed with %d\n", (int)nnret);
exit(1); exit(1);
} }
float result[1001]; float result[1001];
uint32_t resultsz; uint32_t resultsz;
nnret = get_output(ctx, 0, (void *)result, sizeof(result), &resultsz); nnret = wasi_ephemeral_nn_get_output(ctx, 0, (void *)result, sizeof(result),
if (nnret != success) { &resultsz);
if (nnret != wasi_ephemeral_nn_error_success) {
fprintf(stderr, "get_output failed with %d\n", (int)nnret); fprintf(stderr, "get_output failed with %d\n", (int)nnret);
exit(1); exit(1);
} }