How to use float8 for training?
Are there any examples for training the MLP blocks using float8 from torchao?
Thanks!
Hi @vgoklani , we don't currently support this, but you could modify a recipe to call torchao.float8.convert_to_float8_training on your model at the end of this function.
However, I recommend using QLoRA, where frozen base model params are quantized to a lower precision (NF4), while the trainable adapter params are kept in a higher precision.
Here's an example config. The QLoRA model builder replaces the model's linear layers with LoRALinear(...,quantize_base=True) layers. If you want to use float8 instead of nf4, you can modify the LoRALinear class.
Let me know if you have any questions!
Thanks @calvinpelletier. We are using the full-finetune scripts, and since the hardware already supports FP8, we are just leaving a lot of performance on the table... We can add it to our internal version, but I would imagine that there are other groups that want this included too.
We would definitely appreciate a PR if full-finetuning in FP8 works out well for you all!
I was working on adding INT8 training to torchtune #1552, and FP8 was also on the discussion. Once the INT8 PR is merged, we can make another one for FP8 too, since it follows a similar design.
Thank you @calvinpelletier and @gau-nernst
Using Dynamic scaling with the torachao api was trivial, and gave a ~30% performance boost in tokens-per-second
We're running on 4x NVIDIA A6000 Ada cards (SM89)
from torchao.float8 import (
CastConfig,
Float8LinearConfig,
ScalingType,
convert_to_float8_training,
)
config = Float8LinearConfig(
enable_fsdp_float8_all_gather=True,
force_recompute_fp8_weight_in_bwd=True,
cast_config_input=CastConfig(scaling_type=ScalingType.DYNAMIC),
cast_config_weight=CastConfig(scaling_type=ScalingType.DYNAMIC),
cast_config_grad_output=CastConfig(scaling_type=ScalingType.DYNAMIC),
)
convert_to_float8_training(mlp, config=config)
strangely enough, using DELAYED scaling crashed torch.compile... will need to dig into that further...
@vgoklani Delayed scaling is not as well-supported as dynamic scaling I think. Should be fine to stick to dynamic scaling.
Curious. Do you observe any convergence issue?
@gau-nernst The loss was very close to bfloat16! I'm looking forward to int8 training :)
@vgoklani any trick to getting it working other than using the torchao convert on the model? I got an error that a certain operation (or I think attribute?) can't be sharded:
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_sharding_prop.py", line 450, in propagate_op_sharding_non_cached
[rank0]: raise NotImplementedError(
[rank0]: NotImplementedError: Operator aten.is_pinned.default does not have a sharding strategy registered.