Stratified-Transformer icon indicating copy to clipboard operation
Stratified-Transformer copied to clipboard

Cross Attention with attention_step1_v2 using differently sized key and query sets

Open JonasSchult opened this issue 2 years ago • 4 comments

Hi!

Thanks for this amazing project! I'd like to use your memory-efficient imlementation of sparse attention in another project. In my project, I'd like to perform cross-attention between a query set (N_query=100) and a key set (N_key=35000) which have different sizes.

I adapted this test file: https://github.com/dvlab-research/Stratified-Transformer/blob/main/lib/pointops2/functions/test_attention_op_step1_v2.py to allow different sizes for the query and key set. Note the change in the 2nd line of code in the snippet below.

For the first version pointops.attention_step1, I obtain no errors and can perform cross-attention. However for the second (more efficient?) version pointops.attention_step1_v2, I obtain a CUDA runtime error. RuntimeError: CUDA error: an illegal memory access was encountered when I try to access attn_flat_v2, e.g., print(attn_flat_v2).

Do I use the function correctly?

Thanks a lot for your help and keep up the great work!

Best, Jonas

    M = 8000
    N_query, N_key = 100, 35000  # HAVE A DIFFERENT TOKEN LENGTH FOR KEYS AND QUERIES
    C = 96
    h = 6
    query = torch.rand(N_query, h, C // h).cuda()
    key = torch.rand(N_key, h, C // h).cuda()

    index_0 = torch.rand(M)
    index_0[index_0 < 0] = 0
    index_0 = (index_0 * N_query).long().cuda()

    index_1 = torch.rand(M)
    index_1[index_1 < 0] = 0
    index_1 = (index_1 * N_key).long().cuda()

    query.requires_grad = True
    key.requires_grad = True

    # rearrange index for acceleration
    index_0, indices = torch.sort(index_0)  # [M,]
    index_1 = index_1[indices]  # [M,]
    index_0_counts = index_0.bincount()
    n_max = index_0_counts.max()
    index_0_offsets = index_0_counts.cumsum(dim=-1)  # [N]
    index_0_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_0_offsets], 0)  # [N+1]

    attn_flat = pointops.attention_step1(query.float(), key.float(), index_0.int(), index_1.int())
    attn_flat_v2 = pointops.attention_step1_v2(query.float(), key.float(), index_1.int(), index_0_offsets.int(),
                                               n_max)

JonasSchult avatar Apr 22 '22 16:04 JonasSchult

Thanks for your interest in our work. Yes, pointops.attention_step1_v2 is much faster than pointops.attention_step1. To use pointops.attention_step1_v2, you need to make sure that the index_0 is in ascending order, and make sure that each integer number between [0, N_query) appears at least once in index_0.

Besides, you may need to modify this function, because I have assumed the number of queries is equal to that of the keys.

X-Lai avatar Apr 22 '22 18:04 X-Lai

Thanks for your hints! :) As it turned out, the training loss with pointops.attention_step1 didn't decrease as well as my initial implementation using padding. So, it is possible that I need to adapt your functions to work with differently sized key and value sets, as you pointed out. I'll come back when I find out smth useful.

JonasSchult avatar Apr 23 '22 09:04 JonasSchult

so interested .. I'd like to follow comments here thank you all

RokiaAbdeen avatar Jun 16 '22 13:06 RokiaAbdeen

Hi!

Thanks for this amazing project! I'd like to use your memory-efficient imlementation of sparse attention in another project. In my project, I'd like to perform cross-attention between a query set (N_query=100) and a key set (N_key=35000) which have different sizes.

I adapted this test file: https://github.com/dvlab-research/Stratified-Transformer/blob/main/lib/pointops2/functions/test_attention_op_step1_v2.py to allow different sizes for the query and key set. Note the change in the 2nd line of code in the snippet below.

For the first version pointops.attention_step1, I obtain no errors and can perform cross-attention. However for the second (more efficient?) version pointops.attention_step1_v2, I obtain a CUDA runtime error. RuntimeError: CUDA error: an illegal memory access was encountered when I try to access attn_flat_v2, e.g., print(attn_flat_v2).

Do I use the function correctly?

Thanks a lot for your help and keep up the great work!

Best, Jonas

    M = 8000
    N_query, N_key = 100, 35000  # HAVE A DIFFERENT TOKEN LENGTH FOR KEYS AND QUERIES
    C = 96
    h = 6
    query = torch.rand(N_query, h, C // h).cuda()
    key = torch.rand(N_key, h, C // h).cuda()

    index_0 = torch.rand(M)
    index_0[index_0 < 0] = 0
    index_0 = (index_0 * N_query).long().cuda()

    index_1 = torch.rand(M)
    index_1[index_1 < 0] = 0
    index_1 = (index_1 * N_key).long().cuda()

    query.requires_grad = True
    key.requires_grad = True

    # rearrange index for acceleration
    index_0, indices = torch.sort(index_0)  # [M,]
    index_1 = index_1[indices]  # [M,]
    index_0_counts = index_0.bincount()
    n_max = index_0_counts.max()
    index_0_offsets = index_0_counts.cumsum(dim=-1)  # [N]
    index_0_offsets = torch.cat([torch.zeros(1, dtype=torch.long).cuda(), index_0_offsets], 0)  # [N+1]

    attn_flat = pointops.attention_step1(query.float(), key.float(), index_0.int(), index_1.int())
    attn_flat_v2 = pointops.attention_step1_v2(query.float(), key.float(), index_1.int(), index_0_offsets.int(),
                                               n_max)

I have met the same problem, it also may occur when your input tensors are in different devices, e.g.(CPU/CUDA:0).

JUNJIE99 avatar Jul 25 '22 08:07 JUNJIE99