Intermediate training metrics for TAPIR
I'm looking to replicate training of TAPIR on Kubric MOVi-E with adjusted compute resources and need insights on intermediate training metrics.
What were the intermediate training statistics during your runs? Specifically:
- At which epoch did you start observing improved loss values?
- What were the loss values at higher epoch numbers (before final convergence)?
Your setup: 64 TPU × 4 batch size = 256 total batch, 50k steps, 1e-3 learning rate Our proposed setup: 4 GPU × 2 batch size = 8 total batch, 1.6M steps, 3.125e-5 learning rate
(We've scaled the training parameters to account for our reduced compute capacity while maintaining equivalent total compute steps. Would appreciate confirmation if this adaptation looks reasonable.)
Here are some intermediate results:
Each epoch (over entire train set) is ~1250 steps, so at this point, the dataset has seen the train set around 36 times.
I wouldn't expect it to be trivial to make this work with small batch size. At the very least, I expect you would need to re-tune the adam parameters. Probably better to use gradient accumulation if you want to train with fewer devices.
Here's a training curve for the position (huber) loss for a model trained for 100K steps (note that our internal training scripts have progressed beyond what we released for the paper, and the released ones aren't maintained carefully). I'm not sure the loss scales match what you're using here, but I don't expect this to matter much; as long as the huber loss dominates the total loss, then I would expect the model to train.
Thanks. I've added gradient accumulation (4 GPU x 2 BS x 32 grad acc = 256) and losses do seem to go down.
Do you know why the loss scales in your graph vs ours might be so different, since we're using the same loss file as the repository (PyTorch version, not JAX), and whether this would make a difference?
Hi, just wanted to follow up on this
Hi @swarnim-j ,
Sorry we do not have bandwidth to look into the problem. We just attach the intermediate evaluation metrics for TAPVid-DAVIS strided eval as a further reference for you. Hope you can compare with yours and make further analysis. Thanks.