pytorch icon indicating copy to clipboard operation
pytorch copied to clipboard

CudaGraph Capturing failures on `_to_copy()` operations

Open kevinstephano opened this issue 3 years ago • 10 comments

🐛 Describe the bug

This is our number one failure signature. Important to fix!

This is the error:

ERROR:common:Failed for dynamo CUDA error: operation failed due to a previous error during capture                                                                                                                                                                                                                                                                                                                                                                                                        
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.                                                                                                                                                                                                                                                                                                                                                                                    
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.                                                                                                                                                                                                                                                                                                                                                                                                                                                    
Traceback (most recent call last):                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
  File "/opt/pytorch/torchdynamo/torchinductor/compile_fx.py", line 249, in cudagraphify_impl                                                                                                                                                                                                                                                                                                                                                                                                             
    static_outputs = model(*static_inputs)                                                                                                                                                                                                                                                                                                                                                                                                                                                                
  File "/opt/pytorch/pytorch/torch/_prims/executor.py", line 25, in execute                                                                                                                                                                                                                                                                                                                                                                                                                               
    return nvfuser_execute_partitioned(                                                                                                                                                                                                                                                                                                                                                                                                                                                                   
  File "/opt/pytorch/pytorch/torch/_prims/nvfuser_executor.py", line 418, in nvfuser_execute_partitioned                                                                                                                                                                                                                                                                                                                                                                                                  
    return gm(*args)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      
  File "/opt/pytorch/pytorch/torch/fx/graph_module.py", line 660, in call_wrapped                                                                                                                                                                                                                                                                                                                                                                                                                         
    return self._wrapped_call(self, *args, **kwargs)                                                                                                                                                                                                                                                                                                                                                                                                                                                      
  File "/opt/pytorch/pytorch/torch/fx/graph_module.py", line 279, in __call__                                                                                                                                                                                                                                                                                                                                                                                                                             
    raise e                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               
  File "/opt/pytorch/pytorch/torch/fx/graph_module.py", line 269, in __call__                                                                                                                                                                                                                                                                                                                                                                                                                             
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]                                                                                                                                                                                                                                                                                                                                                                                                                           
  File "/opt/pytorch/pytorch/torch/nn/modules/module.py", line 1190, in _call_impl                                                                                                                                                                                                                                                                                                                                                                                                                        
    return forward_call(*input, **kwargs)                                                                                                                                                                                                                                                                                                                                                                                                                                                                 
  File "<eval_with_key>.3119", line 143, in forward                                                                                                                                                                                                                                                                                                                                                                                                                                                       
    _to_copy = torch.ops.aten._to_copy.default(broadcast_in_dim_4, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0));  broadcast_in_dim_4 = None                                                                                                                                                                                                                                                                                                                      
  File "/opt/pytorch/pytorch/torch/_ops.py", line 257, in __call__                                                                                                                                                                                                                                                                                                                                                                                                                                        
    return self._op(*args, **kwargs or {})                                                                                                                                                                                                                                                                                                                                                                                                                                                                
RuntimeError: CUDA error: operation not permitted when stream is capturing 

An example command line:

python -u benchmarks/huggingface.py --training -d cuda --fast --backend nvprims_nvfuser --skip-accuracy-check --performance --only BartForCausalLM

Versions

csarofeen/torchbenchPerf

kevinstephano avatar Oct 18 '22 08:10 kevinstephano

I think the problem with this particular model (BartForCausalLM) is that broadcast_in_dim_4 is for some reason a CPU tensor and torch.ops.aten._to_copy.default is used to send to GPU.

IvanYashchuk avatar Oct 18 '22 12:10 IvanYashchuk

First lines in the graph create CPU tensors

        full: f32[1024,1024] = torch.ops.aten.full.default([1024, 1024], -inf, device = device(type='cpu'), pin_memory = False)
        arange: i64[1024] = torch.ops.aten.arange.default(1024, device = device(type='cpu'), pin_memory = False)

