diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c index fbd6e33f4..7921ec953 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn.c @@ -507,10 +507,35 @@ fail: return res; } +static wasi_nn_error +copyin_and_nul_terminate(wasm_module_inst_t inst, char *name, uint32_t name_len, + char **resultp) +{ + char *nul_terminated_name; + if (!wasm_runtime_validate_native_addr(inst, name, name_len)) { + return invalid_argument; + } + nul_terminated_name = wasm_runtime_malloc(name_len + 1); + if (nul_terminated_name == NULL) { + return runtime_error; + } + bh_memcpy_s(nul_terminated_name, name_len + 1, name, name_len); + nul_terminated_name[name_len] = '\0'; /* ensure NUL termination */ + if (strlen(nul_terminated_name) != name_len) { + /* reject names containing '\0' for now */ + wasm_runtime_free(nul_terminated_name); + return invalid_argument; + } + *resultp = nul_terminated_name; + return success; +} + wasi_nn_error wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, graph *g) { + WASINNContext *wasi_nn_ctx = NULL; + char *nul_terminated_name = NULL; wasi_nn_error res; wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env); @@ -518,25 +543,21 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, return runtime_error; } - if (!wasm_runtime_validate_native_addr(instance, name, name_len)) { - NN_ERR_PRINTF("name is invalid"); - return invalid_argument; - } - if (!wasm_runtime_validate_native_addr(instance, g, (uint64)sizeof(graph))) { NN_ERR_PRINTF("graph is invalid"); return invalid_argument; } - if (name_len == 0 || name[name_len] != '\0') { - NN_ERR_PRINTF("Invalid filename"); - return invalid_argument; + res = copyin_and_nul_terminate(instance, name, name_len, + &nul_terminated_name); + if (res != success) { + goto fail; } - NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME %s...", name); + NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME %s...", nul_terminated_name); - WASINNContext *wasi_nn_ctx = lock_ctx(instance); + wasi_nn_ctx = lock_ctx(instance); if (wasi_nn_ctx == NULL) { res = busy; goto fail; @@ -547,14 +568,20 @@ wasi_nn_load_by_name(wasm_exec_env_t exec_env, char *name, uint32_t name_len, goto fail; call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name, res, - wasi_nn_ctx->backend_ctx, name, name_len, g); + wasi_nn_ctx->backend_ctx, nul_terminated_name, name_len, + g); if (res != success) goto fail; wasi_nn_ctx->is_model_loaded = true; res = success; fail: - unlock_ctx(wasi_nn_ctx); + if (nul_terminated_name != NULL) { + wasm_runtime_free(nul_terminated_name); + } + if (wasi_nn_ctx != NULL) { + unlock_ctx(wasi_nn_ctx); + } return res; } @@ -563,6 +590,9 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name, int32_t name_len, char *config, int32_t config_len, graph *g) { + WASINNContext *wasi_nn_ctx = NULL; + char *nul_terminated_name = NULL; + char *nul_terminated_config = NULL; wasi_nn_error res; wasm_module_inst_t instance = wasm_runtime_get_module_inst(exec_env); @@ -570,30 +600,27 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name, return runtime_error; } - if (!wasm_runtime_validate_native_addr(instance, name, name_len)) { - NN_ERR_PRINTF("name is invalid"); - return invalid_argument; - } - if (!wasm_runtime_validate_native_addr(instance, g, (uint64)sizeof(graph))) { NN_ERR_PRINTF("graph is invalid"); return invalid_argument; } - if (name_len == 0 || name[name_len] != '\0') { - NN_ERR_PRINTF("Invalid filename"); - return invalid_argument; + res = copyin_and_nul_terminate(instance, name, name_len, + &nul_terminated_name); + if (res != success) { + goto fail; + } + res = copyin_and_nul_terminate(instance, config, config_len, + &nul_terminated_config); + if (res != success) { + goto fail; } - if (!config || config_len == 0 || config[config_len] != '\0') { - NN_ERR_PRINTF("Invalid config"); - return invalid_argument; - } + NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME_WITH_CONFIG %s %s...", + nul_terminated_name, nul_terminated_config); - NN_DBG_PRINTF("[WASI NN] LOAD_BY_NAME_WITH_CONFIG %s %s...", name, config); - - WASINNContext *wasi_nn_ctx = lock_ctx(instance); + wasi_nn_ctx = lock_ctx(instance); if (wasi_nn_ctx == NULL) { res = busy; goto fail; @@ -605,15 +632,23 @@ wasi_nn_load_by_name_with_config(wasm_exec_env_t exec_env, char *name, ; call_wasi_nn_func(wasi_nn_ctx->backend, load_by_name_with_config, res, - wasi_nn_ctx->backend_ctx, name, name_len, config, - config_len, g); + wasi_nn_ctx->backend_ctx, nul_terminated_name, name_len, + nul_terminated_config, config_len, g); if (res != success) goto fail; wasi_nn_ctx->is_model_loaded = true; res = success; fail: - unlock_ctx(wasi_nn_ctx); + if (nul_terminated_name != NULL) { + wasm_runtime_free(nul_terminated_name); + } + if (nul_terminated_config != NULL) { + wasm_runtime_free(nul_terminated_config); + } + if (wasi_nn_ctx != NULL) { + unlock_ctx(wasi_nn_ctx); + } return res; }