diff --git a/core/iwasm/libraries/wasi-nn/lib_run_inference.cpp b/core/iwasm/libraries/wasi-nn/lib_run_inference.cpp index b9f13ecef..a88f184ee 100644 --- a/core/iwasm/libraries/wasi-nn/lib_run_inference.cpp +++ b/core/iwasm/libraries/wasi-nn/lib_run_inference.cpp @@ -44,14 +44,20 @@ _load(graph_builder_array builder, graph_encoding encoding) return success; } - -uint32_t set_input() +uint32_t +_set_input(tensor input_tensor) { + auto *input = interpreter->typed_input_tensor(0); - for (int i=0 ; iAllocateTensors(); + wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env); + tensor_data data = + (tensor_data)wasm_runtime_addr_app_to_native(instance, input_tensor); + tensor_dimensions dimensions = + (tensor_dimensions)wasm_runtime_addr_app_to_native(instance, + input_tensor_size); + + tensor_type type = (tensor_type)wasm_runtime_addr_app_to_native( + instance, input_tensor_type); + + tensor tensor_struct = { .dimensions = dimensions, + .type = type, + .data = data }; + + return _set_input(tensor_struct); } - -void wasi_nn_compute() -{ +void +wasi_nn_compute() +{} void wasi_nn_get_output() @@ -68,6 +87,7 @@ wasi_nn_get_output() static NativeSymbol native_symbols_wasi_nn[] = { REG_NATIVE_FUNC(load, "(ii)i"), + REG_NATIVE_FUNC(set_input, "(ii*i*)i"), }; uint32_t diff --git a/core/iwasm/libraries/wasi-nn/wasi_nn.h b/core/iwasm/libraries/wasi-nn/wasi_nn.h index 24a83d8fc..e1c4fd3d4 100644 --- a/core/iwasm/libraries/wasi-nn/wasi_nn.h +++ b/core/iwasm/libraries/wasi-nn/wasi_nn.h @@ -12,12 +12,7 @@ typedef uint32_t buffer_size; typedef uint32_t graph_execution_context; -typedef enum { - success = 0, - invalid_argument, - missing_memory, - busy -} nn_erno; +typedef enum { success = 0, invalid_argument, missing_memory, busy } nn_erno; typedef uint32_t *tensor_dimensions; @@ -45,8 +40,10 @@ load(graph_builder_array builder, graph_encoding encoding); void init_execution_context(); -void -set_input(); +uint32_t +set_input(graph_execution_context context, uint32_t index, + uint32_t *input_tensor_size, uint32_t input_tensor_type, + uint32_t *input_tensor); void compute(); diff --git a/samples/basic/wasm-apps/testapp.c b/samples/basic/wasm-apps/testapp.c index 4e264bdad..50747abc1 100644 --- a/samples/basic/wasm-apps/testapp.c +++ b/samples/basic/wasm-apps/testapp.c @@ -30,6 +30,8 @@ generate_float(int iteration, double seed1, float seed2) load(arr, 1); float ret; + set_input(0, 0, size, 3, arr); + printf("calling into WASM function: %s\n", __FUNCTION__); for (int i = 0; i < iteration; i++) {