IvanYashchuk avatar Oct 18 '22 13:10 IvanYashchuk

Replacing device(type='cpu') with torch.device("cuda:0") in the graph fixes the problem.

UPD: Add to the top of prims_executor function inside torchdynamo/optimizations/training.py

    for node in gm.graph.nodes:
        if node.op == "call_function" and node.target in [
            torch.ops.aten.arange.default,
            torch.ops.aten.full.default,
        ]:
            new_kwargs = dict(node.kwargs)
            if new_kwargs.get("device", False) and new_kwargs["device"].type == "cpu":
                new_kwargs["device"] = torch.device("cuda")
            node.kwargs = new_kwargs
    gm.recompile()

IvanYashchuk avatar Oct 18 '22 13:10 IvanYashchuk

This looks like the problem where maybe the mask isn't getting properly handled as an input:

https://github.com/huggingface/transformers/blob/ebee0a27940adfbb30444d83387b9ea0f1173f40/src/transformers/models/bart/modeling_bart.py#L96-L99

kevinstephano avatar Oct 18 '22 18:10 kevinstephano

It looks like it! There're no device arguments and the default device is cpu. It's sent to CUDA later: https://github.com/huggingface/transformers/blob/ebee0a27940adfbb30444d83387b9ea0f1173f40/src/transformers/models/bart/modeling_bart.py#L914-L916

IvanYashchuk avatar Oct 18 '22 18:10 IvanYashchuk

The diff patch version of Ivan's fix:

diff --git a/torchdynamo/optimizations/training.py b/torchdynamo/optimizations/training.py
index dca9202..1f2178b 100644
--- a/torchdynamo/optimizations/training.py
+++ b/torchdynamo/optimizations/training.py
@@ -356,6 +356,17 @@ def prims_executor(gm, inputs, *, executor, num_fixed=0):
     from torch.fx.experimental.proxy_tensor import make_fx
     from torchinductor.compile_fx import align_inputs, cudagraphify
 
+    for node in gm.graph.nodes:
+        if node.op == "call_function" and node.target in [
+            torch.ops.aten.arange.default,
+            torch.ops.aten.full.default,
+        ]:
+            new_kwargs = dict(node.kwargs)
+            if new_kwargs.get("device", False) and new_kwargs["device"].type == "cpu":
+                new_kwargs["device"] = torch.device("cuda")
+            node.kwargs = new_kwargs
+    gm.recompile()
+
     # First we trace the graph conditionally decomposing nodes
     # that can be sent to the nvfuser executor
     with TorchRefsNvfuserCapabilityMode():

kevinstephano avatar Oct 19 '22 00:10 kevinstephano

Need to add one more op:

    for node in gm.graph.nodes:                                                                                           
        if node.op == "call_function" and node.target in [                                                                
            torch.ops.aten.arange.default,                                                                                
            torch.ops.aten.arange.start_step,                                                                             
            torch.ops.aten.full.default,                                                                                  
        ]:                                                                                                                
            new_kwargs = dict(node.kwargs)                                                                                
            if new_kwargs.get("device", False) and new_kwargs["device"].type == "cpu":                                    
                new_kwargs["device"] = torch.device("cuda")                                                               
            node.kwargs = new_kwargs                                                                                      
    gm.recompile()

kevinstephano avatar Oct 19 '22 09:10 kevinstephano

Can we get a PR up with these changes to torchbenchPerf?

jjsjann123 avatar Oct 19 '22 18:10 jjsjann123

I have added it to nvfuser-cudagraphify branch.

IvanYashchuk avatar Oct 19 '22 19:10 IvanYashchuk

I have added it to nvfuser-cudagraphify branch.

Actually saw that after I asked the question. If it's going towards dynamo, then don't worry about a PR. :bow:

jjsjann123 avatar Oct 19 '22 19:10 jjsjann123