tapnet icon indicating copy to clipboard operation
tapnet copied to clipboard

Intermediate training metrics for TAPIR

Open swarnim-j opened this issue 1 year ago • 4 comments

I'm looking to replicate training of TAPIR on Kubric MOVi-E with adjusted compute resources and need insights on intermediate training metrics.

What were the intermediate training statistics during your runs? Specifically:

  • At which epoch did you start observing improved loss values?
  • What were the loss values at higher epoch numbers (before final convergence)?

Your setup: 64 TPU × 4 batch size = 256 total batch, 50k steps, 1e-3 learning rate Our proposed setup: 4 GPU × 2 batch size = 8 total batch, 1.6M steps, 3.125e-5 learning rate

(We've scaled the training parameters to account for our reduced compute capacity while maintaining equivalent total compute steps. Would appreciate confirmation if this adaptation looks reasonable.)

Here are some intermediate results: image

Each epoch (over entire train set) is ~1250 steps, so at this point, the dataset has seen the train set around 36 times.

swarnim-j avatar Dec 26 '24 04:12 swarnim-j

I wouldn't expect it to be trivial to make this work with small batch size. At the very least, I expect you would need to re-tune the adam parameters. Probably better to use gradient accumulation if you want to train with fewer devices.

Here's a training curve for the position (huber) loss for a model trained for 100K steps (note that our internal training scripts have progressed beyond what we released for the paper, and the released ones aren't maintained carefully). I'm not sure the loss scales match what you're using here, but I don't expect this to matter much; as long as the huber loss dominates the total loss, then I would expect the model to train.

image

cdoersch avatar Jan 02 '25 10:01 cdoersch

Thanks. I've added gradient accumulation (4 GPU x 2 BS x 32 grad acc = 256) and losses do seem to go down.

Do you know why the loss scales in your graph vs ours might be so different, since we're using the same loss file as the repository (PyTorch version, not JAX), and whether this would make a difference?

image

swarnim-j avatar Jan 09 '25 15:01 swarnim-j

Hi, just wanted to follow up on this

swarnim-j avatar Jan 26 '25 21:01 swarnim-j

Hi @swarnim-j ,

Sorry we do not have bandwidth to look into the problem. We just attach the intermediate evaluation metrics for TAPVid-DAVIS strided eval as a further reference for you. Hope you can compare with yours and make further analysis. Thanks.

yangyi02 avatar Apr 04 '25 18:04 yangyi02