Use correct type when calling wasm_runtime_malloc for module_names and graph_paths.

Add check for some wasm_runtime_malloc() calls.
This commit is contained in:
zhanheng1 2026-03-25 14:12:04 +08:00
parent 9d8e31a494
commit cc9520d2e4

View File

@ -1816,20 +1816,32 @@ wasm_runtime_wasi_nn_registry_set_args(WASINNRegistry *registry,
if ((sizeof(uint32_t *) * n_graphs) >= UINT32_MAX) { if ((sizeof(uint32_t *) * n_graphs) >= UINT32_MAX) {
LOG_ERROR("Invalid size for wasm_runtime_malloc."); LOG_ERROR("Invalid size for wasm_runtime_malloc.");
return NULL; return false;
} }
registry->n_graphs = n_graphs; registry->n_graphs = n_graphs;
registry->target = if (!(registry->target =
(uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs))
registry->encoding = || !(registry->encoding = (uint32_t **)wasm_runtime_malloc(
(uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); sizeof(uint32_t *) * n_graphs))
registry->loaded = || !(registry->loaded = (uint32_t **)wasm_runtime_malloc(
(uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); sizeof(uint32_t *) * n_graphs))
registry->model_names = || !(registry->model_names =
(uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); (char **)wasm_runtime_malloc(sizeof(char *) * n_graphs))
registry->graph_paths = || !(registry->graph_paths =
(uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); (char **)wasm_runtime_malloc(sizeof(char *) * n_graphs))) {
if (registry->target)
wasm_runtime_free(registry->target);
if (registry->encoding)
wasm_runtime_free(registry->encoding);
if (registry->loaded)
wasm_runtime_free(registry->loaded);
if (registry->model_names)
wasm_runtime_free(registry->model_names);
if (registry->graph_paths)
wasm_runtime_free(registry->graph_paths);
return false;
}
memset(registry->target, 0, sizeof(uint32_t *) * n_graphs); memset(registry->target, 0, sizeof(uint32_t *) * n_graphs);
memset(registry->encoding, 0, sizeof(uint32_t *) * n_graphs); memset(registry->encoding, 0, sizeof(uint32_t *) * n_graphs);
memset(registry->loaded, 0, sizeof(uint32_t *) * n_graphs); memset(registry->loaded, 0, sizeof(uint32_t *) * n_graphs);