jonb377
jonb377
It would be easiest to debug if we have a standalone reproduction, like if the issue occurs with a simple `nn.Linear`. I can work on a minimal repro, but just...
That sounds fine, please go ahead and share the repo. I suspect this is just related to matmul precision on TPU and CPU. I will do some tests to verify...
Hi @mfatih7, I'll have some time this afternoon to look. Will keep you posted!
@mfatih7 I was able to reproduce the numbers you reported, and increasing the TPU matmul precision didn't help. I'm curious to understand the use case a bit more. I noticed...
Thanks @mfatih7 for the context! Just to confirm my high-level understanding: - The actual model is being trained on SPMD TPU. - Inputs to the training are generated from another...
One concern I have is the `make_optimizer_prime` call is happening on every invocation to `get_checkpoint_template`, so we're doing an extra dummy optimizer step each checkpoint. Depending on the optimizer, this...
@mfatih7 I've done a couple of tests related to the dataloading: - Inputs agree between CPU and TPU before passing through the frozen model. - After passing through the frozen...
Hey @mfatih7, thanks for raising the issue! Could you please share the repro? Based on the logs, the `optimizer_state_dict` isn't being tracked in the checkpoint metadata. A few things to...
Thanks! Can you share the `get_checkpoint_template` code as well?
Hmm, I don't see anything glaringly wrong with the approach. A few questions: - Can you share some more details about when you hit the error? You mentioned it works...