djl icon indicating copy to clipboard operation
djl copied to clipboard

IdEmbedding not support PyTorchEngine

Open dxjjhm opened this issue 4 years ago • 3 comments

IdEmbedding userEmbedding = new IdEmbedding.Builder() .setDictionarySize(userCount) .setEmbeddingSize(64) .build(); userEmbedding.initialize(manager, DataType.FLOAT32, usersND.getShape());

userEmbedding.forward(ps, new NDList(manager.create(new int[] {1,2,3,4})), true).singletonOrThrow();

Exception in thread "main" java.lang.UnsupportedOperationException: Not supported! at ai.djl.ndarray.BaseNDManager.invoke(BaseNDManager.java:285) at ai.djl.nn.transformer.MissingOps.gatherNd(MissingOps.java:31) at ai.djl.nn.transformer.IdEmbedding.forwardInternal(IdEmbedding.java:72) at ai.djl.nn.AbstractBlock.forward(AbstractBlock.java:121) at ai.djl.nn.Block.forward(Block.java:122)

dxjjhm avatar Sep 06 '21 10:09 dxjjhm

@dxjjhm Thanks for reporting this issue. Currently DJL only implemented gather in MXNet. We prioritize this for PyTorch.

frankfliu avatar Sep 06 '21 15:09 frankfliu

Do you have some methods to get embedding-batch by input-batch

dxjjhm avatar Sep 07 '21 01:09 dxjjhm

by using NDArrayEx Interface, I am already found a solution as follow.

NDArray inputUserEmbedding = userInput.getNDArrayInternal() .embedding(userInput, userEmbeddingTable, SparseFormat.DENSE) .singletonOrThrow();

dxjjhm avatar Sep 17 '21 01:09 dxjjhm

gather has been implemented in PyTorch.

frankfliu avatar Dec 27 '22 20:12 frankfliu