super-gradients icon indicating copy to clipboard operation
super-gradients copied to clipboard

Support Predict on Classification Models

Open danielafrimi opened this issue 1 year ago • 1 comments

Until now the changes made only for Resnet (HasPredict interface).

  1. Basic script to run a classifier with predict functionalities.

  2. Created ClassificationPipeline

    • _decode_model_output - call argmax on he model output, needs to pass confidence via softmax activations. in addition we create ClassificationPrediction per prediction with confidence, label and image_shaoe attr.
    • _instantiate_image_prediction
    • _combine_image_prediction_to_images
    • _combine_image_prediction_to_video -> we dont need it to classifier, what should be the implementation? we wont call predict_video.
  3. Created Prediction objects for classification (with the implementation of show, save and draw functions).

  4. Created some processes according to validation transforms (implement Resize, CenterCrop) + default_resnet_imagenet_processing_params() which create the image_processor.

Up Next

  1. Creating a middle class BaseClassificationModel that inherit from Sg.moudle and will preform as base class for all the classifier models (which inherit from Sg.moudle). this BaseClassificationModel obj will include the HasPredict interface function. - Done
  2. Drawing top k with confidence on the image itself.
  3. default_resnet_imagenet_processing_params -> make it generic to classification models (needs to add more processing functions).
  4. create a function of drawing in vizualization dir (use it in draw function in Prediction obj) - Done
  5. run black + flake8 - Done

danielafrimi avatar Jun 15 '23 12:06 danielafrimi

It looks like the code you've pushed is not properly formatted. We are using black and pre-commit hooks to ensure that code is formatted identically for everyone. Feel free to check https://github.com/Deci-AI/super-gradients/blob/master/CONTRIBUTING.md#pre-commit-hooks on configuring hooks and running black

BloodAxe avatar Jun 16 '23 09:06 BloodAxe

Merged in https://github.com/Deci-AI/super-gradients/pull/1220 because no signed commits in this branch

Louis-Dupont avatar Jun 27 '23 12:06 Louis-Dupont