Compare commits

...

2 Commits

Author SHA1 Message Date
dongsheng28849455
ffb70b7d17
Merge 2273302ca6 into 2a952b371a 2025-07-14 13:44:36 +08:00
dongsheng.yan
2273302ca6 Add onnxruntime as wasi-nn backend 2025-07-14 09:54:46 +08:00
8 changed files with 959 additions and 5 deletions

View File

@ -510,7 +510,8 @@ if (WAMR_BUILD_WASI_NN EQUAL 1)
# Variant backends
if (NOT WAMR_BUILD_WASI_NN_TFLITE EQUAL 1 AND
NOT WAMR_BUILD_WASI_NN_OPENVINO EQUAL 1 AND
NOT WAMR_BUILD_WASI_NN_LLAMACPP EQUAL 1)
NOT WAMR_BUILD_WASI_NN_LLAMACPP EQUAL 1 AND
NOT WAMR_BUILD_WASI_NN_ONNX EQUAL 1)
message (FATAL_ERROR " Need to select a backend for WASI-NN")
endif ()
@ -526,6 +527,10 @@ if (WAMR_BUILD_WASI_NN EQUAL 1)
message (" WASI-NN: backend llamacpp enabled")
add_definitions (-DWASM_ENABLE_WASI_NN_LLAMACPP)
endif ()
if (WAMR_BUILD_WASI_NN_ONNX EQUAL 1)
message (" WASI-NN: backend onnx enabled")
add_definitions (-DWASM_ENABLE_WASI_NN_ONNX)
endif ()
# Variant devices
if (WAMR_BUILD_WASI_NN_ENABLE_GPU EQUAL 1)
message (" WASI-NN: GPU enabled")

View File

@ -26,6 +26,7 @@ $ cmake -DWAMR_BUILD_WASI_NN=1 <other options> ...
- `WAMR_BUILD_WASI_NN_TFLITE`. This option designates TensorFlow Lite as the backend.
- `WAMR_BUILD_WASI_NN_OPENVINO`. This option designates OpenVINO as the backend.
- `WAMR_BUILD_WASI_NN_LLAMACPP`. This option designates Llama.cpp as the backend.
- `WAMR_BUILD_WASI_NN_ONNX`. This option designates ONNX Runtime as the backend.
### Wasm
@ -151,7 +152,7 @@ docker run \
Supported:
- Graph encoding: `tensorflowlite`, `openvino` and `ggml`
- Graph encoding: `tensorflowlite`, `openvino`, `ggml` and `onnx`
- Execution target: `cpu` for all. `gpu` and `tpu` for `tensorflowlite`.
- Tensor type: `fp32`.

View File

@ -0,0 +1,77 @@
# Copyright 2025 Sony Semiconductor Solutions Corporation.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Find ONNX Runtime library
#
# This module defines the following variables:
#
# ::
#
# onnxruntime_FOUND - True if onnxruntime is found
# onnxruntime_INCLUDE_DIRS - Include directories for onnxruntime
# onnxruntime_LIBRARIES - List of libraries for onnxruntime
# onnxruntime_VERSION - Version of onnxruntime
#
# ::
#
# Example usage:
#
# find_package(onnxruntime)
# if(onnxruntime_FOUND)
# target_link_libraries(app onnxruntime)
# endif()
# First try to find ONNX Runtime using the CMake config file
# If not found via CMake config, try to find manually
find_path(onnxruntime_INCLUDE_DIR
NAMES onnxruntime_c_api.h
PATHS
/usr/include
/usr/local/include
/opt/onnxruntime/include
$ENV{ONNXRUNTIME_ROOT}/include
${CMAKE_CURRENT_LIST_DIR}/../../../../..
)
find_library(onnxruntime_LIBRARY
NAMES onnxruntime
PATHS
/usr/lib
/usr/local/lib
/opt/onnxruntime/lib
$ENV{ONNXRUNTIME_ROOT}/lib
${CMAKE_CURRENT_LIST_DIR}/../../../../..
)
# Try to determine version from header file
if(onnxruntime_INCLUDE_DIR)
file(STRINGS "${onnxruntime_INCLUDE_DIR}/onnxruntime_c_api.h" onnxruntime_version_str
REGEX "^#define[\t ]+ORT_API_VERSION[\t ]+[0-9]+")
if(onnxruntime_version_str)
string(REGEX REPLACE "^#define[\t ]+ORT_API_VERSION[\t ]+([0-9]+)" "\\1"
onnxruntime_VERSION "${onnxruntime_version_str}")
endif()
endif()
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(onnxruntime
REQUIRED_VARS onnxruntime_LIBRARY onnxruntime_INCLUDE_DIR
VERSION_VAR onnxruntime_VERSION
)
if(onnxruntime_FOUND)
set(onnxruntime_LIBRARIES ${onnxruntime_LIBRARY})
set(onnxruntime_INCLUDE_DIRS ${onnxruntime_INCLUDE_DIR})
if(NOT TARGET onnxruntime)
add_library(onnxruntime UNKNOWN IMPORTED)
set_target_properties(onnxruntime PROPERTIES
IMPORTED_LOCATION "${onnxruntime_LIBRARY}"
INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_INCLUDE_DIRS}"
)
endif()
endif()
mark_as_advanced(onnxruntime_INCLUDE_DIR onnxruntime_LIBRARY)

