Hi,
Thank you for your indescribable work. I was trying to test your method specifically for cross-attention but It seems I get the error " save_for_backward can only save variables, but argument 5 is of type bool". I am not sure what I am doing wrong. I tried your own examples too but get the same error.
Can you please help me out?
Code:
import torch
from memory_efficient_attention_pytorch import Attention
cross_attn = Attention(
dim = 512,
dim_head = 64,
heads = 8,
memory_efficient = True,
q_bucket_size = 1024,
k_bucket_size = 2048
).cuda()
(# out = sm_mod(inp1)) did this to avoid being a header
x = torch.randn(1, 65536, 512).cuda()
context = torch.randn(1, 65536, 512).cuda()
(# mask = torch.ones(1, 65536).bool().cuda()) did this to avoid being a heading
out = cross_attn(x
ERROR:
File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/abali/.vscode-server/extensions/ms-python.python-2022.8.1/pythonFiles/lib/python/debugpy/main.py", line 45, in
cli.main()
File "/home/abali/.vscode-server/extensions/ms-python.python-2022.8.1/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 444, in main
run()
File "/home/abali/.vscode-server/extensions/ms-python.python-2022.8.1/pythonFiles/lib/python/debugpy/../debugpy/server/cli.py", line 285, in run_file
runpy.run_path(target_as_str, run_name=compat.force_str("main"))
File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 265, in run_path
return _run_module_code(code, init_globals, run_name,
File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 97, in _run_module_code
_run_code(code, mod_globals, init_globals,
File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/data/stars/user/abali/Phd_work/ISBI2023/X3D-Multigrid/CrossAttn_X3d_v2.py", line 872, in
out = cross_attn(x, context = context, mask = mask) # (1, 65536, 512) print(out)
File "/home/abali/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/site-packages/memory_efficient_attention_pytorch/memory_efficient_attention.py", line 215, in forward
out = attn_fn(q, k, v, mask = mask, attn_bias = attn_bias, causal = self.causal, q_bucket_size = q_bucket_size, k_bucket_size = k_bucket_size)
File "/home/abali/.conda/envs/py38_ydp5/lib/python3.8/site-packages/memory_efficient_attention_pytorch/memory_efficient_attention.py", line 127, in memory_efficient_attention
exp_weight_chunk, weighted_value_chunk, weight_max_chunk = summarize_qkv_fn(
File "/home/abali/.local/lib/python3.8/site-packages/torch/utils/checkpoint.py", line 163, in checkpoint
return CheckpointFunction.apply(function, preserve, *args)
TypeError: save_for_backward can only save variables, but argument 5 is of type bool
@aliabid2243 Hi Abid! Which version of pytorch are you on?