MinkowskiEngine
MinkowskiEngine copied to clipboard
Gradient Backpropagation through Tensorfield
Hi,
I am trying to implement a self-attention block for sparse tensors. For doing that, I need to multiply the features coming from two sparse tensors but cause I can't multiply features from two SparseTensors, I am separating the features of sparse tensors, do the matrix multiplication using Torch, and making a new sparse tensor using TensorField. My implementation is as follows:
class attn_block(nn.Module):
def __init__(self, in_dim):
super().__init__()
# Construct the conv layers
self.query_conv = ME.MinkowskiConvolution(in_dim , in_dim , kernel_size= 1, dimension=3)
self.key_conv = ME.MinkowskiConvolution(in_dim , in_dim , kernel_size= 1, dimension=3)
self.value_conv = ME.MinkowskiConvolution(in_dim , in_dim , kernel_size= 1, dimension=3)
# Initialize gamma as 0
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
self.union = ME.MinkowskiUnion()
def forward(self, x):
# x is the input sparse tensor
query = self.query_conv(x)
key = self.key_conv(x)
value = self.value_conv(x)
energy = torch.mul(query.F, key.F)
attention = self.softmax(energy)
out_f = torch.mul(attention, value.F)
out = ME.TensorField(self.gamma*out_f + x.F, x.C, tensor_stride = x.tensor_stride)
# out will be sent to the next conv block
return out.sparse(tensor_stride = x.tensor_stride)
My doubt is that would the gradients backpropagate through the new SparseTensor that gets created using TensorField?
Thanks!
This will get the gradient but it would only work with batch size 1. For other batch sizes. You will be computing attention over all points in all batch.
Oh, Okay, thank you! It seems like cause there is no explicit dimension for batch, implementing it for a batch size of more than 1 would be difficult this way. Is that a correct understanding? Any suggestions on how to approach it? I would really appreciate it, Thanks!
hello, have you solved this problem, how can i achieve a self-attention block in Sparse Conv?
Also interested in a solution to this, has this been solved?