djl
djl copied to clipboard
NDArrays.concat has a different behaviour as np.concatenate
Description
It looks like concatenating lists doesn't work as in Python.
Expected Behavior
I was expecting to get an array with the same shape as in Python.
Error Message
No errors, but here are the outputs: Python:
(2, 10, 5)
(20, 5)
Java:
[(2, 10, 5)]
(2, 10, 5)
How to Reproduce?
Python code:
import numpy as np
arr = np.arange(100).reshape((2, 10, 5))
print(arr.shape)
print(np.concatenate(arr).shape)
Java code:
NDManager manager = NDManager.newBaseManager();
NDList arr = new NDList(manager.arange(100).reshape(new Shape(2, 10, 5)));
System.out.println(Arrays.toString(arr.getShapes()));
System.out.println(NDArrays.concat(arr).getShape());
Steps to reproduce
- Run the Python code
- Run the Java code
What have you tried to solve it?
I did not find another method in Java do get the same result as in Python.
Environment Info
djl 0.26.0 pytorch 1.13.1