lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

adding DDP/FSDP transform after JITting does not work

Open carmocca opened this issue 1 year ago • 2 comments

🐛 Bug

The snippet below looks hacky, but it's how I'm approaching support for having the user control the thunder.jit call outside of Fabric: https://github.com/Lightning-AI/litgpt/pull/1204

The objective is that fsdp|ddp can be applied after the thunder.jit call.

It works with FSDP, but not with DDP where it fails with:

[rank1]: Traceback (most recent call last):
[rank1]:   File "/home/carlos/lightning-thunder/kk.py", line 21, in <module>
[rank1]:     out = tmodel(x)
[rank1]:   File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/__init__.py", line 194, in forward
[rank1]:     res = self._forward_fn(*args, **kwargs)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/__init__.py", line 629, in fn_
[rank1]:     cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/__init__.py", line 262, in cache_info_wrapper
[rank1]:     res = fn(*args, **kwargs)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/__init__.py", line 571, in get_computation_and_inputs
[rank1]:     computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/executors/torch_autograd.py", line 283, in split_forward_backward
[rank1]:     bw_trace = optimize_allreduce_in_ddp_backward(bw_trace, compile_data)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/distributed/transforms/ddp.py", line 198, in optimize_allreduce_in_ddp_backward
[rank1]:     updated_bwd_trace = visitor_transform(
[rank1]:   File "/home/carlos/lightning-thunder/thunder/core/transforms.py", line 368, in visitor_transform
[rank1]:     visit_type = visit(bsym)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/distributed/transforms/ddp.py", line 133, in __call__
[rank1]:     self.gradient_buckets.tell(grads_of_bsym[0], self.process_group)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/distributed/bucketing.py", line 150, in tell
[rank1]:     self._maybe_allreduce(bucket, group)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/distributed/bucketing.py", line 138, in _maybe_allreduce
[rank1]:     self.bucket_to_future[bucket] = dist_prims.all_reduce(
[rank1]:   File "/home/carlos/lightning-thunder/thunder/core/symbol.py", line 246, in __call__
[rank1]:     result = self.meta(*args, **kwargs)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/core/langctxs.py", line 124, in _fn
[rank1]:     result = fn(*args, **kwargs)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/distributed/prims.py", line 87, in all_reduce_meta
[rank1]:     utils.check_type(group, torch.distributed.ProcessGroup)
[rank1]:   File "/home/carlos/lightning-thunder/thunder/core/baseutils.py", line 107, in check_type
[rank1]:     check(
[rank1]:   File "/home/carlos/lightning-thunder/thunder/core/baseutils.py", line 103, in check
[rank1]:     raise exception_type(s())
[rank1]: ValueError: None had an unexpected type <class 'NoneType'>. Supported types are <class 'torch.distributed.distributed_c10d.ProcessGroup'>

To Reproduce

import os
import thunder
import torch
import torch.distributed as torch_dist

world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
global_rank = int(os.environ.get("RANK", 0))
if world_size > 1:
    torch_dist.init_process_group(backend="nccl")
    pg = torch_dist.distributed_c10d._get_default_group()
device = torch.device("cuda", local_rank)
torch.cuda.set_device(device)

model = torch.nn.Linear(5, 10, bias=False, device=device)
x = torch.randn(2, 5, device=device)

tmodel = thunder.jit(model)
tmodel._lc_cd.fn = thunder.distributed.ddp(tmodel._lc_cd.fn)

out = tmodel(x)

if local_rank == 0:
    print(thunder.last_backward_traces(tmodel)[-1].python())

torchrun --nproc-per-node 2 bug.py

cc @kshitij12345 since you fixed a similar issue in #23

carmocca avatar Mar 27 '24 23:03 carmocca

I can work around this by setting

tmodel._lc_cd.process_group_for_ddp = tmodel._lc_cd.fn.process_group_for_ddp

since thunder gets this information at jit() time: https://github.com/Lightning-AI/lightning-thunder/blob/94c94948b79875ba5247b5c986afa088d970a49d/thunder/common.py#L224-L226

So my question is: could we delay accessing this attribute until the function runs?

carmocca avatar Mar 27 '24 23:03 carmocca

TBH, this is a very clear "don't do this, chaning the fn is completely unsupported!".

That said, we can talk about distributed-after-jit. The obstacles are:

  • Currently the ddp transformation is applied during the JITing (i.e. while the interpreter runs). This is fundamentally incompatible with what you're trying to do.
  • The syncs inserted can do funny things with tensor shapes, so applying this after the prologue is generated (i.e. after "jit" will have us needing to transform the prologue - this is on our roadmap for other transforms, but not entirely trivial).

I'll chat you up for understanding the need better.

t-vi avatar Mar 28 '24 16:03 t-vi

triage review:

let's look at our current transforms, when they have to be applied, and what they mean when ordered after each other (are all orders supported)? For example, ddp jit grad? jit grad ddp?

do we need to support transforms that change the original module, or maybe produce a new module?

mruberry avatar Apr 01 '24 19:04 mruberry

Currently the ddp transformation is applied during the JITing (i.e. while the interpreter runs). This is fundamentally incompatible with what you're trying to do.

I would appreciate some pointers or examples here that show this, because in my test, the trace does look correct as it contains the appropriate collectives added.

I'm probably misunderstanding how the interpreter works. How can the prologues be generated at jit time if we don't have any input tensors for which to check shapes? I thought this would only happen on the first call

carmocca avatar Apr 03 '24 17:04 carmocca

FSDP and DDP calls are not trace transforms, they are parameter annotators of the original to-be-jitted PyTorch module.

  • We should raise an error when ThunderModule is passed to FSDP or DDP calls suggesting the correct thing
  • Why do you expect ddp(jit(model)) to work and is it more important to support than jit(ddp(model))?

IvanYashchuk avatar Apr 10 '24 19:04 IvanYashchuk

We still need to support jit(ddp(model)), as this is basically what happens whenever you jit a function and not the model.

What I'm advocating for is something like jit(ddp(undo_jit(jit(model)))

Where undo_jit is currently the hack that I describe in the top-post.

Allowing this is convenient because then the user can control the innermost jit(model) call but the framework (fabric, trainer) can control the transforms applied to the model and how they interact with each other if there are more than one.

carmocca avatar Apr 12 '24 11:04 carmocca

I know nothing about Lightning. Do you want to allow users to do jit(model) and then inside Lightning, you apply either DDP or FSDP call to a given model? FSDP is now working, right? You need something that unwraps the jit call. Have you tried using __wrapped__? thunder.jit uses functools.wraps here: https://github.com/Lightning-AI/lightning-thunder/blob/6c64fb93b04672b731180afc4b63d5df55dae92f/thunder/init.py#L642

IvanYashchuk avatar Apr 16 '24 12:04 IvanYashchuk

One thing (probably tangential) I was wondering, why is process_group_for_ddp an attribute for CompileData?

https://github.com/Lightning-AI/lightning-thunder/blob/6c64fb93b04672b731180afc4b63d5df55dae92f/thunder/common.py#L221-L223

I think it would make sense to make it a property. Cause, if a scenario comes where we have to update CompileData.fn, then we might miss updating these corresponding attributes. (This change could potentially also fix the issue)

diff --git a/thunder/common.py b/thunder/common.py
index 85775ff..24cabcb 100644
--- a/thunder/common.py
+++ b/thunder/common.py
@@ -218,10 +218,6 @@ class CompileData:
 
         self.is_module = isinstance(self.fn, torch.nn.Module)
 
-        # We set the process_group_for_ddp attribute on the module when
-        # thunder.distributed.ddp(module) is called.
-        self.process_group_for_ddp = getattr(self.fn, "process_group_for_ddp", None)
-
         #
         # Possibly processes the function
         #
@@ -232,6 +228,12 @@ class CompileData:
 
         assert disable_preprocessing, "please use thunder.compile if you need preprocessing"
 
+    @property
+    def process_group_for_ddp(self):
+        # We set the process_group_for_ddp attribute on the module when
+        # thunder.distributed.ddp(module) is called.
+        return getattr(self.fn, "process_group_for_ddp", None)
+

kshitij12345 avatar Apr 16 '24 13:04 kshitij12345

I think the new fsdp/ddp actually do this.

t-vi avatar Apr 17 '25 16:04 t-vi