Question: How to use Float8InferenceLinear with FSDP1/2?
Hey Team,
I'm trying to use FSDP1/2 with Float8InferenceLinear but seems have some issues (with torch 2.3.1+cu118). Do you suggestion to bump to higher version of torch and have a try or maybe use the training setup without using the inference layer? I also tried using the Flont8linear layer without using the quantization function to convert to Float8InferenceLinear but seems face some issues when using FSDP1 that when computing the amax, some input x tensors are empty (x.numel()=0) and some are NaN.
Best regards, QQ
cc @drisspg @jainapurva
Unfortunately the Float8InferenceLinear is being developed against the latest pytorch nightly and is not very tested on older versions of PyTorch. If it is possible for you to update your PyTorch version that is recommend. If the problem still persists after updating and you are able to create a minimal reproducer we can look into this.
@drisspg Got it. Thank you! To confirm, torch==2.5.0dev should be the right one to use?
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --force-reinstall
Today this would indeed install 2.5 dev for today's date but yeah generally for any feature leveraging torch.compile you want to either be on the latest stable (today this is 2.4) or use nightlies
Thanks @msaroufim ! Is cu118 version also supported and tested? (if I disable torch compile and fsdp2 dtensor and just use fsdp1) let me do a quick test and check. Thank you!
A quick update: it turns out that there might but some issues with torch==2.5.0.dev20240819+cu118 installed from https://download.pytorch.org/whl/nightly/cu118
[rank3]: attn_output = torch.nn.functional.scaled_dot_product_attention(
[rank3]: RuntimeError: cuDNN Frontend error: [cudnn_frontend] Error: No execution plans support the graph.
exploring other options now and probably have to use 12.1 runtime version instead
Update: seems root cause are libnvrtc.so.11.2 loading issues
for 11.8: Could not load library libnvrtc.so.11.2. Error: libnvrtc.so.11.2: cannot open shared object file: No such file or directory Could not load library libnvrtc.so. Error: libnvrtc.so: cannot open shared object file: No such file or directory
for 12.1: Could not load library libnvrtc.so.12. Error: libnvrtc.so.12: cannot open shared object file: No such file or directory investigating now
I'd try isolating things in a fresh conda environment, also if you're mucking around with CUDA versions keep in mind that torchao binaries on pypi are using cuda 12.1 so would recommend installing ao from source or downloading it from the pytorch index
Thank you! Resolved the above issue by adding the current path to LD_LIBRARY_PATH and currently testing the fp8 with latest ao build + torch 2.5.0 dev cu121 as suggested 🤞
Faced the same issue when testing the mixtral 8X7B model (gated routing layer has been excluded) with the code of replacing layers + FSDP below:
ank2]: File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 934, in forward
[rank2]: hidden_states, router_logits = self.block_sparse_moe(hidden_states)
[rank2]: File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 861, in forward
[rank2]: current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
[rank2]: File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/transformers/models/mixtral/modeling_mixtral.py", line 797, in forward
[rank2]: current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
[rank2]: File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torchao/float8/float8_linear.py", line 360, in forward
[rank2]: input_fp8 = self.cast_input_to_float8(input, self.is_amax_initialized)
[rank2]: File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torchao/float8/float8_linear.py", line 253, in cast_input_to_float8
[rank2]: _maybe_initialize_amaxes_scales_for_float8_cast(
[rank2]: File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torchao/float8/float8_scaling_utils.py", line 119, in _maybe_initialize_amaxes_scales_for_float8_cast
[rank2]: new_amax = tensor_to_amax(x, reduce_amax=reduce_amax)
[rank2]: File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank2]: return func(*args, **kwargs)
[rank2]: File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torchao/float8/float8_utils.py", line 102, in tensor_to_amax
[rank2]: amax = torch.max(torch.abs(x))
[rank2]: RuntimeError: max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.
# Define the FSDP configuration
import functools
def custom_auto_wrap_policy(module, recurse, nonwrapped_numel):
# Define the set of layers that you want to wrap
layers_to_wrap = {MixtralDecoderLayer}
# Check if the module is in the set of layers to wrap
return type(module) in layers_to_wrap
if args.enable_fp8:
# from train_utils import patch_torch
# patch_torch()
from torchao.float8 import ( # precompute_float8_dynamic_scale_for_fsdp, # specific to fsdp2 + dynamic scaling, apply after each training loop iter
CastConfig,
Float8LinearConfig,
ScalingType,
convert_to_float8_training,
)
config = Float8LinearConfig(
# enable_amax_init=True, # only needed for autocast + compile + FSDP + float8 delayed
# enable_pre_and_post_forward=True, # only needed for autocast + compile + FSDP + float8 delayed
# enable_fsdp_float8_all_gather=True,
cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
)
# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling
# type
def module_filter_fn(mod: torch.nn.Module, fqn: str):
# don't convert the output module
if "lm_head" in fqn:
return False
# don't convert linear modules with weight dimensions not divisible by 16
if isinstance(mod, torch.nn.Linear):
if "block_sparse_moe.gate" in fqn:
print(f"Ignore router layer replacement {fqn}")
# if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False
return True
convert_to_float8_training(
model,
config=config,
module_filter_fn=module_filter_fn
)
from torchao.float8.inference import (
ActivationCasting,
Float8InferenceLinear,
QuantConfig,
quantize_to_float8,
)
quant_config = QuantConfig(ActivationCasting.DYNAMIC)
# quantize_to_float8(model, quant_config)
print(model)
torch.distributed.constants.default_pg_timeout = timedelta(seconds=7200)
fsdp_config = FSDP(
model,
auto_wrap_policy=custom_auto_wrap_policy,
sharding_strategy=ShardingStrategy.FULL_SHARD,
# backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
# state_dict_type="sharded",
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
# buffer_dtype=torch.bfloat16,
),
device_id=torch.cuda.current_device(),
use_orig_params=True,
)
Also tried uncomment the line quantize_to_float8(model, quant_config) to replace with the Float8InferenceLinear layer and got an error when wrapping this layer with FSDP: (tried a bit to modify the autocast_to_copy code but still got some other errors)
[rank4]: File "/export/home/qsong/torch_fsdp_inference.py", line 222, in main
[rank4]: fsdp_config = FSDP(
[rank4]: File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 483, in __init__
[rank4]: _auto_wrap(
[rank4]: File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/distributed/fsdp/_wrap_utils.py", line 101, in _auto_wrap
[rank4]: _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type]
[rank4]: File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 545, in _recursive_wrap
[rank4]: wrapped_child, num_wrapped_params = _recursive_wrap(
[rank4]: File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 545, in _recursive_wrap
[rank4]: wrapped_child, num_wrapped_params = _recursive_wrap(
[rank4]: File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 545, in _recursive_wrap
[rank4]: wrapped_child, num_wrapped_params = _recursive_wrap(
[rank4]: File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 563, in _recursive_wrap
[rank4]: return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
[rank4]: File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 492, in _wrap
[rank4]: return wrapper_cls(module, **kwargs)
[rank4]: File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 509, in __init__
[rank4]: _init_param_handle_from_module(
[rank4]: File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 612, in _init_param_handle_from_module
[rank4]: _move_module_to_device(
[rank4]: File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 1005, in _move_module_to_device
[rank4]: _move_states_to_device(params_to_move, bufs_to_move, device_from_device_id)
[rank4]: File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 1035, in _move_states_to_device
[rank4]: param.data = param.to(device_from_device_id)
[rank4]: File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torchao/float8/float8_tensor.py", line 359, in __torch_dispatch__
[rank4]: return FLOAT8_OPS_TABLE[func](func, args, kwargs)
[rank4]: File "/export/home/qsong/miniconda3/envs/local_pytorch_env/lib/python3.10/site-packages/torchao/float8/float8_ops.py", line 244, in autocast_to_copy
[rank4]: len(kwargs) == 1 and "dtype" in kwargs
[rank4]: AssertionError: Only support dtype kwarg for autocast
@qingquansong thanks! do you have a minimal repro so we can take a look?
Let me create a mini mixtral model with some synthetic data.
Update:
- for the first issue, it is caused by some NAN values results in some mixtral experts cannot get tokens. And something weird is that FSDP loading itself seems have some issues on the loaded weights (even without FP8 layer) and behave differently with different wrapping policies and likely it's related to the reduction precision I set to be bfloat16 +
sync_module_statesneed to set to beTrue. I'll need to debug a bit more on this and currently can confirm if the weight loading is correct and expert can access at least 1 token, this error should be resolved. Some extra problem is:
[rank7]: ValueError: The module has CPU parameters or buffers when sync_module_states=True, which requires them to be on GPU. Please specify the device_id argument or move the module to GPU before passing it to FSDP. where it seems the original model weights parameters is still allocated on CPU if we enable FP8 convert_to_float8_training conversion thus causing a bit issue, I temporarily set sync_module_states=True and enable_pre_and_post_forward=False to avoid this issue but not sure if this is the correct way.
Some other thing I'm not sure is if I just wanna do inference how should I set the following 6 args + the FSDP args? The speed seems to slow down with the FP8 layer in this case and memory is also not reduced much as expected. Setting input config to DYNAMIC seems to make things faster but still comparable with bf16 for mixtral 8*7B
config = Float8LinearConfig(
# enable_amax_init=True, # only needed for autocast + compile + FSDP + float8 delayed
# enable_pre_and_post_forward=False, # only needed for autocast + compile + FSDP + float8 delayed
# enable_fsdp_float8_all_gather=False,
cast_config_input=CastConfig(scaling_type=ScalingType.DYNAMIC),
cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
)
- For the second Float8InferenceLinear issue, it can be reproduced with the following codes with running command below. (torch and ao are both latest version with cu121 and using H100 X 8 to test, model config can be set to smaller to make things faster) I commented out this line
# quantize_to_float8(model, quant_config)so this script should be able to run smoothly but if commenting out, it will raise issues with the Float8InferenceLinear layer.
ACCELERATE_USE_FSDP=1 FSDP_CPU_RAM_EFFICIENT_LOADING=1 torchrun --nnodes=1 --nproc-per-node=8 torch_fsdp_inference_mini.py \
--batch_size 16 \
--enable_fp8
**save this script in torch_fsdp_inference_mini.py **
import os
from datetime import timedelta
import argparse
from dataclasses import _MISSING_TYPE, dataclass
import torch
import torch.distributed as dist
from config import parse_args
from torch.distributed.fsdp import BackwardPrefetch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
import numpy as np
from torch.utils.data import Dataset, DataLoader
class SyntheticDataset(Dataset):
def __init__(self, num_samples, max_length):
self.num_samples = num_samples
self.max_length = max_length
self.input_ids = np.random.randint(0, num_samples, (num_samples, max_length))
self.attention_mask = np.ones((num_samples, max_length), dtype=np.int32)
self.labels = self.input_ids
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
return {
"input_ids": self.input_ids[idx],
"attention_mask": self.attention_mask[idx],
"labels": self.labels[idx]
}
def get_distributed_dataloader(
batch_size, shuffle=True
):
dataset = SyntheticDataset(num_samples=512, max_length=4096)
sampler = DistributedSampler(
dataset,
shuffle=shuffle,
)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
)
return dataloader
def configure_model():
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
mini_model_config=MixtralConfig(
attention_dropout=0.0,
bos_token_id=1,
eos_token_id=2,
hidden_act="silu",
hidden_size= 4096,
initializer_range=0.02,
intermediate_size=14336,
max_position_embeddings=32768,
num_attention_heads=32,
num_experts_per_tok=2,
num_hidden_layers=1,
num_key_value_heads=8,
num_local_experts=8,
output_router_logits=False,
rms_norm_eps=1e-5,
rope_theta=1000000.0,
router_aux_loss_coef=0.02,
sliding_window=None,
tie_word_embeddings=False,
use_cache=True,
vocab_size=32000,
# At rope backward
# Eager produces incontiguous dq and dk
# SDPA produces contiguous dq and incontiguous dk
# Flash_attn produces contiguous dq and dk
attn_implementation="sdpa", # default value, pytorch native attention
)
return MixtralForCausalLM(mini_model_config).to(dtype=torch.float16)
def cleanup():
dist.destroy_process_group()
def run_inference(model, dataloader, device):
num_correct = 0
num_total = 0
with torch.no_grad():
for batch in tqdm(
dataloader, desc=f"Processing batches on rank {dist.get_rank()}"
):
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch[
"labels"
],
)
shift_logits = outputs.logits[..., :-1, :].contiguous()
shift_labels = batch["labels"][..., 1:].contiguous()
mask = shift_labels != -100
correct = (shift_logits.argmax(dim=-1) == shift_labels) & mask
num_correct += correct.sum().item()
num_total += mask.sum().item()
accuracy = num_correct / num_total
print(f"Final prediction accuracy: {accuracy}")
return accuracy
@dataclass
class TrainingArgs:
enable_fp8: bool = False
batch_size: int = 8
def parse_args() -> TrainingArgs:
parser = argparse.ArgumentParser()
for k, v in TrainingArgs.__dataclass_fields__.items():
if v.type != bool:
parser.add_argument(f"--{k}", type=v.type, default=v.default)
else:
if not v.default:
parser.add_argument(f"--{k}", action="store_true")
else:
parser.add_argument(f"--{k}", action="store_false")
parsed = parser.parse_args()
return TrainingArgs(
**{k: v for k, v in vars(parsed).items() if not isinstance(v, _MISSING_TYPE)}
)
def main():
args = parse_args()
dist.init_process_group("nccl")
local_rank = int(os.environ["LOCAL_RANK"])
global_rank = int(os.environ["RANK"])
torch.manual_seed(42)
val_dataloader = get_distributed_dataloader(
args.batch_size,
)
# Initialize and configure the model
model = configure_model()
# Set device and run inference
torch.cuda.set_device(local_rank)
torch.cuda.empty_cache()
device = "cuda:" + str(local_rank)
# Define the FSDP configuration
def custom_auto_wrap_policy(module, recurse, nonwrapped_numel):
# Define the set of layers that you want to wrap
layers_to_wrap = {MixtralDecoderLayer}
# Check if the module is in the set of layers to wrap
return type(module) in layers_to_wrap
if args.enable_fp8:
from train_utils import patch_torch
patch_torch()
from torchao.float8 import ( # precompute_float8_dynamic_scale_for_fsdp, # specific to fsdp2 + dynamic scaling, apply after each training loop iter
CastConfig,
Float8LinearConfig,
ScalingType,
convert_to_float8_training,
)
config = Float8LinearConfig(
# enable_amax_init=True, # only needed for autocast + compile + FSDP + float8 delayed
# enable_pre_and_post_forward=True, # only needed for autocast + compile + FSDP + float8 delayed
# enable_fsdp_float8_all_gather=True,
cast_config_input=CastConfig(scaling_type=ScalingType.DYNAMIC),
cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
)
# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling
# type
def module_filter_fn(mod: torch.nn.Module, fqn: str):
# don't convert the output module
if "lm_head" in fqn:
return False
# don't convert linear modules with weight dimensions not divisible by 16
if isinstance(mod, torch.nn.Linear):
if "block_sparse_moe.gate" in fqn:
print(f"Ignore router layer replacement {fqn}")
# if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
return False
return True
convert_to_float8_training(
model,
config=config,
module_filter_fn=module_filter_fn
)
from torchao.float8.inference import (
ActivationCasting,
Float8InferenceLinear,
QuantConfig,
quantize_to_float8,
)
quant_config = QuantConfig(ActivationCasting.DYNAMIC)
# quantize_to_float8(model, quant_config)
torch.distributed.constants.default_pg_timeout = timedelta(seconds=7200)
fsdp_config = FSDP(
model,
auto_wrap_policy=custom_auto_wrap_policy,
sharding_strategy=ShardingStrategy.FULL_SHARD,
# backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
# state_dict_type="sharded",
sync_module_states=True,
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
# reduce_dtype=torch.bfloat16,
# buffer_dtype=torch.bfloat16,
),
device_id=torch.cuda.current_device(),
use_orig_params=True,
)
# inference and record the time
init_start_event = torch.cuda.Event(enable_timing=True)
init_end_event = torch.cuda.Event(enable_timing=True)
init_start_event.record()
run_inference(fsdp_config, val_dataloader, device)
init_end_event.record()
torch.cuda.synchronize()
if global_rank == 0:
print(
f"CUDA event elapsed time: {init_start_event.elapsed_time(init_end_event) / 1000}sec"
)
print(f"{model}")
# Clean up
cleanup()
if __name__ == "__main__":
main()
For the speed / memory issue, I guess it related to not using torch compile based on the related tickets:
#685 [FP8] performance degradation in speed and memory without compile
I'll check if I can use torch compile here. Thanks.
Currently it's a bit blocked on the torch compile + Mixtral. [The context of using torch.compile is that it seems it's required to combine with fp8 linear to help improve the speed as discussed in some threads:
https://github.com/pytorch/ao/issues/685
https://github.com/pytorch/torchtitan/issues/462#issuecomment-2284567110
Huggingface Mixtral model does not directly support torch compile as stated here mainly due to the sparse moe with torch where causing the dynamic token numbers in routing to different experts and is also an ongoing efforts here: https://github.com/huggingface/transformers/pull/30793
I've tried the option in gpu-fast (similar as the above pr change to convert to a fused moe) but it's more suitable for fast text generation phase with small batch size and would have high memory consumption for large batch size prefill stage. Also could break the nature of linear layers to replace with fp8linear directly.
I put some of my raw test scripts here https://github.com/qingquansong/fp8_fsdp_test in case anyone is interested. Sorry that didn't change the model and data local paths.