mirror of
https://github.com/bytecodealliance/wasm-micro-runtime.git
synced 2025-05-15 06:01:14 +00:00
wasi-nn: Support uint8 quantized networks (#2433)
Support (non-full) uint8 quantized networks. Inputs and outputs are still required to be `float`. The (de)quantization is done internally by wasi-nn. Example generated from `quantized_model.py`:  Visualization with [netron](https://netron.app/).
This commit is contained in:
parent
a550f4d9f7
commit
0b0af1b3df
2
core/iwasm/libraries/wasi-nn/.gitignore
vendored
Normal file
2
core/iwasm/libraries/wasi-nn/.gitignore
vendored
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
**/*.wasm
|
||||||
|
**/*.tflite
|
|
@ -285,14 +285,37 @@ tensorflowlite_set_input(void *tflite_ctx, graph_execution_context ctx,
|
||||||
return invalid_argument;
|
return invalid_argument;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto *input =
|
if (tensor->quantization.type == kTfLiteNoQuantization) {
|
||||||
tfl_ctx->interpreters[ctx].interpreter->typed_input_tensor<float>(
|
NN_DBG_PRINTF("No quantization information. Using float as default");
|
||||||
index);
|
float *it =
|
||||||
if (input == NULL)
|
tfl_ctx->interpreters[ctx].interpreter->typed_input_tensor<float>(
|
||||||
return missing_memory;
|
index);
|
||||||
|
|
||||||
|
int size = model_tensor_size * sizeof(float);
|
||||||
|
bh_memcpy_s(it, size, input_tensor->data, size);
|
||||||
|
}
|
||||||
|
else { // TODO: Assumming uint8 quantized networks.
|
||||||
|
TfLiteAffineQuantization *quant_info =
|
||||||
|
(TfLiteAffineQuantization *)tensor->quantization.params;
|
||||||
|
if (quant_info->scale->size != 1 || quant_info->zero_point->size != 1) {
|
||||||
|
NN_ERR_PRINTF("Quantization per channel is not supported");
|
||||||
|
return runtime_error;
|
||||||
|
}
|
||||||
|
uint8_t *it =
|
||||||
|
tfl_ctx->interpreters[ctx].interpreter->typed_input_tensor<uint8_t>(
|
||||||
|
index);
|
||||||
|
|
||||||
|
float scale = quant_info->scale->data[0];
|
||||||
|
float zero_point = (float)quant_info->zero_point->data[0];
|
||||||
|
NN_DBG_PRINTF("input tensor: (scale, offset) = (%f, %f)", scale,
|
||||||
|
zero_point);
|
||||||
|
|
||||||
|
float *input_tensor_f = (float *)input_tensor->data;
|
||||||
|
for (uint32_t i = 0; i < model_tensor_size; ++i) {
|
||||||
|
it[i] = (uint8_t)(input_tensor_f[i] / scale + zero_point);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bh_memcpy_s(input, model_tensor_size * sizeof(float), input_tensor->data,
|
|
||||||
model_tensor_size * sizeof(float));
|
|
||||||
return success;
|
return success;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -325,6 +348,7 @@ tensorflowlite_get_output(void *tflite_ctx, graph_execution_context ctx,
|
||||||
NN_DBG_PRINTF("Number of tensors (%d)", num_output_tensors);
|
NN_DBG_PRINTF("Number of tensors (%d)", num_output_tensors);
|
||||||
|
|
||||||
if (index + 1 > num_output_tensors) {
|
if (index + 1 > num_output_tensors) {
|
||||||
|
NN_ERR_PRINTF("Index %d is invalid.", index);
|
||||||
return runtime_error;
|
return runtime_error;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -343,15 +367,37 @@ tensorflowlite_get_output(void *tflite_ctx, graph_execution_context ctx,
|
||||||
return missing_memory;
|
return missing_memory;
|
||||||
}
|
}
|
||||||
|
|
||||||
float *tensor_f =
|
if (tensor->quantization.type == kTfLiteNoQuantization) {
|
||||||
tfl_ctx->interpreters[ctx].interpreter->typed_output_tensor<float>(
|
NN_DBG_PRINTF("No quantization information");
|
||||||
index);
|
float *ot =
|
||||||
for (uint32_t i = 0; i < model_tensor_size; ++i)
|
tfl_ctx->interpreters[ctx].interpreter->typed_output_tensor<float>(
|
||||||
NN_DBG_PRINTF("output: %f", tensor_f[i]);
|
index);
|
||||||
|
|
||||||
|
int size = model_tensor_size * sizeof(float);
|
||||||
|
bh_memcpy_s(output_tensor, size, ot, size);
|
||||||
|
}
|
||||||
|
else { // TODO: Assumming uint8 quantized networks.
|
||||||
|
TfLiteAffineQuantization *quant_info =
|
||||||
|
(TfLiteAffineQuantization *)tensor->quantization.params;
|
||||||
|
if (quant_info->scale->size != 1 || quant_info->zero_point->size != 1) {
|
||||||
|
NN_ERR_PRINTF("Quantization per channel is not supported");
|
||||||
|
return runtime_error;
|
||||||
|
}
|
||||||
|
uint8_t *ot = tfl_ctx->interpreters[ctx]
|
||||||
|
.interpreter->typed_output_tensor<uint8_t>(index);
|
||||||
|
|
||||||
|
float scale = quant_info->scale->data[0];
|
||||||
|
float zero_point = (float)quant_info->zero_point->data[0];
|
||||||
|
NN_DBG_PRINTF("output tensor: (scale, offset) = (%f, %f)", scale,
|
||||||
|
zero_point);
|
||||||
|
|
||||||
|
float *output_tensor_f = (float *)output_tensor;
|
||||||
|
for (uint32_t i = 0; i < model_tensor_size; ++i) {
|
||||||
|
output_tensor_f[i] = (ot[i] - zero_point) * scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
*output_tensor_size = model_tensor_size;
|
*output_tensor_size = model_tensor_size;
|
||||||
bh_memcpy_s(output_tensor, model_tensor_size * sizeof(float), tensor_f,
|
|
||||||
model_tensor_size * sizeof(float));
|
|
||||||
return success;
|
return success;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,6 @@ RUN make -j "$(grep -c ^processor /proc/cpuinfo)"
|
||||||
|
|
||||||
FROM ubuntu:22.04
|
FROM ubuntu:22.04
|
||||||
|
|
||||||
COPY --from=base /home/wamr/product-mini/platforms/linux/build/libvmlib.so /libvmlib.so
|
|
||||||
COPY --from=base /home/wamr/product-mini/platforms/linux/build/iwasm /iwasm
|
COPY --from=base /home/wamr/product-mini/platforms/linux/build/iwasm /iwasm
|
||||||
|
|
||||||
ENTRYPOINT [ "/iwasm" ]
|
ENTRYPOINT [ "/iwasm" ]
|
||||||
|
|
|
@ -44,7 +44,6 @@ RUN mkdir -p /etc/OpenCL/vendors && \
|
||||||
ENV NVIDIA_VISIBLE_DEVICES=all
|
ENV NVIDIA_VISIBLE_DEVICES=all
|
||||||
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
|
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
|
||||||
|
|
||||||
COPY --from=base /home/wamr/product-mini/platforms/linux/build/libvmlib.so /libvmlib.so
|
|
||||||
COPY --from=base /home/wamr/product-mini/platforms/linux/build/iwasm /iwasm
|
COPY --from=base /home/wamr/product-mini/platforms/linux/build/iwasm /iwasm
|
||||||
|
|
||||||
ENTRYPOINT [ "/iwasm" ]
|
ENTRYPOINT [ "/iwasm" ]
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
|
#!/bin/sh
|
||||||
|
|
||||||
# Copyright (C) 2019 Intel Corporation. All rights reserved.
|
# Copyright (C) 2019 Intel Corporation. All rights reserved.
|
||||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
CURR_PATH=$(cd $(dirname $0) && pwd -P)
|
||||||
|
|
||||||
# WASM application that uses WASI-NN
|
# WASM application that uses WASI-NN
|
||||||
|
|
||||||
/opt/wasi-sdk/bin/clang \
|
/opt/wasi-sdk/bin/clang \
|
||||||
|
@ -13,9 +17,25 @@
|
||||||
|
|
||||||
# TFLite models to use in the tests
|
# TFLite models to use in the tests
|
||||||
|
|
||||||
cd models
|
cd ${CURR_PATH}/models
|
||||||
python3 average.py
|
python3 average.py
|
||||||
python3 max.py
|
python3 max.py
|
||||||
python3 mult_dimension.py
|
python3 mult_dimension.py
|
||||||
python3 mult_outputs.py
|
python3 mult_outputs.py
|
||||||
python3 sum.py
|
python3 sum.py
|
||||||
|
|
||||||
|
# Specific tests for TPU
|
||||||
|
|
||||||
|
cd ${CURR_PATH}
|
||||||
|
/opt/wasi-sdk/bin/clang \
|
||||||
|
-Wl,--allow-undefined \
|
||||||
|
-Wl,--strip-all,--no-entry \
|
||||||
|
--sysroot=/opt/wasi-sdk/share/wasi-sysroot \
|
||||||
|
-I../include -I../src/utils \
|
||||||
|
-o test_tensorflow_quantized.wasm \
|
||||||
|
test_tensorflow_quantized.c utils.c
|
||||||
|
|
||||||
|
cd ${CURR_PATH}/models
|
||||||
|
python3 quantized.py
|
||||||
|
|
||||||
|
cd ${CURR_PATH}
|
||||||
|
|
30
core/iwasm/libraries/wasi-nn/test/models/quantized.py
Normal file
30
core/iwasm/libraries/wasi-nn/test/models/quantized.py
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
# Copyright (C) 2019 Intel Corporation. All rights reserved.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
import numpy as np
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
model = tf.keras.Sequential([
|
||||||
|
tf.keras.layers.InputLayer(input_shape=[5, 5, 1]),
|
||||||
|
tf.keras.layers.AveragePooling2D(
|
||||||
|
pool_size=(5, 5), strides=None, padding="valid", data_format=None)
|
||||||
|
|
||||||
|
])
|
||||||
|
|
||||||
|
def representative_dataset():
|
||||||
|
for _ in range(1000):
|
||||||
|
data = np.random.randint(0, 25, (1, 5, 5, 1))
|
||||||
|
yield [data.astype(np.float32)]
|
||||||
|
|
||||||
|
converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
||||||
|
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||||
|
converter.representative_dataset = representative_dataset
|
||||||
|
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
|
||||||
|
converter.inference_input_type = tf.uint8 # or tf.int8
|
||||||
|
converter.inference_output_type = tf.uint8 # or tf.int8
|
||||||
|
tflite_model = converter.convert()
|
||||||
|
|
||||||
|
tflite_models_dir = pathlib.Path("./")
|
||||||
|
tflite_model_file = tflite_models_dir / "quantized_model.tflite"
|
||||||
|
tflite_model_file.write_bytes(tflite_model)
|
|
@ -0,0 +1,63 @@
|
||||||
|
/*
|
||||||
|
* Copyright (C) 2019 Intel Corporation. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <assert.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include <math.h>
|
||||||
|
|
||||||
|
#include "utils.h"
|
||||||
|
#include "logger.h"
|
||||||
|
|
||||||
|
#undef EPSILON
|
||||||
|
#define EPSILON 1e-2
|
||||||
|
|
||||||
|
void
|
||||||
|
test_average_quantized(execution_target target)
|
||||||
|
{
|
||||||
|
int dims[] = { 1, 5, 5, 1 };
|
||||||
|
input_info input = create_input(dims);
|
||||||
|
|
||||||
|
uint32_t output_size = 0;
|
||||||
|
float *output =
|
||||||
|
run_inference(target, input.input_tensor, input.dim, &output_size,
|
||||||
|
"./models/quantized_model.tflite", 1);
|
||||||
|
|
||||||
|
NN_INFO_PRINTF("Output size: %d", output_size);
|
||||||
|
NN_INFO_PRINTF("Result: average is %f", output[0]);
|
||||||
|
// NOTE: 11.95 instead of 12 because of errors due quantization
|
||||||
|
assert(fabs(output[0] - 11.95) < EPSILON);
|
||||||
|
|
||||||
|
free(input.dim);
|
||||||
|
free(input.input_tensor);
|
||||||
|
free(output);
|
||||||
|
}
|
||||||
|
|
||||||
|
int
|
||||||
|
main()
|
||||||
|
{
|
||||||
|
char *env = getenv("TARGET");
|
||||||
|
if (env == NULL) {
|
||||||
|
NN_INFO_PRINTF("Usage:\n--env=\"TARGET=[cpu|gpu|tpu]\"");
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
execution_target target;
|
||||||
|
if (strcmp(env, "cpu") == 0)
|
||||||
|
target = cpu;
|
||||||
|
else if (strcmp(env, "gpu") == 0)
|
||||||
|
target = gpu;
|
||||||
|
else if (strcmp(env, "tpu") == 0)
|
||||||
|
target = tpu;
|
||||||
|
else {
|
||||||
|
NN_ERR_PRINTF("Wrong target!");
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
NN_INFO_PRINTF("################### Testing quantized model...");
|
||||||
|
test_average_quantized(target);
|
||||||
|
|
||||||
|
NN_INFO_PRINTF("Tests: passed!");
|
||||||
|
return 0;
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user