xformers icon indicating copy to clipboard operation
xformers copied to clipboard

Torch JIT breaks when memory_efficient_attention

Open Dango233 opened this issue 1 year ago • 13 comments

🐛 Bug

torch.jit.trace breaks with the following error:
RuntimeError: unsupported output type: int, from operator: xformers::efficient_attention_forward_generic
The output of the ops contains an int that can't be traced by JIT.

Command

To Reproduce

torch.jit.trace the module mentioned in
huggingface/diffusers#532

Expected behavior

No int output so module can be JIT traced.

Dango233 avatar Sep 20 '22 09:09 Dango233

Thanks for reporting :) Should be fixed in https://github.com/facebookresearch/xformers/pull/438

danthe3rd avatar Sep 27 '22 14:09 danthe3rd

Thanks for reporting :) Should be fixed in #438

Hello, has it been fixed right now?

geekinglcq avatar Oct 11 '22 07:10 geekinglcq

Hi, the PR was merged so it should be yes. Please let us know if you have other issues

danthe3rd avatar Oct 11 '22 07:10 danthe3rd

Thank you. I have tried the newest commit of xformers, and the RuntimeError: unsupported output type: int, from operator: xformers::efficient_attention_forward_generic is solved.
However, another problem appears. When I run the following code:

inputs = torch.randn(2,4,64,64, dtype=torch.half, device='cuda:6'), torch.randn(1, dtype=torch.half, device='cuda:6'), torch.randn(2, 77, 768, dtype=torch.half, device='cuda:6')
# Here pipeline is a `diffusers.LDMTextToImagePipeline`
with torch.no_grad():
    with torch.autocast("cuda"):
        jit_unet = torch.jit.trace(pipeline.unet, inputs, strict=False)

image But, if I run the code above twice, the error disappears itself 😂 and the pipeline works fine in later parts.

geekinglcq avatar Oct 11 '22 13:10 geekinglcq

I'm getting this error too. return self._op(*args, **kwargs or {}) RuntimeError: unsupported output type: int, from operator: xformers::efficient_attention_forward_cutlass

gigadeplex avatar Sep 13 '23 10:09 gigadeplex

got this error too.

roninjiang avatar Oct 19 '23 11:10 roninjiang

got save erros when I use torch.jit.trace ,any update?

xinlin-xiao avatar Nov 23 '23 03:11 xinlin-xiao

I think the original fix (https://github.com/facebookresearch/xformers/pull/438) did work, but the issue was re-introduced later in https://github.com/facebookresearch/xformers/pull/587

question to @danthe3rd , what's the purpose of two int output values rng_seed, rng_offset? is it possible to re-apply the fix from #438?

ShijunK avatar Mar 29 '24 17:03 ShijunK

Oh this is a regression - right. The purpose of rng_seed, rng_offset is to keep the RNG state for the backward pass. This is useful when there is a dropout in the FW pass, and we need to mask the exact same values in the BW pass (and we don't want to save a "dropout" mask that would be too expensive). There are also complications due to replaying CUDA Graphs (in which case we want the RNG to be different). I believe we should be able to store these values in a torch.Tensor, or maybe there is a a best-practice for these sort of issues? @drisspg or @fmassa maybe?

danthe3rd avatar Mar 29 '24 17:03 danthe3rd

https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/cuda/attention.cu#L1018-L1055 cc @danthe3rd

drisspg avatar Mar 29 '24 17:03 drisspg

Does JIT support SymInt? Because the version in PT outputs SymInt, not exactly sure why. Anyway we want to rely on the PyTorch version moving forward (with the C++ code moving to PyTorch repo), so hopefully this can be fixed at the same time.

danthe3rd avatar Apr 04 '24 14:04 danthe3rd

@danthe3rd , which version of torch are you referring to? for torch 2.2.0, I see the type is Tensor for both seed and offset

func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset) https://github.com/pytorch/pytorch/blob/4cb7dd0fc99981062cebf8d5a94e62b48bf78446/aten/src/ATen/native/native_functions.yaml#L14484-L14488

and here are when they are initialized: https://github.com/pytorch/pytorch/blob/d47f715d29d05e28b94c280f15dce097ef3dc7cb/aten/src/ATen/native/transformers/cuda/attention.cu#L978-L982

ShijunK avatar Apr 16 '24 16:04 ShijunK

Anyway we want to rely on the PyTorch version moving forward (with the C++ code moving to PyTorch repo)

@danthe3rd are you referring to at::_scaled_dot_product_efficient_attention ?

ShijunK avatar Apr 16 '24 17:04 ShijunK