DeepForest icon indicating copy to clipboard operation
DeepForest copied to clipboard

Add seperate config for predict batch size and train batch size.

Open bw4sz opened this issue 11 months ago • 3 comments

Updating model weights takes alot more GPU memory than just a forward model pass.

predict.tile is slower than it needs to be because its using trainer.predict, which inherits a dataloader with batch size set by the global config https://github.com/weecology/DeepForest/blob/3dbc8342de766f1f504d3c2da69c1fbc2443da42/src/deepforest/main.py#L348

and in train gets from load_dataset.

https://github.com/weecology/DeepForest/blob/3dbc8342de766f1f504d3c2da69c1fbc2443da42/src/deepforest/main.py#L335

the default is 1 because training is unknown size GPU (probably should be 2)

  1. Make a predict_batch_size and a train_batch_size config arg
  2. Update defaults to 2 for train and 8 for predict.
  3. update the config doc
  4. Write tests showing the dataloaders of each are yielding correct sizes.

I'm unsure about the val dataloader batch size, maybe should be higher, not clear to me the GPU memory. I think val batch size should be the predict size, since no weights are updated.

bw4sz avatar Jan 16 '25 15:01 bw4sz

Hi @bw4sz , I saw this issue and thought it looked really interesting! I Would like to contribute to this ? Any guidance would be appreciated :)

rabelmervin avatar Jan 17 '25 07:01 rabelmervin

Go for it. Do you have access to GPU? Not yet sure if validation batch_size and predict_batch size should be the same or separate arguments. Make sure to profile the example code. Do you need a large tile to test on, you won't notice much on the sample package data.

https://www.dropbox.com/scl/fi/yki42nmplok43isi1queb/2021_TEAK_5_322000_4097000_image.tif?rlkey=aaq4sc3jqa13oo4axuh0vw93d&dl=0

import time
import numpy as np
from deepforest import main, get_data

def profile_predict_tile(batch_sizes, raster_path):
    model = main.deepforest()
    model.load_model(model_name="weecology/deepforest-tree")
    
    for batch_size in batch_sizes:
        model.config["batch_size"] = batch_size
        start_time = time.time()
        model.predict_tile(raster_path=raster_path, patch_size=300, patch_overlap=0.25)
        end_time = time.time()
        print(f"Batch Size: {batch_size}, Time Taken: {end_time - start_time} seconds")

if __name__ == "__main__":
    raster_path = <path_to_raster>
    batch_sizes = [1, 2, 4, 8, 16]
    profile_predict_tile(batch_sizes, raster_path)

bw4sz avatar Jan 19 '25 23:01 bw4sz

hi @bw4sz, I have completed the profiling test on my RTX 4050 (6GB). Here are the results:

Image

Please let me know if any further modifications are needed.

Bhavya1604 avatar Feb 28 '25 14:02 Bhavya1604