diff --git a/core/iwasm/libraries/wasi-nn/include/wasi_ephemeral_nn.h b/core/iwasm/libraries/wasi-nn/include/wasi_ephemeral_nn.h index 6a4901afc..f76295a1e 100644 --- a/core/iwasm/libraries/wasi-nn/include/wasi_ephemeral_nn.h +++ b/core/iwasm/libraries/wasi-nn/include/wasi_ephemeral_nn.h @@ -4,4 +4,9 @@ */ #define WASM_ENABLE_WASI_EPHEMERAL_NN 1 +#define WASI_NN_NAME(name) wasi_ephemeral_nn_##name + #include "wasi_nn.h" + +#undef WASM_ENABLE_WASI_EPHEMERAL_NN +#undef WASI_NN_NAME diff --git a/core/iwasm/libraries/wasi-nn/include/wasi_nn.h b/core/iwasm/libraries/wasi-nn/include/wasi_nn.h index 35b2d9bf0..48ffe1263 100644 --- a/core/iwasm/libraries/wasi-nn/include/wasi_nn.h +++ b/core/iwasm/libraries/wasi-nn/include/wasi_nn.h @@ -34,17 +34,22 @@ * @return wasi_nn_error Execution status. */ #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 -wasi_nn_error -load(graph_builder *builder, uint32_t builder_len, graph_encoding encoding, - execution_target target, graph *g) WASI_NN_IMPORT("load"); +WASI_NN_ERROR_TYPE +WASI_NN_NAME(load) +(WASI_NN_NAME(graph_builder) * builder, uint32_t builder_len, + WASI_NN_NAME(graph_encoding) encoding, WASI_NN_NAME(execution_target) target, + WASI_NN_NAME(graph) * g) WASI_NN_IMPORT("load"); #else -wasi_nn_error -load(graph_builder_array *builder, graph_encoding encoding, - execution_target target, graph *g) WASI_NN_IMPORT("load"); +WASI_NN_ERROR_TYPE +WASI_NN_NAME(load) +(WASI_NN_NAME(graph_builder_array) * builder, + WASI_NN_NAME(graph_encoding) encoding, WASI_NN_NAME(execution_target) target, + WASI_NN_NAME(graph) * g) WASI_NN_IMPORT("load"); #endif -wasi_nn_error -load_by_name(const char *name, uint32_t name_len, graph *g) +WASI_NN_ERROR_TYPE +WASI_NN_NAME(load_by_name) +(const char *name, uint32_t name_len, WASI_NN_NAME(graph) * g) WASI_NN_IMPORT("load_by_name"); /** @@ -59,8 +64,9 @@ load_by_name(const char *name, uint32_t name_len, graph *g) * @param ctx Execution context. * @return wasi_nn_error Execution status. */ -wasi_nn_error -init_execution_context(graph g, graph_execution_context *ctx) +WASI_NN_ERROR_TYPE +WASI_NN_NAME(init_execution_context) +(WASI_NN_NAME(graph) g, WASI_NN_NAME(graph_execution_context) * ctx) WASI_NN_IMPORT("init_execution_context"); /** @@ -71,9 +77,10 @@ init_execution_context(graph g, graph_execution_context *ctx) * @param tensor Input tensor. * @return wasi_nn_error Execution status. */ -wasi_nn_error -set_input(graph_execution_context ctx, uint32_t index, tensor *tensor) - WASI_NN_IMPORT("set_input"); +WASI_NN_ERROR_TYPE +WASI_NN_NAME(set_input) +(WASI_NN_NAME(graph_execution_context) ctx, uint32_t index, + WASI_NN_NAME(tensor) * tensor) WASI_NN_IMPORT("set_input"); /** * @brief Compute the inference on the given inputs. @@ -81,8 +88,9 @@ set_input(graph_execution_context ctx, uint32_t index, tensor *tensor) * @param ctx Execution context. * @return wasi_nn_error Execution status. */ -wasi_nn_error -compute(graph_execution_context ctx) WASI_NN_IMPORT("compute"); +WASI_NN_ERROR_TYPE +WASI_NN_NAME(compute) +(WASI_NN_NAME(graph_execution_context) ctx) WASI_NN_IMPORT("compute"); /** * @brief Extract the outputs after inference. @@ -97,14 +105,16 @@ compute(graph_execution_context ctx) WASI_NN_IMPORT("compute"); * @return wasi_nn_error Execution status. */ #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 -wasi_nn_error -get_output(graph_execution_context ctx, uint32_t index, - tensor_data output_tensor, uint32_t output_tensor_max_size, - uint32_t *output_tensor_size) WASI_NN_IMPORT("get_output"); +WASI_NN_ERROR_TYPE +WASI_NN_NAME(get_output) +(WASI_NN_NAME(graph_execution_context) ctx, uint32_t index, + WASI_NN_NAME(tensor_data) output_tensor, uint32_t output_tensor_max_size, + uint32_t *output_tensor_size) WASI_NN_IMPORT("get_output"); #else -wasi_nn_error -get_output(graph_execution_context ctx, uint32_t index, - tensor_data output_tensor, uint32_t *output_tensor_size) +WASI_NN_ERROR_TYPE +WASI_NN_NAME(get_output) +(graph_execution_context ctx, uint32_t index, + WASI_NN_NAME(tensor_data) output_tensor, uint32_t *output_tensor_size) WASI_NN_IMPORT("get_output"); #endif diff --git a/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h b/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h index c66e781a7..7980197b7 100644 --- a/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h +++ b/core/iwasm/libraries/wasi-nn/include/wasi_nn_types.h @@ -13,6 +13,23 @@ extern "C" { #endif +/* our host logic doesn't use any prefix. neither legacy wasi_nn.h does. */ + +#if !defined(__wasm__) || !defined(WASI_NN_NAME) +#define WASI_NN_NAME(name) name +#define WASI_NN_ERROR_NAME(name) name +#define WASI_NN_TYPE_NAME(name) name +#define WASI_NN_ENCODING_NAME(name) name +#define WASI_NN_TARGET_NAME(name) name +#define WASI_NN_ERROR_TYPE wasi_nn_error +#else +#define WASI_NN_ERROR_NAME(name) WASI_NN_NAME(error_##name) +#define WASI_NN_TYPE_NAME(name) WASI_NN_NAME(type_##name) +#define WASI_NN_ENCODING_NAME(name) WASI_NN_NAME(encoding_##name) +#define WASI_NN_TARGET_NAME(name) WASI_NN_NAME(target_##name) +#define WASI_NN_ERROR_TYPE WASI_NN_NAME(error); +#endif + /** * ERRORS * @@ -22,22 +39,22 @@ extern "C" { // https://github.com/WebAssembly/wasi-nn/blob/71320d95b8c6d43f9af7f44e18b1839db85d89b4/wasi-nn.witx#L5-L17 // Error codes returned by functions in this API. typedef enum { - success = 0, - invalid_argument, - invalid_encoding, - missing_memory, - busy, - runtime_error, - unsupported_operation, - too_large, - not_found, + WASI_NN_ERROR_NAME(success) = 0, + WASI_NN_ERROR_NAME(invalid_argument), + WASI_NN_ERROR_NAME(invalid_encoding), + WASI_NN_ERROR_NAME(missing_memory), + WASI_NN_ERROR_NAME(busy), + WASI_NN_ERROR_NAME(runtime_error), + WASI_NN_ERROR_NAME(unsupported_operation), + WASI_NN_ERROR_NAME(too_large), + WASI_NN_ERROR_NAME(not_found), // for WasmEdge-wasi-nn - end_of_sequence = 100, // End of Sequence Found. - context_full = 101, // Context Full. - prompt_tool_long = 102, // Prompt Too Long. - model_not_found = 103, // Model Not Found. -} wasi_nn_error; + WASI_NN_ERROR_NAME(end_of_sequence) = 100, // End of Sequence Found. + WASI_NN_ERROR_NAME(context_full) = 101, // Context Full. + WASI_NN_ERROR_NAME(prompt_tool_long) = 102, // Prompt Too Long. + WASI_NN_ERROR_NAME(model_not_found) = 103, // Model Not Found. +} WASI_NN_ERROR_TYPE; /** * TENSOR @@ -51,15 +68,27 @@ typedef enum { typedef struct { uint32_t *buf; uint32_t size; -} tensor_dimensions; +} WASI_NN_NAME(tensor_dimensions); #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 // sync up with // https://github.com/WebAssembly/wasi-nn/blob/71320d95b8c6d43f9af7f44e18b1839db85d89b4/wasi-nn.witx#L19-L28 // The type of the elements in a tensor. -typedef enum { fp16 = 0, fp32, fp64, u8, i32, i64 } tensor_type; +typedef enum { + WASI_NN_TYPE_NAME(fp16) = 0, + WASI_NN_TYPE_NAME(fp32), + WASI_NN_TYPE_NAME(fp64), + WASI_NN_TYPE_NAME(u8), + WASI_NN_TYPE_NAME(i32), + WASI_NN_TYPE_NAME(i64), +} WASI_NN_NAME(tensor_type); #else -typedef enum { fp16 = 0, fp32, up8, ip32 } tensor_type; +typedef enum { + WASI_NN_TYPE_NAME(fp16) = 0, + WASI_NN_TYPE_NAME(fp32), + WASI_NN_TYPE_NAME(up8), + WASI_NN_TYPE_NAME(ip32), +} WASI_NN_NAME(tensor_type); #endif /* WASM_ENABLE_WASI_EPHEMERAL_NN != 0 */ // The tensor data. @@ -70,7 +99,7 @@ typedef enum { fp16 = 0, fp32, up8, ip32 } tensor_type; // 4-byte f32 elements would have a data array of length 16). Naturally, this // representation requires some knowledge of how to lay out data in // memory--e.g., using row-major ordering--and could perhaps be improved. -typedef uint8_t *tensor_data; +typedef uint8_t *WASI_NN_NAME(tensor_data); // A tensor. typedef struct { @@ -78,16 +107,16 @@ typedef struct { // represent a tensor containing a single value, use `[1]` for the tensor // dimensions. #if WASM_ENABLE_WASI_EPHEMERAL_NN != 0 && defined(__wasm__) - tensor_dimensions dimensions; + WASI_NN_NAME(tensor_dimensions) dimensions; #else - tensor_dimensions *dimensions; + WASI_NN_NAME(tensor_dimensions) * dimensions; #endif // Describe the type of element in the tensor (e.g., f32). uint8_t type; uint8_t _pad[3]; // Contains the tensor data. - tensor_data data; -} tensor; + WASI_NN_NAME(tensor_data) data; +} WASI_NN_NAME(tensor); /** * GRAPH @@ -102,15 +131,15 @@ typedef struct { typedef struct { uint8_t *buf; uint32_t size; -} graph_builder; +} WASI_NN_NAME(graph_builder); typedef struct { - graph_builder *buf; + WASI_NN_NAME(graph_builder) * buf; uint32_t size; -} graph_builder_array; +} WASI_NN_NAME(graph_builder_array); // An execution graph for performing inference (i.e., a model). -typedef uint32_t graph; +typedef uint32_t WASI_NN_NAME(graph); // sync up with // https://github.com/WebAssembly/wasi-nn/blob/main/wit/wasi-nn.wit#L75 @@ -118,21 +147,25 @@ typedef uint32_t graph; // various backends that encode (i.e., serialize) their graph IR with different // formats. typedef enum { - openvino = 0, - onnx, - tensorflow, - pytorch, - tensorflowlite, - ggml, - autodetect, - unknown_backend, -} graph_encoding; + WASI_NN_ENCODING_NAME(openvino) = 0, + WASI_NN_ENCODING_NAME(onnx), + WASI_NN_ENCODING_NAME(tensorflow), + WASI_NN_ENCODING_NAME(pytorch), + WASI_NN_ENCODING_NAME(tensorflowlite), + WASI_NN_ENCODING_NAME(ggml), + WASI_NN_ENCODING_NAME(autodetect), + WASI_NN_ENCODING_NAME(unknown_backend), +} WASI_NN_NAME(graph_encoding); // Define where the graph should be executed. -typedef enum execution_target { cpu = 0, gpu, tpu } execution_target; +typedef enum WASI_NN_NAME(execution_target) { + WASI_NN_TARGET_NAME(cpu) = 0, + WASI_NN_TARGET_NAME(gpu), + WASI_NN_TARGET_NAME(tpu), +} WASI_NN_NAME(execution_target); // Bind a `graph` to the input and output tensors for an inference. -typedef uint32_t graph_execution_context; +typedef uint32_t WASI_NN_NAME(graph_execution_context); #ifdef __cplusplus } diff --git a/wamr-wasi-extensions/samples/nn/app.c b/wamr-wasi-extensions/samples/nn/app.c index 045d1bd4b..a3e49a697 100644 --- a/wamr-wasi-extensions/samples/nn/app.c +++ b/wamr-wasi-extensions/samples/nn/app.c @@ -93,7 +93,7 @@ print_result(const float *result, size_t sz) int main(int argc, char **argv) { - wasi_nn_error nnret; + wasi_ephemeral_nn_error nnret; int ret; void *xml; size_t xmlsz; @@ -112,25 +112,27 @@ main(int argc, char **argv) exit(1); } /* note: openvino takes two buffers, namely IR and weights */ - graph_builder builders[2] = { { - .buf = xml, - .size = xmlsz, - }, - { - .buf = weights, - .size = weightssz, - } }; - graph g; - nnret = load(builders, 2, openvino, cpu, &g); + wasi_ephemeral_nn_graph_builder builders[2] = { { + .buf = xml, + .size = xmlsz, + }, + { + .buf = weights, + .size = weightssz, + } }; + wasi_ephemeral_nn_graph g; + nnret = + wasi_ephemeral_nn_load(builders, 2, wasi_ephemeral_nn_encoding_openvino, + wasi_ephemeral_nn_target_cpu, &g); unmap_file(xml, xmlsz); unmap_file(weights, weightssz); - if (nnret != success) { + if (nnret != wasi_ephemeral_nn_error_success) { fprintf(stderr, "load failed with %d\n", (int)nnret); exit(1); } - graph_execution_context ctx; - nnret = init_execution_context(g, &ctx); - if (nnret != success) { + wasi_ephemeral_nn_graph_execution_context ctx; + nnret = wasi_ephemeral_nn_init_execution_context(g, &ctx); + if (nnret != wasi_ephemeral_nn_error_success) { fprintf(stderr, "init_execution_context failed with %d\n", (int)nnret); exit(1); } @@ -142,26 +144,27 @@ main(int argc, char **argv) strerror(ret)); exit(1); } - tensor tensor = { + wasi_ephemeral_nn_tensor tensor = { .dimensions = { .buf = (uint32_t[]){1, 3, 224, 224,}, .size = 4, }, - .type = fp32, + .type = wasi_ephemeral_nn_type_fp32, .data = tensordata, }; - nnret = set_input(ctx, 0, &tensor); + nnret = wasi_ephemeral_nn_set_input(ctx, 0, &tensor); unmap_file(tensordata, tensordatasz); - if (nnret != success) { + if (nnret != wasi_ephemeral_nn_error_success) { fprintf(stderr, "set_input failed with %d\n", (int)nnret); exit(1); } - nnret = compute(ctx); - if (nnret != success) { + nnret = wasi_ephemeral_nn_compute(ctx); + if (nnret != wasi_ephemeral_nn_error_success) { fprintf(stderr, "compute failed with %d\n", (int)nnret); exit(1); } float result[1001]; uint32_t resultsz; - nnret = get_output(ctx, 0, (void *)result, sizeof(result), &resultsz); - if (nnret != success) { + nnret = wasi_ephemeral_nn_get_output(ctx, 0, (void *)result, sizeof(result), + &resultsz); + if (nnret != wasi_ephemeral_nn_error_success) { fprintf(stderr, "get_output failed with %d\n", (int)nnret); exit(1); }