feat: "formatted files"

This commit is contained in:
Ahmedounet 2022-06-13 17:51:14 +02:00
parent be33067436
commit e1b2e5a60a
5 changed files with 47 additions and 19 deletions

View File

@ -44,14 +44,20 @@ _load(graph_builder_array builder, graph_encoding encoding)
return success; return success;
} }
uint32_t
uint32_t set_input() _set_input(tensor input_tensor)
{ {
auto *input = interpreter->typed_input_tensor<float>(0);
for (int i=0 ; i<input_tensor.size() ; i++ ) for (int i = 0; i < input_tensor.dimensions[0]; i++) {
{ input[i] = (float)input_tensor.data[i];
input[i]= input_tensor[i]; }
}
if (input == nullptr) {
return invalid_argument;
}
else {
return success;
}
} }

View File

@ -12,6 +12,9 @@ extern "C" {
uint32_t uint32_t
_load(graph_builder_array builder, graph_encoding encoding); _load(graph_builder_array builder, graph_encoding encoding);
uint32_t
_set_input(tensor input_tensor);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -46,16 +46,35 @@ void
wasi_nn_init_execution_context() wasi_nn_init_execution_context()
{} {}
void wasi_nn_set_input(wasm_exec_env_t exec_env ,graph_execution_context context, uint32_t index, tensor tensor) uint32_t
wasi_nn_set_input(
wasm_exec_env_t exec_env, graph_execution_context context, uint32_t index,
uint32_t *input_tensor_size, uint32_t input_tensor_type,
uint32_t *input_tensor) // Replaced struct by values inside of
// it as WASMR does not support structs
{ {
printf("Inside wasi_nn_set_input!\n\n"); printf("Inside wasi_nn_set_input!\n\n");
// interpreter->AllocateTensors(); wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
tensor_data data =
(tensor_data)wasm_runtime_addr_app_to_native(instance, input_tensor);
tensor_dimensions dimensions =
(tensor_dimensions)wasm_runtime_addr_app_to_native(instance,
input_tensor_size);
tensor_type type = (tensor_type)wasm_runtime_addr_app_to_native(
instance, input_tensor_type);
tensor tensor_struct = { .dimensions = dimensions,
.type = type,
.data = data };
return _set_input(tensor_struct);
} }
void
void wasi_nn_compute() wasi_nn_compute()
{ {}
void void
wasi_nn_get_output() wasi_nn_get_output()
@ -68,6 +87,7 @@ wasi_nn_get_output()
static NativeSymbol native_symbols_wasi_nn[] = { static NativeSymbol native_symbols_wasi_nn[] = {
REG_NATIVE_FUNC(load, "(ii)i"), REG_NATIVE_FUNC(load, "(ii)i"),
REG_NATIVE_FUNC(set_input, "(ii*i*)i"),
}; };
uint32_t uint32_t

View File

@ -12,12 +12,7 @@ typedef uint32_t buffer_size;
typedef uint32_t graph_execution_context; typedef uint32_t graph_execution_context;
typedef enum { typedef enum { success = 0, invalid_argument, missing_memory, busy } nn_erno;
success = 0,
invalid_argument,
missing_memory,
busy
} nn_erno;
typedef uint32_t *tensor_dimensions; typedef uint32_t *tensor_dimensions;
@ -45,8 +40,10 @@ load(graph_builder_array builder, graph_encoding encoding);
void void
init_execution_context(); init_execution_context();
void uint32_t
set_input(); set_input(graph_execution_context context, uint32_t index,
uint32_t *input_tensor_size, uint32_t input_tensor_type,
uint32_t *input_tensor);
void void
compute(); compute();

View File

@ -30,6 +30,8 @@ generate_float(int iteration, double seed1, float seed2)
load(arr, 1); load(arr, 1);
float ret; float ret;
set_input(0, 0, size, 3, arr);
printf("calling into WASM function: %s\n", __FUNCTION__); printf("calling into WASM function: %s\n", __FUNCTION__);
for (int i = 0; i < iteration; i++) { for (int i = 0; i < iteration; i++) {