Fix dynamo issue
Dynamo use faketensor to trace tensor ops. In some case, the mechanism break compiling with deepspeed.
An example could be found at https://gist.github.com/oraluben/9b8240c2fe482eb4382453d6c97a5f76, to see issues, install deepspeed==0.14.4 instead of my fork
without this PR, llama cannot be compiled.
Detailed explanation:
ZeROOrderedDictdynamo use deepcopy to copy tensors, which will callobject.__reduce__. When copyingZeROOrderedDict, the default implementation do not copy its_parent_moduleand will lead to failure.parammaybe faketensor and do not haveds_statusyet, but during tracing it's ok to just skip theregister_external_parameter, it should be done ways before.
torch.compiler.is_compiling() should be better for this case, however there's still issue, presumably on dynamo side (since we have faketensor we're definitely tracing). So keep it for now.
[rank1]: File "/home/yyc/accelerate-demo/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1720, in __getattr__
[rank1]: return _parameters[name]
[rank1]: File "/home/yyc/repo/DeepSpeed/deepspeed/runtime/zero/parameter_offload.py", line 67, in __getitem__
[rank1]: if not is_compiling() and param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
[rank1]: torch._dynamo.exc.TorchRuntimeError: Failed running call_module L__self___self_attn_q_proj(*(FakeTensor(..., device='cuda:1', size=(1, s0, 4096), dtype=torch.float16,
[rank1]: grad_fn=<MulBackward0>),), **{}):
[rank1]: 'FakeTensor' object has no attribute 'ds_status'
my patch in deepspeed.runtime:
diff --git a/deepspeed/runtime/compiler.py b/deepspeed/runtime/compiler.py
index 879c0a1a..3994c1f5 100644
--- a/deepspeed/runtime/compiler.py
+++ b/deepspeed/runtime/compiler.py
@@ -10,6 +10,15 @@ def is_compile_supported():
return hasattr(torch, "compiler") and hasattr(torch.nn.Module, "compile")
+def is_compiling():
+ if not is_compile_supported():
+ return False
+ elif hasattr(torch.compiler, 'is_compiling'): # torch >= 2.3
+ return torch.compiler.is_compiling()
+ else:
+ return torch._dynamo.is_compiling()
+
+
def disable(func):
if is_compile_supported():
return torch.compiler.disable(func)
@oraluben - sorry this PR has taken so long to be merged, I think it just needed to have master merged again to get the XPU fixes.