Add seperate config for predict batch size and train batch size.
Updating model weights takes alot more GPU memory than just a forward model pass.
predict.tile is slower than it needs to be because its using trainer.predict, which inherits a dataloader with batch size set by the global config https://github.com/weecology/DeepForest/blob/3dbc8342de766f1f504d3c2da69c1fbc2443da42/src/deepforest/main.py#L348
and in train gets from load_dataset.
https://github.com/weecology/DeepForest/blob/3dbc8342de766f1f504d3c2da69c1fbc2443da42/src/deepforest/main.py#L335
the default is 1 because training is unknown size GPU (probably should be 2)
- Make a predict_batch_size and a train_batch_size config arg
- Update defaults to 2 for train and 8 for predict.
- update the config doc
- Write tests showing the dataloaders of each are yielding correct sizes.
I'm unsure about the val dataloader batch size, maybe should be higher, not clear to me the GPU memory. I think val batch size should be the predict size, since no weights are updated.
Hi @bw4sz , I saw this issue and thought it looked really interesting! I Would like to contribute to this ? Any guidance would be appreciated :)
Go for it. Do you have access to GPU? Not yet sure if validation batch_size and predict_batch size should be the same or separate arguments. Make sure to profile the example code. Do you need a large tile to test on, you won't notice much on the sample package data.
https://www.dropbox.com/scl/fi/yki42nmplok43isi1queb/2021_TEAK_5_322000_4097000_image.tif?rlkey=aaq4sc3jqa13oo4axuh0vw93d&dl=0
import time
import numpy as np
from deepforest import main, get_data
def profile_predict_tile(batch_sizes, raster_path):
model = main.deepforest()
model.load_model(model_name="weecology/deepforest-tree")
for batch_size in batch_sizes:
model.config["batch_size"] = batch_size
start_time = time.time()
model.predict_tile(raster_path=raster_path, patch_size=300, patch_overlap=0.25)
end_time = time.time()
print(f"Batch Size: {batch_size}, Time Taken: {end_time - start_time} seconds")
if __name__ == "__main__":
raster_path = <path_to_raster>
batch_sizes = [1, 2, 4, 8, 16]
profile_predict_tile(batch_sizes, raster_path)
hi @bw4sz, I have completed the profiling test on my RTX 4050 (6GB). Here are the results:
Please let me know if any further modifications are needed.