djl icon indicating copy to clipboard operation
djl copied to clipboard

NDArrays.concat has a different behaviour as np.concatenate

Open vm3538 opened this issue 1 year ago • 0 comments

Description

It looks like concatenating lists doesn't work as in Python.

Expected Behavior

I was expecting to get an array with the same shape as in Python.

Error Message

No errors, but here are the outputs: Python:

(2, 10, 5)
(20, 5)

Java:

[(2, 10, 5)]
(2, 10, 5)

How to Reproduce?

Python code:

import numpy as np

arr = np.arange(100).reshape((2, 10, 5))

print(arr.shape)
print(np.concatenate(arr).shape)

Java code:

NDManager manager = NDManager.newBaseManager();
NDList arr = new NDList(manager.arange(100).reshape(new Shape(2, 10, 5)));
System.out.println(Arrays.toString(arr.getShapes()));
System.out.println(NDArrays.concat(arr).getShape());

Steps to reproduce

  1. Run the Python code
  2. Run the Java code

What have you tried to solve it?

I did not find another method in Java do get the same result as in Python.

Environment Info

djl 0.26.0 pytorch 1.13.1

vm3538 avatar Feb 21 '24 15:02 vm3538