mirror of
https://github.com/bytecodealliance/wasm-micro-runtime.git
synced 2025-07-19 18:58:16 +00:00
Compare commits
3 Commits
ffb70b7d17
...
8fd6877aab
Author | SHA1 | Date | |
---|---|---|---|
![]() |
8fd6877aab | ||
![]() |
db942f3aaf | ||
![]() |
2273302ca6 |
|
@ -510,7 +510,8 @@ if (WAMR_BUILD_WASI_NN EQUAL 1)
|
|||
# Variant backends
|
||||
if (NOT WAMR_BUILD_WASI_NN_TFLITE EQUAL 1 AND
|
||||
NOT WAMR_BUILD_WASI_NN_OPENVINO EQUAL 1 AND
|
||||
NOT WAMR_BUILD_WASI_NN_LLAMACPP EQUAL 1)
|
||||
NOT WAMR_BUILD_WASI_NN_LLAMACPP EQUAL 1 AND
|
||||
NOT WAMR_BUILD_WASI_NN_ONNX EQUAL 1)
|
||||
message (FATAL_ERROR " Need to select a backend for WASI-NN")
|
||||
endif ()
|
||||
|
||||
|
@ -526,6 +527,10 @@ if (WAMR_BUILD_WASI_NN EQUAL 1)
|
|||
message (" WASI-NN: backend llamacpp enabled")
|
||||
add_definitions (-DWASM_ENABLE_WASI_NN_LLAMACPP)
|
||||
endif ()
|
||||
if (WAMR_BUILD_WASI_NN_ONNX EQUAL 1)
|
||||
message (" WASI-NN: backend onnx enabled")
|
||||
add_definitions (-DWASM_ENABLE_WASI_NN_ONNX)
|
||||
endif ()
|
||||
# Variant devices
|
||||
if (WAMR_BUILD_WASI_NN_ENABLE_GPU EQUAL 1)
|
||||
message (" WASI-NN: GPU enabled")
|
||||
|
|
|
@ -191,7 +191,7 @@ bind(int sockfd, const struct sockaddr *addr, socklen_t addrlen)
|
|||
error = __wasi_sock_bind(sockfd, &wasi_addr);
|
||||
HANDLE_ERROR(error)
|
||||
|
||||
return __WASI_ERRNO_SUCCESS;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int
|
||||
|
@ -212,7 +212,7 @@ connect(int sockfd, const struct sockaddr *addr, socklen_t addrlen)
|
|||
error = __wasi_sock_connect(sockfd, &wasi_addr);
|
||||
HANDLE_ERROR(error)
|
||||
|
||||
return __WASI_ERRNO_SUCCESS;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int
|
||||
|
@ -220,7 +220,7 @@ listen(int sockfd, int backlog)
|
|||
{
|
||||
__wasi_errno_t error = __wasi_sock_listen(sockfd, backlog);
|
||||
HANDLE_ERROR(error)
|
||||
return __WASI_ERRNO_SUCCESS;
|
||||
return 0;
|
||||
}
|
||||
|
||||
ssize_t
|
||||
|
@ -375,7 +375,7 @@ socket(int domain, int type, int protocol)
|
|||
af = INET6;
|
||||
}
|
||||
else {
|
||||
return __WASI_ERRNO_NOPROTOOPT;
|
||||
HANDLE_ERROR(__WASI_ERRNO_NOPROTOOPT)
|
||||
}
|
||||
|
||||
if (SOCK_DGRAM == type) {
|
||||
|
@ -385,7 +385,7 @@ socket(int domain, int type, int protocol)
|
|||
socktype = SOCKET_STREAM;
|
||||
}
|
||||
else {
|
||||
return __WASI_ERRNO_NOPROTOOPT;
|
||||
HANDLE_ERROR(__WASI_ERRNO_NOPROTOOPT)
|
||||
}
|
||||
|
||||
error = __wasi_sock_open(poolfd, af, socktype, &sockfd);
|
||||
|
@ -408,7 +408,7 @@ getsockname(int sockfd, struct sockaddr *addr, socklen_t *addrlen)
|
|||
error = wasi_addr_to_sockaddr(&wasi_addr, addr, addrlen);
|
||||
HANDLE_ERROR(error)
|
||||
|
||||
return __WASI_ERRNO_SUCCESS;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int
|
||||
|
@ -425,7 +425,7 @@ getpeername(int sockfd, struct sockaddr *addr, socklen_t *addrlen)
|
|||
error = wasi_addr_to_sockaddr(&wasi_addr, addr, addrlen);
|
||||
HANDLE_ERROR(error)
|
||||
|
||||
return __WASI_ERRNO_SUCCESS;
|
||||
return 0;
|
||||
}
|
||||
|
||||
struct aibuf {
|
||||
|
@ -614,38 +614,38 @@ get_sol_socket_option(int sockfd, int optname, void *__restrict optval,
|
|||
error = __wasi_sock_get_recv_timeout(sockfd, &timeout_us);
|
||||
HANDLE_ERROR(error);
|
||||
*(struct timeval *)optval = time_us_to_timeval(timeout_us);
|
||||
return error;
|
||||
return 0;
|
||||
case SO_SNDTIMEO:
|
||||
assert(*optlen == sizeof(struct timeval));
|
||||
error = __wasi_sock_get_send_timeout(sockfd, &timeout_us);
|
||||
HANDLE_ERROR(error);
|
||||
*(struct timeval *)optval = time_us_to_timeval(timeout_us);
|
||||
return error;
|
||||
return 0;
|
||||
case SO_SNDBUF:
|
||||
assert(*optlen == sizeof(int));
|
||||
error = __wasi_sock_get_send_buf_size(sockfd, (size_t *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
case SO_RCVBUF:
|
||||
assert(*optlen == sizeof(int));
|
||||
error = __wasi_sock_get_recv_buf_size(sockfd, (size_t *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
case SO_KEEPALIVE:
|
||||
assert(*optlen == sizeof(int));
|
||||
error = __wasi_sock_get_keep_alive(sockfd, (bool *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
case SO_REUSEADDR:
|
||||
assert(*optlen == sizeof(int));
|
||||
error = __wasi_sock_get_reuse_addr(sockfd, (bool *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
case SO_REUSEPORT:
|
||||
assert(*optlen == sizeof(int));
|
||||
error = __wasi_sock_get_reuse_port(sockfd, (bool *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
case SO_LINGER:
|
||||
assert(*optlen == sizeof(struct linger));
|
||||
error =
|
||||
|
@ -653,12 +653,12 @@ get_sol_socket_option(int sockfd, int optname, void *__restrict optval,
|
|||
HANDLE_ERROR(error);
|
||||
((struct linger *)optval)->l_onoff = (int)is_linger_enabled;
|
||||
((struct linger *)optval)->l_linger = linger_s;
|
||||
return error;
|
||||
return 0;
|
||||
case SO_BROADCAST:
|
||||
assert(*optlen == sizeof(int));
|
||||
error = __wasi_sock_get_broadcast(sockfd, (bool *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
case SO_TYPE:
|
||||
assert(*optlen == sizeof(int));
|
||||
error = __wasi_fd_fdstat_get(sockfd, &sb);
|
||||
|
@ -678,7 +678,7 @@ get_sol_socket_option(int sockfd, int optname, void *__restrict optval,
|
|||
default:
|
||||
error = __WASI_ERRNO_NOTSUP;
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -692,32 +692,32 @@ get_ipproto_tcp_option(int sockfd, int optname, void *__restrict optval,
|
|||
assert(*optlen == sizeof(uint32_t));
|
||||
error = __wasi_sock_get_tcp_keep_idle(sockfd, (uint32_t *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
case TCP_KEEPINTVL:
|
||||
assert(*optlen == sizeof(uint32_t));
|
||||
error = __wasi_sock_get_tcp_keep_intvl(sockfd, (uint32_t *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
case TCP_FASTOPEN_CONNECT:
|
||||
assert(*optlen == sizeof(int));
|
||||
error =
|
||||
__wasi_sock_get_tcp_fastopen_connect(sockfd, (bool *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
case TCP_NODELAY:
|
||||
assert(*optlen == sizeof(int));
|
||||
error = __wasi_sock_get_tcp_no_delay(sockfd, (bool *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
case TCP_QUICKACK:
|
||||
assert(*optlen == sizeof(int));
|
||||
error = __wasi_sock_get_tcp_quick_ack(sockfd, (bool *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
default:
|
||||
error = __WASI_ERRNO_NOTSUP;
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -733,21 +733,21 @@ get_ipproto_ip_option(int sockfd, int optname, void *__restrict optval,
|
|||
error = __wasi_sock_get_ip_multicast_loop(sockfd, false,
|
||||
(bool *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
case IP_TTL:
|
||||
assert(*optlen == sizeof(int));
|
||||
error = __wasi_sock_get_ip_ttl(sockfd, (uint8_t *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
case IP_MULTICAST_TTL:
|
||||
assert(*optlen == sizeof(int));
|
||||
error = __wasi_sock_get_ip_multicast_ttl(sockfd, (uint8_t *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
default:
|
||||
error = __WASI_ERRNO_NOTSUP;
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -762,17 +762,17 @@ get_ipproto_ipv6_option(int sockfd, int optname, void *__restrict optval,
|
|||
assert(*optlen == sizeof(int));
|
||||
error = __wasi_sock_get_ipv6_only(sockfd, (bool *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
case IPV6_MULTICAST_LOOP:
|
||||
assert(*optlen == sizeof(int));
|
||||
error =
|
||||
__wasi_sock_get_ip_multicast_loop(sockfd, true, (bool *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
default:
|
||||
error = __WASI_ERRNO_NOTSUP;
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -794,7 +794,7 @@ getsockopt(int sockfd, int level, int optname, void *__restrict optval,
|
|||
default:
|
||||
error = __WASI_ERRNO_NOTSUP;
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -812,7 +812,7 @@ set_sol_socket_option(int sockfd, int optname, const void *optval,
|
|||
timeout_us = timeval_to_time_us(*(struct timeval *)optval);
|
||||
error = __wasi_sock_set_recv_timeout(sockfd, timeout_us);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
}
|
||||
case SO_SNDTIMEO:
|
||||
{
|
||||
|
@ -820,42 +820,42 @@ set_sol_socket_option(int sockfd, int optname, const void *optval,
|
|||
timeout_us = timeval_to_time_us(*(struct timeval *)optval);
|
||||
error = __wasi_sock_set_send_timeout(sockfd, timeout_us);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
}
|
||||
case SO_SNDBUF:
|
||||
{
|
||||
assert(optlen == sizeof(int));
|
||||
error = __wasi_sock_set_send_buf_size(sockfd, *(size_t *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
}
|
||||
case SO_RCVBUF:
|
||||
{
|
||||
assert(optlen == sizeof(int));
|
||||
error = __wasi_sock_set_recv_buf_size(sockfd, *(size_t *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
}
|
||||
case SO_KEEPALIVE:
|
||||
{
|
||||
assert(optlen == sizeof(int));
|
||||
error = __wasi_sock_set_keep_alive(sockfd, *(bool *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
}
|
||||
case SO_REUSEADDR:
|
||||
{
|
||||
assert(optlen == sizeof(int));
|
||||
error = __wasi_sock_set_reuse_addr(sockfd, *(bool *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
}
|
||||
case SO_REUSEPORT:
|
||||
{
|
||||
assert(optlen == sizeof(int));
|
||||
error = __wasi_sock_set_reuse_port(sockfd, *(bool *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
}
|
||||
case SO_LINGER:
|
||||
{
|
||||
|
@ -864,20 +864,20 @@ set_sol_socket_option(int sockfd, int optname, const void *optval,
|
|||
error = __wasi_sock_set_linger(sockfd, (bool)linger_opt->l_onoff,
|
||||
linger_opt->l_linger);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
}
|
||||
case SO_BROADCAST:
|
||||
{
|
||||
assert(optlen == sizeof(int));
|
||||
error = __wasi_sock_set_broadcast(sockfd, *(bool *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
}
|
||||
default:
|
||||
{
|
||||
error = __WASI_ERRNO_NOTSUP;
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -893,32 +893,32 @@ set_ipproto_tcp_option(int sockfd, int optname, const void *optval,
|
|||
assert(optlen == sizeof(int));
|
||||
error = __wasi_sock_set_tcp_no_delay(sockfd, *(bool *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
case TCP_KEEPIDLE:
|
||||
assert(optlen == sizeof(uint32_t));
|
||||
error = __wasi_sock_set_tcp_keep_idle(sockfd, *(uint32_t *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
case TCP_KEEPINTVL:
|
||||
assert(optlen == sizeof(uint32_t));
|
||||
error = __wasi_sock_set_tcp_keep_intvl(sockfd, *(uint32_t *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
case TCP_FASTOPEN_CONNECT:
|
||||
assert(optlen == sizeof(int));
|
||||
error =
|
||||
__wasi_sock_set_tcp_fastopen_connect(sockfd, *(bool *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
case TCP_QUICKACK:
|
||||
assert(optlen == sizeof(int));
|
||||
error = __wasi_sock_set_tcp_quick_ack(sockfd, *(bool *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
default:
|
||||
error = __WASI_ERRNO_NOTSUP;
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -936,7 +936,7 @@ set_ipproto_ip_option(int sockfd, int optname, const void *optval,
|
|||
error = __wasi_sock_set_ip_multicast_loop(sockfd, false,
|
||||
*(bool *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
case IP_ADD_MEMBERSHIP:
|
||||
assert(optlen == sizeof(struct ip_mreq));
|
||||
ip_mreq_opt = (struct ip_mreq *)optval;
|
||||
|
@ -946,7 +946,7 @@ set_ipproto_ip_option(int sockfd, int optname, const void *optval,
|
|||
error = __wasi_sock_set_ip_add_membership(
|
||||
sockfd, &imr_multiaddr, ip_mreq_opt->imr_interface.s_addr);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
case IP_DROP_MEMBERSHIP:
|
||||
assert(optlen == sizeof(struct ip_mreq));
|
||||
ip_mreq_opt = (struct ip_mreq *)optval;
|
||||
|
@ -956,22 +956,22 @@ set_ipproto_ip_option(int sockfd, int optname, const void *optval,
|
|||
error = __wasi_sock_set_ip_drop_membership(
|
||||
sockfd, &imr_multiaddr, ip_mreq_opt->imr_interface.s_addr);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
case IP_TTL:
|
||||
assert(optlen == sizeof(int));
|
||||
error = __wasi_sock_set_ip_ttl(sockfd, *(uint8_t *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
case IP_MULTICAST_TTL:
|
||||
assert(optlen == sizeof(int));
|
||||
error =
|
||||
__wasi_sock_set_ip_multicast_ttl(sockfd, *(uint8_t *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
default:
|
||||
error = __WASI_ERRNO_NOTSUP;
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -988,13 +988,13 @@ set_ipproto_ipv6_option(int sockfd, int optname, const void *optval,
|
|||
assert(optlen == sizeof(int));
|
||||
error = __wasi_sock_set_ipv6_only(sockfd, *(bool *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
case IPV6_MULTICAST_LOOP:
|
||||
assert(optlen == sizeof(int));
|
||||
error = __wasi_sock_set_ip_multicast_loop(sockfd, true,
|
||||
*(bool *)optval);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
case IPV6_JOIN_GROUP:
|
||||
assert(optlen == sizeof(struct ipv6_mreq));
|
||||
ipv6_mreq_opt = (struct ipv6_mreq *)optval;
|
||||
|
@ -1005,7 +1005,7 @@ set_ipproto_ipv6_option(int sockfd, int optname, const void *optval,
|
|||
error = __wasi_sock_set_ip_add_membership(
|
||||
sockfd, &imr_multiaddr, ipv6_mreq_opt->ipv6mr_interface);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
case IPV6_LEAVE_GROUP:
|
||||
assert(optlen == sizeof(struct ipv6_mreq));
|
||||
ipv6_mreq_opt = (struct ipv6_mreq *)optval;
|
||||
|
@ -1016,11 +1016,11 @@ set_ipproto_ipv6_option(int sockfd, int optname, const void *optval,
|
|||
error = __wasi_sock_set_ip_drop_membership(
|
||||
sockfd, &imr_multiaddr, ipv6_mreq_opt->ipv6mr_interface);
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
default:
|
||||
error = __WASI_ERRNO_NOTSUP;
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1042,6 +1042,6 @@ setsockopt(int sockfd, int level, int optname, const void *optval,
|
|||
default:
|
||||
error = __WASI_ERRNO_NOTSUP;
|
||||
HANDLE_ERROR(error);
|
||||
return error;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,6 +26,7 @@ $ cmake -DWAMR_BUILD_WASI_NN=1 <other options> ...
|
|||
- `WAMR_BUILD_WASI_NN_TFLITE`. This option designates TensorFlow Lite as the backend.
|
||||
- `WAMR_BUILD_WASI_NN_OPENVINO`. This option designates OpenVINO as the backend.
|
||||
- `WAMR_BUILD_WASI_NN_LLAMACPP`. This option designates Llama.cpp as the backend.
|
||||
- `WAMR_BUILD_WASI_NN_ONNX`. This option designates ONNX Runtime as the backend.
|
||||
|
||||
### Wasm
|
||||
|
||||
|
@ -151,7 +152,7 @@ docker run \
|
|||
|
||||
Supported:
|
||||
|
||||
- Graph encoding: `tensorflowlite`, `openvino` and `ggml`
|
||||
- Graph encoding: `tensorflowlite`, `openvino`, `ggml` and `onnx`
|
||||
- Execution target: `cpu` for all. `gpu` and `tpu` for `tensorflowlite`.
|
||||
- Tensor type: `fp32`.
|
||||
|
||||
|
|
77
core/iwasm/libraries/wasi-nn/cmake/Findonnxruntime.cmake
Normal file
77
core/iwasm/libraries/wasi-nn/cmake/Findonnxruntime.cmake
Normal file
|
@ -0,0 +1,77 @@
|
|||
# Copyright 2025 Sony Semiconductor Solutions Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
# Find ONNX Runtime library
|
||||
#
|
||||
# This module defines the following variables:
|
||||
#
|
||||
# ::
|
||||
#
|
||||
# onnxruntime_FOUND - True if onnxruntime is found
|
||||
# onnxruntime_INCLUDE_DIRS - Include directories for onnxruntime
|
||||
# onnxruntime_LIBRARIES - List of libraries for onnxruntime
|
||||
# onnxruntime_VERSION - Version of onnxruntime
|
||||
#
|
||||
# ::
|
||||
#
|
||||
# Example usage:
|
||||
#
|
||||
# find_package(onnxruntime)
|
||||
# if(onnxruntime_FOUND)
|
||||
# target_link_libraries(app onnxruntime)
|
||||
# endif()
|
||||
|
||||
# First try to find ONNX Runtime using the CMake config file
|
||||
|
||||
# If not found via CMake config, try to find manually
|
||||
find_path(onnxruntime_INCLUDE_DIR
|
||||
NAMES onnxruntime_c_api.h
|
||||
PATHS
|
||||
/usr/include
|
||||
/usr/local/include
|
||||
/opt/onnxruntime/include
|
||||
$ENV{ONNXRUNTIME_ROOT}/include
|
||||
${CMAKE_CURRENT_LIST_DIR}/../../../../..
|
||||
)
|
||||
|
||||
find_library(onnxruntime_LIBRARY
|
||||
NAMES onnxruntime
|
||||
PATHS
|
||||
/usr/lib
|
||||
/usr/local/lib
|
||||
/opt/onnxruntime/lib
|
||||
$ENV{ONNXRUNTIME_ROOT}/lib
|
||||
${CMAKE_CURRENT_LIST_DIR}/../../../../..
|
||||
)
|
||||
|
||||
# Try to determine version from header file
|
||||
if(onnxruntime_INCLUDE_DIR)
|
||||
file(STRINGS "${onnxruntime_INCLUDE_DIR}/onnxruntime_c_api.h" onnxruntime_version_str
|
||||
REGEX "^#define[\t ]+ORT_API_VERSION[\t ]+[0-9]+")
|
||||
|
||||
if(onnxruntime_version_str)
|
||||
string(REGEX REPLACE "^#define[\t ]+ORT_API_VERSION[\t ]+([0-9]+)" "\\1"
|
||||
onnxruntime_VERSION "${onnxruntime_version_str}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
include(FindPackageHandleStandardArgs)
|
||||
find_package_handle_standard_args(onnxruntime
|
||||
REQUIRED_VARS onnxruntime_LIBRARY onnxruntime_INCLUDE_DIR
|
||||
VERSION_VAR onnxruntime_VERSION
|
||||
)
|
||||
|
||||
if(onnxruntime_FOUND)
|
||||
set(onnxruntime_LIBRARIES ${onnxruntime_LIBRARY})
|
||||
set(onnxruntime_INCLUDE_DIRS ${onnxruntime_INCLUDE_DIR})
|
||||
|
||||
if(NOT TARGET onnxruntime)
|
||||
add_library(onnxruntime UNKNOWN IMPORTED)
|
||||
set_target_properties(onnxruntime PROPERTIES
|
||||
IMPORTED_LOCATION "${onnxruntime_LIBRARY}"
|
||||
INTERFACE_INCLUDE_DIRECTORIES "${onnxruntime_INCLUDE_DIRS}"
|
||||
)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
mark_as_advanced(onnxruntime_INCLUDE_DIR onnxruntime_LIBRARY)
|
|
@ -109,3 +109,31 @@ if(WAMR_BUILD_WASI_NN_LLAMACPP EQUAL 1)
|
|||
|
||||
install(TARGETS wasi_nn_llamacpp DESTINATION lib)
|
||||
endif()
|
||||
|
||||
# - onnx
|
||||
if(WAMR_BUILD_WASI_NN_ONNX EQUAL 1)
|
||||
find_package(onnxruntime REQUIRED)
|
||||
enable_language(CXX)
|
||||
|
||||
add_library(
|
||||
wasi_nn_onnx
|
||||
SHARED
|
||||
${WASI_NN_ROOT}/src/wasi_nn_onnx.cpp
|
||||
)
|
||||
|
||||
target_include_directories(
|
||||
wasi_nn_onnx
|
||||
PUBLIC
|
||||
${onnxruntime_INCLUDE_DIR}/onnx
|
||||
${onnxruntime_INCLUDE_DIR}
|
||||
)
|
||||
|
||||
target_link_libraries(
|
||||
wasi_nn_onnx
|
||||
PUBLIC
|
||||
vmlib
|
||||
onnxruntime
|
||||
)
|
||||
|
||||
install(TARGETS wasi_nn_onnx DESTINATION lib)
|
||||
endif()
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#else
|
||||
#define WASI_NN_IMPORT(name) \
|
||||
__attribute__((import_module("wasi_nn"), import_name(name)))
|
||||
#warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It's deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.)
|
||||
#warning "You are using \"wasi_nn\", which is a legacy WAMR-specific ABI. It's deprecated and will likely be removed in future versions of WAMR. Please use \"wasi_ephemeral_nn\" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.)"
|
||||
#endif
|
||||
|
||||
/**
|
||||
|
|
|
@ -27,7 +27,7 @@ extern "C" {
|
|||
#define WASI_NN_TYPE_NAME(name) WASI_NN_NAME(type_##name)
|
||||
#define WASI_NN_ENCODING_NAME(name) WASI_NN_NAME(encoding_##name)
|
||||
#define WASI_NN_TARGET_NAME(name) WASI_NN_NAME(target_##name)
|
||||
#define WASI_NN_ERROR_TYPE WASI_NN_NAME(error);
|
||||
#define WASI_NN_ERROR_TYPE WASI_NN_NAME(error)
|
||||
#endif
|
||||
|
||||
/**
|
||||
|
|
|
@ -21,7 +21,8 @@
|
|||
#include "wasm_export.h"
|
||||
|
||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN == 0
|
||||
#warning You are using "wasi_nn", which is a legacy WAMR-specific ABI. It's deperecated and will likely be removed in future versions of WAMR. Please use "wasi_ephemeral_nn" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.)
|
||||
#warning \
|
||||
"You are using \"wasi_nn\", which is a legacy WAMR-specific ABI. It's deprecated and will likely be removed in future versions of WAMR. Please use \"wasi_ephemeral_nn\" instead. (For a WASM module, use the wasi_ephemeral_nn.h header instead. For the runtime configurations, enable WASM_ENABLE_WASI_EPHEMERAL_NN/WAMR_BUILD_WASI_EPHEMERAL_NN.)"
|
||||
#endif
|
||||
|
||||
#define HASHMAP_INITIAL_SIZE 20
|
||||
|
@ -33,6 +34,7 @@
|
|||
#define TFLITE_BACKEND_LIB "libwasi_nn_tflite" LIB_EXTENTION
|
||||
#define OPENVINO_BACKEND_LIB "libwasi_nn_openvino" LIB_EXTENTION
|
||||
#define LLAMACPP_BACKEND_LIB "libwasi_nn_llamacpp" LIB_EXTENTION
|
||||
#define ONNX_BACKEND_LIB "libwasi_nn_onnx" LIB_EXTENTION
|
||||
|
||||
/* Global variables */
|
||||
static korp_mutex wasi_nn_lock;
|
||||
|
@ -240,6 +242,17 @@ choose_a_backend()
|
|||
return openvino;
|
||||
}
|
||||
|
||||
#ifndef NDEBUG
|
||||
NN_WARN_PRINTF("%s", dlerror());
|
||||
#endif
|
||||
|
||||
handle = dlopen(ONNX_BACKEND_LIB, RTLD_LAZY);
|
||||
if (handle) {
|
||||
NN_INFO_PRINTF("Using onnx backend");
|
||||
dlclose(handle);
|
||||
return onnx;
|
||||
}
|
||||
|
||||
#ifndef NDEBUG
|
||||
NN_WARN_PRINTF("%s", dlerror());
|
||||
#endif
|
||||
|
@ -363,6 +376,8 @@ graph_encoding_to_backend_lib_name(graph_encoding encoding)
|
|||
return TFLITE_BACKEND_LIB;
|
||||
case ggml:
|
||||
return LLAMACPP_BACKEND_LIB;
|
||||
case onnx:
|
||||
return ONNX_BACKEND_LIB;
|
||||
default:
|
||||
return NULL;
|
||||
}
|
||||
|
|
828
core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp
Normal file
828
core/iwasm/libraries/wasi-nn/src/wasi_nn_onnx.cpp
Normal file
|
@ -0,0 +1,828 @@
|
|||
/*
|
||||
* Copyright 2025 Sony Semiconductor Solutions Corporation.
|
||||
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
*/
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include "bh_platform.h"
|
||||
#include "wasi_nn_backend.h"
|
||||
#include "utils/logger.h"
|
||||
#include "onnxruntime_c_api.h"
|
||||
|
||||
/* Maximum number of graphs and execution contexts */
|
||||
#define MAX_GRAPHS 10
|
||||
#define MAX_CONTEXTS 10
|
||||
|
||||
/* ONNX Runtime context structure */
|
||||
typedef struct {
|
||||
OrtEnv *env;
|
||||
OrtSessionOptions *session_options;
|
||||
OrtAllocator *allocator;
|
||||
const OrtApi *ort_api;
|
||||
std::mutex mutex;
|
||||
bool is_initialized;
|
||||
} OnnxRuntimeContext;
|
||||
|
||||
/* Graph structure */
|
||||
typedef struct {
|
||||
OrtSession *session;
|
||||
bool is_initialized;
|
||||
} OnnxRuntimeGraph;
|
||||
|
||||
/* Execution context structure */
|
||||
typedef struct {
|
||||
OrtMemoryInfo *memory_info;
|
||||
std::vector<const char *> input_names;
|
||||
std::vector<const char *> output_names;
|
||||
std::unordered_map<uint32_t, OrtValue *> inputs;
|
||||
std::unordered_map<uint32_t, OrtValue *> outputs;
|
||||
OnnxRuntimeGraph *graph;
|
||||
bool is_initialized;
|
||||
} OnnxRuntimeExecCtx;
|
||||
|
||||
/* Global variables */
|
||||
static OnnxRuntimeContext g_ort_ctx;
|
||||
static OnnxRuntimeGraph g_graphs[MAX_GRAPHS];
|
||||
static OnnxRuntimeExecCtx g_exec_ctxs[MAX_CONTEXTS];
|
||||
|
||||
/* Helper functions */
|
||||
static void
|
||||
check_status_and_log(OrtStatus *status)
|
||||
{
|
||||
if (status != nullptr) {
|
||||
const char *msg = g_ort_ctx.ort_api->GetErrorMessage(status);
|
||||
NN_ERR_PRINTF("ONNX Runtime error: %s", msg);
|
||||
g_ort_ctx.ort_api->ReleaseStatus(status);
|
||||
}
|
||||
}
|
||||
|
||||
static wasi_nn_error
|
||||
convert_ort_error_to_wasi_nn_error(OrtStatus *status)
|
||||
{
|
||||
if (status == nullptr) {
|
||||
return success;
|
||||
}
|
||||
|
||||
wasi_nn_error err;
|
||||
OrtErrorCode code = g_ort_ctx.ort_api->GetErrorCode(status);
|
||||
const char *msg = g_ort_ctx.ort_api->GetErrorMessage(status);
|
||||
|
||||
NN_ERR_PRINTF("ONNX Runtime error: %s", msg);
|
||||
|
||||
switch (code) {
|
||||
case ORT_INVALID_ARGUMENT:
|
||||
err = invalid_argument;
|
||||
break;
|
||||
case ORT_RUNTIME_EXCEPTION:
|
||||
err = runtime_error;
|
||||
break;
|
||||
case ORT_NOT_IMPLEMENTED:
|
||||
err = unsupported_operation;
|
||||
break;
|
||||
case ORT_INVALID_PROTOBUF:
|
||||
err = invalid_encoding;
|
||||
break;
|
||||
case ORT_MODEL_LOADED:
|
||||
err = too_large;
|
||||
break;
|
||||
case ORT_INVALID_GRAPH:
|
||||
err = invalid_encoding;
|
||||
break;
|
||||
default:
|
||||
err = runtime_error;
|
||||
break;
|
||||
}
|
||||
|
||||
g_ort_ctx.ort_api->ReleaseStatus(status);
|
||||
return err;
|
||||
}
|
||||
|
||||
static tensor_type
|
||||
convert_ort_type_to_wasi_nn_type(ONNXTensorElementDataType ort_type)
|
||||
{
|
||||
switch (ort_type) {
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
|
||||
return fp32;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
|
||||
return fp16;
|
||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
|
||||
return fp64;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
|
||||
return u8;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
|
||||
return i32;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
|
||||
return i64;
|
||||
#else
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
|
||||
return up8;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
|
||||
return ip32;
|
||||
#endif
|
||||
default:
|
||||
NN_WARN_PRINTF("Unsupported ONNX tensor type: %d", ort_type);
|
||||
return fp32; // Default to fp32
|
||||
}
|
||||
}
|
||||
|
||||
static ONNXTensorElementDataType
|
||||
convert_wasi_nn_type_to_ort_type(tensor_type type)
|
||||
{
|
||||
switch (type) {
|
||||
case fp32:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
case fp16:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
|
||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
case fp64:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
|
||||
case u8:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||
case i32:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
|
||||
case i64:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
|
||||
#else
|
||||
case up8:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
|
||||
case ip32:
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
|
||||
#endif
|
||||
default:
|
||||
NN_WARN_PRINTF("Unsupported wasi-nn tensor type: %d", type);
|
||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; // Default to float
|
||||
}
|
||||
}
|
||||
|
||||
static size_t
|
||||
get_tensor_element_size(tensor_type type)
|
||||
{
|
||||
switch (type) {
|
||||
case fp32:
|
||||
return 4;
|
||||
case fp16:
|
||||
return 2;
|
||||
#if WASM_ENABLE_WASI_EPHEMERAL_NN != 0
|
||||
case fp64:
|
||||
return 8;
|
||||
case u8:
|
||||
return 1;
|
||||
case i32:
|
||||
return 4;
|
||||
case i64:
|
||||
return 8;
|
||||
#else
|
||||
case up8:
|
||||
return 1;
|
||||
case ip32:
|
||||
return 4;
|
||||
#endif
|
||||
default:
|
||||
NN_WARN_PRINTF("Unsupported tensor type: %d", type);
|
||||
return 4; // Default to 4 bytes (float)
|
||||
}
|
||||
}
|
||||
|
||||
/* Backend API implementation */
|
||||
|
||||
extern "C" {
|
||||
|
||||
__attribute__((visibility("default"))) wasi_nn_error
|
||||
init_backend(void **onnx_ctx)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(g_ort_ctx.mutex);
|
||||
|
||||
if (g_ort_ctx.is_initialized) {
|
||||
*onnx_ctx = &g_ort_ctx;
|
||||
return success;
|
||||
}
|
||||
|
||||
g_ort_ctx.ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
|
||||
if (!g_ort_ctx.ort_api) {
|
||||
NN_ERR_PRINTF("Failed to get ONNX Runtime API");
|
||||
return runtime_error;
|
||||
}
|
||||
|
||||
NN_INFO_PRINTF("Creating ONNX Runtime environment...");
|
||||
OrtStatus *status = g_ort_ctx.ort_api->CreateEnv(ORT_LOGGING_LEVEL_VERBOSE,
|
||||
"wasi-nn", &g_ort_ctx.env);
|
||||
if (status != nullptr) {
|
||||
const char *error_message = g_ort_ctx.ort_api->GetErrorMessage(status);
|
||||
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
|
||||
NN_ERR_PRINTF("Failed to create ONNX Runtime environment: %s",
|
||||
error_message);
|
||||
g_ort_ctx.ort_api->ReleaseStatus(status);
|
||||
return err;
|
||||
}
|
||||
NN_INFO_PRINTF("ONNX Runtime environment created successfully");
|
||||
|
||||
status =
|
||||
g_ort_ctx.ort_api->CreateSessionOptions(&g_ort_ctx.session_options);
|
||||
if (status != nullptr) {
|
||||
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
|
||||
g_ort_ctx.ort_api->ReleaseEnv(g_ort_ctx.env);
|
||||
NN_ERR_PRINTF("Failed to create ONNX Runtime session options");
|
||||
return err;
|
||||
}
|
||||
|
||||
status = g_ort_ctx.ort_api->SetSessionGraphOptimizationLevel(
|
||||
g_ort_ctx.session_options, ORT_ENABLE_BASIC);
|
||||
if (status != nullptr) {
|
||||
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
|
||||
g_ort_ctx.ort_api->ReleaseSessionOptions(g_ort_ctx.session_options);
|
||||
g_ort_ctx.ort_api->ReleaseEnv(g_ort_ctx.env);
|
||||
NN_ERR_PRINTF("Failed to set graph optimization level");
|
||||
return err;
|
||||
}
|
||||
|
||||
status =
|
||||
g_ort_ctx.ort_api->GetAllocatorWithDefaultOptions(&g_ort_ctx.allocator);
|
||||
if (status != nullptr) {
|
||||
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
|
||||
g_ort_ctx.ort_api->ReleaseSessionOptions(g_ort_ctx.session_options);
|
||||
g_ort_ctx.ort_api->ReleaseEnv(g_ort_ctx.env);
|
||||
NN_ERR_PRINTF("Failed to get default allocator");
|
||||
return err;
|
||||
}
|
||||
|
||||
for (int i = 0; i < MAX_GRAPHS; i++) {
|
||||
g_graphs[i].is_initialized = false;
|
||||
g_graphs[i].session = nullptr;
|
||||
}
|
||||
|
||||
for (int i = 0; i < MAX_CONTEXTS; i++) {
|
||||
g_exec_ctxs[i].is_initialized = false;
|
||||
g_exec_ctxs[i].memory_info = nullptr;
|
||||
g_exec_ctxs[i].graph = nullptr;
|
||||
g_exec_ctxs[i].input_names.clear();
|
||||
g_exec_ctxs[i].output_names.clear();
|
||||
g_exec_ctxs[i].inputs.clear();
|
||||
g_exec_ctxs[i].outputs.clear();
|
||||
}
|
||||
|
||||
g_ort_ctx.is_initialized = true;
|
||||
*onnx_ctx = &g_ort_ctx;
|
||||
|
||||
NN_INFO_PRINTF("ONNX Runtime backend initialized");
|
||||
return success;
|
||||
}
|
||||
|
||||
__attribute__((visibility("default"))) wasi_nn_error
|
||||
deinit_backend(void *onnx_ctx)
|
||||
{
|
||||
OnnxRuntimeContext *ctx = (OnnxRuntimeContext *)onnx_ctx;
|
||||
std::lock_guard<std::mutex> lock(ctx->mutex);
|
||||
|
||||
if (!ctx->is_initialized) {
|
||||
return success;
|
||||
}
|
||||
|
||||
for (int i = 0; i < MAX_GRAPHS; i++) {
|
||||
if (g_graphs[i].is_initialized) {
|
||||
ctx->ort_api->ReleaseSession(g_graphs[i].session);
|
||||
g_graphs[i].is_initialized = false;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < MAX_CONTEXTS; i++) {
|
||||
if (g_exec_ctxs[i].is_initialized) {
|
||||
for (auto &input : g_exec_ctxs[i].inputs) {
|
||||
ctx->ort_api->ReleaseValue(input.second);
|
||||
}
|
||||
for (auto &output : g_exec_ctxs[i].outputs) {
|
||||
ctx->ort_api->ReleaseValue(output.second);
|
||||
}
|
||||
ctx->ort_api->ReleaseMemoryInfo(g_exec_ctxs[i].memory_info);
|
||||
g_exec_ctxs[i].is_initialized = false;
|
||||
}
|
||||
}
|
||||
|
||||
ctx->ort_api->ReleaseSessionOptions(ctx->session_options);
|
||||
ctx->ort_api->ReleaseEnv(ctx->env);
|
||||
ctx->is_initialized = false;
|
||||
|
||||
NN_INFO_PRINTF("ONNX Runtime backend deinitialized");
|
||||
return success;
|
||||
}
|
||||
|
||||
__attribute__((visibility("default"))) wasi_nn_error
|
||||
load(void *onnx_ctx, graph_builder_array *builder, graph_encoding encoding,
|
||||
execution_target target, graph *g)
|
||||
{
|
||||
if (encoding != onnx) {
|
||||
NN_ERR_PRINTF("Unsupported encoding: %d", encoding);
|
||||
return invalid_encoding;
|
||||
}
|
||||
|
||||
if (target != cpu) {
|
||||
NN_ERR_PRINTF("Only CPU target is supported");
|
||||
return unsupported_operation;
|
||||
}
|
||||
|
||||
OnnxRuntimeContext *ctx = (OnnxRuntimeContext *)onnx_ctx;
|
||||
std::lock_guard<std::mutex> lock(ctx->mutex);
|
||||
|
||||
int graph_index = -1;
|
||||
for (int i = 0; i < MAX_GRAPHS; i++) {
|
||||
if (!g_graphs[i].is_initialized) {
|
||||
graph_index = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (graph_index == -1) {
|
||||
NN_ERR_PRINTF("Maximum number of graphs reached");
|
||||
return runtime_error;
|
||||
}
|
||||
|
||||
if (builder->size == 0 || builder->buf == NULL) {
|
||||
NN_ERR_PRINTF("No model data provided");
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
NN_INFO_PRINTF("[ONNX Runtime] Loading model of size %zu bytes...",
|
||||
builder->buf[0].size);
|
||||
|
||||
if (builder->buf[0].size > 16) {
|
||||
NN_INFO_PRINTF(
|
||||
"Model header bytes: %02x %02x %02x %02x %02x %02x %02x %02x",
|
||||
((uint8_t *)builder->buf[0].buf)[0],
|
||||
((uint8_t *)builder->buf[0].buf)[1],
|
||||
((uint8_t *)builder->buf[0].buf)[2],
|
||||
((uint8_t *)builder->buf[0].buf)[3],
|
||||
((uint8_t *)builder->buf[0].buf)[4],
|
||||
((uint8_t *)builder->buf[0].buf)[5],
|
||||
((uint8_t *)builder->buf[0].buf)[6],
|
||||
((uint8_t *)builder->buf[0].buf)[7]);
|
||||
}
|
||||
|
||||
OrtStatus *status = ctx->ort_api->CreateSessionFromArray(
|
||||
ctx->env, builder->buf[0].buf, builder->buf[0].size,
|
||||
ctx->session_options, &g_graphs[graph_index].session);
|
||||
|
||||
if (status != nullptr) {
|
||||
const char *error_message = ctx->ort_api->GetErrorMessage(status);
|
||||
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
|
||||
NN_ERR_PRINTF("Failed to create ONNX Runtime session: %s",
|
||||
error_message);
|
||||
ctx->ort_api->ReleaseStatus(status);
|
||||
return err;
|
||||
}
|
||||
|
||||
NN_INFO_PRINTF("ONNX Runtime session created successfully");
|
||||
|
||||
g_graphs[graph_index].is_initialized = true;
|
||||
*g = graph_index;
|
||||
|
||||
NN_INFO_PRINTF("ONNX model loaded as graph %d", graph_index);
|
||||
return success;
|
||||
}
|
||||
|
||||
__attribute__((visibility("default"))) wasi_nn_error
|
||||
load_by_name(void *onnx_ctx, const char *name, uint32_t filename_len, graph *g)
|
||||
{
|
||||
OnnxRuntimeContext *ctx = (OnnxRuntimeContext *)onnx_ctx;
|
||||
std::lock_guard<std::mutex> lock(ctx->mutex);
|
||||
|
||||
int graph_index = -1;
|
||||
for (int i = 0; i < MAX_GRAPHS; i++) {
|
||||
if (!g_graphs[i].is_initialized) {
|
||||
graph_index = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (graph_index == -1) {
|
||||
NN_ERR_PRINTF("Maximum number of graphs reached");
|
||||
return runtime_error;
|
||||
}
|
||||
|
||||
OrtStatus *status = ctx->ort_api->CreateSession(
|
||||
ctx->env, name, ctx->session_options, &g_graphs[graph_index].session);
|
||||
|
||||
if (status != nullptr) {
|
||||
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
|
||||
NN_ERR_PRINTF("Failed to create ONNX Runtime session from file: %s",
|
||||
name);
|
||||
return err;
|
||||
}
|
||||
|
||||
g_graphs[graph_index].is_initialized = true;
|
||||
*g = graph_index;
|
||||
|
||||
NN_INFO_PRINTF("ONNX model loaded from file %s as graph %d", name,
|
||||
graph_index);
|
||||
return success;
|
||||
}
|
||||
|
||||
__attribute__((visibility("default"))) wasi_nn_error
|
||||
init_execution_context(void *onnx_ctx, graph g, graph_execution_context *ctx)
|
||||
{
|
||||
if (g >= MAX_GRAPHS || !g_graphs[g].is_initialized) {
|
||||
NN_ERR_PRINTF("Invalid graph handle: %d", g);
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx;
|
||||
std::lock_guard<std::mutex> lock(ort_ctx->mutex);
|
||||
|
||||
int ctx_index = -1;
|
||||
for (int i = 0; i < MAX_CONTEXTS; i++) {
|
||||
if (!g_exec_ctxs[i].is_initialized) {
|
||||
ctx_index = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (ctx_index == -1) {
|
||||
NN_ERR_PRINTF("Maximum number of execution contexts reached");
|
||||
return runtime_error;
|
||||
}
|
||||
|
||||
OnnxRuntimeExecCtx *exec_ctx = &g_exec_ctxs[ctx_index];
|
||||
exec_ctx->graph = &g_graphs[g];
|
||||
|
||||
OrtStatus *status = ort_ctx->ort_api->CreateCpuMemoryInfo(
|
||||
OrtArenaAllocator, OrtMemTypeDefault, &exec_ctx->memory_info);
|
||||
if (status != nullptr) {
|
||||
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
|
||||
NN_ERR_PRINTF("Failed to create CPU memory info");
|
||||
return err;
|
||||
}
|
||||
|
||||
size_t num_input_nodes;
|
||||
status = ort_ctx->ort_api->SessionGetInputCount(exec_ctx->graph->session,
|
||||
&num_input_nodes);
|
||||
if (status != nullptr) {
|
||||
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
|
||||
ort_ctx->ort_api->ReleaseMemoryInfo(exec_ctx->memory_info);
|
||||
NN_ERR_PRINTF("Failed to get input count");
|
||||
return err;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < num_input_nodes; i++) {
|
||||
char *input_name;
|
||||
status = ort_ctx->ort_api->SessionGetInputName(
|
||||
exec_ctx->graph->session, i, ort_ctx->allocator, &input_name);
|
||||
if (status != nullptr) {
|
||||
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
|
||||
ort_ctx->ort_api->ReleaseMemoryInfo(exec_ctx->memory_info);
|
||||
NN_ERR_PRINTF("Failed to get input name");
|
||||
return err;
|
||||
}
|
||||
exec_ctx->input_names.push_back(input_name);
|
||||
}
|
||||
|
||||
size_t num_output_nodes;
|
||||
status = ort_ctx->ort_api->SessionGetOutputCount(exec_ctx->graph->session,
|
||||
&num_output_nodes);
|
||||
if (status != nullptr) {
|
||||
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
|
||||
ort_ctx->ort_api->ReleaseMemoryInfo(exec_ctx->memory_info);
|
||||
for (const char *name : exec_ctx->input_names) {
|
||||
ort_ctx->allocator->Free(ort_ctx->allocator, (void *)name);
|
||||
}
|
||||
NN_ERR_PRINTF("Failed to get output count");
|
||||
return err;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < num_output_nodes; i++) {
|
||||
char *output_name;
|
||||
status = ort_ctx->ort_api->SessionGetOutputName(
|
||||
exec_ctx->graph->session, i, ort_ctx->allocator, &output_name);
|
||||
if (status != nullptr) {
|
||||
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
|
||||
ort_ctx->ort_api->ReleaseMemoryInfo(exec_ctx->memory_info);
|
||||
for (const char *name : exec_ctx->input_names) {
|
||||
ort_ctx->allocator->Free(ort_ctx->allocator, (void *)name);
|
||||
}
|
||||
NN_ERR_PRINTF("Failed to get output name");
|
||||
return err;
|
||||
}
|
||||
exec_ctx->output_names.push_back(output_name);
|
||||
}
|
||||
|
||||
exec_ctx->is_initialized = true;
|
||||
*ctx = ctx_index;
|
||||
|
||||
NN_INFO_PRINTF("Execution context %d initialized for graph %d", ctx_index,
|
||||
g);
|
||||
return success;
|
||||
}
|
||||
|
||||
__attribute__((visibility("default"))) wasi_nn_error
|
||||
set_input(void *onnx_ctx, graph_execution_context ctx, uint32_t index,
|
||||
tensor *input_tensor)
|
||||
{
|
||||
if (ctx >= MAX_CONTEXTS || !g_exec_ctxs[ctx].is_initialized) {
|
||||
NN_ERR_PRINTF("Invalid execution context handle: %d", ctx);
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
if (index >= g_exec_ctxs[ctx].input_names.size()) {
|
||||
NN_ERR_PRINTF("Invalid input index: %d (max: %zu)", index,
|
||||
g_exec_ctxs[ctx].input_names.size() - 1);
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx;
|
||||
std::lock_guard<std::mutex> lock(ort_ctx->mutex);
|
||||
OnnxRuntimeExecCtx *exec_ctx = &g_exec_ctxs[ctx];
|
||||
|
||||
OrtTypeInfo *type_info = nullptr;
|
||||
OrtStatus *status = ort_ctx->ort_api->SessionGetInputTypeInfo(
|
||||
exec_ctx->graph->session, index, &type_info);
|
||||
if (status != nullptr) {
|
||||
ort_ctx->ort_api->ReleaseTypeInfo(type_info);
|
||||
return runtime_error;
|
||||
}
|
||||
|
||||
const OrtTensorTypeAndShapeInfo *tensor_info;
|
||||
status =
|
||||
ort_ctx->ort_api->CastTypeInfoToTensorInfo(type_info, &tensor_info);
|
||||
if (status != nullptr) {
|
||||
ort_ctx->ort_api->ReleaseTypeInfo(type_info);
|
||||
return runtime_error;
|
||||
}
|
||||
|
||||
size_t num_model_dims;
|
||||
status = ort_ctx->ort_api->GetDimensionsCount(tensor_info, &num_model_dims);
|
||||
std::vector<int64_t> model_dims(num_model_dims);
|
||||
status = ort_ctx->ort_api->GetDimensions(tensor_info, model_dims.data(),
|
||||
num_model_dims);
|
||||
|
||||
size_t model_tensor_size = 1;
|
||||
for (size_t i = 0; i < num_model_dims; ++i)
|
||||
model_tensor_size *= model_dims[i];
|
||||
|
||||
size_t input_tensor_size = 1;
|
||||
for (size_t i = 0; i < input_tensor->dimensions->size; ++i)
|
||||
input_tensor_size *= input_tensor->dimensions->buf[i];
|
||||
|
||||
void *input_tensor_data = input_tensor->data.buf;
|
||||
void *input_tensor_scaled_data = NULL;
|
||||
ort_ctx->ort_api->ReleaseTypeInfo(type_info);
|
||||
size_t num_dims = input_tensor->dimensions->size;
|
||||
int64_t *ort_dims = (int64_t *)malloc(num_dims * sizeof(int64_t));
|
||||
if (!ort_dims) {
|
||||
NN_ERR_PRINTF("Failed to allocate memory for tensor dimensions");
|
||||
return runtime_error;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < num_dims; i++) {
|
||||
ort_dims[i] = input_tensor->dimensions->buf[i];
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType ort_type = convert_wasi_nn_type_to_ort_type(
|
||||
static_cast<tensor_type>(input_tensor->type));
|
||||
|
||||
OrtValue *input_value = nullptr;
|
||||
size_t total_elements = 1;
|
||||
for (size_t i = 0; i < num_dims; i++) {
|
||||
total_elements *= input_tensor->dimensions->buf[i];
|
||||
}
|
||||
|
||||
status = ort_ctx->ort_api->CreateTensorWithDataAsOrtValue(
|
||||
exec_ctx->memory_info, input_tensor->data.buf,
|
||||
get_tensor_element_size(static_cast<tensor_type>(input_tensor->type))
|
||||
* total_elements,
|
||||
ort_dims, num_dims, ort_type, &input_value);
|
||||
|
||||
free(ort_dims);
|
||||
|
||||
if (status != nullptr) {
|
||||
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
|
||||
NN_ERR_PRINTF("Failed to create input tensor");
|
||||
return err;
|
||||
}
|
||||
|
||||
if (exec_ctx->inputs.count(index) > 0) {
|
||||
ort_ctx->ort_api->ReleaseValue(exec_ctx->inputs[index]);
|
||||
}
|
||||
exec_ctx->inputs[index] = input_value;
|
||||
|
||||
NN_INFO_PRINTF("Input tensor set for context %d, index %d", ctx, index);
|
||||
return success;
|
||||
}
|
||||
|
||||
__attribute__((visibility("default"))) wasi_nn_error
|
||||
compute(void *onnx_ctx, graph_execution_context ctx)
|
||||
{
|
||||
if (ctx >= MAX_CONTEXTS || !g_exec_ctxs[ctx].is_initialized) {
|
||||
NN_ERR_PRINTF("Invalid execution context handle: %d", ctx);
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx;
|
||||
std::lock_guard<std::mutex> lock(ort_ctx->mutex);
|
||||
OnnxRuntimeExecCtx *exec_ctx = &g_exec_ctxs[ctx];
|
||||
|
||||
std::vector<OrtValue *> input_values;
|
||||
std::vector<const char *> input_names;
|
||||
|
||||
for (size_t i = 0; i < exec_ctx->input_names.size(); i++) {
|
||||
if (exec_ctx->inputs.count(i) == 0) {
|
||||
NN_ERR_PRINTF("Input tensor not set for index %zu", i);
|
||||
return invalid_argument;
|
||||
}
|
||||
input_values.push_back(exec_ctx->inputs[i]);
|
||||
input_names.push_back(exec_ctx->input_names[i]);
|
||||
}
|
||||
|
||||
for (auto &output : exec_ctx->outputs) {
|
||||
ort_ctx->ort_api->ReleaseValue(output.second);
|
||||
}
|
||||
exec_ctx->outputs.clear();
|
||||
|
||||
std::vector<OrtValue *> output_values(exec_ctx->output_names.size());
|
||||
|
||||
OrtStatus *status = ort_ctx->ort_api->Run(
|
||||
exec_ctx->graph->session, nullptr, input_names.data(),
|
||||
input_values.data(), input_values.size(), exec_ctx->output_names.data(),
|
||||
exec_ctx->output_names.size(), output_values.data());
|
||||
|
||||
for (size_t i = 0; i < output_values.size(); i++) {
|
||||
exec_ctx->outputs[i] = output_values[i];
|
||||
}
|
||||
|
||||
if (status != nullptr) {
|
||||
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
|
||||
NN_ERR_PRINTF("Failed to run inference");
|
||||
return err;
|
||||
}
|
||||
|
||||
NN_INFO_PRINTF("Inference computed for context %d", ctx);
|
||||
return success;
|
||||
}
|
||||
|
||||
__attribute__((visibility("default"))) wasi_nn_error
|
||||
get_output(void *onnx_ctx, graph_execution_context ctx, uint32_t index,
|
||||
tensor_data *out_buffer, uint32_t *out_buffer_size)
|
||||
{
|
||||
if (ctx >= MAX_CONTEXTS || !g_exec_ctxs[ctx].is_initialized) {
|
||||
NN_ERR_PRINTF("Invalid execution context handle: %d", ctx);
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
if (index >= g_exec_ctxs[ctx].output_names.size()) {
|
||||
NN_ERR_PRINTF("Invalid output index: %d (max: %zu)", index,
|
||||
g_exec_ctxs[ctx].output_names.size() - 1);
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
OnnxRuntimeContext *ort_ctx = (OnnxRuntimeContext *)onnx_ctx;
|
||||
std::lock_guard<std::mutex> lock(ort_ctx->mutex);
|
||||
OnnxRuntimeExecCtx *exec_ctx = &g_exec_ctxs[ctx];
|
||||
|
||||
OrtValue *output_value = exec_ctx->outputs[index];
|
||||
if (!output_value) {
|
||||
NN_ERR_PRINTF("Output tensor not available for index %d", index);
|
||||
return runtime_error;
|
||||
}
|
||||
|
||||
OrtTensorTypeAndShapeInfo *tensor_info;
|
||||
OrtStatus *status =
|
||||
ort_ctx->ort_api->GetTensorTypeAndShape(output_value, &tensor_info);
|
||||
if (status != nullptr) {
|
||||
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
|
||||
NN_ERR_PRINTF("Failed to get tensor type and shape");
|
||||
return err;
|
||||
}
|
||||
|
||||
ONNXTensorElementDataType element_type;
|
||||
status = ort_ctx->ort_api->GetTensorElementType(tensor_info, &element_type);
|
||||
if (status != nullptr) {
|
||||
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
|
||||
ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info);
|
||||
NN_ERR_PRINTF("Failed to get tensor element type");
|
||||
return err;
|
||||
}
|
||||
|
||||
size_t num_dims;
|
||||
status = ort_ctx->ort_api->GetDimensionsCount(tensor_info, &num_dims);
|
||||
if (status != nullptr) {
|
||||
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
|
||||
ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info);
|
||||
NN_ERR_PRINTF("Failed to get tensor dimensions count");
|
||||
return err;
|
||||
}
|
||||
|
||||
int64_t *dims = (int64_t *)malloc(num_dims * sizeof(int64_t));
|
||||
if (!dims) {
|
||||
ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info);
|
||||
NN_ERR_PRINTF("Failed to allocate memory for tensor dimensions");
|
||||
return runtime_error;
|
||||
}
|
||||
|
||||
status = ort_ctx->ort_api->GetDimensions(tensor_info, dims, num_dims);
|
||||
if (status != nullptr) {
|
||||
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
|
||||
free(dims);
|
||||
ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info);
|
||||
NN_ERR_PRINTF("Failed to get tensor dimensions");
|
||||
return err;
|
||||
}
|
||||
|
||||
size_t tensor_size;
|
||||
status =
|
||||
ort_ctx->ort_api->GetTensorShapeElementCount(tensor_info, &tensor_size);
|
||||
if (status != nullptr) {
|
||||
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
|
||||
free(dims);
|
||||
ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info);
|
||||
NN_ERR_PRINTF("Failed to get tensor element count");
|
||||
return err;
|
||||
}
|
||||
|
||||
NN_INFO_PRINTF("Output tensor dimensions: ");
|
||||
for (size_t i = 0; i < num_dims; i++) {
|
||||
NN_INFO_PRINTF(" dim[%zu] = %lld", i, dims[i]);
|
||||
}
|
||||
NN_INFO_PRINTF("Total elements: %zu", tensor_size);
|
||||
|
||||
ort_ctx->ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info);
|
||||
free(dims);
|
||||
|
||||
if (tensor_size == 0) {
|
||||
NN_ERR_PRINTF("Tensor is empty (zero elements)");
|
||||
return runtime_error;
|
||||
}
|
||||
|
||||
void *tensor_data = nullptr;
|
||||
status = ort_ctx->ort_api->GetTensorMutableData(output_value, &tensor_data);
|
||||
if (status != nullptr) {
|
||||
wasi_nn_error err = convert_ort_error_to_wasi_nn_error(status);
|
||||
NN_ERR_PRINTF("Failed to get tensor data");
|
||||
return err;
|
||||
}
|
||||
|
||||
if (tensor_data == nullptr) {
|
||||
NN_ERR_PRINTF("Tensor data pointer is null");
|
||||
return runtime_error;
|
||||
}
|
||||
|
||||
size_t element_size;
|
||||
switch (element_type) {
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
|
||||
element_size = sizeof(float);
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
|
||||
element_size = sizeof(uint16_t);
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
|
||||
element_size = sizeof(double);
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
|
||||
element_size = sizeof(int32_t);
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
|
||||
element_size = sizeof(int64_t);
|
||||
break;
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
|
||||
element_size = sizeof(uint8_t);
|
||||
break;
|
||||
default:
|
||||
NN_ERR_PRINTF("Unsupported tensor element type: %d", element_type);
|
||||
return unsupported_operation;
|
||||
}
|
||||
|
||||
size_t output_size_bytes = tensor_size * element_size;
|
||||
|
||||
NN_INFO_PRINTF("Output tensor size: %zu elements, element size: %zu bytes, "
|
||||
"total: %zu bytes",
|
||||
tensor_size, element_size, output_size_bytes);
|
||||
|
||||
if (*out_buffer_size < output_size_bytes) {
|
||||
NN_ERR_PRINTF(
|
||||
"Output buffer too small: %u bytes provided, %zu bytes needed",
|
||||
*out_buffer_size, output_size_bytes);
|
||||
*out_buffer_size = output_size_bytes;
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
if (tensor_data == nullptr) {
|
||||
NN_ERR_PRINTF("Tensor data is null");
|
||||
return runtime_error;
|
||||
}
|
||||
|
||||
if (out_buffer->buf == nullptr) {
|
||||
NN_ERR_PRINTF("Output buffer is null");
|
||||
return invalid_argument;
|
||||
}
|
||||
|
||||
memcpy(out_buffer->buf, tensor_data, output_size_bytes);
|
||||
*out_buffer_size = output_size_bytes;
|
||||
|
||||
NN_INFO_PRINTF(
|
||||
"Output tensor retrieved for context %d, index %d, size %zu bytes", ctx,
|
||||
index, output_size_bytes);
|
||||
return success;
|
||||
}
|
||||
|
||||
} /* End of extern "C" */
|
Loading…
Reference in New Issue
Block a user