torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

reproducable numerics for loss, weights and gradients for single node (8 GPUs)

Open weifengpy opened this issue 1 year ago • 2 comments

by default, torchtitan use FSDP2 mixed precision (param_dtype=bfloat16, reduce_dtype=float32)

for low-precision dtypes (float8 and int8), it's nature to compare loss curve with bfloat16 and see how well they match. (also a good idea to compare weights norm and gradients norm)

for bfloat16 itself, multiple runs will yield different loss curves and the undeterminism should be understood and documented (say NCCL gradient reduction, attention, seed). Otherwise it's hard to understand if numeric differences are coming from low-precision dtypes

I plotted gradient norms, loss = sum(model.parameters.grad), using llama3-8b with 8 GPUs with deterministic model init and deterministic data loader

for bfloat16, gradients are quite different in repeated runs Screenshot 2024-09-30 at 5 15 08 PM

turning off gradient norm clipping helps a lot, but could not explain all of the divergence Screenshot 2024-09-30 at 5 17 06 PM

filing the issue here and hopefully it can be a good candidate for what's next

weifengpy avatar Oct 01 '24 00:10 weifengpy

IIUC, the default SDPA backend for us is flash, and flash backward is non-deterministic?

I think we can try to enable some deterministic SDPA: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

awgu avatar Oct 01 '24 03:10 awgu

IIUC, the default SDPA backend for us is flash, and flash backward is non-deterministic?

I think we can try to enable some deterministic SDPA: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

good call out!

weifengpy avatar Oct 01 '24 04:10 weifengpy