YOLO icon indicating copy to clipboard operation
YOLO copied to clipboard

Incorrect Output Predictions for Image Classification Model

Open janeSmith99221 opened this issue 4 months ago • 1 comments

I am encountering incorrect predictions from my image classification model. The model accepts input of shape [1, 3, 224, 224] and is expected to output class probabilities in the shape [1, 44]. However, the predicted outputs are inconsistent or incorrect when compared to the expected class labels.

Details: Model Type: Image Classification. Input Shape: [1, 3, 224, 224] (single image input). Output Shape: [1, 44] (predicted class probabilities for 44 classes).

here is ImageClassification code

`class ImageClassification( private val context: Context, private val modelPath: String, private val labelPath: String?, private val classificationListener: ClassificationListener, private val message: (String) -> Unit ) {

private var interpreter: Interpreter
private var labels = mutableListOf<String>()

private var tensorWidth = 0
private var tensorHeight = 0
private var numClass = 0

private val imageProcessor = ImageProcessor.Builder()
    .add(NormalizeOp(INPUT_MEAN, INPUT_STANDARD_DEVIATION))
    .add(CastOp(INPUT_IMAGE_TYPE))
    .build()

init {

    val options = Interpreter.Options().apply{
       this.setNumThreads(4)
    }

    val model = FileUtil.loadMappedFile(context, modelPath)
    interpreter = Interpreter(model, options)

    labels.addAll(extractNamesFromMetadata(model))
    if (labels.isEmpty()) {
        if (labelPath == null) {
            message("Model not contains metadata, provide LABELS_PATH in Constants.kt")
            labels.addAll(MetaData.TEMP_CLASSES)
        } else {
            labels.addAll(extractNamesFromLabelFile(context, labelPath))
        }
    }

    val inputShape = interpreter.getInputTensor(0)?.shape()
    val outputShape = interpreter.getOutputTensor(0)?.shape()

    if (inputShape != null) {
        tensorWidth = inputShape[1]
        tensorHeight = inputShape[2]

        // If in case input shape is in format of [1, 3, ..., ...]
        if (inputShape[1] == 3) {
            tensorWidth = inputShape[2]
            tensorHeight = inputShape[3]
        }
    }
    if (outputShape != null) {
        numClass = outputShape[1]
    }
    Log.i("TODO", "input shape: ${inputShape.contentToString()}")
    Log.i("TODO", "output shape: ${outputShape.contentToString()}")
    Log.i("TODO", "image width: $tensorWidth")
    Log.i("TODO", "image height: $tensorHeight")
    Log.i("TODO", "model classes: $numClass")
}

fun close() {
    interpreter.close()
}

fun invoke(frame: Bitmap) {
    if (tensorWidth == 0) return
    if (tensorHeight == 0) return

    var inferenceTime = SystemClock.uptimeMillis()

    val resizedBitmap = Bitmap.createScaledBitmap(frame, tensorWidth, tensorHeight, false)

    val tensorImage = TensorImage(INPUT_IMAGE_TYPE)
    tensorImage.load(resizedBitmap)
    val processedImage = imageProcessor.process(tensorImage)
    val imageBuffer = processedImage.buffer
    val output = TensorBuffer.createFixedSize(intArrayOf(1 , numClass) , OUTPUT_IMAGE_TYPE)
    interpreter.run(imageBuffer, output.buffer)

    val outputArray = output.floatArray
    Log.i("TODO", "invoke: model output:  ${outputArray.contentToString()}")
    val predictions = mutableListOf<Prediction>()

    outputArray.forEachIndexed { index, float ->
        if (float > CONFIDENCE_THRESHOLD) {
            predictions.add(
                Prediction(
                    id = index,
                    name = labels[index],
                    score = float
                )
            )
        }
    }

    predictions.sortByDescending { it.score }

    inferenceTime = SystemClock.uptimeMillis() - inferenceTime
    classificationListener.onResult(predictions, inferenceTime)
}

interface ClassificationListener {
    fun onResult(data: List<Prediction>, inferenceTime: Long)
}


companion object {
    private const val INPUT_MEAN = 0f
    private const val INPUT_STANDARD_DEVIATION = 255f
    private val INPUT_IMAGE_TYPE = DataType.FLOAT32
    private val OUTPUT_IMAGE_TYPE = DataType.FLOAT32
    private const val CONFIDENCE_THRESHOLD = 0.01F
}

}`

janeSmith99221 avatar Oct 11 '24 11:10 janeSmith99221