DeepForest
DeepForest copied to clipboard
Create a workflow for downstream classification of crops.
Often users don't have annotations for entire images with annotated bounding boxes, but just have individual point locations. It would be helpful to have a fine-tuning workflow that allows users to take a torchvision classification model (e.g resnet-50) and apply it predicted class locations based on an existing retinanet model. We would need to a workflow to train a classification model and deliver crops from a deepforest model for prediction.
For example, I have used this approach to be predict whether trees are alive/dead.
Model definition
- https://github.com/weecology/DeepTreeAttention/blob/main/src/models/dead.py
Predict locations of 'Tree'
- https://github.com/weecology/DeepTreeAttention/blob/6c8dd01e839c1179961f4d6cd3794fc41f218ed0/src/generate.py#L17
Crop the predicted tree locations into chips
- https://github.com/weecology/DeepTreeAttention/blob/6c8dd01e839c1179961f4d6cd3794fc41f218ed0/src/patches.py#L4
Classify each of those chips using the alive/dead model
- https://github.com/weecology/DeepTreeAttention/blob/6c8dd01e839c1179961f4d6cd3794fc41f218ed0/src/predict.py#L179
This particular example is nested with a larger project. It would be great to generalize it, formalize it for the deepforest repo.
I am imagine a set of functions like
from main import predict_crops
crop_classifier = predict_crops.train_classifier(annotation_dir=<path_to_shapefiles_of_annotations>)
dataloader = predict_crops.create_crops("<path to giant unlabeled tile>.tif")
predictions = crop_classifier.predict_dataloader(dataloader)
Where we have a config block that points to a folder that follows the https://pytorch.org/vision/main/generated/torchvision.datasets.ImageFolder.html conventions for model training.
- train_classifier would convert a shapefile of point locations into labeled crops and run a finetuning on a resnet model.
- create_crops would cut an unlabeled tile into predicted locations of some deepforest baseline model (like trees -> use_release) and return a pytorch dataloader
- predict_dataloader would take in the dataloader and return a shapefile of bounding box locations and predicted class.
@bw4sz I would like to work on this issue, can you guide me on how to get started with it?