object_detection_flutter icon indicating copy to clipboard operation
object_detection_flutter copied to clipboard

Object detection Example with float32 model

Open funwithflutter opened this issue 3 years ago • 7 comments

Sorry don't know what to label this issue as. I think it's more likely an error my side than something wrong with the package. Any help will be appreciated (I'm new to TensorFlow in general). Also thanks for the amazing package!

I'm trying to use this model: https://tfhub.dev/intel/lite-model/midas/v2_1_small/1/lite/1 It computes depth from an image.

And as far as I can see I'm doing all the necessary steps. I copied the code from your image classification example, and also double checked with the Android example provided in the link above (and as far as I can see I'm doing the same steps).

I'm using the tflite flutter helper package.

I'm getting a failed precondition in Quiver at the following point (when I call interpreter.run):

checkState(tfLiteTensorCopyFromBuffer(_tensor, ptr.cast(), bytes.length) ==
        TfLiteStatus.ok);

Stacktrace:

flutter: #0      checkState
package:quiver/check.dart:73
am15h/tflite_flutter_plugin#1      Tensor.setTo
package:tflite_flutter/src/tensor.dart:150
am15h/tflite_flutter_plugin#2      Interpreter.runForMultipleInputs
package:tflite_flutter/src/interpreter.dart:194
am15h/tflite_flutter_plugin#3      Interpreter.run
package:tflite_flutter/src/interpreter.dart:165
am15h/tflite_flutter_plugin#4      Classifier.predict
package:tensorflow_poc/classifier.dart:113
am15h/tflite_flutter_plugin#5      _MyHomePageState._predict
package:tensorflow_poc/main.dart:69
am15h/tflite_flutter_plugin#6      _MyHomePageState.getImage.<anonymous closure>
package:tensorflow_poc/main.dart:63
am15h/tflite_flutter_plugin#7      State.setState
package:flutter/…/widgets/framework.dart:1267
am15h/tflite_flutter_plugin#8      _MyHomePageState.getImage
package:tensorflow_poc/main.dart:57
<asynchronous suspension>

Something that also has me confused is that interpreter.getInputTensor(0).type returns TfLiteType.float32, but I expected this to be uint8 from the model description.

Below is my classifier class (I'm using this classifier in the Image Classification example from this package):

import 'dart:math';

import 'package:image/image.dart';
import 'package:collection/collection.dart';
import 'package:logger/logger.dart';
import 'package:tflite_flutter/tflite_flutter.dart';
import 'package:tflite_flutter_helper/tflite_flutter_helper.dart';

abstract class Classifier {
  Interpreter interpreter;
  InterpreterOptions _interpreterOptions;

  var logger = Logger();

  List<int> _inputShape;
  List<int> _outputShape;

  TensorImage _inputImage;
  TensorBuffer _outputBuffer;

  TfLiteType _outputType;

  String get modelName;

  NormalizeOp get preProcessNormalizeOp;

  Classifier({int numThreads}) {
    _interpreterOptions = InterpreterOptions();

    if (numThreads != null) {
      _interpreterOptions.threads = numThreads;
    }

    loadModel();
  }

  Future<void> loadModel() async {
    try {
      interpreter =
          await Interpreter.fromAsset(modelName, options: _interpreterOptions);
      print('Interpreter Created Successfully');
      _inputShape = interpreter.getInputTensor(0).shape; // {1, 256, 256, 3}
      _outputShape = interpreter.getOutputTensor(0).shape; // {1, 256, 256}
      _outputType = interpreter.getOutputTensor(0).type; // TfLiteType.float32
      print('_inputShape[0]: ${_inputShape[0]}');
      print('_inputShape[1]: ${_inputShape[1]}');
      print('_inputShape[2]: ${_inputShape[2]}');
      print('_inputShape[3]: ${_inputShape[3]}');
      print('_outputShape[0]: ${_outputShape[0]}');
      print('_outputShape[1]: ${_outputShape[1]}');
      print('_outputShape[2]: ${_outputShape[2]}');
      print('_outputType: $_outputType');
      print(
          '_intputType: ${interpreter.getInputTensor(0).type}'); // TfLiteType.float32, but expected this to be uint8
      _outputBuffer = TensorBuffer.createFixedSize(_outputShape, _outputType);
      _probabilityProcessor =
          TensorProcessorBuilder().add(postProcessNormalizeOp).build();
    } catch (e) {
      print('Unable to create interpreter, Caught Exception: ${e.toString()}');
    }
  }

  Future<void> loadLabels() async {
    labels = await FileUtil.loadLabels(_labelsFileName);
    if (labels.length == _labelsLength) {
      print('Labels loaded successfully');
    } else {
      print('Unable to load labels');
    }
  }

  TensorImage _preProcess() {
    int cropSize = min(_inputImage.height, _inputImage.width);
    return ImageProcessorBuilder()
        .add(ResizeWithCropOrPadOp(cropSize, cropSize))
        .add(ResizeOp(
            _inputShape[1], _inputShape[2], ResizeMethod.NEAREST_NEIGHBOUR))
        .add(preProcessNormalizeOp)
        .build()
        .process(_inputImage);
  }

  void predict(Image image) {
    try {
      if (interpreter == null) {
        throw StateError('Cannot run inference, Intrepreter is null');
      }
      final pres = DateTime.now().millisecondsSinceEpoch;
      _inputImage = TensorImage.fromImage(image);
      print('input image data type: ${_inputImage.dataType}');
      _inputImage = _preProcess();
      print('input image width: ${_inputImage.width}');
      print('input image height: ${_inputImage.height}');
      print('input image data type: ${_inputImage.dataType}');
      final pre = DateTime.now().millisecondsSinceEpoch - pres;
      print('Time to load image: $pre ms');
      print('input buffer: ${_inputImage.buffer}');
      print('output buffer: ${_outputBuffer.getBuffer()}');
      final runs = DateTime.now().millisecondsSinceEpoch;

      interpreter.run(_inputImage.buffer, _outputBuffer.buffer); // THROWS
      final run = DateTime.now().millisecondsSinceEpoch - runs;

      print('Time to run inference: $run ms');

      print(_outputBuffer.getDoubleList());
    } catch (e, st) {
      logger.e('error', e, st);
      print(st);
    }
  }

  void close() {
    if (interpreter != null) {
      interpreter.close();
    }
  }
}

And implementation class:

import 'package:tensorflow_poc/classifier.dart';
import 'package:tflite_flutter_helper/tflite_flutter_helper.dart';

class ClassifierQuant extends Classifier {
  ClassifierQuant({int numThreads: 1}) : super(numThreads: numThreads);

  @override
  String get modelName => 'lite-model_midas_v2_1_small_1_lite_1.tflite';

  @override
  NormalizeOp get preProcessNormalizeOp => NormalizeOp(0, 1);
}

funwithflutter avatar Mar 18 '21 23:03 funwithflutter