adding DDP/FSDP transform after JITting does not work
🐛 Bug
The snippet below looks hacky, but it's how I'm approaching support for having the user control the thunder.jit call outside of Fabric: https://github.com/Lightning-AI/litgpt/pull/1204
The objective is that fsdp|ddp can be applied after the thunder.jit call.
It works with FSDP, but not with DDP where it fails with:
[rank1]: Traceback (most recent call last):
[rank1]: File "/home/carlos/lightning-thunder/kk.py", line 21, in <module>
[rank1]: out = tmodel(x)
[rank1]: File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "/home/carlos/lightning-thunder/thunder/__init__.py", line 194, in forward
[rank1]: res = self._forward_fn(*args, **kwargs)
[rank1]: File "/home/carlos/lightning-thunder/thunder/__init__.py", line 629, in fn_
[rank1]: cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
[rank1]: File "/home/carlos/lightning-thunder/thunder/__init__.py", line 262, in cache_info_wrapper
[rank1]: res = fn(*args, **kwargs)
[rank1]: File "/home/carlos/lightning-thunder/thunder/__init__.py", line 571, in get_computation_and_inputs
[rank1]: computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps)
[rank1]: File "/home/carlos/lightning-thunder/thunder/executors/torch_autograd.py", line 283, in split_forward_backward
[rank1]: bw_trace = optimize_allreduce_in_ddp_backward(bw_trace, compile_data)
[rank1]: File "/home/carlos/lightning-thunder/thunder/distributed/transforms/ddp.py", line 198, in optimize_allreduce_in_ddp_backward
[rank1]: updated_bwd_trace = visitor_transform(
[rank1]: File "/home/carlos/lightning-thunder/thunder/core/transforms.py", line 368, in visitor_transform
[rank1]: visit_type = visit(bsym)
[rank1]: File "/home/carlos/lightning-thunder/thunder/distributed/transforms/ddp.py", line 133, in __call__
[rank1]: self.gradient_buckets.tell(grads_of_bsym[0], self.process_group)
[rank1]: File "/home/carlos/lightning-thunder/thunder/distributed/bucketing.py", line 150, in tell
[rank1]: self._maybe_allreduce(bucket, group)
[rank1]: File "/home/carlos/lightning-thunder/thunder/distributed/bucketing.py", line 138, in _maybe_allreduce
[rank1]: self.bucket_to_future[bucket] = dist_prims.all_reduce(
[rank1]: File "/home/carlos/lightning-thunder/thunder/core/symbol.py", line 246, in __call__
[rank1]: result = self.meta(*args, **kwargs)
[rank1]: File "/home/carlos/lightning-thunder/thunder/core/langctxs.py", line 124, in _fn
[rank1]: result = fn(*args, **kwargs)
[rank1]: File "/home/carlos/lightning-thunder/thunder/distributed/prims.py", line 87, in all_reduce_meta
[rank1]: utils.check_type(group, torch.distributed.ProcessGroup)
[rank1]: File "/home/carlos/lightning-thunder/thunder/core/baseutils.py", line 107, in check_type
[rank1]: check(
[rank1]: File "/home/carlos/lightning-thunder/thunder/core/baseutils.py", line 103, in check
[rank1]: raise exception_type(s())
[rank1]: ValueError: None had an unexpected type <class 'NoneType'>. Supported types are <class 'torch.distributed.distributed_c10d.ProcessGroup'>
To Reproduce
import os
import thunder
import torch
import torch.distributed as torch_dist
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
global_rank = int(os.environ.get("RANK", 0))
if world_size > 1:
torch_dist.init_process_group(backend="nccl")
pg = torch_dist.distributed_c10d._get_default_group()
device = torch.device("cuda", local_rank)
torch.cuda.set_device(device)
model = torch.nn.Linear(5, 10, bias=False, device=device)
x = torch.randn(2, 5, device=device)
tmodel = thunder.jit(model)
tmodel._lc_cd.fn = thunder.distributed.ddp(tmodel._lc_cd.fn)
out = tmodel(x)
if local_rank == 0:
print(thunder.last_backward_traces(tmodel)[-1].python())
torchrun --nproc-per-node 2 bug.py
cc @kshitij12345 since you fixed a similar issue in #23
I can work around this by setting
tmodel._lc_cd.process_group_for_ddp = tmodel._lc_cd.fn.process_group_for_ddp
since thunder gets this information at jit() time: https://github.com/Lightning-AI/lightning-thunder/blob/94c94948b79875ba5247b5c986afa088d970a49d/thunder/common.py#L224-L226
So my question is: could we delay accessing this attribute until the function runs?
TBH, this is a very clear "don't do this, chaning the fn is completely unsupported!".
That said, we can talk about distributed-after-jit. The obstacles are:
- Currently the ddp transformation is applied during the JITing (i.e. while the interpreter runs). This is fundamentally incompatible with what you're trying to do.
- The syncs inserted can do funny things with tensor shapes, so applying this after the prologue is generated (i.e. after "jit" will have us needing to transform the prologue - this is on our roadmap for other transforms, but not entirely trivial).
I'll chat you up for understanding the need better.
triage review:
let's look at our current transforms, when they have to be applied, and what they mean when ordered after each other (are all orders supported)? For example, ddp jit grad? jit grad ddp?
do we need to support transforms that change the original module, or maybe produce a new module?
Currently the ddp transformation is applied during the JITing (i.e. while the interpreter runs). This is fundamentally incompatible with what you're trying to do.
I would appreciate some pointers or examples here that show this, because in my test, the trace does look correct as it contains the appropriate collectives added.
I'm probably misunderstanding how the interpreter works. How can the prologues be generated at jit time if we don't have any input tensors for which to check shapes? I thought this would only happen on the first call
FSDP and DDP calls are not trace transforms, they are parameter annotators of the original to-be-jitted PyTorch module.
- We should raise an error when ThunderModule is passed to FSDP or DDP calls suggesting the correct thing
- Why do you expect
ddp(jit(model))to work and is it more important to support thanjit(ddp(model))?
We still need to support jit(ddp(model)), as this is basically what happens whenever you jit a function and not the model.
What I'm advocating for is something like jit(ddp(undo_jit(jit(model)))
Where undo_jit is currently the hack that I describe in the top-post.
Allowing this is convenient because then the user can control the innermost jit(model) call but the framework (fabric, trainer) can control the transforms applied to the model and how they interact with each other if there are more than one.
I know nothing about Lightning. Do you want to allow users to do jit(model) and then inside Lightning, you apply either DDP or FSDP call to a given model? FSDP is now working, right? You need something that unwraps the jit call. Have you tried using __wrapped__? thunder.jit uses functools.wraps here: https://github.com/Lightning-AI/lightning-thunder/blob/6c64fb93b04672b731180afc4b63d5df55dae92f/thunder/init.py#L642
One thing (probably tangential) I was wondering, why is process_group_for_ddp an attribute for CompileData?
https://github.com/Lightning-AI/lightning-thunder/blob/6c64fb93b04672b731180afc4b63d5df55dae92f/thunder/common.py#L221-L223
I think it would make sense to make it a property. Cause, if a scenario comes where we have to update CompileData.fn, then we might miss updating these corresponding attributes. (This change could potentially also fix the issue)
diff --git a/thunder/common.py b/thunder/common.py
index 85775ff..24cabcb 100644
--- a/thunder/common.py
+++ b/thunder/common.py
@@ -218,10 +218,6 @@ class CompileData:
self.is_module = isinstance(self.fn, torch.nn.Module)
- # We set the process_group_for_ddp attribute on the module when
- # thunder.distributed.ddp(module) is called.
- self.process_group_for_ddp = getattr(self.fn, "process_group_for_ddp", None)
-
#
# Possibly processes the function
#
@@ -232,6 +228,12 @@ class CompileData:
assert disable_preprocessing, "please use thunder.compile if you need preprocessing"
+ @property
+ def process_group_for_ddp(self):
+ # We set the process_group_for_ddp attribute on the module when
+ # thunder.distributed.ddp(module) is called.
+ return getattr(self.fn, "process_group_for_ddp", None)
+
I think the new fsdp/ddp actually do this.