Moved logits `.float()` to loss and compiled it if compiling
Stack from ghstack (oldest at bottom):
- #567
- -> #551
Compiling the loss improves performance. Moving the .float() upcast to inside this compiled loss further improves performance.
@yifuwang I need to fix PP before this is landable 😢
@H-Huang @wconstab do you have any idea if the output logits being fp32 is a hard requirement for PP? anyway we can leave them as bf16?
lost the local branch going to reopen
@H-Huang @wconstab do you have any idea if the output logits being fp32 is a hard requirement for PP? anyway we can leave them as bf16?
sorry didn't see this before.
This is just the outputs of the last layer? PP should be OK with this, with a few caveats- 1- if its not the last layer, then we'll send/recv it, and iirc we do not have nccl support for 8-bit types, though 16 should be fine. 2- we also need to know the size when we allocate send/recv buffers. In august this would have had to be configured at PipelineStage.init time when you pass the manual 'input/output'_args values, but recently landed shape-inference runs at runtime based on real tensors provided to .step() api so this should be totally transparent now
Both of those points should not apply to the values produced as outputs of the last stage (or the losses), so i'll need more details about what kind of error you see.