mlx-examples icon indicating copy to clipboard operation
mlx-examples copied to clipboard

Segmentation fault during training

Open JKwon0331 opened this issue 1 year ago • 22 comments

Hi,

I have difficulty with training the model. I always meet the segmentation fault error. It occurs randomly, I mean random epoch. For example, it occurred in epoch 65 in the below picture. Screenshot 2024-03-30 at 8 23 14 PM

Sometimes, it occurred in epoch 99 or 104, etc. I know, it is hard to figure out the reason with this short information. However, could you let me know what can I suspect the reason?

JKwon0331 avatar Mar 31 '24 01:03 JKwon0331

Is it one of the examples in this repo? Can you share anything else? There shouldn't be a segfault from MLX, so if you are getting one from MLX it is a bug. But without anything more information it's almost impossible for us to debug it, so anything you can share is appreciated.

awni avatar Mar 31 '24 01:03 awni

Also useful if you can share MLX version and platform (OS, machine, etc)

awni avatar Mar 31 '24 01:03 awni

Hi, I'm so sorry for the late response. I uploaded the example codes in my github. https://github.com/JKwon0331/mlx_test/tree/main

In mlx, you can see the example codes. When I run train.py. Screenshot 2024-03-31 at 1 54 01 AM

I'm trying to implement Squeezeformer and refer to the code from https://github.com/upskyy/Squeezeformer.

As I know, there is no depthwise convolution, so I tried to implement it as I could.

Also, regarding with #625, You might be able to see when you uncomment the lines 125 and 133 and comment the lines 126 and 134 of model.py .

Finally, you can also see the Pytorch version in my github, which was a simplified version from https://github.com/upskyy/Squeezeformer.

You can also run train.py.

As you can see, the training speed in pytorch (about 8 s per epoch) is quite faster than that in mlx (about 12 s per epoch). Screenshot 2024-03-31 at 1 58 56 AM

I tried to compare the pytorch depthwise convolution and my depthwise convolution implemented for mlx, there was no difference in training speed.

I am not sure that I implemented the code correctly. However, I hope this could be helpful to fix the debug and improve the speed.

If you have any questions, please feel free to let me know.

JKwon0331 avatar Mar 31 '24 07:03 JKwon0331

Oh, I use the mlx 0.9.0 and 16-inch M2 Max MacBook pro

JKwon0331 avatar Mar 31 '24 07:03 JKwon0331

Thanks for the code that's great. I'm running it now. It's at epoch 60 so far.. no segfault yet, let's see

awni avatar Mar 31 '24 13:03 awni

As you can see, the training speed in pytorch (about 8 s per epoch) is quite faster than that in mlx (about 12 s per epoch).

Indeed I'm not too surprised by that for this model. Training conv nets needs some optimization in MLX and also some of the features you are using will be a bit slow (like the depthwise conv, RNN, etc). We need some time to add these features and optimize more. But this is a great benchmark to have, thanks!

awni avatar Mar 31 '24 13:03 awni

I ran it for hundreds of epochs on both my M1 Max and an M2 Ultra and was not able to get a segfault. It may have been something we fixed in the latest MLX 🤔 . If you still see the segfault after the next release, please let us know.

awni avatar Apr 02 '24 17:04 awni

I ran it for hundreds of epochs on both my M1 Max and an M2 Ultra and was not able to get a segfault. It may have been something we fixed in the latest MLX 🤔 . If you still see the segfault after the next release, please let us know.

Hi, for me it still occurs. Screenshot 2024-04-02 at 6 20 52 PM

hm... do you use the external monitor?

I have no idea...

JKwon0331 avatar Apr 02 '24 23:04 JKwon0331

@JKwon0331 we aren't able to reproduce the segfault. Could you share a bit more information:

  1. Operating system version
  2. Output of python -c "import mlx.core as mx; print(mx.__version__)"

awni avatar Apr 03 '24 14:04 awni

@awni Here they are.

  1. Mac Sonoma 14.4
  2. 0.9.0

JKwon0331 avatar Apr 03 '24 19:04 JKwon0331

@awni

Hi, I tried it several times. Sometimes, it doesn't happen for hundreds of epochs. However, sometimes, it happens in just a few epochs, as shown in the figure below.

Screenshot 2024-04-03 at 11 55 32 PM

JKwon0331 avatar Apr 04 '24 04:04 JKwon0331

@JKwon0331 I let it run overnight for several hundreds of epochs and didn't encounter it. Would it be possible to run it with the debugger attached so we can maybe get an idea for where it segfaults?

Since we can't repro even after several hours you will have to do some of the digging unfortunately. Step 1 would be to just run it with the debugger attached as is. Step 2 which may be a bit of a pain would be to compile mlx in debug mode and run t with the debugger attached. If you can do step 2 and it segfaults then we will know exactly where it happened in the code.

