From da6019f74947a6fd472ab6d2659904433d9a7eb8 Mon Sep 17 00:00:00 2001 From: YAMAMOTO Takashi Date: Tue, 1 Jul 2025 20:31:00 +0900 Subject: [PATCH] wasi_nn_llamacpp.c: reject invalid graph and execution context (#4422) * return valid graph and execution context instead of using stack garbage. (always 0 for now because we don't implement multiple graph/context for this backend.) * validate user-given graph and execution context values. reject invalid ones. --- .../libraries/wasi-nn/src/wasi_nn_llamacpp.c | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c b/core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c index 5cad663dd..ce72afa74 100644 --- a/core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c +++ b/core/iwasm/libraries/wasi-nn/src/wasi_nn_llamacpp.c @@ -305,6 +305,11 @@ __load_by_name_with_configuration(void *ctx, const char *filename, graph *g) { struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx; + if (backend_ctx->model != NULL) { + // we only implement a single graph + return unsupported_operation; + } + // make sure backend_ctx->config is initialized struct llama_model_params model_params = @@ -323,6 +328,7 @@ __load_by_name_with_configuration(void *ctx, const char *filename, graph *g) #endif backend_ctx->model = model; + *g = 0; return success; } @@ -363,6 +369,16 @@ init_execution_context(void *ctx, graph g, graph_execution_context *exec_ctx) { struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx; + if (g != 0 || backend_ctx->model == NULL) { + // we only implement a single graph + return runtime_error; + } + + if (backend_ctx->ctx != NULL) { + // we only implement a single context + return unsupported_operation; + } + struct llama_context_params ctx_params = llama_context_params_from_wasi_nn_llama_config(&backend_ctx->config); struct llama_context *llama_ctx = @@ -373,6 +389,7 @@ init_execution_context(void *ctx, graph g, graph_execution_context *exec_ctx) } backend_ctx->ctx = llama_ctx; + *exec_ctx = 0; NN_INFO_PRINTF("n_predict = %d, n_ctx = %d", backend_ctx->config.n_predict, llama_n_ctx(backend_ctx->ctx)); @@ -384,6 +401,12 @@ set_input(void *ctx, graph_execution_context exec_ctx, uint32_t index, tensor *wasi_nn_tensor) { struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx; + + if (exec_ctx != 0 || backend_ctx->ctx == NULL) { + // we only implement a single context + return runtime_error; + } + // tensor->data is the prompt string. char *prompt_text = (char *)wasi_nn_tensor->data.buf; uint32_t prompt_text_len = wasi_nn_tensor->data.size; @@ -433,6 +456,11 @@ compute(void *ctx, graph_execution_context exec_ctx) struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx; wasi_nn_error ret = runtime_error; + if (exec_ctx != 0 || backend_ctx->ctx == NULL) { + // we only implement a single context + return runtime_error; + } + // reset the generation buffer if (backend_ctx->generation == NULL) { backend_ctx->generation = @@ -554,6 +582,11 @@ get_output(void *ctx, graph_execution_context exec_ctx, uint32_t index, { struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx; + if (exec_ctx != 0 || backend_ctx->ctx == NULL) { + // we only implement a single context + return runtime_error; + } + // Compatibility with WasmEdge if (index > 1) { NN_ERR_PRINTF("Invalid output index %d", index);