tflite-support
tflite-support copied to clipboard
Add C++ Tests for Image Segmenter Category Mask
The current C++ category mask test mocks an empty frame buffer to test the Postprocess function. There is no test for calling the Segment function with a frame buffer constructed from image data and obtaining a category mask result. Could this be added in C++? This will also let the iOS, Python APIs compare to the values of this test for category mask.
@lu-wang-g @mbrenon @khanhlvg
My understanding is that we have test for the category mask but not the confidence mask. Your question is how to add test for confidence mask - is that correct?
Here is how we write the confidence mask test internally in the Java layer. Can you do the same for the iOS tests?
@Test
public void segment_succeedsWithConfidenceMask() throws Exception {
TensorImage image =
TensorImage.fromBitmap(
getBitmapFromAssets(INPUT_IMAGE, ImageType.JPEG, Config.ARGB_8888));
ImageSegmenterOptions options =
ImageSegmenterOptions.builder().setOutputType(OutputType.CONFIDENCE_MASK).build();
ImageSegmenter imageSegmenter =
ImageSegmenter.createFromFileAndOptions(
ApplicationProvider.getApplicationContext(), MODEL_FILE, options);
List<Segmentation> results = imageSegmenter.segment(image);
assertSegmentationResultWithConfidenceMasksIsCorrect(
results, image, ImageProcessingOptions.builder().build());
}
/**
* Verifies the segmentation results with confidence masks by comparing with the corresponding
* category mask. The correctness of the category mask has been verified in other tests, such as
* segment_succeedsWithCategoryMask. The category mask is obtained by running another {@link
* ImageSegmenter}, which outputs category mask, on the same input image and image processing
* options.
*/
private static void assertSegmentationResultWithConfidenceMasksIsCorrect(
List<Segmentation> results, TensorImage image, ImageProcessingOptions imageOptions)
throws Exception {
// There should be only one Segmentation result.
assertThat(results).hasSize(1);
Segmentation result = results.get(0);
// The output type of the Segmentation result should be CONFIDENCE_MASK.
assertThat(result.getOutputType()).isEqualTo(OutputType.CONFIDENCE_MASK);
// The Segmentation result should contain 21 masks corresponding to 21 colored labels.
assertThat(result.getMasks()).hasSize(21);
// Verify the shape and type of each mask.
for (TensorImage mask : result.getMasks()) {
assertThat(mask.getWidth()).isEqualTo(OUTPUT_IMAGE_WIDTH);
assertThat(mask.getHeight()).isEqualTo(OUTPUT_IMAGE_HEIGHT);
assertThat(mask.getColorSpaceType()).isEqualTo(ColorSpaceType.GRAYSCALE);
}
// Create a ImageSegmenter that outputs category mask, and verify that the confidence masks
// resuts match the category mask result, which has been verified in the test,
// segment_succeedsWithCategoryMask.
ImageSegmenter imageSegmenterCategory =
ImageSegmenter.createFromFile(ApplicationProvider.getApplicationContext(), MODEL_FILE);
List<Segmentation> resultsCategory = imageSegmenterCategory.segment(image, imageOptions);
assertConfidenceMasksMatchCategoryMask(
result.getMasks(), resultsCategory.get(0).getMasks().get(0));
// Verify the colored labels in the Segmentation results match the golden one.
assertThat(result.getColoredLabels()).containsExactlyElementsIn(createGoldenColoredLabels());
}
private static void assertConfidenceMasksMatchCategoryMask(
List<TensorImage> confidenceMasks, TensorImage categoryMask) {
// Convert the confidenceMasks in TensorImage into a list of float arrays.
List<float[]> confidenceArray = new ArrayList<>();
for (TensorImage mask : confidenceMasks) {
confidenceArray.add(mask.getTensorBuffer().getFloatArray());
}
// Convert the categoryMask in TensorImage into a float array.
int[] categoryArray = categoryMask.getTensorBuffer().getIntArray();
// Verify that for every pixel position, the highest confidence corresponds to the one indicated
// by the category mask.
for (int i = 0; i < categoryArray.length; i++) {
float maxConfidence = confidenceArray.get(categoryArray[i])[i];
for (float[] element : confidenceArray) {
assertThat(element[i]).isAtMost(maxConfidence);
}
}
}
Sorry, my bad I meant the confidence masks. I linked to the wrong function. Thanks. I will use this logic for C and iOS tests.