google_ml_kit_flutter
google_ml_kit_flutter copied to clipboard
Google ML Kit Image labeling cannot be used with models without final softmax or sigmoid layer (that is, those outputting raw logits)
Context: how ML Kit image labeling works
Google ML Kit image labeling allows developers to use a TensorFlow Lite model to make predictions. The output of the prediction is an array of values indicating the probability that the image is of the corresponding class. Probability values are expected to be between 0.0 and 1.0 for each class. Supported neural models have a final output layer applying either:
- a
softmax
on all class probability values. The sum of all probabilities for all classes is then guaranteed to be 1.0. - less often, a
sigmoid
on class probability values, to constrain each probability value to the[0.0, 1.0]
value range
ML Kit allows to specify a confidenceThreshold
value to filter out classes/labels where the probability is less than a specified value, 0.5 by default. Note that confidenceThreshold
MUST be between 0.0 and 1.0, inclusive.
Describe the problem
Not all TF Lite model generate values in the [0.0, 1.0] range. For instance, many models produce arbitrary values and the developer is supposed to manually apply a softMax with them after the fact. For these models, arbitrary values may be negative, and there is no way to tell ML Kit to keep all values, including those that are < 0.0. As a result, it is impossible to use ML Kit with these models.
The only solution is to forget ML Kit altogether and use eg. tflite_flutter
, https://pub.dev/packages/tflite_flutter.
Unfortunately, tflite_flutter
only supports recent Android versions (because of this issue), while ML Kit supports all Android versions since Android 5.0.
To Reproduce
Below are repro steps showing two things:
- that ML Kit seems to give access to raw logits values, which is good. This can be verified by looking at the
confidence
value of eachImageLabel
returned byprocessImage
. Note this value can be greater than 1.0. - that ML Kit does NOT allow the developer to retrieve image labels with negative probability (aka confidence) values. This can be verified by counting the size of the image label list returned by
processImage
, a list which will unfortunately not include labels for which the confidence value is negative. Negative logits are all always filtered out.
Repro steps:
- Go to https://tfhub.dev/tensorflow/efficientnet/lite4/classification/2 and download the TensorFlow (not TensorFlow Lite) model provided in the leftmost tab.
- Convert the model to TensorFlow Lite using by using, eg. the
tflite_convert
provided by the TensorFlow Python package - Create a minimal Flutter application creating a Camera stream feed, load the TF Lite model with the following code:
LocalLabelerOptions customImageLabelerOptions = LocalLabelerOptions(
confidenceThreshold: 0, modelPath: "assets/efficientnet_lite4_classification_2.tflite", maxCount: 1000000);
imageLabeler = ImageLabeler(options: customImageLabelerOptions);
In the above code, I am using:
-
confidenceThreshold: 0.0
: I am trying here to keep all logits. ⚠️ That's the problem here,confidenceThreshold
must be between 0.0 and 1.0. ML Kit will filter image labels that are below the confidence threshold. -
maxCount: 1000000
: I want to keep all classes. 1000000 here means infinity.
If you run the application, with the above mentionned model, you will see that processImage() returns a list of ImageLabel with arbitrary values (even greater than 1.0), but no negative value. The size of the list will not be the total number of classes.
Expected behavior
Raw logits should all be exposed (not just those having a positive value) if we set, eg, confidenceThreshold
to double.negativeInfinity
(which is currently not allowed).
Shall we use tflite_flutter
for these models? If this is the case, the documentation of ML Kit image labeling should be more clear. It should be said in the documentation that the use case is, eg, softmax output and nothing else.
The advantages of ML Kit vs tflite_flutter
should be more clear.
Setup
- OS: Windows 11
- Device: Google Pixel 6a
- OS: Android 13
- Flutter/Dart Version 3.13.3
- Plugin version: 0.9.0
Thanks @andynewman10 for your detailed issue report. I think you should report this issue directly to Google. That is outside the scope of this plugin.
In the README of this plugin we explain that this plugin is just a wrapper around the native ML Kit API by Google. Using MethodChannels
we pass the info to the native API (developed by Google), we take the response and return it back to Flutter. More details in the first section of the README.
I have seen a couple of the things you have describe, that is why I wrote this tutorial on how to create a custom model using TensorFlow that is compatible with ML Kit requirements. It is not cover all, but you can find some guidance in it.
Some of the things that I still need to add is the need to use softmax to the response coming from ML Kit. ML Kit API's is very high level API, there are not that much methods that you need to call like using TF framework. If you need the ML Kit API to be modified that is something you need to suggest to Google.
My recommendation is that you tried their native example apps. See if you can get better results with your custom model with their native example app. If you find that their example apps are okay and still you have an issue using this plugin then come back here. But if the issue is still reproducible in the example app then move the conversation to their repo.
Thanks for your reply. I found that it is actually possible to add a layer to a trained model and shift values by a specified amount, for instance. This little hack can be done to ensure all logits are positive in all cases and switching to TF Lite is not necessary (at least, not for this specific reason).
Great, could you share what layer you added to the model, how and a sample code so I can update the tutorial and close this issue. Thanks/
Sure, I use:
model2 = tf.keras.Sequential([
model,
tf.keras.layers.Rescaling(1.0, offset=1000.0)
])
model.add(tf.keras.layers.Rescaling(1.0, offset=1000.0))
should also work but I only tested the case above (successfully).
This issue is stale because it has been open for 30 days with no activity.