wasi-nn: Support uint8 quantized networks (#2433)

Support (non-full) uint8 quantized networks.
Inputs and outputs are still required to be `float`. The (de)quantization is done internally by wasi-nn.

Example generated from `quantized_model.py`:
![Screenshot from 2023-08-07 17-57-05](https://github.com/bytecodealliance/wasm-micro-runtime/assets/80318361/91f12ff6-870c-427a-b1dc-e307f7d1f5ee)

Visualization with [netron](https://netron.app/).
This commit is contained in:
tonibofarull 2023-08-11 01:55:40 +02:00 committed by GitHub
parent a550f4d9f7
commit 0b0af1b3df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 176 additions and 17 deletions

View File

@ -0,0 +1,2 @@
**/*.wasm
**/*.tflite

View File

@ -285,14 +285,37 @@ tensorflowlite_set_input(void *tflite_ctx, graph_execution_context ctx,
return invalid_argument;
}
auto *input =
tfl_ctx->interpreters[ctx].interpreter->typed_input_tensor<float>(
index);
if (input == NULL)
return missing_memory;
if (tensor->quantization.type == kTfLiteNoQuantization) {
NN_DBG_PRINTF("No quantization information. Using float as default");
float *it =
tfl_ctx->interpreters[ctx].interpreter->typed_input_tensor<float>(
index);
int size = model_tensor_size * sizeof(float);
bh_memcpy_s(it, size, input_tensor->data, size);
}
else { // TODO: Assumming uint8 quantized networks.
TfLiteAffineQuantization *quant_info =
(TfLiteAffineQuantization *)tensor->quantization.params;
if (quant_info->scale->size != 1 || quant_info->zero_point->size != 1) {
NN_ERR_PRINTF("Quantization per channel is not supported");
return runtime_error;
}
uint8_t *it =
tfl_ctx->interpreters[ctx].interpreter->typed_input_tensor<uint8_t>(
index);
float scale = quant_info->scale->data[0];
float zero_point = (float)quant_info->zero_point->data[0];
NN_DBG_PRINTF("input tensor: (scale, offset) = (%f, %f)", scale,
zero_point);
float *input_tensor_f = (float *)input_tensor->data;
for (uint32_t i = 0; i < model_tensor_size; ++i) {
it[i] = (uint8_t)(input_tensor_f[i] / scale + zero_point);
}
}
bh_memcpy_s(input, model_tensor_size * sizeof(float), input_tensor->data,
model_tensor_size * sizeof(float));
return success;
}
@ -325,6 +348,7 @@ tensorflowlite_get_output(void *tflite_ctx, graph_execution_context ctx,
NN_DBG_PRINTF("Number of tensors (%d)", num_output_tensors);
if (index + 1 > num_output_tensors) {
NN_ERR_PRINTF("Index %d is invalid.", index);
return runtime_error;
}
@ -343,15 +367,37 @@ tensorflowlite_get_output(void *tflite_ctx, graph_execution_context ctx,
return missing_memory;
}
float *tensor_f =
tfl_ctx->interpreters[ctx].interpreter->typed_output_tensor<float>(
index);
for (uint32_t i = 0; i < model_tensor_size; ++i)
NN_DBG_PRINTF("output: %f", tensor_f[i]);
if (tensor->quantization.type == kTfLiteNoQuantization) {
NN_DBG_PRINTF("No quantization information");
float *ot =
tfl_ctx->interpreters[ctx].interpreter->typed_output_tensor<float>(
index);
int size = model_tensor_size * sizeof(float);
bh_memcpy_s(output_tensor, size, ot, size);
}
else { // TODO: Assumming uint8 quantized networks.
TfLiteAffineQuantization *quant_info =
(TfLiteAffineQuantization *)tensor->quantization.params;
if (quant_info->scale->size != 1 || quant_info->zero_point->size != 1) {
NN_ERR_PRINTF("Quantization per channel is not supported");
return runtime_error;
}
uint8_t *ot = tfl_ctx->interpreters[ctx]
.interpreter->typed_output_tensor<uint8_t>(index);
float scale = quant_info->scale->data[0];
float zero_point = (float)quant_info->zero_point->data[0];
NN_DBG_PRINTF("output tensor: (scale, offset) = (%f, %f)", scale,
zero_point);
float *output_tensor_f = (float *)output_tensor;
for (uint32_t i = 0; i < model_tensor_size; ++i) {
output_tensor_f[i] = (ot[i] - zero_point) * scale;
}
}
*output_tensor_size = model_tensor_size;
bh_memcpy_s(output_tensor, model_tensor_size * sizeof(float), tensor_f,
model_tensor_size * sizeof(float));
return success;
}

View File

@ -30,7 +30,6 @@ RUN make -j "$(grep -c ^processor /proc/cpuinfo)"
FROM ubuntu:22.04
COPY --from=base /home/wamr/product-mini/platforms/linux/build/libvmlib.so /libvmlib.so
COPY --from=base /home/wamr/product-mini/platforms/linux/build/iwasm /iwasm
ENTRYPOINT [ "/iwasm" ]

View File

@ -44,7 +44,6 @@ RUN mkdir -p /etc/OpenCL/vendors && \
ENV NVIDIA_VISIBLE_DEVICES=all
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
COPY --from=base /home/wamr/product-mini/platforms/linux/build/libvmlib.so /libvmlib.so
COPY --from=base /home/wamr/product-mini/platforms/linux/build/iwasm /iwasm
ENTRYPOINT [ "/iwasm" ]

View File

@ -1,6 +1,10 @@
#!/bin/sh
# Copyright (C) 2019 Intel Corporation. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
CURR_PATH=$(cd $(dirname $0) && pwd -P)
# WASM application that uses WASI-NN
/opt/wasi-sdk/bin/clang \
@ -13,9 +17,25 @@
# TFLite models to use in the tests
cd models
cd ${CURR_PATH}/models
python3 average.py
python3 max.py
python3 mult_dimension.py
python3 mult_outputs.py
python3 sum.py
# Specific tests for TPU
cd ${CURR_PATH}
/opt/wasi-sdk/bin/clang \
-Wl,--allow-undefined \
-Wl,--strip-all,--no-entry \
--sysroot=/opt/wasi-sdk/share/wasi-sysroot \
-I../include -I../src/utils \
-o test_tensorflow_quantized.wasm \
test_tensorflow_quantized.c utils.c
cd ${CURR_PATH}/models
python3 quantized.py
cd ${CURR_PATH}

View File

@ -0,0 +1,30 @@
# Copyright (C) 2019 Intel Corporation. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import tensorflow as tf
import numpy as np
import pathlib
model = tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=[5, 5, 1]),
tf.keras.layers.AveragePooling2D(
pool_size=(5, 5), strides=None, padding="valid", data_format=None)
])
def representative_dataset():
for _ in range(1000):
data = np.random.randint(0, 25, (1, 5, 5, 1))
yield [data.astype(np.float32)]
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8 # or tf.int8
converter.inference_output_type = tf.uint8 # or tf.int8
tflite_model = converter.convert()
tflite_models_dir = pathlib.Path("./")
tflite_model_file = tflite_models_dir / "quantized_model.tflite"
tflite_model_file.write_bytes(tflite_model)

