feat: ouptut and compute functions

This commit is contained in:
Ahmedounet 2022-06-21 10:37:02 +02:00
parent 0bbbb15ca7
commit de7bfdda53
3 changed files with 63 additions and 7 deletions

View File

@ -72,13 +72,45 @@ wasi_nn_set_input(
return _set_input(tensor_struct); return _set_input(tensor_struct);
} }
void uint32_t
wasi_nn_compute() wasi_nn_get_output(wasm_exec_env_t exec_env, graph_execution_context context, uint32_t index, uint32_t *input_tensor_size, uint32_t input_tensor_type,
{} uint32_t *input_tensor)
{
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
uint32_t w_index = wasm_runtime_addr_app_to_native(instance, index);
graph_execution_context graph_context = wasm_runtime_addr_app_to_native(instance, context);
tensor_data data =
(tensor_data)wasm_runtime_addr_app_to_native(instance, input_tensor);
tensor_dimensions dimensions =
(tensor_dimensions)wasm_runtime_addr_app_to_native(instance,
input_tensor_size);
tensor_type type = (tensor_type)wasm_runtime_addr_app_to_native(
instance, input_tensor_type);
tensor tensor_struct = { .dimensions = dimensions,
.type = type,
.data = data };
return _get_output(graph_context, w_index ,tensor_struct);
}
uint32_t
wasi_nn_compute(wasm_exec_env_t exec_env , graph_execution_context context)
{
wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env);
graph_execution_context context = wasm_runtime_addr_app_to_native(instance, context);
return _compute(_compute(context));
}
void
wasi_nn_get_output()
{}
/* clang-format off */ /* clang-format off */
#define REG_NATIVE_FUNC(func_name, signature) \ #define REG_NATIVE_FUNC(func_name, signature) \

View File

@ -59,4 +59,22 @@ _set_input(tensor input_tensor)
return success; return success;
} }
uint32_t _compute( graph_execution_context context ){
interpreter->Invoke();
return success;
}
uint32_t _get_output(graph_execution_context context, uint32_t index, tensor tensor){
auto* output = interpreter->typed_output_tensor<float>(0);
return success;
}

View File

@ -15,6 +15,12 @@ _load(graph_builder_array builder, graph_encoding encoding);
uint32_t uint32_t
_set_input(tensor input_tensor); _set_input(tensor input_tensor);
uint32_t _get_output(graph_execution_context context, uint32_t index, tensor tensor);
uint32_t
_compute( graph_execution_context context );
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif