djl
djl copied to clipboard
[pytorch] optimize memory copy cost for pytorch NDArray
Description
Now the public double[] toDoubleArray() method in pytorch Ndarray needs to copy the contents of the local memory to the heap memory, and then copy it from the heap memory to the array object. There is additional copy overhead in this process, and data can be copied directly from local memory to the object.
Thank you for your response. What you mean is to set ai.djl.pytorch.engine.PtNDArray#toByteBuffer to return a DirectBuffer. I had considered this idea initially, but the current framework defaults to it being a non-DirectBuffer, and directly modifying it would cause issues. For example, in the function ai.djl.ndarray.NDArray#copyTo,
default void copyTo(NDArray array) {
array.set(toByteBuffer());
}
which calls ai.djl.pytorch.engine.PtNDArray#toByteBuffer to convert to Buffer, it defaults to this being non-direct memory. Therefore, in ai.djl.pytorch.engine.PtNDArray#set:
@Override
public void set(Buffer buffer) {
int size = Math.toIntExact(size());
DataType type = getDataType();
BaseNDManager.validateBuffer(buffer, type, size);
// TODO how do we handle the exception happened in the middle
dataRef = null;
if (buffer.isDirect() && buffer instanceof ByteBuffer) {
// If NDArray is on the GPU, it is native code responsibility to control the data life cycle
if (!getDevice().isGpu()) {
dataRef = (ByteBuffer) buffer;
}
JniUtils.set(this, (ByteBuffer) buffer);
return;
}
// int8, uint8, boolean use ByteBuffer, so need to explicitly input DataType
ByteBuffer buf = manager.allocateDirect(size * type.getNumOfBytes());
BaseNDManager.copyBuffer(buffer, buf);
// If NDArray is on the GPU, it is native code responsibility to control the data life cycle
if (!getDevice().isGpu()) {
dataRef = buf;
}
JniUtils.set(this, buf);
}
By determining it is not direct memory, it can create new direct memory and perform data copying, thus achieving the purpose of deep copying. If ai.djl.pytorch.engine.PtNDArray#toByteBuffer returns direct memory, it would cause them to share the same memory. Ideally, an toDirectByteBuffer should be implemented in ai.djl.ndarray.NDArray, so that directBuffer and nonDirectBuffer can be distinguished. However, this would involve modifications to multiple engines, such as the Onnxruntime engine. Although Onnxruntime's underlying layer provides ai.onnxruntime.OnnxTensor#getBuffer() for obtaining direct memory, this is a private method and cannot be directly used. Therefore, at this stage, I have only adjusted for Pytorch.
Regarding the byte order issue you mentioned, I observed that in the ONNX Runtime engine, the byte order is set at the Java level:
/*
* Class: ai_onnxruntime_OnnxTensor
* Method: getBuffer
* Signature: (JJ)Ljava/nio/ByteBuffer;
*/
JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxTensor_getBuffer
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
(void) jobj; // Required JNI parameter not needed by functions that don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtValue* ortValue = (OrtValue *) handle;
JavaTensorTypeShape typeShape;
OrtErrorCode code = getTensorTypeShape(jniEnv, &typeShape, api, ortValue);
if (code == ORT_OK) {
size_t typeSize = onnxTypeSize(typeShape.onnxTypeEnum);
size_t sizeBytes = typeShape.elementCount * typeSize;
uint8_t* arr = NULL;
code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr));
if (code == ORT_OK) {
return (*jniEnv)->NewDirectByteBuffer(jniEnv, arr, (jlong)sizeBytes);
}
}
return NULL;
}
private ByteBuffer getBuffer() {
return getBuffer(OnnxRuntime.ortApiHandle, nativeHandle).order(ByteOrder.nativeOrder());
}
private native ByteBuffer getBuffer(long apiHandle, long nativeHandle);
It sets the byte order at the Java level, following this approach, so I also set it at the Java level:
public static ByteBuffer getDirectByteBuffer(PtNDArray ndArray) {
// Operation is CPU only
if (!ndArray.getDevice().equals(Device.cpu())) {
ndArray = ndArray.toDevice(Device.cpu(), false);
}
return PyTorchLibrary.LIB.torchDirectByteBuffer(ndArray.getHandle())
.order(ByteOrder.nativeOrder());
}
Moreover, I am not aware of whether there is an API that can set the byte order at the time of NewDirectByteBuffer.
@ewan0x79
I tried to fix integration test failure, but I cannot make it. I don't think this solution can really work. When GC kick in and deleted DirectBuffer, in certain cases, the native memory get trashed. I can cause multiple data corruption.
@frankfliu
You're right, there was indeed a problem with my previous approach.
JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchDirectByteBuffer(JNIEnv* env, jobject jthis, jlong jhandle)
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
// sparse and mkldnn are required to be converted to dense to access data ptr
auto tensor = (tensor_ptr->is_sparse() || tensor_ptr->is_mkldnn()) ? tensor_ptr->to_dense() : *tensor_ptr;
tensor = (tensor.is_contiguous()) ? tensor : tensor.contiguous();
size_t nbytes = tensor.nbytes();
if (nbytes > 0x7fffffff) {
env->ThrowNew(ENGINE_EXCEPTION_CLASS, "toDirectByteBuffer() is not supported for large tensor");
return nullptr;
}
// Use tensor.data_ptr() to get the data pointer and create a direct ByteBuffer with NewDirectByteBuffer
void* data_ptr = tensor.data_ptr();
jobject directBuffer = env->NewDirectByteBuffer(data_ptr, nbytes);
return directBuffer;
API_END_RETURN()
Previously, I overlooked the fact that tensor.contiguous() returns a new tensor.
When a new tensor is returned, the local variable tensor on the stack points to a new tensor in the local memory.
The intention of using jobject directBuffer = env->NewDirectByteBuffer(data_ptr, nbytes);
is to allow ByteBuffer to hold the direct memory address of the tensor data.
However, after the local method ends, the local variables on the stack are recycled,
causing the new tensor in the local memory to be recycled as well (this part is handled by C++),
leading to ByteBuffer pointing to an invalid address. Therefore, calling this method might cause issues.
Now, I have made some modifications. Here's the updated code snippet:
public double[] toDoubleArray() {
if (getDataType() != DataType.FLOAT64) {
throw new IllegalStateException(
"DataType mismatch, Required double" + " Actual " + getDataType());
}
if (isSparse() || JniUtils.getLayout(this) == 2 || !isContiguous()) {
try (final PtNDArray ptNDArray = toContiguous()) {
return toDoubleArray(ptNDArray);
}
} else {
return toDoubleArray(this);
}
}
We need to determine whether the tensor is contiguous. If it is contiguous, we can directly use ByteBuffer to hold the local memory address of the data, just like before. If it is not contiguous, we need to construct a new tensor (moving the original non-contiguous data to contiguous memory). ByteBuffer can hold the local memory address of this new tensor, which will be destroyed after the data is copied from local memory to the array.
@frankfliu
Your modifications still have certain issues.
- Firstly, in the function
JNIEXPORT jobject JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchDirectByteBuffer,
void* data_ptr = tensor.data_ptr();
if (tensor_ptr->is_sparse() || tensor_ptr->is_mkldnn() || !tensor_ptr->is_contiguous()) {
// We have to make a copy anyway
void* data = new char[nbytes];
data_ptr = std::memcpy(data, data_ptr, nbytes);
}
return env->NewDirectByteBuffer(data_ptr, nbytes);
you use the following method to copy non-contiguous memory to a new byte array, and then let ByteBuffer hold the direct memory address of data. However, the JVM will not release the native memory space held by ByteBuffer when it releases ByteBuffer, which will lead to native memory leakage.
- Secondly, letting
getByteBufferreturn direct memory is temporarily not a problem, but if similar modifications are applied to other engines in the future, issues may arise. For example, in the ONNXRUNTIME engine, the methodai.djl.ndarray.NDArrayAdapter#getAlternativeArrayassumes the existence of both PyTorch and ONNXruntime engines.
private NDArray getAlternativeArray() {
if (alternativeManager == null) {
throw new UnsupportedOperationException(UNSUPPORTED_MSG);
}
if (alternativeArray == null) {
alternativeArray = alternativeManager.from(this);
} else {
alternativeArray.set(getDataType().asDataType(toByteBuffer()));
}
return alternativeArray;
}
In this case, it will look for the PyTorch engine and execute alternativeArray = alternativeManager.from(this);. In PyTorch's ai.djl.pytorch.engine.PtNDManager#from,
@Override
public PtNDArray from(NDArray array) {
if (array == null || array instanceof PtNDArray) {
return (PtNDArray) array;
}
PtNDArray result = create(array.toByteBuffer(), array.getShape(), array.getDataType());
result.setName(array.getName());
return result;
}
@Override
public PtNDArray create(Buffer data, Shape shape, DataType dataType) {
int size = Math.toIntExact(shape.size());
BaseNDManager.validateBuffer(data, dataType, size);
if (data.isDirect() && data instanceof ByteBuffer) {
return JniUtils.createNdFromByteBuffer(
this, (ByteBuffer) data, shape, dataType, SparseFormat.DENSE, device);
}
ByteBuffer buf = allocateDirect(size * dataType.getNumOfBytes());
copyBuffer(data, buf);
return JniUtils.createNdFromByteBuffer(
this, buf, shape, dataType, SparseFormat.DENSE, device);
}
it is determined whether the Buffer is direct memory. If ONNX also returns direct memory through getByteBuffer, it might cause two NDARRAYs to share the same memory space. I am not sure if their underlying formats are the same (operations on a tensor might change the underlying data, for example, one is released while the other is not; one changes the underlying data format, but the other does not know). This seems to pose a significant risk. Of course, this could also be implemented by overriding ai.djl.pytorch.engine.PtNDManager#create(java.nio.Buffer, ai.djl.ndarray.types.Shape, ai.djl.ndarray.types.DataType), but it might not be easy.
@ewan0x79
You are right, my changes will cause memory leak. Will revert my part.
Codecov Report
Attention: Patch coverage is 79.16667% with 5 lines in your changes are missing coverage. Please review.
Project coverage is 68.43%. Comparing base (
6efe660) to head (9c9592a). Report is 220 commits behind head on master.
| Files | Patch % | Lines |
|---|---|---|
| api/src/main/java/ai/djl/ndarray/NDArray.java | 72.72% | 3 Missing :warning: |
| ...ine/src/main/java/ai/djl/pytorch/jni/JniUtils.java | 80.00% | 0 Missing and 2 partials :warning: |
:exclamation: Your organization needs to install the Codecov GitHub app to enable full functionality.
Additional details and impacted files
@@ Coverage Diff @@
## master #3137 +/- ##
============================================
- Coverage 71.03% 68.43% -2.60%
+ Complexity 7199 7031 -168
============================================
Files 694 697 +3
Lines 32614 32765 +151
Branches 3374 3409 +35
============================================
- Hits 23166 22423 -743
- Misses 7842 8732 +890
- Partials 1606 1610 +4
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.