djl icon indicating copy to clipboard operation
djl copied to clipboard

Segfault from Dataset.prepare() on MPS device

Open Ubehebe opened this issue 2 years ago • 1 comments

Description

I am trying to combine the MNIST tutorial with the directions in #2037 to train a model on the MPS pytorch backend. I have gotten the MPSTest in that PR to pass on my M2 machine: it successfully creates an NDArray on the MPS device.

Expected Behavior

I expected the MNIST dataset to prepare successfully on the MPS device, so that I could create a model from it. (Please let me know if this is unreasonable, or if I have to prepare the dataset on the CPU and then use some other API to load it onto MPS.)

Error Message

#
# A fatal error has been detected by the Java Runtime Environment:
#
#  SIGSEGV (0xb) at pc=0x0000000197d43f00, pid=94738, tid=8707
#
# JRE version: OpenJDK Runtime Environment Zulu17.32+13-CA (17.0.2+8) (build 17.0.2+8-LTS)
# Java VM: OpenJDK 64-Bit Server VM Zulu17.32+13-CA (17.0.2+8-LTS, mixed mode, sharing, tiered, compressed oops, compressed class ptrs, g1 gc, bsd-aarch64)
# Problematic frame:
# C  [libobjc.A.dylib+0x7f00]  objc_retain+0x10
#
# No core dump will be written. Core dumps have been disabled. To enable core dumping, try "ulimit -c unlimited" before starting Java again
#
# If you would like to submit a bug report, please visit:
#   http://www.azul.com/support/
# The crash happened outside the Java Virtual Machine in native code.
# See problematic frame for where to report the bug.
#

---------------  S U M M A R Y ------------

Command Line: ml.djl.DjlSampleKt

Host: "Mac14,5" arm64 1 MHz, 12 cores, 64G, Darwin 22.4.0, macOS 13.3 (22E252)
Time: Wed Apr  5 07:23:05 2023 PDT elapsed time: 0.326062 seconds (0d 0h 0m 0s)

---------------  T H R E A D  ---------------

Current thread (0x000000014500dc00):  JavaThread "main" [_thread_in_native, id=8707, stack(0x000000016b62c000,0x000000016b82f000)]

