nn-cli: add an option to use load_by_name (#4490)

by specifying a name for --load-graph.
for example,
```
--load-graph=name=foo
```
This commit is contained in:
YAMAMOTO Takashi 2025-07-17 10:47:20 +09:00 committed by GitHub
parent 79408e59cc
commit 248e10b79e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -31,18 +31,18 @@ load_graph(char *options)
const char *id = "default"; const char *id = "default";
wasi_ephemeral_nn_graph_builder *builders = NULL; wasi_ephemeral_nn_graph_builder *builders = NULL;
size_t nbuilders = 0; size_t nbuilders = 0;
const char *name = NULL;
enum { enum {
opt_id, opt_id,
opt_file, opt_file,
opt_name,
opt_encoding, opt_encoding,
opt_target, opt_target,
}; };
static char *const keylistp[] = { static char *const keylistp[] = {
[opt_id] = "id", [opt_id] = "id", [opt_file] = "file",
[opt_file] = "file", [opt_name] = "name", [opt_encoding] = "encoding",
[opt_encoding] = "encoding", [opt_target] = "target", NULL,
[opt_target] = "target",
NULL,
}; };
while (*options) { while (*options) {
extern char *suboptarg; extern char *suboptarg;
@ -74,6 +74,13 @@ load_graph(char *options)
exit(1); exit(1);
} }
break; break;
case opt_name:
if (value == NULL) {
fprintf(stderr, "no value for %s\n", saved);
exit(2);
}
name = value;
break;
case opt_encoding: case opt_encoding:
if (value == NULL) { if (value == NULL) {
fprintf(stderr, "no value for %s\n", saved); fprintf(stderr, "no value for %s\n", saved);
@ -94,13 +101,25 @@ load_graph(char *options)
} }
} }
if (name != NULL && nbuilders != 0) {
fprintf(stderr, "name and file are exclusive\n");
exit(1);
}
wasi_ephemeral_nn_error nnret; wasi_ephemeral_nn_error nnret;
wasi_ephemeral_nn_graph g; wasi_ephemeral_nn_graph g;
nnret = wasi_ephemeral_nn_load(builders, nbuilders, encoding, target, &g); if (name != NULL) {
size_t i; /* we ignore encoding and target */
for (i = 0; i < nbuilders; i++) { nnret = wasi_ephemeral_nn_load_by_name(name, strlen(name), &g);
wasi_ephemeral_nn_graph_builder *b = &builders[i]; }
unmap_file(b->buf, b->size); else {
nnret =
wasi_ephemeral_nn_load(builders, nbuilders, encoding, target, &g);
size_t i;
for (i = 0; i < nbuilders; i++) {
wasi_ephemeral_nn_graph_builder *b = &builders[i];
unmap_file(b->buf, b->size);
}
} }
if (nnret != wasi_ephemeral_nn_error_success) { if (nnret != wasi_ephemeral_nn_error_success) {
fprintf(stderr, "load failed with %d\n", (int)nnret); fprintf(stderr, "load failed with %d\n", (int)nnret);