pytorch-lightning
pytorch-lightning copied to clipboard
TransformerEnginePrecision _convert_layers(module) fails for FSDP zero2/zero3
Bug description
TransformerEnginePrecision.convert_module function seems to not work for the the FSDP-wrapped model.
What version are you seeing the problem on?
master
How to reproduce the bug
model = FSDP(
model,
sharding_strategy=sharding_strategy,
auto_wrap_policy=custom_wrap_policy,
device_id=local_rank,
use_orig_params=True,
device_mesh=mesh,
)
te_precision = TransformerEnginePrecision(weights_dtype=torch.bfloat16, replace_layers=True)
self.model = te_precision.convert_module(self.model)
Error messages and logs
[rank1]: self.model = te_precision.convert_module(self.model)
[rank1]: _convert_layers(module)
[rank1]: File "/usr/local/lib/python3.10/dist-packages/lightning/fabric/plugins/precision/transformer_engine.py", line 165, in _convert_layers
[rank1]: replacement.weight.data = child.weight.data.clone()
[rank1]: RuntimeError: Attempted to call `variable.set_data(tensor)`, but `variable` and `tensor` have incompatible tensor type.
More info
I actually see it for pytorch-lightning==2.3.0