djl
djl copied to clipboard
Pick index not working for multidimensional arrays
Description
NDArray#get fails with pick index for multidimensional arrays: Executing code:
NDArray target = manager.arange(6).reshape(3, 2);
NDArray index = manager.create(new long[] {0, 2});
NDArray result = target.get(new NDIndex().addPickDim(index));
Expected Behavior
Expect ndarray of shape 2x2 (as described in javadoc for NDIndex#addPickDim): [[0, 1], [4, 5]]
Error Message
java.lang.IllegalArgumentException: expand shape failed! Cannot expand from (2)to (3, 2)
at ai.djl.pytorch.jni.JniUtils.pick(JniUtils.java:618)
at ai.djl.pytorch.jni.JniUtils.indexAdv(JniUtils.java:464)
at ai.djl.pytorch.engine.PtNDArrayIndexer.get(PtNDArrayIndexer.java:74)
at ai.djl.ndarray.NDArray.get(NDArray.java:523)
at ai.djl.ndarray.NDArray.get(NDArray.java:512)
No, the defination of addPickDim is aligned with https://mxnet.apache.org/versions/1.6/api/r/docs/api/mx.nd.pick.html. So the output of the code
NDArray target = manager.arange(6).reshape(3, 2);
NDArray pickIndex = manager.create(new long[] {0, 2}, new Shape(1, 2));
NDArray result = target.get(new NDIndex().addPickDim(pickIndex));
should be [[ 0, 5],]. This feature is not often used though.
To get [[0, 1], [4, 5]], you will need the array indexing.
NDArray index = manager.create(new long[] {0, 2});
NDArray ret = target.get(index);
Check out: https://github.com/deepjavalibrary/djl/blob/866be61a0cd8a75b98a23efef9dbf6cf13fac910/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java#L153-L161