xformers icon indicating copy to clipboard operation
xformers copied to clipboard

efficient_attention_forward_cutlass op is incompatible with Torch JIT

Open philipwan opened this issue 3 months ago • 2 comments

🐛 Bug

RuntimeError: unsupported output type: int, from operator: xformers::efficient_attention_forward_cutlass

Command

To Reproduce

StableDiffusionXL model use with torch.jit and enable_xformers_memory_efficient_attention function

m.enable_xformers_memory_efficient_attention
torch.jit.trace(m)

Steps to reproduce the behavior:

File "/home/philipwan/philipwan/torch-engine/tests/inference/sdxl_chn_pipeline.py", line 847, in __call__
    image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
  File "/opt/conda/lib/python3.8/site-packages/diffusers/utils/accelerate_utils.py", line 46, in wrapper
    return method(self, *args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/diffusers/models/autoencoder_kl.py", line 318, in decode
    decoded = self._decode(z).sample
  File "/opt/conda/lib/python3.8/site-packages/diffusers/models/autoencoder_kl.py", line 289, in _decode
    dec = self.decoder(z)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/philipwan/philipwan/torch-engine/agie/jit/trace_helper.py", line 172, in wrapper
    traced_m, call_helper = trace_with_kwargs(
  File "/home/philipwan/philipwan/torch-engine/agie/jit/trace_helper.py", line 147, in trace_with_kwargs
    traced_module = better_trace(TraceablePosArgOnlyModuleWrapper(func),
  File "/home/philipwan/philipwan/torch-engine/agie/jit/utils.py", line 28, in better_trace
    script_module = torch.jit.trace(func, *args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py", line 806, in trace
    return trace_module(
  File "/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py", line 1074, in trace_module
    module._c._create_method_from_trace(
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/philipwan/philipwan/torch-engine/agie/jit/trace_helper.py", line 306, in forward
    return self.module(*orig_args, **orig_kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/philipwan/philipwan/torch-engine/agie/jit/trace_helper.py", line 220, in forward
    return self.func(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/diffusers/models/vae.py", line 318, in forward
    sample = self.mid_block(sample, latent_embeds)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/diffusers/models/unet_2d_blocks.py", line 624, in forward
    hidden_states = attn(hidden_states, temb=temb)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/diffusers/models/attention_processor.py", line 540, in forward
    return self.processor(
  File "/opt/conda/lib/python3.8/site-packages/diffusers/models/attention_processor.py", line 1214, in __call__
    hidden_states = xformers.ops.memory_efficient_attention(
  File "/opt/conda/lib/python3.8/site-packages/xformers/ops/fmha/__init__.py", line 247, in memory_efficient_attention
    return _memory_efficient_attention(
  File "/opt/conda/lib/python3.8/site-packages/xformers/ops/fmha/__init__.py", line 368, in _memory_efficient_attention
    return _memory_efficient_attention_forward(
  File "/opt/conda/lib/python3.8/site-packages/xformers/ops/fmha/__init__.py", line 388, in _memory_efficient_attention_forward
    out, *_ = op.apply(inp, needs_gradient=False)
  File "/opt/conda/lib/python3.8/site-packages/xformers/ops/fmha/cutlass.py", line 202, in apply
    return cls.apply_bmhk(inp, needs_gradient=needs_gradient)
  File "/opt/conda/lib/python3.8/site-packages/xformers/ops/fmha/cutlass.py", line 266, in apply_bmhk
    out, lse, rng_seed, rng_offset = cls.OPERATOR(
  File "/opt/conda/lib/python3.8/site-packages/torch/_ops.py", line 755, in __call__
    return self._op(*args, **(kwargs or {}))
RuntimeError: unsupported output type: int, from operator: xformers::efficient_attention_forward_cutlass

Environment

Please copy and paste the output from the environment collection script from PyTorch (or fill out the checklist below manually).

  • PyTorch Version (e.g., 1.0): 2.2.1
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.8
  • CUDA/cuDNN version: CUDA12.1
  • GPU models and configuration:
  • Any other relevant information: xformers :0.0.25

Additional context

I saw the relavent issue https://github.com/facebookresearch/xformers/issues/406 And I really want to know how to solve it Thank you !!!

philipwan avatar Mar 28 '24 03:03 philipwan