lightning-thunder
lightning-thunder copied to clipboard
RuntimeError: Attempted to call `variable.set_data(tensor)`, but `variable` and `tensor` have incompatible tensor type.
🐛 Bug
There is an error when running fp8 for FSDP in benchmark_litgpt.py , which was also reported in the pytorch-lightning repo
To Reproduce
Steps to reproduce the behavior:
-
Start container by running:
docker run --pull=always --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 -it INTERNAL_IMAGE:pjnl-20240724 -
Execute:
torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) thunder/benchmarks/benchmark_litgpt.py --max_iters 20 --warmup_iters 5 --model_name dolly-v2-3b --distributed_mode fsdp --shard_mode zero3 --compile eager --checkpoint_activations False --low_precision_mode fp8-delayed-te
- Notice an error:
[rank0]: Traceback (most recent call last):
...
[rank0]: File "/mnt/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py", line 647, in benchmark_main
[rank0]: benchmark = Benchmark_litGPT(**kwargs)
[rank0]: File "/mnt/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py", line 281, in __init__
[rank0]: self.model = te_precision.convert_module(self.model)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/lightning/fabric/plugins/precision/transformer_engine.py", line 104, in convert_module
[rank0]: _convert_layers(module)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/lightning/fabric/plugins/precision/transformer_engine.py", line 165, in _convert_layers
[rank0]: replacement.weight.data = child.weight.data.clone()
[rank0]: RuntimeError: Attempted to call `variable.set_data(tensor)`, but `variable` and `tensor` have incompatible tensor type.
This comes from the fabric function here.
I did some investigation and I think this is due to the model being placed at meta device causing that to happen.
I created a function similar to the currently used fabric TransformerEnginePrecision wrapper (with additional flag to check whether we swap LayerNorm layers or not as a nice addition to benchmark w/ and w/o LayerNorm cases):
def swap_linear_layers_for_te(model: nn.Module, device: torch.device, swap_layernorm: bool = True) -> None:
def parameters_cnt(model: nn.Module) -> int:
return sum(p.numel() for p in model.parameters())
def _resursively_swap_linear_layers_for_te(module: nn.Module) -> None:
for n, m in module.named_children():
if len(list(m.children())) > 0:
_resursively_swap_linear_layers_for_te(m)
if isinstance(m, nn.Linear):
bias_flag = m.bias is not None
new_linear = te.Linear(
m.in_features, m.out_features, bias=bias_flag
)
new_linear.weight.data = m.weight.data.clone()
if bias_flag:
new_linear.bias.data = m.bias.data.clone()
setattr(module, n, new_linear)
if swap_layernorm and isinstance(m, nn.LayerNorm):
new_layernorm = te.LayerNorm(
m.normalized_shape[0], eps=m.eps
)
new_layernorm.weight.data = m.weight.data.clone()
new_layernorm.bias.data = m.bias.data.clone()
setattr(module, n, new_layernorm)
initial_params_cnt = parameters_cnt(model)
# Check if the model's parameters are meta tensors
if any(p.is_meta for p in model.parameters()):
model.to_empty(device=device)
_resursively_swap_linear_layers_for_te(model)
assert initial_params_cnt == parameters_cnt(model)
for m in model.modules():
assert not isinstance(m, nn.Linear)
if swap_layernorm:
assert not isinstance(m, nn.LayerNorm)
Note the snippet in the function:
# Check if the model's parameters are meta tensors
if any(p.is_meta for p in model.parameters()):
model.to_empty(device=device)
which seems to be the spot that fixes the issue. Not sure if it is wise to do it, so I am happy to learn from others. Then the lines in benchmark_litgpt.py in :
te_precision = TransformerEnginePrecision(weights_dtype=torch.bfloat16, replace_layers=True)
self.model = te_precision.convert_module(self.model)
are replaced with:
swap_linear_layers_for_te(self.model, device, swap_layernorm=(not self.low_precision_mode == 'fp8-delayed-te-wo_layernorm'))
self.model.to(torch.bfloat16)
Happy to share more details if necessary.
Expected behavior
The code should execute in FSDP
Environment
As in the container (but should be reproducible in any case)
Additional context
Feel free to reach me here or ping me on Slack. Also happy to learn what @IvanYashchuk thinks about this. Maybe we could entirely replace the wrapper with the above function, if it is a correct implementation?
Another way would be just to skip cloning weights when the module is on meta device:
def swap_linear_layers_for_te(model: nn.Module, swap_layernorm: bool = True) -> None:
def parameters_cnt(model: nn.Module) -> int:
return sum(p.numel() for p in model.parameters())
def is_meta_module(module: nn.Module) -> bool:
for param in module.parameters():
if param.is_meta:
return True
return False
def _resursively_swap_linear_layers_for_te(module: nn.Module) -> None:
for n, m in module.named_children():
if len(list(m.children())) > 0:
_resursively_swap_linear_layers_for_te(m)
if isinstance(m, nn.Linear):
bias_flag = m.bias is not None
new_linear = te.Linear(
m.in_features, m.out_features, bias=bias_flag
)
if not is_meta_module(m):
new_linear.weight.data = m.weight.data.clone()
if bias_flag:
new_linear.bias.data = m.bias.data.clone()
setattr(module, n, new_linear)
if swap_layernorm and isinstance(m, nn.LayerNorm):
new_layernorm = te.LayerNorm(
m.normalized_shape[0], eps=m.eps
)
if not is_meta_module(m):
new_layernorm.weight.data = m.weight.data.clone()
new_layernorm.bias.data = m.bias.data.clone()
setattr(module, n, new_layernorm)
initial_params_cnt = parameters_cnt(model)
_resursively_swap_linear_layers_for_te(model)
assert initial_params_cnt == parameters_cnt(model)
for m in model.modules():
assert not isinstance(m, nn.Linear)
if swap_layernorm:
assert not isinstance(m, nn.LayerNorm)
@lantiga could you please assign someone to look into this issue?
The error sounds like either device or dtype don't match between the two. Maybe you could insert a debug print to check what's going on. Note that this is not in the thunder branch, so we need a fabric expert, maybe @awaelchli has an immediate idea?
@csarofeen yes thanks for flagging
Hello @t-vi ! Yes, please note my comment:
I did some investigation and I think this is due to the model being placed at
metadevice causing that to happen.
Which is the case in benchmark_litgpt.py script. Because meta does not load weights yet, it does break on cloning lines, i.e:
new_linear.weight.data = m.weight.data.clone()
Therefore, I proposed two ways to solve this.
- The first way:
if any(p.is_meta for p in model.parameters()):
model.to_empty(device=device)
So we use to_empty function if module is on meta and then the cloning does not break.
- We skip cloning step altogether:
if isinstance(m, nn.Linear):
bias_flag = m.bias is not None
new_linear = te.Linear(
m.in_features, m.out_features, bias=bias_flag
)
if not is_meta_module(m):
new_linear.weight.data = m.weight.data.clone()
if bias_flag:
new_linear.bias.data = m.bias.data.clone()
setattr(module, n, new_linear)
The third way I see is that we move the devices out of meta before this step in benchmark_litgpt.py, but I don't know about the implications of such move. Note that swapping is performed before setting up distributed training (FSDP/DDP).
In my opinion, the TransformerEnginePrecision wrapper should handle cases when the module is stored on meta, but I am happy to learn what others think about this.
triage review:
- this is a lightning fabric thing, not a thunder thing
- @t-vi to talk to @awaelchli about what's going on here
Here is my complete workaround that we used in our last runs:
def swap_linear_layers_for_te(model: nn.Module, swap_layernorm: bool = True, device: str = "meta") -> None:
def parameters_cnt(model: nn.Module) -> int:
return sum(p.numel() for p in model.parameters())
def _resursively_swap_linear_layers_for_te(module: nn.Module) -> None:
for n, m in module.named_children():
if len(list(m.children())) > 0:
_resursively_swap_linear_layers_for_te(m)
if isinstance(m, nn.Linear):
bias_flag = m.bias is not None
new_linear = te.Linear(
m.in_features, m.out_features, bias=bias_flag, device=device
)
setattr(module, n, new_linear)
if swap_layernorm and isinstance(m, nn.LayerNorm):
new_layernorm = te.LayerNorm(
m.normalized_shape[0], eps=m.eps, device=device
)
setattr(module, n, new_layernorm)
initial_params_cnt = parameters_cnt(model)
_resursively_swap_linear_layers_for_te(model)
assert initial_params_cnt == parameters_cnt(model)
for m in model.modules():
assert not isinstance(m, nn.Linear)
if swap_layernorm:
assert not isinstance(m, nn.LayerNorm)
Can I make a PR for this instead of the current fabric solution?
Can I make a PR for this instead of the current fabric solution?
Yes, that would be excellent, thank you, Wojciech!
To chime in more: I think that it would be good to go with the approach @wprazuch suggested. (Checked with Adrian offline.)
Thanks @t-vi ! The PR is ready for review.