flutter-tflite
flutter-tflite copied to clipboard
Yolo v5 support
Hi,
I've been trying to enable object detection on flutter using a custom trained yolov5n model. I have been using the https://pub.dev/packages/flutter_vision package. It currently has support only for android, not sure when (if ever) iOS support will be added for it. Anyhow, the inference was painfully slow and not really usable for my use case. I have been trying to integrate my app with the flutter-tflite library directly, however I am having issues with setting up the output tensors I think. I have been trying to tweak the 'live_object_detection_ssd_mobilenet' example to get it work with my model but it doesn't work. I understand that this is because both models are very different, I just wish there was an example with YOLOv5 so I'd know how to do it.
This is the netron properties info of my model, can anyone guide me on how to get this working? If there is an example, I could follow that as well.
Also, will the inference time be fast enough as well?
I went through several threads and pages online before posting that issue so as to not waste anyone's time. Hope I can finally put these struggles behind me now.
Thanks!
I'm also attaching the snippet from the example where inference happens for reference.
// Copyright 2023 The Flutter team. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
import 'dart:async';
import 'dart:io';
import 'dart:isolate';
import 'package:camera/camera.dart';
import 'package:flutter/foundation.dart';
import 'package:flutter/services.dart';
import 'package:image/image.dart' as image_lib;
import 'package:gods_eye/models/recognition.dart';
import 'package:gods_eye/utils/image_utils.dart';
import 'package:tflite_flutter/tflite_flutter.dart';
///////////////////////////////////////////////////////////////////////////////
// **WARNING:** This is not production code and is only intended to be used for
// demonstration purposes.
//
// The following Detector example works by spawning a background isolate and
// communicating with it over Dart's SendPort API. It is presented below as a
// demonstration of the feature "Background Isolate Channels" and shows using
// plugins from a background isolate. The [Detector] operates on the root
// isolate and the [_DetectorServer] operates on a background isolate.
//
// Here is an example of the protocol they use to communicate:
//
// _________________ ________________________
// [:Detector] [:_DetectorServer]
// ----------------- ------------------------
// | |
// |<---------------(init)------------------------|
// |----------------(init)----------------------->|
// |<---------------(ready)---------------------->|
// | |
// |----------------(detect)--------------------->|
// |<---------------(busy)------------------------|
// |<---------------(result)----------------------|
// | . . . |
// |----------------(detect)--------------------->|
// |<---------------(busy)------------------------|
// |<---------------(result)----------------------|
//
///////////////////////////////////////////////////////////////////////////////
/// All the command codes that can be sent and received between [Detector] and
/// [_DetectorServer].
enum _Codes {
init,
busy,
ready,
detect,
result,
}
/// A command sent between [Detector] and [_DetectorServer].
class _Command {
const _Command(this.code, {this.args});
final _Codes code;
final List<Object>? args;
}
/// A Simple Detector that handles object detection via Service
///
/// All the heavy operations like pre-processing, detection, ets,
/// are executed in a background isolate.
/// This class just sends and receives messages to the isolate.
class Detector {
// static const String _modelPath = 'assets/models/ssd_mobilenet.tflite';
// static const String _labelPath = 'assets/models/labelmap.txt';
static const String _modelPath = 'assets/models/yolov5n.tflite';
static const String _labelPath = 'assets/models/yolov5n_labels.txt';
Detector._(this._isolate, this._interpreter, this._labels);
final Isolate _isolate;
late final Interpreter _interpreter;
late final List<String> _labels;
// To be used by detector (from UI) to send message to our Service ReceivePort
late final SendPort _sendPort;
bool _isReady = false;
// // Similarly, StreamControllers are stored in a queue so they can be handled
// // asynchronously and serially.
final StreamController<Map<String, dynamic>> resultsStream =
StreamController<Map<String, dynamic>>();
/// Open the database at [path] and launch the server on a background isolate..
static Future<Detector> start() async {
final ReceivePort receivePort = ReceivePort();
// sendPort - To be used by service Isolate to send message to our ReceiverPort
final Isolate isolate =
await Isolate.spawn(_DetectorServer._run, receivePort.sendPort);
final Detector result = Detector._(
isolate,
await _loadModel(),
await _loadLabels(),
);
receivePort.listen((message) {
result._handleCommand(message as _Command);
});
return result;
}
static Future<Interpreter> _loadModel() async {
final interpreterOptions = InterpreterOptions();
// Use XNNPACK Delegate
if (Platform.isAndroid) {
interpreterOptions.addDelegate(XNNPackDelegate());
}
return Interpreter.fromAsset(
_modelPath,
options: interpreterOptions..threads = 4,
);
}
static Future<List<String>> _loadLabels() async {
return (await rootBundle.loadString(_labelPath)).split('\n');
}
/// Starts CameraImage processing
void processFrame(CameraImage cameraImage) {
if (_isReady) {
_sendPort.send(_Command(_Codes.detect, args: [cameraImage]));
}
}
/// Handler invoked when a message is received from the port communicating
/// with the database server.
void _handleCommand(_Command command) {
switch (command.code) {
case _Codes.init:
_sendPort = command.args?[0] as SendPort;
// ----------------------------------------------------------------------
// Before using platform channels and plugins from background isolates we
// need to register it with its root isolate. This is achieved by
// acquiring a [RootIsolateToken] which the background isolate uses to
// invoke [BackgroundIsolateBinaryMessenger.ensureInitialized].
// ----------------------------------------------------------------------
RootIsolateToken rootIsolateToken = RootIsolateToken.instance!;
_sendPort.send(_Command(_Codes.init, args: [
rootIsolateToken,
_interpreter.address,
_labels,
]));
case _Codes.ready:
_isReady = true;
case _Codes.busy:
_isReady = false;
case _Codes.result:
_isReady = true;
resultsStream.add(command.args?[0] as Map<String, dynamic>);
default:
debugPrint('Detector unrecognized command: ${command.code}');
}
}
/// Kills the background isolate and its detector server.
void stop() {
_isolate.kill();
}
}
/// The portion of the [Detector] that runs on the background isolate.
///
/// This is where we use the new feature Background Isolate Channels, which
/// allows us to use plugins from background isolates.
class _DetectorServer {
/// Input size of image (height = width = 300)
static const int mlModelInputSize = 300;
/// Result confidence threshold
static const double confidence = 0.5;
Interpreter? _interpreter;
List<String>? _labels;
_DetectorServer(this._sendPort);
final SendPort _sendPort;
// ----------------------------------------------------------------------
// Here the plugin is used from the background isolate.
// ----------------------------------------------------------------------
/// The main entrypoint for the background isolate sent to [Isolate.spawn].
static void _run(SendPort sendPort) {
ReceivePort receivePort = ReceivePort();
final _DetectorServer server = _DetectorServer(sendPort);
receivePort.listen((message) async {
final _Command command = message as _Command;
await server._handleCommand(command);
});
// receivePort.sendPort - used by UI isolate to send commands to the service receiverPort
sendPort.send(_Command(_Codes.init, args: [receivePort.sendPort]));
}
/// Handle the [command] received from the [ReceivePort].
Future<void> _handleCommand(_Command command) async {
switch (command.code) {
case _Codes.init:
// ----------------------------------------------------------------------
// The [RootIsolateToken] is required for
// [BackgroundIsolateBinaryMessenger.ensureInitialized] and must be
// obtained on the root isolate and passed into the background isolate via
// a [SendPort].
// ----------------------------------------------------------------------
RootIsolateToken rootIsolateToken =
command.args?[0] as RootIsolateToken;
// ----------------------------------------------------------------------
// [BackgroundIsolateBinaryMessenger.ensureInitialized] for each
// background isolate that will use plugins. This sets up the
// [BinaryMessenger] that the Platform Channels will communicate with on
// the background isolate.
// ----------------------------------------------------------------------
BackgroundIsolateBinaryMessenger.ensureInitialized(rootIsolateToken);
_interpreter = Interpreter.fromAddress(command.args?[1] as int);
_labels = command.args?[2] as List<String>;
_sendPort.send(const _Command(_Codes.ready));
case _Codes.detect:
_sendPort.send(const _Command(_Codes.busy));
_convertCameraImage(command.args?[0] as CameraImage);
default:
debugPrint('_DetectorService unrecognized command ${command.code}');
}
}
void _convertCameraImage(CameraImage cameraImage) {
var preConversionTime = DateTime.now().millisecondsSinceEpoch;
convertCameraImageToImage(cameraImage).then((image) {
if (image != null) {
if (Platform.isAndroid) {
image = image_lib.copyRotate(image, angle: 90);
}
final results = analyseImage(image, preConversionTime);
_sendPort.send(_Command(_Codes.result, args: [results]));
}
});
}
Map<String, dynamic> analyseImage(
image_lib.Image? image, int preConversionTime) {
var conversionElapsedTime =
DateTime.now().millisecondsSinceEpoch - preConversionTime;
var preProcessStart = DateTime.now().millisecondsSinceEpoch;
/// Pre-process the image
/// Resizing image for model [300, 300]
final imageInput = image_lib.copyResize(
image!,
width: mlModelInputSize,
height: mlModelInputSize,
);
// Creating matrix representation, [300, 300, 3]
final imageMatrix = List.generate(
imageInput.height,
(y) => List.generate(
imageInput.width,
(x) {
final pixel = imageInput.getPixel(x, y);
return [pixel.r, pixel.g, pixel.b];
},
),
);
var preProcessElapsedTime =
DateTime.now().millisecondsSinceEpoch - preProcessStart;
var inferenceTimeStart = DateTime.now().millisecondsSinceEpoch;
final output = _runInference(imageMatrix);
// Location
final locationsRaw = output.first.first as List<List<double>>;
final List<Rect> locations = locationsRaw
.map((list) => list.map((value) => (value * mlModelInputSize)).toList())
.map((rect) => Rect.fromLTRB(rect[1], rect[0], rect[3], rect[2]))
.toList();
// Classes
final classesRaw = output.elementAt(1).first as List<double>;
final classes = classesRaw.map((value) => value.toInt()).toList();
// Scores
final scores = output.elementAt(2).first as List<double>;
// Number of detections
final numberOfDetectionsRaw = output.last.first as double;
final numberOfDetections = numberOfDetectionsRaw.toInt();
final List<String> classification = [];
for (var i = 0; i < numberOfDetections; i++) {
classification.add(_labels![classes[i]]);
}
/// Generate recognitions
List<Recognition> recognitions = [];
for (int i = 0; i < numberOfDetections; i++) {
// Prediction score
var score = scores[i];
// Label string
var label = classification[i];
if (score > confidence) {
recognitions.add(
Recognition(i, label, score, locations[i]),
);
}
}
var inferenceElapsedTime =
DateTime.now().millisecondsSinceEpoch - inferenceTimeStart;
var totalElapsedTime =
DateTime.now().millisecondsSinceEpoch - preConversionTime;
return {
"recognitions": recognitions,
"stats": <String, String>{
'Conversion time:': conversionElapsedTime.toString(),
'Pre-processing time:': preProcessElapsedTime.toString(),
'Inference time:': inferenceElapsedTime.toString(),
'Total prediction time:': totalElapsedTime.toString(),
'Frame': '${image.width} X ${image.height}',
},
};
}
/// Object detection main function
List<List<Object>> _runInference(
List<List<List<num>>> imageMatrix,
) {
// // Set input tensor [1, 300, 300, 3]
// Set input tensor [1, 640, 640, 3]
final input = [imageMatrix];
// Set output tensor
// Locations: [1, 10, 4]
// Classes: [1, 10],
// Scores: [1, 10],
// Number of detections: [1]
final output = {
0: [List<List<num>>.filled(10, List<num>.filled(4, 0))],
1: [List<num>.filled(10, 0)],
2: [List<num>.filled(10, 0)],
3: [0.0],
};
_interpreter!.runForMultipleInputs([input], output);
return output.values.toList();
}
}
As part of this, I got it working. I modified the tensors and it works, but the inference time is extremely slow.
I tried exporting my yolov5 model to tflite using non-maximum suppression. However, on starting up the interpreter then I get this error:
Select TensorFlow op(s), included in the given model, is(are) not supported by this interpreter.
I'm not sure what to update, how to update and where all. The documentation step has quite a few steps which are confusing me. Can anyone help?
For faster tensor operation on images you can use this library. Its discontinued but still works(just fork it in your personal github) with images 3.0.0. I have done a speed comparison here https://github.com/tensorflow/flutter-tflite/issues/171 which shows almost 8X speed up on CPU(could be higher on GPU)
Regarding your second question, this happens when you go from torch to onnx to tflite. I am not sure why, but some operations become unsupported while converting.
So please find a tensorflow model, convert it to tflite and use it, you shouldn't face this issue.
For faster tensor operation on images you can use this library. Its discontinued but still works(just fork it in your personal github) with images 3.0.0. I have done a speed comparison here https://github.com/tensorflow/flutter-tflite/issues/171 which shows almost 8X speed up on CPU(could be higher on GPU)
Regarding your second question, this happens when you go from torch to onnx to tflite. I am not sure why, but some operations become unsupported while converting.
So please find a tensorflow model, convert it to tflite and use it, you shouldn't face this issue.
Thanks for the reply @saurabhkumar8112 . Since MediaPipe might take a while I think this might be my best shot. Have to see how that goes! Any idea how to have these inferences make use of the GPU? Does it use the GPU by default?
@mp051998 Thinking of what you mentioned, I recognized a same issue was mentioned in #167. To be honest, try another model that can be converted to tflite. If you still want to use the same model, would you mind trying to containerize the model and then deploy it to the app in another way instead of using tflite and flutter-tflite?
had the same issue. How do you solve this pr ?
Can you give my your code work?
I ditched YOLO altogether. After months of trying to make YOLO work (and learning how their licensing for even v5 is a pain), I decided to switch to SSD. SSD works much better, it's a lot faster as well. Accuracy might take a hit, but training a better model and employing multiple levels of inference can help.
Also, just wanted to add this for more context: The YOLO inference would take ~5+ seconds sometimes, while the SSD one would take an avg of 100-200ms.
@mp051998 Thinking of what you mentioned, I recognized a same issue was mentioned in #167. To be honest, try another model that can be converted to tflite. If you still want to use the same model, would you mind trying to containerize the model and then deploy it to the app in another way instead of using tflite and flutter-tflite?
I agree. That was my conclusion too, but running my model on the phone was a priority so moved to SSD!