torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

How to use float8 for training?

Open vgoklani opened this issue 1 year ago • 7 comments

Are there any examples for training the MLP blocks using float8 from torchao?

Thanks!

vgoklani avatar Dec 23 '24 17:12 vgoklani

Hi @vgoklani , we don't currently support this, but you could modify a recipe to call torchao.float8.convert_to_float8_training on your model at the end of this function.

However, I recommend using QLoRA, where frozen base model params are quantized to a lower precision (NF4), while the trainable adapter params are kept in a higher precision.

Here's an example config. The QLoRA model builder replaces the model's linear layers with LoRALinear(...,quantize_base=True) layers. If you want to use float8 instead of nf4, you can modify the LoRALinear class.

Let me know if you have any questions!

calvinpelletier avatar Dec 23 '24 18:12 calvinpelletier

Thanks @calvinpelletier. We are using the full-finetune scripts, and since the hardware already supports FP8, we are just leaving a lot of performance on the table... We can add it to our internal version, but I would imagine that there are other groups that want this included too.

vgoklani avatar Dec 23 '24 22:12 vgoklani

We would definitely appreciate a PR if full-finetuning in FP8 works out well for you all!

calvinpelletier avatar Dec 23 '24 23:12 calvinpelletier

I was working on adding INT8 training to torchtune #1552, and FP8 was also on the discussion. Once the INT8 PR is merged, we can make another one for FP8 too, since it follows a similar design.

gau-nernst avatar Dec 24 '24 05:12 gau-nernst

Thank you @calvinpelletier and @gau-nernst

Using Dynamic scaling with the torachao api was trivial, and gave a ~30% performance boost in tokens-per-second

We're running on 4x NVIDIA A6000 Ada cards (SM89)


from torchao.float8 import (
    CastConfig,
    Float8LinearConfig,
    ScalingType,
    convert_to_float8_training,
)

config = Float8LinearConfig(
    enable_fsdp_float8_all_gather=True,
    force_recompute_fp8_weight_in_bwd=True,
    cast_config_input=CastConfig(scaling_type=ScalingType.DYNAMIC),
    cast_config_weight=CastConfig(scaling_type=ScalingType.DYNAMIC),
    cast_config_grad_output=CastConfig(scaling_type=ScalingType.DYNAMIC),
)

convert_to_float8_training(mlp, config=config)

strangely enough, using DELAYED scaling crashed torch.compile... will need to dig into that further...

vgoklani avatar Dec 24 '24 16:12 vgoklani

@vgoklani Delayed scaling is not as well-supported as dynamic scaling I think. Should be fine to stick to dynamic scaling.

Curious. Do you observe any convergence issue?

gau-nernst avatar Dec 25 '24 01:12 gau-nernst

@gau-nernst The loss was very close to bfloat16! I'm looking forward to int8 training :)

vgoklani avatar Dec 25 '24 02:12 vgoklani

@vgoklani any trick to getting it working other than using the torchao convert on the model? I got an error that a certain operation (or I think attribute?) can't be sharded:

[rank0]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_sharding_prop.py", line 450, in propagate_op_sharding_non_cached
[rank0]:     raise NotImplementedError(
[rank0]: NotImplementedError: Operator aten.is_pinned.default does not have a sharding strategy registered.

nazrak-atlassian avatar Feb 12 '25 12:02 nazrak-atlassian