Let me know if you need help with either of the above.

angeloskath avatar Apr 04 '24 05:04 angeloskath

A tutorial for lldb (LLVM debugger) can be found at https://lldb.llvm.org/use/tutorial.html .

However the simplest way to attach it would be the following steps

  • Get the PID of the training process, one way would be import os; print(os.getpid())
  • In another terminal run sudo lldb -p <PID printed in the training run>
  • c + enter in the debugger to continue the training
  • Wait until it segfaults
  • bt all in the debugger to print all backtraces

Just in case, it might be simpler to try from a new python environment first. Start a brand new environment, install the latest MLX and try again to see if it segfaults.

angeloskath avatar Apr 04 '24 16:04 angeloskath

A tutorial for lldb (LLVM debugger) can be found at https://lldb.llvm.org/use/tutorial.html .

However the simplest way to attach it would be the following steps

  • Get the PID of the training process, one way would be import os; print(os.getpid())
  • In another terminal run sudo lldb -p <PID printed in the training run>
  • c + enter in the debugger to continue the training
  • Wait until it segfaults
  • bt all in the debugger to print all backtraces

Just in case, it might be simpler to try from a new python environment first. Start a brand new environment, install the latest MLX and try again to see if it segfaults.

Hi, I'm so sorry for the late response.

According to your recommendation, I tried a new Python environment. Python 3.11.8 and mlx 0.9.1 Unfortunately, there was another issue, the bus error.

Attached figures are the screenshots following the steps you mentioned. Please let me know if there are other things I need to do.

Screenshot 2024-04-07 at 10 41 51 PM Screenshot 2024-04-07 at 10 40 56 PM Screenshot 2024-04-07 at 10 41 09 PM

JKwon0331 avatar Apr 08 '24 03:04 JKwon0331

Attached is for segmentation fault.

Screenshot 2024-04-07 at 11 57 28 PM Screenshot 2024-04-07 at 11 57 44 PM Screenshot 2024-04-07 at 11 57 53 PM

JKwon0331 avatar Apr 08 '24 04:04 JKwon0331

Thanks that is very helpful!

There seems to be an issue in the function collapse_contiguous_dims which is used internally to route to a better kernel if possible. It is still very weird that this only happens on your machine but we are closer to figuring it out. It is great that it happens from a reshape because it is the simplest way to call collapse_contiguous_dims so I have some chance to brute-force replicate it.

I will look into it. Not tonight but possibly tomorrow.

angeloskath avatar Apr 09 '24 05:04 angeloskath

Sorry for the delayed response. I am trying to reproduce the bug locally even though we know where it happens and no luck yet. It happens from a reshape which calls copy so I wrote a small fuzzer and run several 100s of thousands of different reshapes including transpositions, broadcasts, strided slices etc but nothing breaks unfortunately (or fortunately :-) ).

This means I will have to trouble you a bit more since it seems to be something maybe only on your machine. If you can build from source in debug mode

$ cd mlx/
$ CMAKE_ARGS='-DCMAKE_BUILD_TYPE=Debug' pip install .

After that the backtraces will include the line of code where the problem is encountered. You can also just run the fuzzer https://gist.github.com/angeloskath/a3dc38b030c080ae5e4135f0125a94b2 to see if this causes the error on your setup.

angeloskath avatar Apr 12 '24 21:04 angeloskath

Hi, I am so sorry about this but, I am not familiar with debug, could you let me know more details?

You mean,

  1. Get the PID of the training process, one way would be import os; print(os.getpid())
  2. In another terminal run sudo lldb -p <PID printed in the training run>
  3. c + enter in the debugger to continue the training
  4. Wait until it segfaults
  5. bt all in the debugger to print all backtraces

Before start 3, I can build from this source? $ cd mlx/ $ CMAKE_ARGS='-DCMAKE_BUILD_TYPE=Debug' pip install

Actually, cd mlx/ does not work for me, with no such file or directory: mlx.

Thank you

JKwon0331 avatar Apr 16 '24 02:04 JKwon0331

When I run the fuzzer, there was a no error.

JKwon0331 avatar Apr 16 '24 15:04 JKwon0331

Sorry, cd mlx/ was meant to be cd your/local/path/to/mlx/source.

angeloskath avatar Apr 16 '24 15:04 angeloskath

I am sorry.. I can not find that directory. I tried to Go to the Folder in Finder, but still, I have no idea what you meant.

JKwon0331 avatar Apr 17 '24 04:04 JKwon0331

I think that he assumed that you had cloned the GitHub source and installed from there, with pip install .