Stack: [0x000000016b62c000,0x000000016b82f000],  sp=0x000000016b82d7a0,  free space=2053k
Native frames: (J=compiled Java code, j=interpreted, Vv=VM code, C=native code)
C  [libobjc.A.dylib+0x7f00]  objc_retain+0x10
C  [libtorch_cpu.dylib+0x4458830]  at::native::mps::copy_cast_mps(at::Tensor&, at::Tensor const&, id<MTLBuffer>, objc_object<MTLBuffer>, bool)+0x2ec
C  [libtorch_cpu.dylib+0x445a958]  at::native::mps::mps_copy_(at::Tensor&, at::Tensor const&, bool)+0x1e10
C  [libtorch_cpu.dylib+0x49b6d8]  at::native::copy_impl(at::Tensor&, at::Tensor const&, bool)+0x5cc
C  [libtorch_cpu.dylib+0x49b04c]  at::native::copy_(at::Tensor&, at::Tensor const&, bool)+0x64
C  [libtorch_cpu.dylib+0x10d3f14]  at::_ops::copy_::call(at::Tensor&, at::Tensor const&, bool)+0x120
C  [libtorch_cpu.dylib+0x7c0f0c]  at::native::_to_copy(at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>)+0xbd8
C  [libtorch_cpu.dylib+0xc64120]  at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>)+0xbc
C  [libtorch_cpu.dylib+0xc64120]  at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>)+0xbc
C  [libtorch_cpu.dylib+0x280348c]  c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::DispatchKeySet, at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>), &torch::autograd::VariableType::(anonymous namespace)::_to_copy(c10::DispatchKeySet, at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>)>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>>>, at::Tensor (c10::DispatchKeySet, at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>)+0x448
C  [libtorch_cpu.dylib+0xc63de8]  at::_ops::_to_copy::call(at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>)+0x154
C  [libtorch_cpu.dylib+0xe1bed0]  at::_ops::to_device::call(at::Tensor const&, c10::Device, c10::ScalarType, bool, bool, c10::optional<c10::MemoryFormat>)+0x140
C  [0.21.0-libdjl_torch.dylib+0x78974]  at::Tensor::to(c10::Device, c10::ScalarType, bool, bool, c10::optional<c10::MemoryFormat>) const+0x8c
C  [0.21.0-libdjl_torch.dylib+0x786c4]  Java_ai_djl_pytorch_jni_PyTorchLibrary_torchTo+0xb4
j  ai.djl.pytorch.jni.PyTorchLibrary.torchTo(JI[I)J+0
j  ai.djl.pytorch.jni.JniUtils.to(Lai/djl/pytorch/engine/PtNDArray;Lai/djl/ndarray/types/DataType;Lai/djl/Device;)Lai/djl/pytorch/engine/PtNDArray;+61
j  ai.djl.pytorch.engine.PtNDArray.toType(Lai/djl/ndarray/types/DataType;Z)Lai/djl/pytorch/engine/PtNDArray;+23
j  ai.djl.pytorch.engine.PtNDArray.toType(Lai/djl/ndarray/types/DataType;Z)Lai/djl/ndarray/NDArray;+3
j  ai.djl.basicdataset.cv.classification.Mnist.readLabel(Lai/djl/repository/Artifact$Item;)Lai/djl/ndarray/NDArray;+88
j  ai.djl.basicdataset.cv.classification.Mnist.prepare(Lai/djl/util/Progress;)V+146
j  ai.djl.training.dataset.Dataset.prepare()V+2
j  ml.djl.DjlSampleKt.main([Ljava/lang/String;)V+36
v  ~StubRoutines::call_stub
V  [libjvm.dylib+0x46b270]  JavaCalls::call_helper(JavaValue*, methodHandle const&, JavaCallArguments*, JavaThread*)+0x38c
V  [libjvm.dylib+0x4cfa64]  jni_invoke_static(JNIEnv_*, JavaValue*, _jobject*, JNICallType, _jmethodID*, JNI_ArgumentPusher*, JavaThread*)+0x12c
V  [libjvm.dylib+0x4d30f8]  jni_CallStaticVoidMethod+0x130
C  [libjli.dylib+0x5378]  JavaMain+0x9d4
C  [libjli.dylib+0x76e8]  ThreadJavaMain+0xc
C  [libsystem_pthread.dylib+0x6fa8]  _pthread_start+0x94

Java frames: (J=compiled Java code, j=interpreted, Vv=VM code)
j  ai.djl.pytorch.jni.PyTorchLibrary.torchTo(JI[I)J+0
j  ai.djl.pytorch.jni.JniUtils.to(Lai/djl/pytorch/engine/PtNDArray;Lai/djl/ndarray/types/DataType;Lai/djl/Device;)Lai/djl/pytorch/engine/PtNDArray;+61
j  ai.djl.pytorch.engine.PtNDArray.toType(Lai/djl/ndarray/types/DataType;Z)Lai/djl/pytorch/engine/PtNDArray;+23
j  ai.djl.pytorch.engine.PtNDArray.toType(Lai/djl/ndarray/types/DataType;Z)Lai/djl/ndarray/NDArray;+3
j  ai.djl.basicdataset.cv.classification.Mnist.readLabel(Lai/djl/repository/Artifact$Item;)Lai/djl/ndarray/NDArray;+88
j  ai.djl.basicdataset.cv.classification.Mnist.prepare(Lai/djl/util/Progress;)V+146
j  ai.djl.training.dataset.Dataset.prepare()V+2
j  ml.djl.DjlSampleKt.main([Ljava/lang/String;)V+36
v  ~StubRoutines::call_stub

The top of the JVM stack at the moment of the segfault is this line, where the dataset is casting the NDArray to float32:

try (NDArray array = manager.create(new Shape(buf.length), DataType.UINT8)) {
    array.set(buf);
    return array.toType(DataType.FLOAT32, false);
}

The top of the native stack is objc_retain, followed by copy_cast_mps. (I can provide more details from the JVM crash report if that would be valuable.)

How to Reproduce?

Here is a minimal repro. It's in Kotlin, but I'm happy to rewrite in Java if you prefer:

import ai.djl.Device
import ai.djl.basicdataset.cv.classification.Mnist
import ai.djl.ndarray.NDManager

fun main(args: Array<String>) {
  val device = Device.fromName("mps")
  val manager = NDManager.newBaseManager(device)
  Mnist.builder()
      .optManager(manager)
      .setSampling(32 /* batchSize */, true /* random */)
      .build()
      .prepare() // segfault!
}

Steps to reproduce

Run the program above, with the pytorch-engine, pytorch-jni, and pytorch-native-cpu-osx-aarch64 jars on the runtime classpath.

What have you tried to solve it?

Nothing, beyond isolating this repro. (I'm new to ML in general, but have a lot of JVM experience.)

Environment Info

I'm using DJL v0.21.0 and pytorch-jni v1.13.1. System information:

$ uname -mrsv
Darwin 22.4.0 Darwin Kernel Version 22.4.0: Mon Mar  6 20:59:58 PST 2023; root:xnu-8796.101.5~3/RELEASE_ARM64_T6020 arm64

Ubehebe avatar Apr 05 '23 15:04 Ubehebe

MPS has many limitations, and we do observe crash when using MPS device. See: https://github.com/deepjavalibrary/djl/pull/2044

frankfliu avatar Apr 05 '23 19:04 frankfliu