[BUG] Outputs of type NamedTuple cause crash in `_apply_to_tensors_only` (stage 3 + shard parameters)
Hi DeepSpeed team 👋 Thank you for this amazing library, it's been great to learn from and implement in our project 🙇 In my quest to start training large models with deepspeed I stumbled upon a incompatibility with our existing modules.
I have since patched this behavior (monkeypatch) and am willing to commit this back to the repo, if you'll have me. 🤗
Describe the bug
When the output of a module is a NamedTuple _apply_to_tensors_only fails because NamedTuples cannot be instantiated with outputs.__class__(touched_outputs). This occurs only when parameters are sharded in DeepSpeed stage 3.
To Reproduce
Install: pip install deepspeed pytorch_lightning.
from typing import NamedTuple
import torch
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from torch import Tensor
from torch.nn import Linear
from torch.optim import Optimizer
from torch.utils.data import DataLoader, Dataset
class MyOutputs(NamedTuple):
linear_output: Tensor
metadata: str
class MyModule(Linear):
# This is an example Layer where more than just 1 Tensor is outputted.
# Outputting Tuples/Lists/Dicts is already supported but are less strongly typed than NamedTuple.
def forward(self, input: Tensor) -> MyOutputs:
linear_output = super().forward(input)
return MyOutputs(linear_output=linear_output, metadata='my_layer_output')
class SimpleDataset(Dataset):
"""Minimalistic Dataset for debugging."""
def __init__(self, size: int, data_dim: int = 32) -> None:
self.size = size
self.data = torch.randn(size, data_dim)
def __getitem__(self, index: int) -> torch.Tensor:
return self.data[index]
def __len__(self) -> int:
return self.size
class SimpleDataModule(LightningDataModule):
"""Minimalistic LightningDataModule for debugging."""
def __init__(self, dataset: SimpleDataset, num_workers: int) -> None:
super().__init__()
self.val_dataset = self.train_dataset = self.predict_dataset = dataset
self.num_workers = num_workers
def train_dataloader(self) -> DataLoader:
return DataLoader(dataset=self.train_dataset, num_workers=self.num_workers)
def val_dataloader(self) -> DataLoader:
return DataLoader(dataset=self.val_dataset, num_workers=self.num_workers)
def predict_dataloader(self) -> DataLoader:
return DataLoader(dataset=self.predict_dataset, num_workers=self.num_workers)
class SimpleModel(LightningModule):
"""Minimalistic LightningModule for debugging."""
def __init__(self, *, input_dim: int, output_dim: int) -> None:
super().__init__()
self.input_dim = input_dim
self.layer = MyModule(input_dim, output_dim)
def forward(self, x: Tensor) -> MyOutputs:
output: MyOutputs = self.layer(x)
return output
def training_step(self, batch: Tensor, batch_idx: int) -> dict[str, Tensor]:
loss: Tensor = self(batch).linear_output.mean()
self.log('train_loss', loss)
return {'loss': loss}
def validation_step(self, batch: Tensor, batch_idx: int) -> None:
loss: Tensor = self(batch).linear_output.mean()
self.log('val_loss', loss)
def test_step(self, batch: Tensor, batch_idx: int) -> None:
loss: Tensor = self(batch).linear_output.mean()
self.log('test_loss', loss)
def configure_optimizers(self) -> Optimizer:
return torch.optim.SGD(self.layer.parameters(), lr=1e-4)
def main() -> None:
model = SimpleModel(input_dim=32, output_dim=2)
datamodule = SimpleDataModule(dataset=SimpleDataset(size=64, data_dim=32), num_workers=0)
trainer = Trainer(
fast_dev_run=True,
accelerator='gpu',
strategy='deepspeed_stage_3',
devices=-1,
)
trainer.fit(model=model, datamodule=datamodule)
if __name__ == '__main__':
main()
Outputs
File ".../torch/nn/modules/module.py", line 1212, in _call_impl
result = forward_call(*input, **kwargs)
File ".../test.py", line 64, in forward
output: MyOutputs = self.layer(x)
File ".../torch/nn/modules/module.py", line 1215, in _call_impl
hook_result = hook(self, input, result)
File ".../deepspeed/runtime/zero/parameter_offload.py", line 409, in _pre_backward_module_hook
return _apply_to_tensors_only(module,
File ".../deepspeed/runtime/zero/parameter_offload.py", line 39, in _apply_to_tensors_only
return outputs.__class__(touched_outputs)
TypeError: MyOutputs.__new__() missing 1 required positional argument: 'metadata'
Expected behavior
No TypError is thrown.
Potential solution
def isinstance_namedtuple(obj: object) -> bool:
return isinstance(obj, tuple) and hasattr(obj, '_asdict') and hasattr(obj, '_fields')
def _apply_to_tensors_only(module, functional, backward_function, outputs):
if isinstance(outputs, (tuple, list)):
touched_outputs = []
for output in outputs:
touched_output = patched_apply_to_tensors_only(
module, functional, backward_function, output
)
touched_outputs.append(touched_output)
# ----- BEING PATCH ----- #
if isinstance_namedtuple(outputs):
return outputs.__class__(*touched_outputs)
# ----- END PATCH ----- #
return outputs.__class__(touched_outputs)
...
ds_report output
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
runtime if needed. Op compatibility means that your system
meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
async_io ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
[WARNING] please install triton==1.0.0 if you want to use sparse attention
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
utils .................. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/usr/local/lib/python3.10/dist-packages/torch']
torch version .................... 1.13.1+cu116
deepspeed install path ........... ['.../env/lib/python3.10/site-packages/deepspeed']
deepspeed info ................... 0.8.2, unknown, unknown
torch cuda version ............... 11.6
torch hip version ................ None
nvcc version ..................... 11.6
deepspeed wheel compiled w. ...... torch 1.13, cuda 11.6
System info (please complete the following information):
- OS: Ubuntu 20.04
- GPU count and types: N/A (happens on any configuration, tested on 1x-8x A100)
- Interconnects (if applicable): N/A
- Python version: 3.10
- Any other relevant info about your setup: N/A
Launcher context
Are you launching your experiment with the deepspeed launcher, MPI, or something else?
Yes, but can be reproduced without.
Docker context
Are you using a specific docker image that you can share?
Yes, a docker image based off of nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04.
@AlexanderVanEck, this is awesome!!! I think we would appreciate this PR.
@stas00, @jeffra FYI
Totally! I think NamedTuple is generic enough to have a built-in support. I wonder how many other output types haven't been considered.
And Tunji, if you remember we had this discussion a month ago. We were talking about giving users a way to register a special method for classes that aren't Tensors so that this function could automatically extract the tensors it needs to process.
Yes, indeed. Seems to be a prophetic discussion in hindsight :).
@AlexanderVanEck, hope you don't mind me assigning this you?
Thank you @tjruwase & @stas00 🙌
I've opened a PR for the suggested change.
@AlexanderVanEck, your PR is now merged. Can you please verify that this issue is resolved?
🎉 Thank you @tjruwase
I'll wait for a new release to go out before I report back. We've been running with this fix for a while so I don't expect any surprises.
@AlexanderVanEck, I wanted to check if releases >= 0.9.0 are working as expected. Thanks!
Closing since merged PR is available in new release. Please re-open if needed.