jrp2014 avatar May 03 '24 22:05 jrp2014

Hi @angeloskath, @awni pointed me to this issue as I was having a similar seg fault issue. My seg fault is also intermittent and difficult to reproduce but it does happen. I could not reproduce the faults mentioned so far in this issue, but below is a stack trace from a seg fault that I am seeing when my models are run with real data. My models use a combination of convolution, attention and recurrent layers which is similar to the examples shown.

Hopefully the following stack trace will help. We can see that the array strides does not have uniform shape causing out of bounds issue within the collapse_contiguous_dims function.

My machine is a 16in MBP with M2 Max and 96gb ram. MacOS: 14.5 Python: 3.11.9 MLX: 0.15.2.dev20240630+20bb301

* thread #33, stop reason = EXC_BAD_ACCESS (code=1, address=0x346700008)
    frame #0: 0x00000001044543ec libmlx.dylib`std::__1::tuple<std::__1::vector<int, std::__1::allocator<int>>, std::__1::vector<std::__1::vector<unsigned long, std::__1::allocator<unsigned long>>, std::__1::allocator<std::__1::vector<unsigned long, std::__1::allocator<unsigned long>>>>> mlx::core::collapse_contiguous_dims<unsigned long>(shape=size=4, strides=size=2) at utils.h:51:13
   48       for (int i = 1; i < shape.size(); i++) {
   49         bool contiguous = true;
   50         for (const std::vector<stride_t>& st : strides) {
-> 51           if (st[i] * shape[i] != st[i - 1]) {
   52             contiguous = false;
   53           }
   54           if (!contiguous) {
Target 0: (python) stopped.

Frame variables:

(lldb) frame variable
(const std::vector<int> &) shape = size=4: {
  [0] = 128
  [1] = 194
  [2] = 3
  [3] = 8
}
(const std::vector<std::vector<unsigned long> >) strides = size=2 {
  [0] = size=4 {
    [0] = 1568
    [1] = 8
    [2] = 8
    [3] = 1
  }
  [1] = size=2 {
    [0] = 24
    [1] = 1
  }
}
(std::vector<int>) to_collapse = size=5 {
  [0] = 0
  [1] = -1
  [2] = 1
  [3] = -1
  [4] = 2
}
(std::vector<int>) out_shape = size=0 {}
(std::vector<std::vector<unsigned long> >) out_strides = size=0 {}
(int) i = 3
(bool) contiguous = true
(const std::vector<unsigned long> &) st = size=2: {
  [0] = 24
  [1] = 1
}

Full stack trace:

(lldb) bt all
  thread #1, queue = 'com.apple.main-thread'
    frame #0: 0x000000019386af04 IOKit`iokit_user_client_trap + 8
    frame #1: 0x000000019a3bbdb8 IOSurface`-[IOSurfaceSharedEvent waitUntilSignaledValue:timeoutMS:] + 72
    frame #2: 0x0000000104aa546c libmlx.dylib`mlx::core::Event::wait() at NSObject.hpp:216:16
    frame #3: 0x0000000104aa5454 libmlx.dylib`mlx::core::Event::wait() [inlined] MTL::SharedEvent::waitUntilSignaledValue(this=0x0000000484cc9060, value=647, milliseconds=18446744073709551615) at MTLEvent.hpp:137:12
    frame #4: 0x0000000104aa5430 libmlx.dylib`mlx::core::Event::wait(this=0x0000000300c70310) at event.cpp:21:14
    frame #5: 0x0000000104165fe8 libmlx.dylib`mlx::core::eval(outputs=size=0) at transforms.cpp:187:48
    frame #6: 0x00000001014140cc core.cpython-311-darwin.so`init_transforms(nanobind::module_&)::$_4::operator()(this=0x000000010089dd18, args=0x000000016fdfdf38) const at transforms.cpp:603:11
    frame #7: 0x0000000101413fe8 core.cpython-311-darwin.so`_object* nanobind::detail::func_create<false, true, init_transforms(nanobind::module_&)::$_4, void, nanobind::args const&, 0ul, nanobind::scope, nanobind::name, nanobind::arg, nanobind::sig, char [365]>(init_transforms(nanobind::module_&)::$_4&&, void (*)(nanobind::args const&), std::__1::integer_sequence<unsigned long, 0ul>, nanobind::scope const&, nanobind::name const&, nanobind::arg const&, nanobind::sig const&, char const (&) [365])::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::__invoke(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*) [inlined] _object* nanobind::detail::func_create<false, true, init_transforms(nanobind::module_&)::$_4, void, nanobind::args const&, 0ul, nanobind::scope, nanobind::name, nanobind::arg, nanobind::sig, char [365]>(this=0x000000016fdfdee7, p=0x000000010089dd18, args=0x000000016fdfdfe0, args_flags="", policy=automatic, cleanup=0x000000016fdfe1c0)::$_4&&, void (*)(nanobind::args const&), std::__1::integer_sequence<unsigned long, 0ul>, nanobind::scope const&, nanobind::name const&, nanobind::arg const&, nanobind::sig const&, char const (&) [365])::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::operator()(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*) const at nb_func.h:202:13
    frame #8: 0x0000000101413ef8 core.cpython-311-darwin.so`_object* nanobind::detail::func_create<false, true, init_transforms(nanobind::module_&)::$_4, void, nanobind::args const&, 0ul, nanobind::scope, nanobind::name, nanobind::arg, nanobind::sig, char [365]>(init_transforms(nanobind::module_&)::$_4&&, void (*)(nanobind::args const&), std::__1::integer_sequence<unsigned long, 0ul>, nanobind::scope const&, nanobind::name const&, nanobind::arg const&, nanobind::sig const&, char const (&) [365])::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::__invoke(p=0x000000010089dd18, args=0x000000016fdfdfe0, args_flags="", policy=automatic, cleanup=0x000000016fdfe1c0) at nb_func.h:177:14
    frame #9: 0x0000000101475dc8 core.cpython-311-darwin.so`nanobind::detail::nb_func_vectorcall_complex(self=0x000000010089dcf0, args_in=0x00000001007002e8, nargsf=9223372036854775809, kwargs_in=0x0000000000000000) at nb_func.cpp:719:26
    frame #10: 0x00000001001b54d8 python`_PyEval_EvalFrameDefault + 192276
    frame #11: 0x0000000100183cf8 python`_PyEval_Vector + 464
    frame #12: 0x0000000100219c18 python`run_mod + 276
    frame #13: 0x0000000100219988 python`pyrun_file + 148
    frame #14: 0x00000001002193b0 python`_PyRun_SimpleFileObject + 268
    frame #15: 0x0000000100218d28 python`_PyRun_AnyFileObject + 216
    frame #16: 0x000000010023cd88 python`pymain_run_file_obj + 260
    frame #17: 0x000000010023c730 python`pymain_run_file + 72
    frame #18: 0x000000010023bfe4 python`Py_RunMain + 1552
    frame #19: 0x0000000100005864 python`main + 56
    frame #20: 0x000000018fd020e0 dyld`start + 2360
  thread #26
    frame #0: 0x0000000190052b70 libsystem_kernel.dylib`poll + 8
    frame #1: 0x0000000103a24e10 _socket.cpython-311-darwin.so`internal_select + 116
    frame #2: 0x0000000103a24c00 _socket.cpython-311-darwin.so`sock_call_ex + 164
    frame #3: 0x0000000103a25768 _socket.cpython-311-darwin.so`sock_recv_guts + 72
    frame #4: 0x0000000103a23054 _socket.cpython-311-darwin.so`sock_recv + 124
    frame #5: 0x00000001000815c8 python`method_vectorcall_VARARGS + 352
    frame #6: 0x00000001001b54d8 python`_PyEval_EvalFrameDefault + 192276
    frame #7: 0x000000010006f700 python`_PyFunction_Vectorcall + 476
    frame #8: 0x000000010007513c python`method_vectorcall + 364
    frame #9: 0x00000001001b9b5c python`_PyEval_EvalFrameDefault + 210328
    frame #10: 0x000000010006f700 python`_PyFunction_Vectorcall + 476
    frame #11: 0x000000010007513c python`method_vectorcall + 364
    frame #12: 0x00000001002a087c python`thread_run + 228
    frame #13: 0x000000010022b3c4 python`pythread_wrapper + 48
    frame #14: 0x000000019008af94 libsystem_pthread.dylib`_pthread_start + 136
  thread #27
    frame #0: 0x000000019004d9ec libsystem_kernel.dylib`__psynch_cvwait + 8
    frame #1: 0x000000019008b55c libsystem_pthread.dylib`_pthread_cond_wait + 1228
    frame #2: 0x000000010022b7b8 python`PyThread_acquire_lock_timed + 348
    frame #3: 0x000000010029f814 python`acquire_timed + 212
    frame #4: 0x000000010029f9fc python`lock_PyThread_acquire_lock + 72
    frame #5: 0x00000001000812b4 python`method_vectorcall_VARARGS_KEYWORDS + 144
    frame #6: 0x00000001001b54d8 python`_PyEval_EvalFrameDefault + 192276
    frame #7: 0x000000010006f700 python`_PyFunction_Vectorcall + 476
    frame #8: 0x000000010007513c python`method_vectorcall + 364
    frame #9: 0x00000001001b9b5c python`_PyEval_EvalFrameDefault + 210328
    frame #10: 0x000000010006f700 python`_PyFunction_Vectorcall + 476
    frame #11: 0x000000010007513c python`method_vectorcall + 364
    frame #12: 0x00000001002a087c python`thread_run + 228
    frame #13: 0x000000010022b3c4 python`pythread_wrapper + 48
    frame #14: 0x000000019008af94 libsystem_pthread.dylib`_pthread_start + 136
  thread #28
    frame #0: 0x000000019004d9ec libsystem_kernel.dylib`__psynch_cvwait + 8
    frame #1: 0x000000019008b55c libsystem_pthread.dylib`_pthread_cond_wait + 1228
    frame #2: 0x000000010022b7b8 python`PyThread_acquire_lock_timed + 348
    frame #3: 0x000000010029f814 python`acquire_timed + 212
    frame #4: 0x000000010029f9fc python`lock_PyThread_acquire_lock + 72
    frame #5: 0x00000001000812b4 python`method_vectorcall_VARARGS_KEYWORDS + 144
    frame #6: 0x00000001001b54d8 python`_PyEval_EvalFrameDefault + 192276
    frame #7: 0x000000010006f700 python`_PyFunction_Vectorcall + 476
    frame #8: 0x000000010007513c python`method_vectorcall + 364
    frame #9: 0x00000001001b9b5c python`_PyEval_EvalFrameDefault + 210328
    frame #10: 0x000000010006f700 python`_PyFunction_Vectorcall + 476
    frame #11: 0x000000010007513c python`method_vectorcall + 364
    frame #12: 0x00000001002a087c python`thread_run + 228
    frame #13: 0x000000010022b3c4 python`pythread_wrapper + 48
    frame #14: 0x000000019008af94 libsystem_pthread.dylib`_pthread_start + 136
  thread #29
    frame #0: 0x000000019004d9ec libsystem_kernel.dylib`__psynch_cvwait + 8
    frame #1: 0x000000019008b55c libsystem_pthread.dylib`_pthread_cond_wait + 1228
    frame #2: 0x000000010022b7b8 python`PyThread_acquire_lock_timed + 348
    frame #3: 0x000000010029f814 python`acquire_timed + 212
    frame #4: 0x000000010029f9fc python`lock_PyThread_acquire_lock + 72
    frame #5: 0x00000001000812b4 python`method_vectorcall_VARARGS_KEYWORDS + 144
    frame #6: 0x00000001001b54d8 python`_PyEval_EvalFrameDefault + 192276
    frame #7: 0x000000010006f700 python`_PyFunction_Vectorcall + 476
    frame #8: 0x000000010007513c python`method_vectorcall + 364
    frame #9: 0x00000001001b9b5c python`_PyEval_EvalFrameDefault + 210328
    frame #10: 0x000000010006f700 python`_PyFunction_Vectorcall + 476
    frame #11: 0x000000010007513c python`method_vectorcall + 364
    frame #12: 0x00000001002a087c python`thread_run + 228
    frame #13: 0x000000010022b3c4 python`pythread_wrapper + 48
    frame #14: 0x000000019008af94 libsystem_pthread.dylib`_pthread_start + 136
  thread #30
    frame #0: 0x000000019004d9ec libsystem_kernel.dylib`__psynch_cvwait + 8
    frame #1: 0x000000019008b55c libsystem_pthread.dylib`_pthread_cond_wait + 1228
    frame #2: 0x000000010022b7b8 python`PyThread_acquire_lock_timed + 348
    frame #3: 0x000000010029f814 python`acquire_timed + 212
    frame #4: 0x000000010029f9fc python`lock_PyThread_acquire_lock + 72
    frame #5: 0x00000001000812b4 python`method_vectorcall_VARARGS_KEYWORDS + 144
    frame #6: 0x00000001001b54d8 python`_PyEval_EvalFrameDefault + 192276
    frame #7: 0x000000010006f700 python`_PyFunction_Vectorcall + 476
    frame #8: 0x000000010007513c python`method_vectorcall + 364
    frame #9: 0x00000001002a087c python`thread_run + 228
    frame #10: 0x000000010022b3c4 python`pythread_wrapper + 48
    frame #11: 0x000000019008af94 libsystem_pthread.dylib`_pthread_start + 136
* thread #33, stop reason = EXC_BAD_ACCESS (code=1, address=0x346700008)
  * frame #0: 0x00000001044543ec libmlx.dylib`std::__1::tuple<std::__1::vector<int, std::__1::allocator<int>>, std::__1::vector<std::__1::vector<unsigned long, std::__1::allocator<unsigned long>>, std::__1::allocator<std::__1::vector<unsigned long, std::__1::allocator<unsigned long>>>>> mlx::core::collapse_contiguous_dims<unsigned long>(shape=size=4, strides=size=2) at utils.h:51:13
    frame #1: 0x0000000104a73948 libmlx.dylib`void mlx::core::copy_gpu_inplace<unsigned long>(in=0x0000000484c798d0, out=0x000000017e39b010, data_shape=size=4, strides_in_pre=size=4, strides_out_pre=size=2, inp_offset=0, out_offset=0, ctype=General, s=0x0000000d84e3f640) at copy.cpp:62:27
    frame #2: 0x0000000104a73800 libmlx.dylib`mlx::core::copy_gpu_inplace(in=0x0000000484c798d0, out=0x000000017e39b010, ctype=General, s=0x0000000d84e3f640) at copy.cpp:157:10
    frame #3: 0x0000000104a73744 libmlx.dylib`mlx::core::copy_gpu(in=0x0000000484c798d0, out=0x000000017e39b010, ctype=General, s=0x0000000d84e3f640) at copy.cpp:39:3
    frame #4: 0x0000000104a7385c libmlx.dylib`mlx::core::copy_gpu(in=0x0000000484c798d0, out=0x000000017e39b010, ctype=General) at copy.cpp:43:3
    frame #5: 0x0000000104b2d328 libmlx.dylib`mlx::core::Reshape::eval_gpu(this=0x0000000d84e3f638, inputs=size=1, out=0x000000017e39b010) at primitives.cpp:276:5
    frame #6: 0x0000000104078588 libmlx.dylib`mlx::core::UnaryPrimitive::eval_gpu(this=0x0000000d84e3f638, inputs=size=1, outputs=size=1) at primitives.h:145:5
    frame #7: 0x0000000104b21050 libmlx.dylib`mlx::core::metal::make_task(mlx::core::array, bool)::$_1::operator()(this=0x0000000484c51788) at metal.cpp:66:23
    frame #8: 0x0000000104b20d90 libmlx.dylib`decltype(std::declval<mlx::core::metal::make_task(mlx::core::array, bool)::$_1&>()()) std::__1::__invoke[abi:ue170006]<mlx::core::metal::make_task(mlx::core::array, bool)::$_1&>(__f=0x0000000484c51788) at invoke.h:340:25
    frame #9: 0x0000000104b20d48 libmlx.dylib`void std::__1::__invoke_void_return_wrapper<void, true>::__call[abi:ue170006]<mlx::core::metal::make_task(mlx::core::array, bool)::$_1&>(__args=0x0000000484c51788) at invoke.h:415:5
    frame #10: 0x0000000104b20d24 libmlx.dylib`std::__1::__function::__alloc_func<mlx::core::metal::make_task(mlx::core::array, bool)::$_1, std::__1::allocator<mlx::core::metal::make_task(mlx::core::array, bool)::$_1>, void ()>::operator()[abi:ue170006](this=0x0000000484c51788) at function.h:193:16
    frame #11: 0x0000000104b1fb5c libmlx.dylib`std::__1::__function::__func<mlx::core::metal::make_task(mlx::core::array, bool)::$_1, std::__1::allocator<mlx::core::metal::make_task(mlx::core::array, bool)::$_1>, void ()>::operator()(this=0x0000000484c51780) at function.h:364:12
    frame #12: 0x000000010415b30c libmlx.dylib`std::__1::__function::__value_func<void ()>::operator()[abi:ue170006](this=0x0000000173702ee8) const at function.h:518:16
    frame #13: 0x000000010415aac4 libmlx.dylib`std::__1::function<void ()>::operator()(this=0x0000000173702ee8) const at function.h:1169:12
    frame #14: 0x0000000104159d9c libmlx.dylib`mlx::core::scheduler::StreamThread::thread_fn(this=0x0000000165ddc040) at scheduler.h:54:7
    frame #15: 0x000000010415bd64 libmlx.dylib`decltype(*std::declval<mlx::core::scheduler::StreamThread*>().*std::declval<void (mlx::core::scheduler::StreamThread::*)()>()()) std::__1::__invoke[abi:ue170006]<void (mlx::core::scheduler::StreamThread::*)(), mlx::core::scheduler::StreamThread*, void>(__f=0x0000000165dda958, __a0=0x0000000165dda968) at invoke.h:308:25
    frame #16: 0x000000010415bca4 libmlx.dylib`void std::__1::__thread_execute[abi:ue170006]<std::__1::unique_ptr<std::__1::__thread_struct, std::__1::default_delete<std::__1::__thread_struct>>, void (mlx::core::scheduler::StreamThread::*)(), mlx::core::scheduler::StreamThread*, 2ul>(__t=size=3, (null)=__tuple_indices<2UL> @ 0x0000000173702f7f) at thread.h:227:5
    frame #17: 0x000000010415b5a0 libmlx.dylib`void* std::__1::__thread_proxy[abi:ue170006]<std::__1::tuple<std::__1::unique_ptr<std::__1::__thread_struct, std::__1::default_delete<std::__1::__thread_struct>>, void (mlx::core::scheduler::StreamThread::*)(), mlx::core::scheduler::StreamThread*>>(__vp=0x0000000165dda950) at thread.h:238:5
    frame #18: 0x000000019008af94 libsystem_pthread.dylib`_pthread_start + 136
  thread #34
    frame #0: 0x000000019004d9ec libsystem_kernel.dylib`__psynch_cvwait + 8
    frame #1: 0x000000019008b55c libsystem_pthread.dylib`_pthread_cond_wait + 1228
    frame #2: 0x000000018ffb0b14 libc++.1.dylib`std::__1::condition_variable::wait(std::__1::unique_lock<std::__1::mutex>&) + 28
    frame #3: 0x000000010415a9f0 libmlx.dylib`void std::__1::condition_variable::wait<mlx::core::scheduler::StreamThread::thread_fn()::'lambda'()>(this=0x0000000165de41f0, __lk=0x0000000173c46ed0, __pred=(unnamed class) @ 0x0000000173c46e78) at condition_variable.h:148:5
    frame #4: 0x0000000104159cd8 libmlx.dylib`mlx::core::scheduler::StreamThread::thread_fn(this=0x0000000165de4180) at scheduler.h:46:14
    frame #5: 0x000000010415bd64 libmlx.dylib`decltype(*std::declval<mlx::core::scheduler::StreamThread*>().*std::declval<void (mlx::core::scheduler::StreamThread::*)()>()()) std::__1::__invoke[abi:ue170006]<void (mlx::core::scheduler::StreamThread::*)(), mlx::core::scheduler::StreamThread*, void>(__f=0x0000000165d29708, __a0=0x0000000165d29718) at invoke.h:308:25
    frame #6: 0x000000010415bca4 libmlx.dylib`void std::__1::__thread_execute[abi:ue170006]<std::__1::unique_ptr<std::__1::__thread_struct, std::__1::default_delete<std::__1::__thread_struct>>, void (mlx::core::scheduler::StreamThread::*)(), mlx::core::scheduler::StreamThread*, 2ul>(__t=size=3, (null)=__tuple_indices<2UL> @ 0x0000000173c46f7f) at thread.h:227:5
    frame #7: 0x000000010415b5a0 libmlx.dylib`void* std::__1::__thread_proxy[abi:ue170006]<std::__1::tuple<std::__1::unique_ptr<std::__1::__thread_struct, std::__1::default_delete<std::__1::__thread_struct>>, void (mlx::core::scheduler::StreamThread::*)(), mlx::core::scheduler::StreamThread*>>(__vp=0x0000000165d29700) at thread.h:238:5
    frame #8: 0x000000019008af94 libsystem_pthread.dylib`_pthread_start + 136
  thread #35, queue = 'com.Metal.CommandQueueDispatch'
    frame #0: 0x000000019386af04 IOKit`iokit_user_client_trap + 8
    frame #1: 0x00000001af118a30 IOGPU`IOGPUCommandQueueSubmitCommandBuffers + 164
    frame #2: 0x00000001af1086b4 IOGPU`-[IOGPUMetalCommandQueue _submitCommandBuffers:count:] + 356
    frame #3: 0x00000001af108528 IOGPU`-[IOGPUMetalCommandQueue submitCommandBuffers:count:] + 72
    frame #4: 0x000000019a3f6368 Metal`-[_MTLCommandQueue _submitAvailableCommandBuffers] + 492
    frame #5: 0x000000018feda3e8 libdispatch.dylib`_dispatch_client_callout + 20
    frame #6: 0x000000018fedd8ec libdispatch.dylib`_dispatch_continuation_pop + 600
    frame #7: 0x000000018fef17f0 libdispatch.dylib`_dispatch_source_latch_and_call + 420
    frame #8: 0x000000018fef03b4 libdispatch.dylib`_dispatch_source_invoke + 832
    frame #9: 0x000000018fee1898 libdispatch.dylib`_dispatch_lane_serial_drain + 368
    frame #10: 0x000000018fee2544 libdispatch.dylib`_dispatch_lane_invoke + 380
    frame #11: 0x000000018feed2d0 libdispatch.dylib`_dispatch_root_queue_drain_deferred_wlh + 288
    frame #12: 0x000000018feecb44 libdispatch.dylib`_dispatch_workloop_worker_thread + 404
    frame #13: 0x000000019008700c libsystem_pthread.dylib`_pthread_wqthread + 288
  thread #36
    frame #0: 0x0000000190085d20 libsystem_pthread.dylib`start_wqthread
  thread #37
    frame #0: 0x0000000190085d20 libsystem_pthread.dylib`start_wqthread
  thread #38
    frame #0: 0x0000000000000000

johnsonjzhou avatar Jun 30 '24 10:06 johnsonjzhou

Here is some code that can reproduce the above error semi-regularly on my machine:

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import random
from functools import partial
from tqdm import tqdm


class Model(nn.Module):

    def __init__(self):
        super().__init__()
        self.F = nn.Sequential(
            nn.Conv2d(
                in_channels=2,
                out_channels=4,
                kernel_size=(1, 5),
                padding=(0, 0)
            ),
            nn.BatchNorm(num_features=4),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=4,
                out_channels=8,
                kernel_size=(1, 3),
                padding=(0, 0)
            ),
            nn.BatchNorm(num_features=8),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=8,
                out_channels=16,
                kernel_size=(1, 3),
                padding=(0, 0)
            ),
            nn.BatchNorm(num_features=16),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=16,
                out_channels=32,
                kernel_size=(1, 3),
                padding=(0, 0)
            ),
            nn.BatchNorm(num_features=32),
            nn.ReLU()
        )
        self.G = nn.GRU(input_size=32, hidden_size=16)
        self.P = nn.Linear(input_dims=16, output_dims=5)

    def __call__(self, x: mx.array) -> mx.array:
        x = self.F(x)
        x = self.G(x)[..., -1, :] # Summarise axis 2
        return x

