torch-mlir
torch-mlir copied to clipboard
Lowering `torch.scatter` op
This issue is created in order to figure out the best possible approach to lower torch.scatter
op.
https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html#torch.Tensor.scatter_
I think one possible approach could be through tm_tensor
dialect like this:
- The
torch.scatter(self, dim, index, src)
updates self like:
For a 3-D tensor, self is updated as:
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
One possible approach: Mapping torch.scatter(self, dim, index, src) -> tm_tensor.scatter(self, indices, updates)
- Scan
index
andsrc
tensor with the help oflinalg.generic
op: (lets assume theindex
is a 3D tensor, which meansself
andsrc
will also be a 3D tensor as per https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html#torch.Tensor.scatter_)
outs[0], outs[1], outs[2], outs[3] =
for i=0; i<index_shape[0]; i++
for j=0; j<index_shape[1]; j++
for k=0; k<index_shape[2]; k++
return index[i][j][k], j, k, src[i][j][k] #(assuming dim=0)
In the above case the outs are 1D and shape of all outs[0,1,2,3] will be (index_shape[0]*index_shape[1]*index_shape[2]).
Also the index mapping of outs is: ((d0 * index_shape[1] + d1) * index_shape[2]) + d2
while creating the linalg.generic
op.
- Now prepare
indices
andupdates
for thetm_tensor.scatter
op as:
indices = concat(outs[0], outs[1], outs[2])
updates = outs[3]
tm_tensor.scatter(self, indices, updates)
@silvasean @ramiro050 Could you please take a look and guide me regarding this. cc: @powderluv
Hi @Shukla-Gaurav, thanks for the nice explanation! This seems like a good approach to me.
That seems reasonable.
Is this lowering is completed to close the issue? I can see [TM_TENSOR] Add aten.scatter.[src|value] op #1499 is there.
@xgupta Let me address all the comments of the PR, will wrap it up soon and close this one.