Shibo Jie
Shibo Jie
> The problem may come with DDP settings. The PyTorch DDP [notes about **Backward Pass**](https://pytorch.org/docs/master/notes/ddp.html) “so after the backward pass, the grad field on the same corresponding parameter across different...
@nbasyl PyTorch 2.0.1 ```python3 from torch.distributed.algorithms.join import ( Join, Joinable, JoinHook, ) from torch.distributed.utils import ( _verify_param_shape_across_processes, _sync_module_states, _to_kwargs, ) from torch.nn.parallel.distributed import _find_tensors, _tree_flatten_with_rref, _DDPSink, _tree_unflatten_with_rref def ddp_forward(self, *inputs,...