xformers
xformers copied to clipboard
Torch JIT breaks when memory_efficient_attention
🐛 Bug
torch.jit.trace breaks with the following error:
RuntimeError: unsupported output type: int, from operator: xformers::efficient_attention_forward_generic
The output of the ops contains an int that can't be traced by JIT.
Command
To Reproduce
torch.jit.trace the module mentioned in
huggingface/diffusers#532
Expected behavior
No int output so module can be JIT traced.
Thanks for reporting :) Should be fixed in https://github.com/facebookresearch/xformers/pull/438
Thanks for reporting :) Should be fixed in #438
Hello, has it been fixed right now?
Hi, the PR was merged so it should be yes. Please let us know if you have other issues
Thank you. I have tried the newest commit of xformers
, and the RuntimeError: unsupported output type: int, from operator: xformers::efficient_attention_forward_generic
is solved.
However, another problem appears. When I run the following code:
inputs = torch.randn(2,4,64,64, dtype=torch.half, device='cuda:6'), torch.randn(1, dtype=torch.half, device='cuda:6'), torch.randn(2, 77, 768, dtype=torch.half, device='cuda:6')
# Here pipeline is a `diffusers.LDMTextToImagePipeline`
with torch.no_grad():
with torch.autocast("cuda"):
jit_unet = torch.jit.trace(pipeline.unet, inputs, strict=False)
But, if I run the code above twice, the error disappears itself 😂 and the
pipeline
works fine in later parts.
I'm getting this error too. return self._op(*args, **kwargs or {}) RuntimeError: unsupported output type: int, from operator: xformers::efficient_attention_forward_cutlass
got this error too.
got save erros when I use torch.jit.trace ,any update?
I think the original fix (https://github.com/facebookresearch/xformers/pull/438) did work, but the issue was re-introduced later in https://github.com/facebookresearch/xformers/pull/587
question to @danthe3rd , what's the purpose of two int output values rng_seed, rng_offset
? is it possible to re-apply the fix from #438?
Oh this is a regression - right.
The purpose of rng_seed
, rng_offset
is to keep the RNG state for the backward pass. This is useful when there is a dropout in the FW pass, and we need to mask the exact same values in the BW pass (and we don't want to save a "dropout" mask that would be too expensive).
There are also complications due to replaying CUDA Graphs (in which case we want the RNG to be different).
I believe we should be able to store these values in a torch.Tensor
, or maybe there is a a best-practice for these sort of issues? @drisspg or @fmassa maybe?
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/cuda/attention.cu#L1018-L1055 cc @danthe3rd
Does JIT support SymInt? Because the version in PT outputs SymInt
, not exactly sure why.
Anyway we want to rely on the PyTorch version moving forward (with the C++ code moving to PyTorch repo), so hopefully this can be fixed at the same time.
@danthe3rd , which version of torch are you referring to? for torch 2.2.0, I see the type is Tensor
for both seed
and offset
func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)
https://github.com/pytorch/pytorch/blob/4cb7dd0fc99981062cebf8d5a94e62b48bf78446/aten/src/ATen/native/native_functions.yaml#L14484-L14488
and here are when they are initialized: https://github.com/pytorch/pytorch/blob/d47f715d29d05e28b94c280f15dce097ef3dc7cb/aten/src/ATen/native/transformers/cuda/attention.cu#L978-L982
Anyway we want to rely on the PyTorch version moving forward (with the C++ code moving to PyTorch repo)
@danthe3rd are you referring to at::_scaled_dot_product_efficient_attention
?