axial-attention
axial-attention copied to clipboard
Hi, I have a problem
import torch from axial_attention import AxialAttention
img = torch.randn(1, 3, 256, 256)
attn = AxialAttention( dim = 3, # embedding dimension dim_index = 1, # where is the embedding dimension dim_heads = 32, # dimension of each head. defaults to dim // heads if not supplied heads = 1, # number of heads for multi-head attention num_dimensions = 2, # number of axial dimensions (images is 2, video is 3, or more) sum_axial_out = True # whether to sum the contributions of attention on each axis, or to run the input through them sequentially. defaults to true )
attn(img) # (1, 3, 256, 256)
Thanks for your great project, I want to ask if my image is one channel image will influence the num_dimensions value?