View File

@ -0,0 +1,63 @@
/*
* Copyright (C) 2019 Intel Corporation. All rights reserved.
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
*/
#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
#include <string.h>
#include <math.h>
#include "utils.h"
#include "logger.h"
#undef EPSILON
#define EPSILON 1e-2
void
test_average_quantized(execution_target target)
{
int dims[] = { 1, 5, 5, 1 };
input_info input = create_input(dims);
uint32_t output_size = 0;
float *output =
run_inference(target, input.input_tensor, input.dim, &output_size,
"./models/quantized_model.tflite", 1);
NN_INFO_PRINTF("Output size: %d", output_size);
NN_INFO_PRINTF("Result: average is %f", output[0]);
// NOTE: 11.95 instead of 12 because of errors due quantization
assert(fabs(output[0] - 11.95) < EPSILON);
free(input.dim);
free(input.input_tensor);
free(output);
}
int
main()
{
char *env = getenv("TARGET");
if (env == NULL) {
NN_INFO_PRINTF("Usage:\n--env=\"TARGET=[cpu|gpu|tpu]\"");
return 1;
}
execution_target target;
if (strcmp(env, "cpu") == 0)
target = cpu;
else if (strcmp(env, "gpu") == 0)
target = gpu;
else if (strcmp(env, "tpu") == 0)
target = tpu;
else {
NN_ERR_PRINTF("Wrong target!");
return 1;
}
NN_INFO_PRINTF("################### Testing quantized model...");
test_average_quantized(target);
NN_INFO_PRINTF("Tests: passed!");
return 0;
}