[RFC] AtenEmbeddingBagPaddingIdxOp
Hello, I implemented the logic for the "sum" mode which required me to use the linalg::GenericOp. After having a discussion with @ramiro050 in office hours we came to a conclusion to request an RFC in order to make sure this is the correct way to implement the Op. This is what a high level python implementation of the sum mode looks like
import numpy as np
weight = [[1, 3, 5, 3],
[3, 4, 2, 1],
[2, 2, 3, 2],
[0, 4, 2, 1],]
indices = [0, 2, 3, 1, 2, 3, 2, 1, 0, 1]
offsets = [0, 3, 5]
embedding_size = len(weight[0])
indices_shape = len(indices)
offsets_shape = len(offsets)
# add the size of the indices shape as a last offset
# this will avoid additional control flow inside the for loop.
offsets.append(indices_shape)
#convert to numpy arrays
weight = np.array(weight)
indices = np.array(indices)
offsets = np.array(offsets)
output_tensor = np.zeros( (offsets_shape, embedding_size) )
for i in range(offsets_shape):
for j in range(indices_shape):
for k in range(embedding_size):
if(offsets[i] <= j and j < offsets[i+1]):
output_tensor[i][k] = output_tensor[i][k] + weight[indices[j]][k]
else:
break
I also have a WIP branch for implementing the AtenEmbeddingBagPaddingIdxOp here: https://github.com/llvm/torch-mlir/pull/1066 I have implemented the sum mode here using the GenericOp.
@silvasean, the main concern here is that for every element in the offsets tensor, the entire output_tensor gets iterated over. Do you know if there is a different approach using something like TMTensor dialect to create a more efficient implementation?
@silvasean, the main concern here is that for every element in the
offsetstensor, the entireoutput_tensorgets iterated over. Do you know if there is a different approach using something likeTMTensordialect to create a more efficient implementation?
Not off the top of my head (we whiteboarded some stuff but putting it into a working code is beyond what I have in cache right now). Can you look into what ATen does to implement this op? Happy to dig in further if that doesn't yield anything.
This appears to have been implemented in b70548edff7f55d5fa8335bf5af67f3d1ba0ba8f ed13ebfd8dae0d311208eea89826ee77e29271cc