diff --git a/core/iwasm/common/wasm_runtime_common.c b/core/iwasm/common/wasm_runtime_common.c index 61b8e7796..e86a2a3ea 100644 --- a/core/iwasm/common/wasm_runtime_common.c +++ b/core/iwasm/common/wasm_runtime_common.c @@ -1816,20 +1816,32 @@ wasm_runtime_wasi_nn_registry_set_args(WASINNRegistry *registry, if ((sizeof(uint32_t *) * n_graphs) >= UINT32_MAX) { LOG_ERROR("Invalid size for wasm_runtime_malloc."); - return NULL; + return false; } registry->n_graphs = n_graphs; - registry->target = - (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); - registry->encoding = - (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); - registry->loaded = - (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); - registry->model_names = - (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); - registry->graph_paths = - (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs); + if (!(registry->target = + (uint32_t **)wasm_runtime_malloc(sizeof(uint32_t *) * n_graphs)) + || !(registry->encoding = (uint32_t **)wasm_runtime_malloc( + sizeof(uint32_t *) * n_graphs)) + || !(registry->loaded = (uint32_t **)wasm_runtime_malloc( + sizeof(uint32_t *) * n_graphs)) + || !(registry->model_names = + (char **)wasm_runtime_malloc(sizeof(char *) * n_graphs)) + || !(registry->graph_paths = + (char **)wasm_runtime_malloc(sizeof(char *) * n_graphs))) { + if (registry->target) + wasm_runtime_free(registry->target); + if (registry->encoding) + wasm_runtime_free(registry->encoding); + if (registry->loaded) + wasm_runtime_free(registry->loaded); + if (registry->model_names) + wasm_runtime_free(registry->model_names); + if (registry->graph_paths) + wasm_runtime_free(registry->graph_paths); + return false; + } memset(registry->target, 0, sizeof(uint32_t *) * n_graphs); memset(registry->encoding, 0, sizeof(uint32_t *) * n_graphs); memset(registry->loaded, 0, sizeof(uint32_t *) * n_graphs);