View File

@ -109,3 +109,31 @@ if(WAMR_BUILD_WASI_NN_LLAMACPP EQUAL 1)
install(TARGETS wasi_nn_llamacpp DESTINATION lib)
endif()
# - onnx
if(WAMR_BUILD_WASI_NN_ONNX EQUAL 1)
find_package(onnxruntime REQUIRED)
enable_language(CXX)
add_library(
wasi_nn_onnx
SHARED
${WASI_NN_ROOT}/src/wasi_nn_onnx.cpp
)
target_include_directories(
wasi_nn_onnx
PUBLIC
${onnxruntime_INCLUDE_DIR}/onnx
${onnxruntime_INCLUDE_DIR}
)
target_link_libraries(
wasi_nn_onnx
PUBLIC
vmlib
onnxruntime
)
install(TARGETS wasi_nn_onnx DESTINATION lib)
endif()

View File

@ -21,7 +21,7 @@
#else
#define WASI_NN_IMPORT(name) \
__attribute__((import_module("wasi_nn"), import_name(name)))
#warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It's deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.)
#warning "You are using \"wasi_nn\", which is a legacy WAMR-specific ABI. It's deprecated and will likely be removed in future versions of WAMR. Please use \"wasi_ephemeral_nn\" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.)"
#endif
/**

View File

@ -27,7 +27,7 @@ extern "C" {
#define WASI_NN_TYPE_NAME(name) WASI_NN_NAME(type_##name)
#define WASI_NN_ENCODING_NAME(name) WASI_NN_NAME(encoding_##name)
#define WASI_NN_TARGET_NAME(name) WASI_NN_NAME(target_##name)
#define WASI_NN_ERROR_TYPE WASI_NN_NAME(error);
#define WASI_NN_ERROR_TYPE WASI_NN_NAME(error)
#endif
/**

View File

@ -21,7 +21,8 @@
#include "wasm_export.h"
#if WASM_ENABLE_WASI_EPHEMERAL_NN == 0
#warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It's deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.)
#warning \
"You are using \"wasi_nn\", which is a legacy WAMR-specific ABI. It's deprecated and will likely be removed in future versions of WAMR. Please use \"wasi_ephemeral_nn\" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.)"
#endif
#define HASHMAP_INITIAL_SIZE 20
@ -33,6 +34,7 @@
#define TFLITE_BACKEND_LIB "libwasi_nn_tflite" LIB_EXTENTION
#define OPENVINO_BACKEND_LIB "libwasi_nn_openvino" LIB_EXTENTION
#define LLAMACPP_BACKEND_LIB "libwasi_nn_llamacpp" LIB_EXTENTION
#define ONNX_BACKEND_LIB "libwasi_nn_onnx" LIB_EXTENTION
/* Global variables */
static korp_mutex wasi_nn_lock;
@ -240,6 +242,17 @@ choose_a_backend()
return openvino;
}
#ifndef NDEBUG
NN_WARN_PRINTF("%s", dlerror());
#endif
handle = dlopen(ONNX_BACKEND_LIB, RTLD_LAZY);
if (handle) {
NN_INFO_PRINTF("Using onnx backend");
dlclose(handle);
return onnx;
}
#ifndef NDEBUG
NN_WARN_PRINTF("%s", dlerror());
#endif
@ -363,6 +376,8 @@ graph_encoding_to_backend_lib_name(graph_encoding encoding)
return TFLITE_BACKEND_LIB;
case ggml:
return LLAMACPP_BACKEND_LIB;
case onnx:
return ONNX_BACKEND_LIB;
default:
return NULL;
}

