TransformerEngine
TransformerEngine copied to clipboard
[JAX] Rewrite the Format of FP8 Meta and Remove unused ShardingTypes.
Description
Reformatted FP8 meta to one set per tensor, removed fp8_max and scale_inv from the set of FP8 meta, and deleted unused functions and types.
Fixes # (issue)
To avoid unnecessary slice of FP8 meta then unblock pipeliner to re-schedule the collectives.
Type of change
- [ ] Documentation change (change only to the documentation, either a fix or a new content)
- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [x] Breaking change (fix or feature that would cause existing functionality to not work as expected) The existing FP8 checkpoints would not be compatiable.
Changes
Please list the changes introduced in this PR:
- Reformat FP8 meta to be one-set-per-tensor.
- Remove
fp8_maxandscale_invfrom FP8 meta set. - Remove unused functions in fp8.py, like
update_fp8_metas. - Remove unused
ShardingTypeandMajorShardingType.
Checklist:
- [x] I have read and followed the contributing guidelines
- [x] The functionality is complete
- [x] I have commented my code, particularly in hard-to-understand areas
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my feature works
- [x] New and existing unit tests pass locally with my changes
/te-ci jax
/te-ci jax
Can you tell what speed up you see with this PR?
/te-ci jax
@denera As this PR is removed some old API not used, we need to have that documented in the next release. Is there a place we need to add them to be sure to be included in the next releases?
/te-ci jax
/te-ci jax
LGTM 👍. Would be interesting to see the diff in performance if any.
Change to Draft for waiting internal verifiy
/te-ci jax