torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

Lowering `torch.scatter` op

Open Shukla-Gaurav opened this issue 2 years ago • 3 comments

This issue is created in order to figure out the best possible approach to lower torch.scatter op. https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html#torch.Tensor.scatter_

Shukla-Gaurav avatar Oct 11 '22 14:10 Shukla-Gaurav

I think one possible approach could be through tm_tensor dialect like this:

  1. The torch.scatter(self, dim, index, src) updates self like:
For a 3-D tensor, self is updated as:
self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

One possible approach: Mapping torch.scatter(self, dim, index, src) -> tm_tensor.scatter(self, indices, updates)

  1. Scan index and src tensor with the help of linalg.generic op: (lets assume the index is a 3D tensor, which means self and src will also be a 3D tensor as per https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html#torch.Tensor.scatter_)
outs[0], outs[1], outs[2], outs[3] = 
                        for i=0; i<index_shape[0]; i++
                            for j=0; j<index_shape[1]; j++
                                for k=0; k<index_shape[2]; k++
                                    return index[i][j][k], j, k, src[i][j][k]          #(assuming dim=0)

In the above case the outs are 1D and shape of all outs[0,1,2,3] will be (index_shape[0]*index_shape[1]*index_shape[2]). Also the index mapping of outs is: ((d0 * index_shape[1] + d1) * index_shape[2]) + d2 while creating the linalg.generic op.

  1. Now prepare indices and updates for the tm_tensor.scatter op as:
  indices = concat(outs[0], outs[1], outs[2]) 
  updates = outs[3]
  tm_tensor.scatter(self, indices, updates)

@silvasean @ramiro050 Could you please take a look and guide me regarding this. cc: @powderluv

Shukla-Gaurav avatar Oct 11 '22 16:10 Shukla-Gaurav

Hi @Shukla-Gaurav, thanks for the nice explanation! This seems like a good approach to me.

ramiro050 avatar Oct 11 '22 18:10 ramiro050

That seems reasonable.

silvasean avatar Oct 12 '22 09:10 silvasean

Is this lowering is completed to close the issue? I can see [TM_TENSOR] Add aten.scatter.[src|value] op #1499 is there.

xgupta avatar Dec 21 '22 11:12 xgupta

@xgupta Let me address all the comments of the PR, will wrap it up soon and close this one.

Shukla-Gaurav avatar Dec 21 '22 12:12 Shukla-Gaurav