View File

@ -0,0 +1,828 @@
/*
* Copyright 2025 Sony Semiconductor Solutions Corporation.
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
*/
#include <dlfcn.h>
#include <stdlib.h>
#include <string.h>
#include <mutex>
#include <vector>
#include <unordered_map>
#include "bh_platform.h"
#include "wasi_nn_backend.h"
#include "utils/logger.h"
#include "onnxruntime_c_api.h"
/* Maximum number of graphs and execution contexts */
#define MAX_GRAPHS 10
#define MAX_CONTEXTS 10
/* ONNX Runtime context structure */
typedef struct {
OrtEnv *env;
OrtSessionOptions *session_options;
OrtAllocator *allocator;
const OrtApi *ort_api;
std::mutex mutex;
bool is_initialized;
} OnnxRuntimeContext;
/* Graph structure */
typedef struct {
OrtSession *session;
bool is_initialized;
} OnnxRuntimeGraph;
/* Execution context structure */
typedef struct {
OrtMemoryInfo *memory_info;
std::vector<const char *> input_names;
std::vector<const char *> output_names;
std::unordered_map<uint32_t, OrtValue *> inputs;
std::unordered_map<uint32_t, OrtValue *> outputs;
OnnxRuntimeGraph *graph;
bool is_initialized;
} OnnxRuntimeExecCtx;
/* Global variables */
static OnnxRuntimeContext g_ort_ctx;
static OnnxRuntimeGraph g_graphs[MAX_GRAPHS];
static OnnxRuntimeExecCtx g_exec_ctxs[MAX_CONTEXTS];
/* Helper functions */
static void
check_status_and_log(OrtStatus *status)
{
if (status != nullptr) {
const char *msg = g_ort_ctx.ort_api->GetErrorMessage(status);
NN_ERR_PRINTF("ONNX Runtime error: %s", msg);
g_ort_ctx.ort_api->ReleaseStatus(status);
}
}
static wasi_nn_error
convert_ort_error_to_wasi_nn_error(OrtStatus *status)
{
if (status == nullptr) {
return success;
}
wasi_nn_error err;
OrtErrorCode code = g_ort_ctx.ort_api->GetErrorCode(status);
const char *msg = g_ort_ctx.ort_api->GetErrorMessage(status);
NN_ERR_PRINTF("ONNX Runtime error: %s", msg);
switch (code) {
case ORT_INVALID_ARGUMENT:
err = invalid_argument;
break;
case ORT_RUNTIME_EXCEPTION:
err = runtime_error;
break;
case ORT_NOT_IMPLEMENTED:
err = unsupported_operation;
break;
case ORT_INVALID_PROTOBUF:
err = invalid_encoding;
break;
case ORT_MODEL_LOADED:
err = too_large;
break;
case ORT_INVALID_GRAPH:
err = invalid_encoding;
break;
default:
err = runtime_error;
break;
}
g_ort_ctx.ort_api->ReleaseStatus(status);
return err;
}
static tensor_type
convert_ort_type_to_wasi_nn_type(ONNXTensorElementDataType ort_type)
{
switch (ort_type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
return fp32;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
return fp16;
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
return fp64;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
return u8;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
return i32;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
return i64;
#else
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
return up8;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
return ip32;
#endif
default:
NN_WARN_PRINTF("Unsupported ONNX tensor type: %d", ort_type);
return fp32; // Default to fp32
}
}
static ONNXTensorElementDataType
convert_wasi_nn_type_to_ort_type(tensor_type type)
{
switch (type) {
case fp32:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
case fp16:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
case fp64:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
case u8:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
case i32:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
case i64:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
#else
case up8:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
case ip32:
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
#endif
default:
NN_WARN_PRINTF("Unsupported wasi-nn tensor type: %d", type);
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; // Default to float
}
}
static size_t
get_tensor_element_size(tensor_type type)
{
switch (type) {
case fp32:
return 4;
case fp16:
return 2;
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
case fp64:
return 8;
case u8:
return 1;
case i32:
return 4;
case i64:
return 8;
#else
case up8:
return 1;
case ip32:
return 4;
#endif
default:
NN_WARN_PRINTF("Unsupported tensor type: %d", type);
return 4; // Default to 4 bytes (float)
}
}
/* Backend API implementation */
extern "C" {
__attribute__((visibility("default"))) wasi_nn_error
init_backend(void **onnx_ctx)
{
std::lock_guard<std::mutex> lock(g_ort_ctx.mutex);
if (g_ort_ctx.is_initialized) {
*onnx_ctx = &g_ort_ctx;
return success;
}
g_ort_ctx.ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
if (!g_ort_ctx.ort_api) {
NN_ERR_PRINTF("Failed to get ONNX Runtime API");
return runtime_error;
}
NN_INFO_PRINTF("Creating ONNX Runtime environment...");
OrtStatus *status = g_ort_ctx.ort_api->CreateEnv(ORT_LOGGING_LEVEL_VERBOSE,
"wasi-nn", &g_ort_ctx.env);
if (status != nullptr) {
const char *error_message = g_ort_ctx.ort_api->GetErrorMessage(status);
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
NN_ERR_PRINTF("Failed to create ONNX Runtime environment: %s",
error_message);
g_ort_ctx.ort_api->ReleaseStatus(status);
return err;
}
NN_INFO_PRINTF("ONNX Runtime environment created successfully");
status =
g_ort_ctx.ort_api->CreateSessionOptions(&g_ort_ctx.session_options);
if (status != nullptr) {
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
g_ort_ctx.ort_api->ReleaseEnv(g_ort_ctx.env);
NN_ERR_PRINTF("Failed to create ONNX Runtime session options");
return err;
}
status = g_ort_ctx.ort_api->SetSessionGraphOptimizationLevel(
g_ort_ctx.session_options, ORT_ENABLE_BASIC);
if (status != nullptr) {
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
g_ort_ctx.ort_api->ReleaseSessionOptions(g_ort_ctx.session_options);
g_ort_ctx.ort_api->ReleaseEnv(g_ort_ctx.env);
NN_ERR_PRINTF("Failed to set graph optimization level");
return err;
}
status =
g_ort_ctx.ort_api->GetAllocatorWithDefaultOptions(&g_ort_ctx.allocator);
if (status != nullptr) {
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
g_ort_ctx.ort_api->ReleaseSessionOptions(g_ort_ctx.session_options);
g_ort_ctx.ort_api->ReleaseEnv(g_ort_ctx.env);
NN_ERR_PRINTF("Failed to get default allocator");
return err;
}
for (int i = 0; i < MAX_GRAPHS; i++) {
g_graphs[i].is_initialized = false;
g_graphs[i].session = nullptr;
}
for (int i = 0; i < MAX_CONTEXTS; i++) {
g_exec_ctxs[i].is_initialized = false;
g_exec_ctxs[i].memory_info = nullptr;
g_exec_ctxs[i].graph = nullptr;
g_exec_ctxs[i].input_names.clear();
g_exec_ctxs[i].output_names.clear();
g_exec_ctxs[i].inputs.clear();
g_exec_ctxs[i].outputs.clear();
}
g_ort_ctx.is_initialized = true;
*onnx_ctx = &g_ort_ctx;
NN_INFO_PRINTF("ONNX Runtime backend initialized");
return success;
}
__attribute__((visibility("default"))) wasi_nn_error
deinit_backend(void *onnx_ctx)
{
OnnxRuntimeContext *ctx = (OnnxRuntimeContext *)onnx_ctx;
std::lock_guard<std::mutex> lock(ctx->mutex);
if (!ctx->is_initialized) {
return success;
}
for (int i = 0; i < MAX_GRAPHS; i++) {
if (g_graphs[i].is_initialized) {
ctx->ort_api->ReleaseSession(g_graphs[i].session);
g_graphs[i].is_initialized = false;
}
}
for (int i = 0; i < MAX_CONTEXTS; i++) {
if (g_exec_ctxs[i].is_initialized) {
for (auto &input : g_exec_ctxs[i].inputs) {
ctx->ort_api->ReleaseValue(input.second);
}
for (auto &output : g_exec_ctxs[i].outputs) {
ctx->ort_api->ReleaseValue(output.second);
}
ctx->ort_api->ReleaseMemoryInfo(g_exec_ctxs[i].memory_info);
g_exec_ctxs[i].is_initialized = false;
}
}
ctx->ort_api->ReleaseSessionOptions(ctx->session_options);
ctx->ort_api->ReleaseEnv(ctx->env);
ctx->is_initialized = false;
NN_INFO_PRINTF("ONNX Runtime backend deinitialized");
return success;
}
__attribute__((visibility("default"))) wasi_nn_error
load(void *onnx_ctx, graph_builder_array *builder, graph_encoding encoding,
execution_target target, graph *g)
{
if (encoding != onnx) {
NN_ERR_PRINTF("Unsupported encoding: %d", encoding);
return invalid_encoding;
}
if (target != cpu) {
NN_ERR_PRINTF("Only CPU target is supported");
return unsupported_operation;
}
OnnxRuntimeContext *ctx = (OnnxRuntimeContext *)onnx_ctx;
std::lock_guard<std::mutex> lock(ctx->mutex);
int graph_index = -1;
for (int i = 0; i < MAX_GRAPHS; i++) {
if (!g_graphs[i].is_initialized) {
graph_index = i;
break;
}
}
if (graph_index == -1) {
NN_ERR_PRINTF("Maximum number of graphs reached");
return runtime_error;
}
if (builder->size == 0 || builder->buf == NULL) {
NN_ERR_PRINTF("No model data provided");
return invalid_argument;
}
NN_INFO_PRINTF("[ONNX Runtime] Loading model of size %zu bytes...",
builder->buf[0].size);
if (builder->buf[0].size > 16) {
NN_INFO_PRINTF(
"Model header bytes: %02x %02x %02x %02x %02x %02x %02x %02x",
((uint8_t *)builder->buf[0].buf)[0],
((uint8_t *)builder->buf[0].buf)[1],
((uint8_t *)builder->buf[0].buf)[2],
((uint8_t *)builder->buf[0].buf)[3],
((uint8_t *)builder->buf[0].buf)[4],
((uint8_t *)builder->buf[0].buf)[5],
((uint8_t *)builder->buf[0].buf)[6],
((uint8_t *)builder->buf[0].buf)[7]);
}
OrtStatus *status = ctx->ort_api->CreateSessionFromArray(
ctx->env, builder->buf[0].buf, builder->buf[0].size,
ctx->session_options, &g_graphs[graph_index].session);
if (status != nullptr) {
const char *error_message = ctx->ort_api->GetErrorMessage(status);
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
NN_ERR_PRINTF("Failed to create ONNX Runtime session: %s",
error_message);
ctx->ort_api->ReleaseStatus(status);
return err;
}
NN_INFO_PRINTF("ONNX Runtime session created successfully");
g_graphs[graph_index].is_initialized = true;
*g = graph_index;
NN_INFO_PRINTF("ONNX model loaded as graph %d", graph_index);
return success;
}
__attribute__((visibility("default"))) wasi_nn_error
load_by_name(void *onnx_ctx, const char *name, uint32_t filename_len, graph *g)
{
OnnxRuntimeContext *ctx = (OnnxRuntimeContext *)onnx_ctx;
std::lock_guard<std::mutex> lock(ctx->mutex);
int graph_index = -1;
for (int i = 0; i < MAX_GRAPHS; i++) {
if (!g_graphs[i].is_initialized) {
graph_index = i;
break;
}
}
if (graph_index == -1) {
NN_ERR_PRINTF("Maximum number of graphs reached");
return runtime_error;
}
OrtStatus *status = ctx->ort_api->CreateSession(
ctx->env, name, ctx->session_options, &g_graphs[graph_index].session);
if (status != nullptr) {
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
NN_ERR_PRINTF("Failed to create ONNX Runtime session from file: %s",
name);
return err;
}
g_graphs[graph_index].is_initialized = true;
*g = graph_index;
NN_INFO_PRINTF("ONNX model loaded from file %s as graph %d", name,
graph_index);
return success;
}
__attribute__((visibility("default"))) wasi_nn_error
init_execution_context(void *onnx_ctx, graph g, graph_execution_context *ctx)
{
if (g >= MAX_GRAPHS || !g_graphs[g].is_initialized) {
NN_ERR_PRINTF("Invalid graph handle: %d", g);
return invalid_argument;
}
OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx;
std::lock_guard<std::mutex> lock(ort_ctx->mutex);
int ctx_index = -1;
for (int i = 0; i < MAX_CONTEXTS; i++) {
if (!g_exec_ctxs[i].is_initialized) {
ctx_index = i;
break;
}
}
if (ctx_index == -1) {
NN_ERR_PRINTF("Maximum number of execution contexts reached");
return runtime_error;
}
OnnxRuntimeExecCtx *exec_ctx = &g_exec_ctxs[ctx_index];
exec_ctx->graph = &g_graphs[g];
OrtStatus *status = ort_ctx->ort_api->CreateCpuMemoryInfo(
OrtArenaAllocator, OrtMemTypeDefault, &exec_ctx->memory_info);
if (status != nullptr) {
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
NN_ERR_PRINTF("Failed to create CPU memory info");
return err;
}
size_t num_input_nodes;
status = ort_ctx->ort_api->SessionGetInputCount(exec_ctx->graph->session,
&num_input_nodes);
if (status != nullptr) {
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
ort_ctx->ort_api->ReleaseMemoryInfo(exec_ctx->memory_info);
NN_ERR_PRINTF("Failed to get input count");
return err;
}
for (size_t i = 0; i < num_input_nodes; i++) {
char *input_name;
status = ort_ctx->ort_api->SessionGetInputName(
exec_ctx->graph->session, i, ort_ctx->allocator, &input_name);
if (status != nullptr) {
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
ort_ctx->ort_api->ReleaseMemoryInfo(exec_ctx->memory_info);
NN_ERR_PRINTF("Failed to get input name");
return err;
}
exec_ctx->input_names.push_back(input_name);
}
size_t num_output_nodes;
status = ort_ctx->ort_api->SessionGetOutputCount(exec_ctx->graph->session,
&num_output_nodes);
if (status != nullptr) {
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
ort_ctx->ort_api->ReleaseMemoryInfo(exec_ctx->memory_info);
for (const char *name : exec_ctx->input_names) {
ort_ctx->allocator->Free(ort_ctx->allocator, (void *)name);
}
NN_ERR_PRINTF("Failed to get output count");
return err;
}
for (size_t i = 0; i < num_output_nodes; i++) {
char *output_name;
status = ort_ctx->ort_api->SessionGetOutputName(
exec_ctx->graph->session, i, ort_ctx->allocator, &output_name);
if (status != nullptr) {
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
ort_ctx->ort_api->ReleaseMemoryInfo(exec_ctx->memory_info);
for (const char *name : exec_ctx->input_names) {
ort_ctx->allocator->Free(ort_ctx->allocator, (void *)name);
}
NN_ERR_PRINTF("Failed to get output name");
return err;
}
exec_ctx->output_names.push_back(output_name);
}
exec_ctx->is_initialized = true;
*ctx = ctx_index;
NN_INFO_PRINTF("Execution context %d initialized for graph %d", ctx_index,
g);
return success;
}
__attribute__((visibility("default"))) wasi_nn_error
set_input(void *onnx_ctx, graph_execution_context ctx, uint32_t index,
tensor *input_tensor)
{
if (ctx >= MAX_CONTEXTS || !g_exec_ctxs[ctx].is_initialized) {
NN_ERR_PRINTF("Invalid execution context handle: %d", ctx);
return invalid_argument;
}
if (index >= g_exec_ctxs[ctx].input_names.size()) {
NN_ERR_PRINTF("Invalid input index: %d (max: %zu)", index,
g_exec_ctxs[ctx].input_names.size() - 1);
return invalid_argument;
}
OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx;
std::lock_guard<std::mutex> lock(ort_ctx->mutex);
OnnxRuntimeExecCtx *exec_ctx = &g_exec_ctxs[ctx];
OrtTypeInfo *type_info = nullptr;
OrtStatus *status = ort_ctx->ort_api->SessionGetInputTypeInfo(
exec_ctx->graph->session, index, &type_info);
if (status != nullptr) {
ort_ctx->ort_api->ReleaseTypeInfo(type_info);
return runtime_error;
}
const OrtTensorTypeAndShapeInfo *tensor_info;
status =
ort_ctx->ort_api->CastTypeInfoToTensorInfo(type_info, &tensor_info);
if (status != nullptr) {
ort_ctx->ort_api->ReleaseTypeInfo(type_info);
return runtime_error;
}
size_t num_model_dims;
status = ort_ctx->ort_api->GetDimensionsCount(tensor_info, &num_model_dims);
std::vector<int64_t> model_dims(num_model_dims);
status = ort_ctx->ort_api->GetDimensions(tensor_info, model_dims.data(),
num_model_dims);
size_t model_tensor_size = 1;
for (size_t i = 0; i < num_model_dims; ++i)
model_tensor_size *= model_dims[i];
size_t input_tensor_size = 1;
for (size_t i = 0; i < input_tensor->dimensions->size; ++i)
input_tensor_size *= input_tensor->dimensions->buf[i];
void *input_tensor_data = input_tensor->data.buf;
void *input_tensor_scaled_data = NULL;
ort_ctx->ort_api->ReleaseTypeInfo(type_info);
size_t num_dims = input_tensor->dimensions->size;
int64_t *ort_dims = (int64_t *)malloc(num_dims * sizeof(int64_t));
if (!ort_dims) {
NN_ERR_PRINTF("Failed to allocate memory for tensor dimensions");
return runtime_error;
}
for (size_t i = 0; i < num_dims; i++) {
ort_dims[i] = input_tensor->dimensions->buf[i];
}
ONNXTensorElementDataType ort_type = convert_wasi_nn_type_to_ort_type(
static_cast<tensor_type>(input_tensor->type));
OrtValue *input_value = nullptr;
size_t total_elements = 1;
for (size_t i = 0; i < num_dims; i++) {
total_elements *= input_tensor->dimensions->buf[i];
}
status = ort_ctx->ort_api->CreateTensorWithDataAsOrtValue(
exec_ctx->memory_info, input_tensor->data.buf,
get_tensor_element_size(static_cast<tensor_type>(input_tensor->type))
* total_elements,
ort_dims, num_dims, ort_type, &input_value);
free(ort_dims);
if (status != nullptr) {
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
NN_ERR_PRINTF("Failed to create input tensor");
return err;
}
if (exec_ctx->inputs.count(index) > 0) {
ort_ctx->ort_api->ReleaseValue(exec_ctx->inputs[index]);
}
exec_ctx->inputs[index] = input_value;
NN_INFO_PRINTF("Input tensor set for context %d, index %d", ctx, index);
return success;
}
__attribute__((visibility("default"))) wasi_nn_error
compute(void *onnx_ctx, graph_execution_context ctx)
{
if (ctx >= MAX_CONTEXTS || !g_exec_ctxs[ctx].is_initialized) {
NN_ERR_PRINTF("Invalid execution context handle: %d", ctx);
return invalid_argument;
}
OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx;
std::lock_guard<std::mutex> lock(ort_ctx->mutex);
OnnxRuntimeExecCtx *exec_ctx = &g_exec_ctxs[ctx];
std::vector<OrtValue *> input_values;
std::vector<const char *> input_names;
for (size_t i = 0; i < exec_ctx->input_names.size(); i++) {
if (exec_ctx->inputs.count(i) == 0) {
NN_ERR_PRINTF("Input tensor not set for index %zu", i);
return invalid_argument;
}
input_values.push_back(exec_ctx->inputs[i]);
input_names.push_back(exec_ctx->input_names[i]);
}
for (auto &output : exec_ctx->outputs) {
ort_ctx->ort_api->ReleaseValue(output.second);
}
exec_ctx->outputs.clear();
std::vector<OrtValue *> output_values(exec_ctx->output_names.size());
OrtStatus *status = ort_ctx->ort_api->Run(
exec_ctx->graph->session, nullptr, input_names.data(),
input_values.data(), input_values.size(), exec_ctx->output_names.data(),
exec_ctx->output_names.size(), output_values.data());
for (size_t i = 0; i < output_values.size(); i++) {
exec_ctx->outputs[i] = output_values[i];
}
if (status != nullptr) {
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
NN_ERR_PRINTF("Failed to run inference");
return err;
}
NN_INFO_PRINTF("Inference computed for context %d", ctx);
return success;
}
__attribute__((visibility("default"))) wasi_nn_error
get_output(void *onnx_ctx, graph_execution_context ctx, uint32_t index,
tensor_data *out_buffer, uint32_t *out_buffer_size)
{
if (ctx >= MAX_CONTEXTS || !g_exec_ctxs[ctx].is_initialized) {
NN_ERR_PRINTF("Invalid execution context handle: %d", ctx);
return invalid_argument;
}
if (index >= g_exec_ctxs[ctx].output_names.size()) {
NN_ERR_PRINTF("Invalid output index: %d (max: %zu)", index,
g_exec_ctxs[ctx].output_names.size() - 1);
return invalid_argument;
}
OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx;
std::lock_guard<std::mutex> lock(ort_ctx->mutex);
OnnxRuntimeExecCtx *exec_ctx = &g_exec_ctxs[ctx];
OrtValue *output_value = exec_ctx->outputs[index];
if (!output_value) {
NN_ERR_PRINTF("Output tensor not available for index %d", index);
return runtime_error;
}
OrtTensorTypeAndShapeInfo *tensor_info;
OrtStatus *status =
ort_ctx->ort_api->GetTensorTypeAndShape(output_value, &tensor_info);
if (status != nullptr) {
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
NN_ERR_PRINTF("Failed to get tensor type and shape");
return err;
}
ONNXTensorElementDataType element_type;
status = ort_ctx->ort_api->GetTensorElementType(tensor_info, &element_type);
if (status != nullptr) {
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info);
NN_ERR_PRINTF("Failed to get tensor element type");
return err;
}
size_t num_dims;
status = ort_ctx->ort_api->GetDimensionsCount(tensor_info, &num_dims);
if (status != nullptr) {
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info);
NN_ERR_PRINTF("Failed to get tensor dimensions count");
return err;
}
int64_t *dims = (int64_t *)malloc(num_dims * sizeof(int64_t));
if (!dims) {
ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info);
NN_ERR_PRINTF("Failed to allocate memory for tensor dimensions");
return runtime_error;
}
status = ort_ctx->ort_api->GetDimensions(tensor_info, dims, num_dims);
if (status != nullptr) {
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
free(dims);
ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info);
NN_ERR_PRINTF("Failed to get tensor dimensions");
return err;
}
size_t tensor_size;
status =
ort_ctx->ort_api->GetTensorShapeElementCount(tensor_info, &tensor_size);
if (status != nullptr) {
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
free(dims);
ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info);
NN_ERR_PRINTF("Failed to get tensor element count");
return err;
}
NN_INFO_PRINTF("Output tensor dimensions: ");
for (size_t i = 0; i < num_dims; i++) {
NN_INFO_PRINTF(" dim[%zu] = %lld", i, dims[i]);
}
NN_INFO_PRINTF("Total elements: %zu", tensor_size);
ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info);
free(dims);
if (tensor_size == 0) {
NN_ERR_PRINTF("Tensor is empty (zero elements)");
return runtime_error;
}
void *tensor_data = nullptr;
status = ort_ctx->ort_api->GetTensorMutableData(output_value, &tensor_data);
if (status != nullptr) {
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
NN_ERR_PRINTF("Failed to get tensor data");
return err;
}
if (tensor_data == nullptr) {
NN_ERR_PRINTF("Tensor data pointer is null");
return runtime_error;
}
size_t element_size;
switch (element_type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
element_size = sizeof(float);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
element_size = sizeof(uint16_t);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
element_size = sizeof(double);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
element_size = sizeof(int32_t);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
element_size = sizeof(int64_t);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
element_size = sizeof(uint8_t);
break;
default:
NN_ERR_PRINTF("Unsupported tensor element type: %d", element_type);
return unsupported_operation;
}
size_t output_size_bytes = tensor_size * element_size;
NN_INFO_PRINTF("Output tensor size: %zu elements, element size: %zu bytes, "
"total: %zu bytes",
tensor_size, element_size, output_size_bytes);
if (*out_buffer_size < output_size_bytes) {
NN_ERR_PRINTF(
"Output buffer too small: %u bytes provided, %zu bytes needed",
*out_buffer_size, output_size_bytes);
*out_buffer_size = output_size_bytes;
return invalid_argument;
}
if (tensor_data == nullptr) {
NN_ERR_PRINTF("Tensor data is null");
return runtime_error;
}
if (out_buffer->buf == nullptr) {
NN_ERR_PRINTF("Output buffer is null");
return invalid_argument;
}
memcpy(out_buffer->buf, tensor_data, output_size_bytes);
*out_buffer_size = output_size_bytes;
NN_INFO_PRINTF(
"Output tensor retrieved for context %d, index %d, size %zu bytes", ctx,
index, output_size_bytes);
return success;
}
} /* End of extern "C" */