TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

Add high_precision_init_val to model params when using fp8_model_init

Open kunlunl opened this issue 1 year ago • 8 comments
trafficstars

Description

When using fp8_model_init to create a model, the weights will be casted to Float8Tensor. However, in scenarios where high-precision (FP32) master weights are needed, initializing the master weights with these FP8 weights can affect the loss convergence compared to using bf16/fp16 to initialize master weights (especially in the early stages of training). This PR stores the original bf16/fp16 params as cpu tensors within the FP8 weights, which can be used to initialize master weights in other frameworks like MCore.

Fixes # (issue)

Type of change

  • [ ] Documentation change (change only to the documentation, either a fix or a new content)
  • [ ] Bug fix (non-breaking change which fixes an issue)
  • [x] New feature (non-breaking change which adds functionality)
  • [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • [ ] Infra/Build change
  • [ ] Code refractor

Changes

Please list the changes introduced in this PR:

  • Stores the original bf16/fp16 params as cpu tensors within the FP8 weights

Checklist:

  • [x] I have read and followed the contributing guidelines
  • [x] The functionality is complete
  • [x] I have commented my code, particularly in hard-to-understand areas
  • [x] I have made corresponding changes to the documentation
  • [x] My changes generate no new warnings
  • [x] I have added tests that prove my fix is effective or that my feature works
  • [x] New and existing unit tests pass locally with my changes

kunlunl avatar Aug 19 '24 13:08 kunlunl

Hmmm, I see the problem you are trying to solve, although I don't think I like the approach (don't think I have an alternative ready yet though :-( ). Did you consider any other ways to solve this issue?

ptrendx avatar Aug 19 '24 18:08 ptrendx

I also feel it is a very ugly approach, but I can't think of a better way to do it. Then I asked @timmoon10 if he has any insight, I quote his ideas here, @timmoon10 you can add more comments if you have.

  • Storing on CPU is a pretty hacky approach, so we could modify te.fp8_model_init so that the CPU copy is optional.
  • I wonder how much the problem is because the initial scaling factor is 1, which is likely too low and results in many underflows. One approach is to do the FP8 cast twice: once to figure out the amax and again with an optimal scaling factor. One problem is that this doesn't handle tensor parallelism well, since we want the amaxes to be synchronized over the TP group. We either need to have many TP max all-reduces (one per param) or we need to make structural changes in how we initialize FP8 params
  • If we have master weights in FP32, how about perturbing them to approximate the original distribution:
    • fp32_params = fp8_params.from_float8()
    • fp32_params += torch.abs(fp32_params) + (torch.rand_like(fp32_params) - 0.5) * fp8_eps
    • This is simple, but the numerics are subtly different and I'm not sure if it'll also affect convergence

I feel that method 1 (make CPU copy optional) and method 3 (add random perturbation to master weights) are more feasible methods. You can decide whether to adopt method 1, and I can go to test whether method 3 can help convergence.

However, for method 3, my concerns are:

  • The distribution of master weighs and fp8 weights may be inconsistent after adding random perturbations. Even if I make them consistent through some hard-coded method, if the initialization parameters of fp8 weights change in the future, their distribution will be different again.
  • Even if I test it and find it can help convergence, it may still not work on other models that I don't test, after all I can't test all models.
  • (In addition, I'm not sure whether the introduction of such random perturbations can be accepted by MCore.)

kunlunl avatar Aug 20 '24 10:08 kunlunl

Right... The best option would have been to create the master weights first, but that is not really possible due to the API of pyTorch.

Ok, so let's maybe do this:

  • create the option in fp8_model_init to preserve_high_precision_initialization which would be the trigger to save the copy on the CPU. We should document it properly
  • add a function to the fp8 parameters to clear the high precision weights so that after they are stored in the master weights they can be freed properly

Then for the pretraining people can use this option, while for inference/Lora/etc where those weights come pretrained they will not incur the CPU memory overhead.

ptrendx avatar Aug 20 '24 16:08 ptrendx

Ok, so let's maybe do this:

  • create the option in fp8_model_init to preserve_high_precision_initialization which would be the trigger to save the copy on the CPU. We should document it properly
  • add a function to the fp8 parameters to clear the high precision weights so that after they are stored in the master weights they can be freed properly

Ok, I'll do this.

kunlunl avatar Aug 21 '24 14:08 kunlunl

@ptrendx I've finished the revision, could you help to find someone to review it?

kunlunl avatar Aug 27 '24 10:08 kunlunl

Added a unit test and signed off my commit.

kunlunl avatar Aug 28 '24 08:08 kunlunl

@timmoon10 Can you take a look and merge this?

kunlunl avatar Sep 03 '24 02:09 kunlunl

Are we reliant on the user to clear the additional memory? Alternate could to free this memory during the forward pass or something, if the only use here is to initialize master weights.

@ksivaman Yes, in my previous idea, there were two ways to clear this variable: 1. Clear the variable inside the get_xxx method, so that the memory is automatically reclaimed after the user accesses it once; 2. The user manually clears the variable by calling clear_xxx method. I choose the latter because I found that when I used it in MCore, I needed to access it more than once..

I think what you said makes a lot of sense, but besides automatically reclaiming this resource after first forward, should we still keep the manual delete method? This way, users can reclaim this resources in advance when needed, avoiding some corner cases (for example, running out of cpu memory during the first forward, although I'm not sure if this will happen...)

Also, do you have any suggestions on where to put the code for "automatically reclaiming this resource after the first forward"? Should I put it in the forward() of each module? (This need to modify the forward() code of each module).

but maybe we could think of something clearer since this is a documented arg.

Do you have any suggestions for this also?

kunlunl avatar Sep 03 '24 17:09 kunlunl