def train():

    # N: batch size
    # T: number of time steps
    # L: length of signal segment
    # C: number of channels
    N, T, L, C = (64, 8, 512, 2)

    ds = [
        (
            # x, the L dim can be variable
            mx.random.uniform(
                low=-1,
                high=1,
                shape=(N, T, random.choice(range(16, L)), C)
            ),
            # y
            mx.random.randint(
                low=0,
                high=5,
                shape=(N, T)
            )
        )
        for _ in range(5000)
    ]

    model = Model()
    mx.eval(model.parameters())

    optimizer = optimizer = optim.AdamW(learning_rate=1e-4, weight_decay=0.01)
    state = [model.state, optimizer.state, mx.random.state]

    def loss_fn(model, x, y):
        logits = model(x)
        y_pred = model.P(logits)
        losses = nn.losses.cross_entropy(y_pred, y, reduction="none")
        losses = mx.sum(losses, axis=1)
        losses = mx.mean(losses)
        return losses

    @partial(mx.compile, inputs=state, outputs=state)
    def train_step(x, y):
        model.train()
        loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
        loss, grads = loss_and_grad_fn(model, x, y)
        optimizer.update(model, grads)
        return loss

    for epoch in range(1000):
        batches = random.sample(ds, len(ds))
        batches = tqdm(batches, desc=f"Epoch {epoch}:")
        for x, y in batches:
            x, y = mx.array(x), mx.array(y)
            mx.eval(state)
            loss = train_step(x, y)
            batches.set_postfix({"loss": loss.item()})
            continue

    return

