ColossalAI
ColossalAI copied to clipboard
[PROPOSAL]: FP8 with block-wise amax
Proposal
@kuozhang brought up in #6101 that FP8 with TP should all_reduce a global amax history.
However based on my understanding of the code for creating amax history, it seems to only create and update local scaling factors, with num_scale=1 meaning one factor over all features? This seems equivalent to computing block-wise amax usingtp_size blocks as in QLoRA, and should be more accurate.
In anycase, I feel NVIDIA's method for tracking amax stats is quite coarse. They only track the amax over all features during a history window, and don't test precision with other methods. In the future we could test
- Computing block-wise amax instead of over all features?
- Other amax history tracking methods, such as exponential moving average. Feel free to submit PRs/correct my opinions if the community/team have time :)
Self-service
- [ ] I'd be willing to do some initial work on this proposal myself.