xformers
xformers copied to clipboard
efficient_attention_forward_cutlass op is incompatible with Torch JIT
🐛 Bug
RuntimeError: unsupported output type: int, from operator: xformers::efficient_attention_forward_cutlass
Command
To Reproduce
StableDiffusionXL model use with torch.jit and enable_xformers_memory_efficient_attention function
m.enable_xformers_memory_efficient_attention
torch.jit.trace(m)
Steps to reproduce the behavior:
File "/home/philipwan/philipwan/torch-engine/tests/inference/sdxl_chn_pipeline.py", line 847, in __call__
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
File "/opt/conda/lib/python3.8/site-packages/diffusers/utils/accelerate_utils.py", line 46, in wrapper
return method(self, *args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/diffusers/models/autoencoder_kl.py", line 318, in decode
decoded = self._decode(z).sample
File "/opt/conda/lib/python3.8/site-packages/diffusers/models/autoencoder_kl.py", line 289, in _decode
dec = self.decoder(z)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/philipwan/philipwan/torch-engine/agie/jit/trace_helper.py", line 172, in wrapper
traced_m, call_helper = trace_with_kwargs(
File "/home/philipwan/philipwan/torch-engine/agie/jit/trace_helper.py", line 147, in trace_with_kwargs
traced_module = better_trace(TraceablePosArgOnlyModuleWrapper(func),
File "/home/philipwan/philipwan/torch-engine/agie/jit/utils.py", line 28, in better_trace
script_module = torch.jit.trace(func, *args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py", line 806, in trace
return trace_module(
File "/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py", line 1074, in trace_module
module._c._create_method_from_trace(
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _slow_forward
result = self.forward(*input, **kwargs)
File "/home/philipwan/philipwan/torch-engine/agie/jit/trace_helper.py", line 306, in forward
return self.module(*orig_args, **orig_kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _slow_forward
result = self.forward(*input, **kwargs)
File "/home/philipwan/philipwan/torch-engine/agie/jit/trace_helper.py", line 220, in forward
return self.func(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/diffusers/models/vae.py", line 318, in forward
sample = self.mid_block(sample, latent_embeds)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _slow_forward
result = self.forward(*input, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/diffusers/models/unet_2d_blocks.py", line 624, in forward
hidden_states = attn(hidden_states, temb=temb)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _slow_forward
result = self.forward(*input, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/diffusers/models/attention_processor.py", line 540, in forward
return self.processor(
File "/opt/conda/lib/python3.8/site-packages/diffusers/models/attention_processor.py", line 1214, in __call__
hidden_states = xformers.ops.memory_efficient_attention(
File "/opt/conda/lib/python3.8/site-packages/xformers/ops/fmha/__init__.py", line 247, in memory_efficient_attention
return _memory_efficient_attention(
File "/opt/conda/lib/python3.8/site-packages/xformers/ops/fmha/__init__.py", line 368, in _memory_efficient_attention
return _memory_efficient_attention_forward(
File "/opt/conda/lib/python3.8/site-packages/xformers/ops/fmha/__init__.py", line 388, in _memory_efficient_attention_forward
out, *_ = op.apply(inp, needs_gradient=False)
File "/opt/conda/lib/python3.8/site-packages/xformers/ops/fmha/cutlass.py", line 202, in apply
return cls.apply_bmhk(inp, needs_gradient=needs_gradient)
File "/opt/conda/lib/python3.8/site-packages/xformers/ops/fmha/cutlass.py", line 266, in apply_bmhk
out, lse, rng_seed, rng_offset = cls.OPERATOR(
File "/opt/conda/lib/python3.8/site-packages/torch/_ops.py", line 755, in __call__
return self._op(*args, **(kwargs or {}))
RuntimeError: unsupported output type: int, from operator: xformers::efficient_attention_forward_cutlass
Environment
Please copy and paste the output from the environment collection script from PyTorch (or fill out the checklist below manually).
- PyTorch Version (e.g., 1.0): 2.2.1
- OS (e.g., Linux):
- How you installed PyTorch (
conda
,pip
, source): pip - Build command you used (if compiling from source):
- Python version: 3.8
- CUDA/cuDNN version: CUDA12.1
- GPU models and configuration:
- Any other relevant information: xformers :0.0.25
Additional context
I saw the relavent issue https://github.com/facebookresearch/xformers/issues/406 And I really want to know how to solve it Thank you !!!
How did you solve it? Cheers!
How did you solve it? Cheers!
I didn't solve it, but I have a workaround by matching custom pattern rather than xformers::efficient_attention_forward_cutlass
@philipwan Could you maybe elaborate on your workaround. I'm currently having the same problem :(
I have the same problems when using tensorRT and xformers.