tflite-micro icon indicating copy to clipboard operation
tflite-micro copied to clipboard

TFLM inference results abnormal

Open Unbinilium opened this issue 1 year ago • 6 comments

Hello, recently I encountered an issue when deploying the model from YOLO-World to a device using TFLM. I found that with the same INT8 per-channel quantized TFLite model and using the same image tensor as input, there is a significant discrepancy between the output tensors from TFLM inference and tensorflow.lite.Interpreter.

image

As shown in the figure, the model has 6 outputs, and the histograms in blue and orange represent the INT8 tensors obtained from tensorflow.lite.Interpreter and TFLM inference, respectively. In the INT8 space, the proportion of inconsistent data exceeds even 1/3.

image

However, after undergoing complex post-processing, the actual observed result shows only a few pixels' offset in the bounding boxes.

And I modified the flatbuffer of the model to pre-fetch outputs of certain tensors:

image

It can be observed that errors have already occurred in shallow-level operations. As the network deepens, accumulated errors may lead to inaccuracies in the final results.

image

Although the hacked logistic implementation in TFLM is different from TFLite (perhaps more, they are stored in the repository with the same name and path, it is likely to be misunderstood if you don't open these files to confirm the implementation), the +/-1 offsets of these results after the convolution makes me feel a little confused.

Is it some mistake in TFLM? If you have any debugging suggestions or solutions, please let me know. Thanks!

Test environments:

  • tensorflow 2.16.2
  • tflite-micro https://github.com/tensorflow/tflite-micro/commit/7a0249686f412551634a5058ddd6d2ec3f224203
  • clang 14.0.0
  • python 3.10.12

Issues may related to:

  • https://github.com/tensorflow/tflite-micro/issues/2319

Unbinilium avatar Jul 17 '24 08:07 Unbinilium

I installed the public version of tflite_micro 0.dev20240715200401 from PyPi and modified the test script from https://github.com/tensorflow/tflite-micro/issues/2319 to reproduce:

import numpy as np

import tflite_micro as tflm
from tflite_micro.python.tflite_micro import runtime
import tensorflow as tf


print(tflm.__version__)
print(tf.__version__)


print("Checking TFLM post-installation...")
tflm.postinstall_check.passed()


print("Loading the model...")
with open("yolo_world.tflite", "rb") as f:
    tflite_model = f.read()


print("Analyzing the model...")
tf.lite.experimental.Analyzer.analyze(model_content=tflite_model)

tfl_interpreter = tf.lite.Interpreter(model_content=tflite_model, experimental_preserve_all_tensors=True)
tfl_interpreter.allocate_tensors()

tflm_interpreter = runtime.Interpreter.from_bytes(
    tflite_model, intrepreter_config=runtime.InterpreterConfig.kPreserveAllTensors
)


tfl_input_details = tfl_interpreter.get_input_details()
tflm_input_details = tflm_interpreter.get_input_details(0)

input_shape = tfl_input_details[0]["shape"]
input_dtype = tfl_input_details[0]["dtype"]

dummy_input = np.random.randint(-127, 128, size=input_shape, dtype=input_dtype)


tfl_interpreter.set_tensor(tfl_input_details[0]["index"], dummy_input)
tflm_interpreter.set_input(dummy_input, 0)

print("Invoking...")
tfl_interpreter.invoke()
tflm_interpreter.invoke()

print("Comparing the results...")
for i in range(0, tfl_interpreter._interpreter.NumTensors(0)):
    print(f"Tensor {i}: {tfl_interpreter.get_tensor_details()[i]['name']}")
    tflm_tensor = tflm_interpreter.GetTensor(i, 0)["tensor_data"]

    try:
        tfl_tensor = tfl_interpreter.get_tensor(i, 0)
    except ValueError:
        print("  TFL: N/A")
        print(f" TFLM: shape={tflm_tensor.shape}, dtype={tflm_tensor.dtype}")
        print("")
        continue

    is_match = np.allclose(tfl_tensor, tflm_tensor, atol=1e-3)
    print(f"  TFL: shape={tfl_tensor.shape}, dtype={tfl_tensor.dtype}")
    print(f" TFLM: shape={tflm_tensor.shape}, dtype={tflm_tensor.dtype}")
    print(f" MATCH: {'YES' if is_match else 'NO'}")
    print("")

Logs: log.txt Model file: yolo_world.tflite.zip

Unbinilium avatar Jul 17 '24 10:07 Unbinilium

Hey @Unbinilium, if you have access to the non tflite model could you try converting it with the converter._experimental_disable_per_channel_quantization_for_dense_layers = True flag and see if the issue is resolved?

Ah wait, i see tensorflow 2.16, so this might not help..

ArmRyan avatar Jul 17 '24 11:07 ArmRyan

Hi @ArmRyan, thanks for your suggestion. I'll enable this option and try re-exporting the TFLite model from pth to test again later.

In fact, I initially suspected that this option might be causing the problem too.

  • https://github.com/tensorflow/tensorflow/blob/377f47694fa790e98db6665b9adecde00b5e0d68/tensorflow/lite/python/lite.py#L674

However, this change was previously tested only on the device side (with the TFLite model compiled via ethos-u-vela), and it didn't seem to help (possibly there are issues with ethos-u-vela as well).

Unbinilium avatar Jul 17 '24 13:07 Unbinilium

With converter._experimental_disable_per_channel flag True in tensorflow 2.16.1 (which may have same effect as converter._experimental_disable_per_channel_quantization_for_dense_layers = True), the newly converted model still shows wrong results in shallow-level operations at first Conv2D.

...

Tensor 297: tfl.pseudo_qconst251
  TFL: shape=(16, 3, 3, 3), dtype=int8
 TFLM: shape=(16, 3, 3, 3), dtype=int8
 MATCH: YES

Tensor 298: model_94/tf.compat.v1.pad/Pad
  TFL: shape=(1, 322, 322, 3), dtype=int8
 TFLM: shape=(1, 322, 322, 3), dtype=int8
 MATCH: YES

Tensor 299: model_94/tf.math.add/Add;model_94/tf.nn.convolution/convolution;Const_319
  TFL: shape=(1, 160, 160, 16), dtype=int8
 TFLM: shape=(1, 160, 160, 16), dtype=int8
 MATCH: NO

Tensor 300: model_94/tf.math.sigmoid/Sigmoid
  TFL: shape=(1, 160, 160, 16), dtype=int8
 TFLM: shape=(1, 160, 160, 16), dtype=int8
 MATCH: NO

...

Correspondingly, its position in the graph is:

image

I guess I should per-step debug the TFLM C++ code to see if there's an error when some parameter is taken out of the flatbuffer.

Full log: log_per_tensor_quant.txt Model: yolo_world_disable_per_channel.tflite.zip

Unbinilium avatar Jul 18 '24 02:07 Unbinilium

In addition, turning off per channel quantization does seem to reduce the accumulation of errors from outputs:

image

Unbinilium avatar Jul 18 '24 03:07 Unbinilium

It might be useful to try out the layer_by_layer_debugger script. It can be used to compare the output of each layer against the TFLite output to determine where the output starts to differ. The tool has some rough edges, but seems like it would work well for this use case.

rascani avatar Jul 24 '24 16:07 rascani

"This issue is being marked as stale due to inactivity. Remove label or comment to prevent closure in 5 days."

github-actions[bot] avatar Aug 19 '24 10:08 github-actions[bot]

"This issue is being closed because it has been marked as stale for 5 days with no further activity."

github-actions[bot] avatar Aug 25 '24 10:08 github-actions[bot]

Hi @rascani, sorry for delay and thanks for your suggestions. By using this script, I changed the activation functions and some operators, after strictly ensuring the version of each modules, the consistency performance of the output results of TFL and TFLM has improved a lot (although there are still some differences), but it can basically meet the normal use scenarios.

Unbinilium avatar Sep 16 '24 13:09 Unbinilium