mirror of
https://github.com/bytecodealliance/wasm-micro-runtime.git
synced 2025-09-06 01:41:35 +00:00
wasi_nn_llamacpp.c: reject invalid graph and execution context (#4422)
* return valid graph and execution context instead of using stack garbage. (always 0 for now because we don't implement multiple graph/context for this backend.) * validate user-given graph and execution context values. reject invalid ones.
This commit is contained in:
parent
ebf1404ad1
commit
da6019f749
|
@ -305,6 +305,11 @@ __load_by_name_with_configuration(void *ctx, const char *filename, graph *g)
|
||||||
{
|
{
|
||||||
struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
|
struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
|
||||||
|
|
||||||
|
if (backend_ctx->model != NULL) {
|
||||||
|
// we only implement a single graph
|
||||||
|
return unsupported_operation;
|
||||||
|
}
|
||||||
|
|
||||||
// make sure backend_ctx->config is initialized
|
// make sure backend_ctx->config is initialized
|
||||||
|
|
||||||
struct llama_model_params model_params =
|
struct llama_model_params model_params =
|
||||||
|
@ -323,6 +328,7 @@ __load_by_name_with_configuration(void *ctx, const char *filename, graph *g)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
backend_ctx->model = model;
|
backend_ctx->model = model;
|
||||||
|
*g = 0;
|
||||||
|
|
||||||
return success;
|
return success;
|
||||||
}
|
}
|
||||||
|
@ -363,6 +369,16 @@ init_execution_context(void *ctx, graph g, graph_execution_context *exec_ctx)
|
||||||
{
|
{
|
||||||
struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
|
struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
|
||||||
|
|
||||||
|
if (g != 0 || backend_ctx->model == NULL) {
|
||||||
|
// we only implement a single graph
|
||||||
|
return runtime_error;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (backend_ctx->ctx != NULL) {
|
||||||
|
// we only implement a single context
|
||||||
|
return unsupported_operation;
|
||||||
|
}
|
||||||
|
|
||||||
struct llama_context_params ctx_params =
|
struct llama_context_params ctx_params =
|
||||||
llama_context_params_from_wasi_nn_llama_config(&backend_ctx->config);
|
llama_context_params_from_wasi_nn_llama_config(&backend_ctx->config);
|
||||||
struct llama_context *llama_ctx =
|
struct llama_context *llama_ctx =
|
||||||
|
@ -373,6 +389,7 @@ init_execution_context(void *ctx, graph g, graph_execution_context *exec_ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
backend_ctx->ctx = llama_ctx;
|
backend_ctx->ctx = llama_ctx;
|
||||||
|
*exec_ctx = 0;
|
||||||
|
|
||||||
NN_INFO_PRINTF("n_predict = %d, n_ctx = %d", backend_ctx->config.n_predict,
|
NN_INFO_PRINTF("n_predict = %d, n_ctx = %d", backend_ctx->config.n_predict,
|
||||||
llama_n_ctx(backend_ctx->ctx));
|
llama_n_ctx(backend_ctx->ctx));
|
||||||
|
@ -384,6 +401,12 @@ set_input(void *ctx, graph_execution_context exec_ctx, uint32_t index,
|
||||||
tensor *wasi_nn_tensor)
|
tensor *wasi_nn_tensor)
|
||||||
{
|
{
|
||||||
struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
|
struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
|
||||||
|
|
||||||
|
if (exec_ctx != 0 || backend_ctx->ctx == NULL) {
|
||||||
|
// we only implement a single context
|
||||||
|
return runtime_error;
|
||||||
|
}
|
||||||
|
|
||||||
// tensor->data is the prompt string.
|
// tensor->data is the prompt string.
|
||||||
char *prompt_text = (char *)wasi_nn_tensor->data.buf;
|
char *prompt_text = (char *)wasi_nn_tensor->data.buf;
|
||||||
uint32_t prompt_text_len = wasi_nn_tensor->data.size;
|
uint32_t prompt_text_len = wasi_nn_tensor->data.size;
|
||||||
|
@ -433,6 +456,11 @@ compute(void *ctx, graph_execution_context exec_ctx)
|
||||||
struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
|
struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
|
||||||
wasi_nn_error ret = runtime_error;
|
wasi_nn_error ret = runtime_error;
|
||||||
|
|
||||||
|
if (exec_ctx != 0 || backend_ctx->ctx == NULL) {
|
||||||
|
// we only implement a single context
|
||||||
|
return runtime_error;
|
||||||
|
}
|
||||||
|
|
||||||
// reset the generation buffer
|
// reset the generation buffer
|
||||||
if (backend_ctx->generation == NULL) {
|
if (backend_ctx->generation == NULL) {
|
||||||
backend_ctx->generation =
|
backend_ctx->generation =
|
||||||
|
@ -554,6 +582,11 @@ get_output(void *ctx, graph_execution_context exec_ctx, uint32_t index,
|
||||||
{
|
{
|
||||||
struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
|
struct LlamaContext *backend_ctx = (struct LlamaContext *)ctx;
|
||||||
|
|
||||||
|
if (exec_ctx != 0 || backend_ctx->ctx == NULL) {
|
||||||
|
// we only implement a single context
|
||||||
|
return runtime_error;
|
||||||
|
}
|
||||||
|
|
||||||
// Compatibility with WasmEdge
|
// Compatibility with WasmEdge
|
||||||
if (index > 1) {
|
if (index > 1) {
|
||||||
NN_ERR_PRINTF("Invalid output index %d", index);
|
NN_ERR_PRINTF("Invalid output index %d", index);
|
||||||
|
|
Loading…
Reference in New Issue
Block a user