TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

[JAX] Rewrite the Format of FP8 Meta and Remove unused ShardingTypes.

Open mingxu1067 opened this issue 1 year ago • 8 comments

Description

Reformatted FP8 meta to one set per tensor, removed fp8_max and scale_inv from the set of FP8 meta, and deleted unused functions and types.

Fixes # (issue) To avoid unnecessary slice of FP8 meta then unblock pipeliner to re-schedule the collectives.

Type of change

  • [ ] Documentation change (change only to the documentation, either a fix or a new content)
  • [ ] Bug fix (non-breaking change which fixes an issue)
  • [ ] New feature (non-breaking change which adds functionality)
  • [x] Breaking change (fix or feature that would cause existing functionality to not work as expected) The existing FP8 checkpoints would not be compatiable.

Changes

Please list the changes introduced in this PR:

  1. Reformat FP8 meta to be one-set-per-tensor.
  2. Remove fp8_max and scale_inv from FP8 meta set.
  3. Remove unused functions in fp8.py, like update_fp8_metas.
  4. Remove unused ShardingType and MajorShardingType.

Checklist:

  • [x] I have read and followed the contributing guidelines
  • [x] The functionality is complete
  • [x] I have commented my code, particularly in hard-to-understand areas
  • [x] I have made corresponding changes to the documentation
  • [x] My changes generate no new warnings
  • [x] I have added tests that prove my fix is effective or that my feature works
  • [x] New and existing unit tests pass locally with my changes

mingxu1067 avatar May 13 '24 15:05 mingxu1067

/te-ci jax

mingxu1067 avatar May 13 '24 16:05 mingxu1067

/te-ci jax

mingxu1067 avatar May 13 '24 18:05 mingxu1067

Can you tell what speed up you see with this PR?

nouiz avatar May 13 '24 21:05 nouiz

/te-ci jax

mingxu1067 avatar May 13 '24 22:05 mingxu1067

@denera As this PR is removed some old API not used, we need to have that documented in the next release. Is there a place we need to add them to be sure to be included in the next releases?

nouiz avatar May 14 '24 17:05 nouiz

/te-ci jax

mingxu1067 avatar May 14 '24 18:05 mingxu1067

/te-ci jax

mingxu1067 avatar May 14 '24 20:05 mingxu1067

LGTM 👍. Would be interesting to see the diff in performance if any.

phu0ngng avatar May 15 '24 18:05 phu0ngng

Change to Draft for waiting internal verifiy

mingxu1067 avatar May 21 '24 23:05 mingxu1067

/te-ci jax

mingxu1067 avatar Jun 11 '24 19:06 mingxu1067