mediapipe
mediapipe copied to clipboard
Mobile SSD models are expected to have exactly 4 outputs, found 2
Have I written custom code (as opposed to using a stock example script provided in MediaPipe)
None
OS Platform and Distribution
Ubuntu 22 in wsl2 , android 12
Python Version
3.10
MediaPipe Model Maker version
2.0.4.1
Task name (e.g. Image classification, Gesture recognition etc.)
object detector
Describe the actual behavior
I use mediapipe_model_maker 2.0.4.1 to train an model and use it in an android programme, but it can't run and throws exception.
Describe the expected behaviour
the android programme throws: java.lang.IllegalArgumentException: Error occurred when initializing ObjectDetector: Mobile SSD models are expected to have exactly 4 outputs, found 2
Standalone code/steps you may have used to try to get what you need
1.I trained a tflite model with mediapipe_model_maker.
the code is:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from mediapipe_model_maker import object_detector
train_dataset_path = '/mnt/d/workspace/imgupload/img/selected1/bt3/train/shot'
validation_dataset_path = '/mnt/d/workspace/imgupload/img/selected1/bt3/train/shot'
cache_dir = '/mnt/d/workspace/imgupload/img/selected1/tmp'
train_data = object_detector.Dataset.from_pascal_voc_folder(
train_dataset_path,
cache_dir=cache_dir)
validate_data = object_detector.Dataset.from_pascal_voc_folder(
validation_dataset_path,
cache_dir=cache_dir)
hparams = object_detector.HParams(batch_size=8, learning_rate=0.3, epochs=50, export_dir='exported_model')
options = object_detector.ObjectDetectorOptions(
supported_model=object_detector.SupportedModels.MOBILENET_V2,
hparams=hparams)
model = object_detector.ObjectDetector.create(
train_data=train_data,
validation_data=validate_data,
options=options)
loss, coco_metrics = model.evaluate(validate_data, batch_size=4)
print(f"Validation loss: {loss}")
print(f"Validation coco metrics: {coco_metrics}")
model.export_model('dogs.tflite')
2.I use the model in an android programme, the code is from:
https://github.com/tensorflow/examples/tree/master/lite/examples/object_detection/android
The Main code is :
package org.tensorflow.lite.examples.detection;
import android.graphics.Bitmap;
import android.graphics.Bitmap.Config;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Matrix;
import android.graphics.Paint;
import android.graphics.Paint.Style;
import android.graphics.RectF;
import android.graphics.Typeface;
import android.media.ImageReader.OnImageAvailableListener;
import android.os.SystemClock;
import android.util.Size;
import android.util.TypedValue;
import android.widget.Toast;
import com.example.namespace.R;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.tensorflow.lite.examples.detection.customview.OverlayView;
import org.tensorflow.lite.examples.detection.customview.OverlayView.DrawCallback;
import org.tensorflow.lite.examples.detection.env.BorderedText;
import org.tensorflow.lite.examples.detection.env.ImageUtils;
import org.tensorflow.lite.examples.detection.env.Logger;
import org.tensorflow.lite.examples.detection.tflite.Detector;
import org.tensorflow.lite.examples.detection.tflite.TFLiteObjectDetectionAPIModel;
import org.tensorflow.lite.examples.detection.tracking.MultiBoxTracker;
/**
* An activity that uses a TensorFlowMultiBoxDetector and ObjectTracker to detect and then track
* objects.
*/
public class DetectorActivity extends CameraActivity implements OnImageAvailableListener {
private static final Logger LOGGER = new Logger();
// Configuration values for the prepackaged SSD model.
private static final int TF_OD_API_INPUT_SIZE = 300;
private static final boolean TF_OD_API_IS_QUANTIZED = true;
//private static final String TF_OD_API_MODEL_FILE = "detect.tflite";
private static final String TF_OD_API_MODEL_FILE = "dogs.tflite";
private static final String TF_OD_API_LABELS_FILE = "labelmap.txt";
private static final DetectorMode MODE = DetectorMode.TF_OD_API;
// Minimum detection confidence to track a detection.
private static final float MINIMUM_CONFIDENCE_TF_OD_API = 0.5f;
private static final boolean MAINTAIN_ASPECT = false;
private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 480);
private static final boolean SAVE_PREVIEW_BITMAP = false;
private static final float TEXT_SIZE_DIP = 10;
OverlayView trackingOverlay;
private Integer sensorOrientation;
private Detector detector;
private long lastProcessingTimeMs;
private Bitmap rgbFrameBitmap = null;
private Bitmap croppedBitmap = null;
private Bitmap cropCopyBitmap = null;
private boolean computingDetection = false;
private long timestamp = 0;
private Matrix frameToCropTransform;
private Matrix cropToFrameTransform;
private MultiBoxTracker tracker;
private BorderedText borderedText;
@Override
public void onPreviewSizeChosen(final Size size, final int rotation) {
final float textSizePx =
TypedValue.applyDimension(
TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics());
borderedText = new BorderedText(textSizePx);
borderedText.setTypeface(Typeface.MONOSPACE);
tracker = new MultiBoxTracker(this);
int cropSize = TF_OD_API_INPUT_SIZE;
try {
detector =
TFLiteObjectDetectionAPIModel.create(
this,
TF_OD_API_MODEL_FILE,
TF_OD_API_LABELS_FILE,
TF_OD_API_INPUT_SIZE,
TF_OD_API_IS_QUANTIZED);
cropSize = TF_OD_API_INPUT_SIZE;
} catch (final IOException e) {
e.printStackTrace();
LOGGER.e(e, "Exception initializing Detector!");
Toast toast =
Toast.makeText(
getApplicationContext(), "Detector could not be initialized", Toast.LENGTH_SHORT);
toast.show();
finish();
}
previewWidth = size.getWidth();
previewHeight = size.getHeight();
sensorOrientation = rotation - getScreenOrientation();
LOGGER.i("Camera orientation relative to screen canvas: %d", sensorOrientation);
LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight);
rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888);
croppedBitmap = Bitmap.createBitmap(cropSize, cropSize, Config.ARGB_8888);
frameToCropTransform =
ImageUtils.getTransformationMatrix(
previewWidth, previewHeight,
cropSize, cropSize,
sensorOrientation, MAINTAIN_ASPECT);
cropToFrameTransform = new Matrix();
frameToCropTransform.invert(cropToFrameTransform);
trackingOverlay = (OverlayView) findViewById(R.id.tracking_overlay);
trackingOverlay.addCallback(
new DrawCallback() {
@Override
public void drawCallback(final Canvas canvas) {
tracker.draw(canvas);
if (isDebug()) {
tracker.drawDebug(canvas);
}
}
});
tracker.setFrameConfiguration(previewWidth, previewHeight, sensorOrientation);
}
@Override
protected void processImage() {
++timestamp;
final long currTimestamp = timestamp;
trackingOverlay.postInvalidate();
// No mutex needed as this method is not reentrant.
if (computingDetection) {
readyForNextImage();
return;
}
computingDetection = true;
LOGGER.i("Preparing image " + currTimestamp + " for detection in bg thread.");
rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight);
readyForNextImage();
final Canvas canvas = new Canvas(croppedBitmap);
canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null);
// For examining the actual TF input.
if (SAVE_PREVIEW_BITMAP) {
ImageUtils.saveBitmap(croppedBitmap);
}
runInBackground(
new Runnable() {
@Override
public void run() {
LOGGER.i("Running detection on image " + currTimestamp);
final long startTime = SystemClock.uptimeMillis();
final List<Detector.Recognition> results = detector.recognizeImage(croppedBitmap);
lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime;
cropCopyBitmap = Bitmap.createBitmap(croppedBitmap);
final Canvas canvas = new Canvas(cropCopyBitmap);
final Paint paint = new Paint();
paint.setColor(Color.RED);
paint.setStyle(Style.STROKE);
paint.setStrokeWidth(2.0f);
float minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API;
switch (MODE) {
case TF_OD_API:
minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API;
break;
}
final List<Detector.Recognition> mappedRecognitions =
new ArrayList<Detector.Recognition>();
for (final Detector.Recognition result : results) {
final RectF location = result.getLocation();
if (location != null && result.getConfidence() >= minimumConfidence) {
canvas.drawRect(location, paint);
cropToFrameTransform.mapRect(location);
result.setLocation(location);
mappedRecognitions.add(result);
}
}
tracker.trackResults(mappedRecognitions, currTimestamp);
trackingOverlay.postInvalidate();
computingDetection = false;
runOnUiThread(
new Runnable() {
@Override
public void run() {
showFrameInfo(previewWidth + "x" + previewHeight);
showCropInfo(cropCopyBitmap.getWidth() + "x" + cropCopyBitmap.getHeight());
showInference(lastProcessingTimeMs + "ms");
}
});
}
});
}
@Override
protected int getLayoutId() {
return R.layout.tfe_od_camera_connection_fragment_tracking;
}
@Override
protected Size getDesiredPreviewFrameSize() {
return DESIRED_PREVIEW_SIZE;
}
// Which detection model to use: by default uses Tensorflow Object Detection API frozen
// checkpoints.
private enum DetectorMode {
TF_OD_API;
}
@Override
protected void setUseNNAPI(final boolean isChecked) {
runInBackground(
() -> {
try {
detector.setUseNNAPI(isChecked);
} catch (UnsupportedOperationException e) {
LOGGER.e(e, "Failed to set \"Use NNAPI\".");
runOnUiThread(
() -> {
Toast.makeText(this, e.getMessage(), Toast.LENGTH_LONG).show();
});
}
});
}
@Override
protected void setNumThreads(final int numThreads) {
runInBackground(() -> detector.setNumThreads(numThreads));
}
}
Other info / Complete Logs
The error log is :
Error getting native address of native library: task_vision_jni
java.lang.IllegalArgumentException: Error occurred when initializing ObjectDetector: Mobile SSD models are expected to have exactly 4 outputs, found 2
at org.tensorflow.lite.task.vision.detector.ObjectDetector.initJniWithByteBuffer(Native Method)
at org.tensorflow.lite.task.vision.detector.ObjectDetector.access$100(ObjectDetector.java:88)
at org.tensorflow.lite.task.vision.detector.ObjectDetector$3.createHandle(ObjectDetector.java:223)
at org.tensorflow.lite.task.core.TaskJniUtils.createHandleFromLibrary(TaskJniUtils.java:91)
at org.tensorflow.lite.task.vision.detector.ObjectDetector.createFromBufferAndOptions(ObjectDetector.java:219)
at org.tensorflow.lite.examples.detection.tflite.TFLiteObjectDetectionAPIModel.<init>(TFLiteObjectDetectionAPIModel.java:87)
at org.tensorflow.lite.examples.detection.tflite.TFLiteObjectDetectionAPIModel.create(TFLiteObjectDetectionAPIModel.java:81)
at org.tensorflow.lite.examples.detection.DetectorActivity.onPreviewSizeChosen(DetectorActivity.java:103)
at org.tensorflow.lite.examples.detection.CameraActivity$7.onPreviewSizeChosen(CameraActivity.java:448)
at org.tensorflow.lite.examples.detection.CameraConnectionFragment.setUpCameraOutputs(CameraConnectionFragment.java:360)
at org.tensorflow.lite.examples.detection.CameraConnectionFragment.openCamera(CameraConnectionFragment.java:365)
at org.tensorflow.lite.examples.detection.CameraConnectionFragment.-$$Nest$mopenCamera(Unknown Source:0)
at org.tensorflow.lite.examples.detection.CameraConnectionFragment$3.onSurfaceTextureAvailable(CameraConnectionFragment.java:174)
at android.view.TextureView.getTextureLayer(TextureView.java:410)
at android.view.TextureView.draw(TextureView.java:353)
at android.view.View.updateDisplayListIfDirty(View.java:21885)
at android.view.View.draw(View.java:22743)
at android.view.ViewGroup.drawChild(ViewGroup.java:4542)
at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4300)
at android.view.View.updateDisplayListIfDirty(View.java:21876)
at android.view.View.draw(View.java:22743)
at android.view.ViewGroup.drawChild(ViewGroup.java:4542)
at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4300)
at android.view.View.updateDisplayListIfDirty(View.java:21876)
at android.view.View.draw(View.java:22743)
at android.view.ViewGroup.drawChild(ViewGroup.java:4542)
at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4300)
at android.view.View.draw(View.java:23021)
at android.view.View.updateDisplayListIfDirty(View.java:21885)
at android.view.View.draw(View.java:22743)
at android.view.ViewGroup.drawChild(ViewGroup.java:4542)
at androidx.coordinatorlayout.widget.CoordinatorLayout.drawChild(CoordinatorLayout.java:1246)
at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4300)
at android.view.View.draw(View.java:23021)
at android.view.View.updateDisplayListIfDirty(View.java:21885)
at android.view.View.draw(View.java:22743)
at android.view.ViewGroup.drawChild(ViewGroup.java:4542)
at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4300)
at android.view.View.updateDisplayListIfDirty(View.java:21876)
at android.view.View.draw(View.java:22743)
at android.view.ViewGroup.drawChild(ViewGroup.java:4542)
at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4300)
at android.view.View.updateDisplayListIfDirty(View.java:21876)
at android.view.View.draw(View.java:22743)
at android.view.ViewGroup.drawChild(ViewGroup.java:4542)
at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4300)
at android.view.View.updateDisplayListIfDirty(View.java:21876)
at android.view.View.draw(View.java:22743)
at android.view.ViewGroup.drawChild(ViewGroup.java:4542)
at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4300)
at android.view.View.updateDisplayListIfDirty(View.java:21876)
at android.view.View.draw(View.java:22743)
at android.view.ViewGroup.drawChild(ViewGroup.java:4542)
at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4300)
2024-07-31 00:44:09.336 31317-31317 TaskJniUtils org...lite.examples.objectdetection E at android.view.View.draw(View.java:23021)
at com.android.internal.policy.DecorView.draw(DecorView.java:891)
at android.view.View.updateDisplayListIfDirty(View.java:21885)
at android.view.ThreadedRenderer.updateViewTreeDisplayList(ThreadedRenderer.java:534)
at android.view.ThreadedRenderer.updateRootDisplayList(ThreadedRenderer.java:542)
at android.view.ThreadedRenderer.draw(ThreadedRenderer.java:625)
at android.view.ViewRootImpl.draw(ViewRootImpl.java:4657)
at android.view.ViewRootImpl.performDraw(ViewRootImpl.java:4375)
at android.view.ViewRootImpl.performTraversals(ViewRootImpl.java:3486)
at android.view.ViewRootImpl.doTraversal(ViewRootImpl.java:2277)
at android.view.ViewRootImpl$TraversalRunnable.run(ViewRootImpl.java:9037)
at android.view.Choreographer$CallbackRecord.run(Choreographer.java:1142)
at android.view.Choreographer.doCallbacks(Choreographer.java:946)
at android.view.Choreographer.doFrame(Choreographer.java:875)
at android.view.Choreographer$FrameDisplayEventReceiver.run(Choreographer.java:1127)
at android.os.Handler.handleCallback(Handler.java:938)
at android.os.Handler.dispatchMessage(Handler.java:99)
at android.os.Looper.loopOnce(Looper.java:210)
at android.os.Looper.loop(Looper.java:299)
at android.app.ActivityThread.main(ActivityThread.java:8293)
at java.lang.reflect.Method.invoke(Native Method)
at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:556)
at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:1045)