if __name__ == "__main__":
    train()

johnsonjzhou avatar Jul 01 '24 01:07 johnsonjzhou

There are quite a few calls of collapse_contiguous_dims where the input strides do not have the same size. It looks like the function expects them to have the same size but I'm not sure yet 🤔

awni avatar Jul 01 '24 13:07 awni

Were you able to reproduce it on your end?

johnsonjzhou avatar Jul 02 '24 07:07 johnsonjzhou

No not yet. Can you give rough estimates for how often and when the segfault shows up?

awni avatar Jul 02 '24 16:07 awni

If it will show, it will show within the first few epochs, most frequently around 30-40% of the batch in the first epoch. If not, it will run fine forever. So to reproduce it, stop the run and start it again. I have tried setting a random seed but it does not fault at the same location.

johnsonjzhou avatar Jul 03 '24 00:07 johnsonjzhou

Try these weights. It seg faults fairly consistently now within the first half of the first epoch with these weights for some reason. test_weights.safetensors.zip

johnsonjzhou avatar Jul 03 '24 02:07 johnsonjzhou

We believe this is fixed in the latest MLX (0.16). There was a bug in the part of the code that it seems you all were getting segfaults. Since we can't reproduce your issue though we can't be sure. But would you mind running with the latest MLX to check if you still see segfaults?

awni avatar Jul 11 '24 22:07 awni