torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

Moved logits `.float()` to loss and compiled it if compiling

Open awgu opened this issue 1 year ago • 2 comments

Stack from ghstack (oldest at bottom):

  • #567
  • -> #551

Compiling the loss improves performance. Moving the .float() upcast to inside this compiled loss further improves performance.

awgu avatar Aug 21 '24 16:08 awgu

@yifuwang I need to fix PP before this is landable 😢

awgu avatar Aug 21 '24 18:08 awgu

@H-Huang @wconstab do you have any idea if the output logits being fp32 is a hard requirement for PP? anyway we can leave them as bf16?

awgu avatar Aug 21 '24 20:08 awgu

lost the local branch going to reopen

awgu avatar Oct 22 '24 14:10 awgu

@H-Huang @wconstab do you have any idea if the output logits being fp32 is a hard requirement for PP? anyway we can leave them as bf16?

sorry didn't see this before.

This is just the outputs of the last layer? PP should be OK with this, with a few caveats- 1- if its not the last layer, then we'll send/recv it, and iirc we do not have nccl support for 8-bit types, though 16 should be fine. 2- we also need to know the size when we allocate send/recv buffers. In august this would have had to be configured at PipelineStage.init time when you pass the manual 'input/output'_args values, but recently landed shape-inference runs at runtime based on real tensors provided to .step() api so this should be totally transparent now

Both of those points should not apply to the values produced as outputs of the last stage (or the losses), so i'll need more details about what kind of error you see.

wconstab avatar Oct 22 '24 16:10 wconstab