Stratified-Transformer
Stratified-Transformer copied to clipboard
Cross Attention with attention_step1_v2 using differently sized key and query sets
Hi!
Thanks for this amazing project! I'd like to use your memory-efficient imlementation of sparse attention in another project.
In my project, I'd like to perform cross-attention between a query set (N_query=100
) and a key set (N_key=35000
) which have different sizes.
I adapted this test file: https://github.com/dvlab-research/Stratified-Transformer/blob/main/lib/pointops2/functions/test_attention_op_step1_v2.py to allow different sizes for the query and key set. Note the change in the 2nd line of code in the snippet below.
For the first version pointops.attention_step1
, I obtain no errors and can perform cross-attention.
However for the second (more efficient?) version pointops.attention_step1_v2
, I obtain a CUDA runtime error. RuntimeError: CUDA error: an illegal memory access was encountered
when I try to access attn_flat_v2
, e.g., print(attn_flat_v2)
.
Do I use the function correctly?
Thanks a lot for your help and keep up the great work!
Best, Jonas
M = 8000
N_query, N_key = 100, 35000 # HAVE A DIFFERENT TOKEN LENGTH FOR KEYS AND QUERIES
C = 96
h = 6
query = torch.rand(N_query, h, C // h).cuda()
key = torch.rand(N_key, h, C // h).cuda()
index_0 = torch.rand(M)
index_0[index_0 < 0] = 0
index_0 = (index_0 * N_query).long().cuda()
index_1 = torch.rand(M)
index_1[index_1 < 0] = 0
index_1 = (index_1 * N_key).long().cuda()
query.requires_grad = True
key.requires_grad = True
# rearrange index for acceleration
index_0, indices = torch.sort(index_0) # [M,]
index_1 = index_1[indices] # [M,]
index_0_counts = index_0.bincount()
n_max = index_0_counts.max()
index_0_offsets = index_0_counts.cumsum(dim=-1) # [N]
index_0_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_0_offsets], 0) # [N+1]
attn_flat = pointops.attention_step1(query.float(), key.float(), index_0.int(), index_1.int())
attn_flat_v2 = pointops.attention_step1_v2(query.float(), key.float(), index_1.int(), index_0_offsets.int(),
n_max)
Thanks for your interest in our work.
Yes, pointops.attention_step1_v2
is much faster than pointops.attention_step1
. To use pointops.attention_step1_v2
, you need to make sure that the index_0
is in ascending order, and make sure that each integer number between [0, N_query) appears at least once in index_0
.
Besides, you may need to modify this function, because I have assumed the number of queries is equal to that of the keys.
Thanks for your hints! :)
As it turned out, the training loss with pointops.attention_step1
didn't decrease as well as my initial implementation using padding. So, it is possible that I need to adapt your functions to work with differently sized key and value sets, as you pointed out.
I'll come back when I find out smth useful.
so interested .. I'd like to follow comments here thank you all
Hi!
Thanks for this amazing project! I'd like to use your memory-efficient imlementation of sparse attention in another project. In my project, I'd like to perform cross-attention between a query set (
N_query=100
) and a key set (N_key=35000
) which have different sizes.I adapted this test file: https://github.com/dvlab-research/Stratified-Transformer/blob/main/lib/pointops2/functions/test_attention_op_step1_v2.py to allow different sizes for the query and key set. Note the change in the 2nd line of code in the snippet below.
For the first version
pointops.attention_step1
, I obtain no errors and can perform cross-attention. However for the second (more efficient?) versionpointops.attention_step1_v2
, I obtain a CUDA runtime error.RuntimeError: CUDA error: an illegal memory access was encountered
when I try to accessattn_flat_v2
, e.g.,print(attn_flat_v2)
.Do I use the function correctly?
Thanks a lot for your help and keep up the great work!
Best, Jonas
M = 8000 N_query, N_key = 100, 35000 # HAVE A DIFFERENT TOKEN LENGTH FOR KEYS AND QUERIES C = 96 h = 6 query = torch.rand(N_query, h, C // h).cuda() key = torch.rand(N_key, h, C // h).cuda() index_0 = torch.rand(M) index_0[index_0 < 0] = 0 index_0 = (index_0 * N_query).long().cuda() index_1 = torch.rand(M) index_1[index_1 < 0] = 0 index_1 = (index_1 * N_key).long().cuda() query.requires_grad = True key.requires_grad = True # rearrange index for acceleration index_0, indices = torch.sort(index_0) # [M,] index_1 = index_1[indices] # [M,] index_0_counts = index_0.bincount() n_max = index_0_counts.max() index_0_offsets = index_0_counts.cumsum(dim=-1) # [N] index_0_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_0_offsets], 0) # [N+1] attn_flat = pointops.attention_step1(query.float(), key.float(), index_0.int(), index_1.int()) attn_flat_v2 = pointops.attention_step1_v2(query.float(), key.float(), index_1.int(), index_0_offsets.int(), n_max)
I have met the same problem, it also may occur when your input tensors are in different devices, e.g.(CPU/CUDA:0).