DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

Fix dynamo issue

Open oraluben opened this issue 1 year ago • 1 comments

Dynamo use faketensor to trace tensor ops. In some case, the mechanism break compiling with deepspeed.

An example could be found at https://gist.github.com/oraluben/9b8240c2fe482eb4382453d6c97a5f76, to see issues, install deepspeed==0.14.4 instead of my fork

without this PR, llama cannot be compiled.

Detailed explanation:

  1. ZeROOrderedDict dynamo use deepcopy to copy tensors, which will call object.__reduce__. When copying ZeROOrderedDict, the default implementation do not copy its _parent_module and will lead to failure.
  2. param maybe faketensor and do not have ds_status yet, but during tracing it's ok to just skip the register_external_parameter, it should be done ways before.

oraluben avatar Sep 12 '24 05:09 oraluben

torch.compiler.is_compiling() should be better for this case, however there's still issue, presumably on dynamo side (since we have faketensor we're definitely tracing). So keep it for now.

[rank1]:   File "/home/yyc/accelerate-demo/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1720, in __getattr__
[rank1]:     return _parameters[name]
[rank1]:   File "/home/yyc/repo/DeepSpeed/deepspeed/runtime/zero/parameter_offload.py", line 67, in __getitem__
[rank1]:     if not is_compiling() and param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
[rank1]: torch._dynamo.exc.TorchRuntimeError: Failed running call_module L__self___self_attn_q_proj(*(FakeTensor(..., device='cuda:1', size=(1, s0, 4096), dtype=torch.float16,
[rank1]:            grad_fn=<MulBackward0>),), **{}):
[rank1]: 'FakeTensor' object has no attribute 'ds_status'

my patch in deepspeed.runtime:

diff --git a/deepspeed/runtime/compiler.py b/deepspeed/runtime/compiler.py
index 879c0a1a..3994c1f5 100644
--- a/deepspeed/runtime/compiler.py
+++ b/deepspeed/runtime/compiler.py
@@ -10,6 +10,15 @@ def is_compile_supported():
     return hasattr(torch, "compiler") and hasattr(torch.nn.Module, "compile")
 
 
+def is_compiling():
+    if not is_compile_supported():
+        return False
+    elif hasattr(torch.compiler, 'is_compiling'):  # torch >= 2.3
+        return torch.compiler.is_compiling()
+    else:
+        return torch._dynamo.is_compiling()
+
+
 def disable(func):
     if is_compile_supported():
         return torch.compiler.disable(func)

oraluben avatar Sep 13 '24 03:09 oraluben

@oraluben - sorry this PR has taken so long to be merged, I think it just needed to have master merged again to get the XPU fixes.

loadams avatar Oct 23 '24 20:10 loadams