jax
jax copied to clipboard
Mysterious crash on TPU pod
Description
I'm getting a very mysterious crash on a TPU pod. It only happens during my eval step (not train step), and only happens on a multi-worker pod (not a single VM). Nothing is printed to stdout or stderr, even with logging.DEBUG. The process just... exits. All I could gather from /tmp/tpu_logs was this line:
E1224 22:35:53.165575 26868 real_program_continuator.cc:1331] 0x0x0_TC0: *** Halt is unexpected, all pending programs will fail. This is likely an XLA compiler bug, please file a bug to xla-tpu@ and CC tfrt-devs@.
The progress bar gets to exactly 32 steps before the crash. I'm assuming this is due to the filling of some sort of queue (this is on a v4-128).
Is there anything else I can do to diagnose this? I've attached the full logs from worker 0 for reference.
tpu_driver.WARNING.txt tpu_driver.ERROR.txt tpu_driver.INFO.txt
What jax/jaxlib version are you using?
jax 0.4.23, jaxlib 0.4.23, libtpu-nightly 0.1.dev20231213
Which accelerator(s) are you using?
TPUv4
Additional system info?
Python 3.10.12, tpu-vm-v4-base
NVIDIA GPU info
No response
Update: I implemented FSDP, and now it happens with the train step too, which is a bit more of an issue. Still works on a single VM but not the pod. I'm also getting these errors now:
Involuntary full rematerialization. The compiler was not able to go from sharding {devices=[64,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63} to {devices=[1,4,16]<=[16,4]T(1,0) last_tile_dim_replicate} without doing a full rematerialization of the tensor. You probably want to enrich the sharding annotations to prevent this from happening.
Which doesn't seem fatal, but is also difficult to address, as I couldn't find any documentation explaining what that syntax means (e.g., devices=[1,4,16]<=[16,4]T(1,0) last_tile_dim_replicate). I don't understand why my setup shouldn't work, as it's a fairly standard FSDP setup. (The weight matrices are sharded along the mesh axis of size 4, and the data is sharded along both axes, i.e., all 64 devices. I'm using in_shardings and out_shardings to enforce this.)
I'm also getting these errors on some, but not all of the workers:
A program or fatal error occurred, all pending programs will fail, and results may be corrupted. This is likely an XLA compiler bug or your job hit a bad TPU hardware. If you suspect a compiler bug please file a bug to xla-tpu@ and CC tfrt-devs@. If the error is repeatedly happening on the same machine and you suspect faulty hardware, consider reporting a bad machine: go/allocator-faq#report-a-bad-machine.
Here is the full log for reference: tpu_driver.ERROR.txt
I understand that the problem is probably not solvable outright based on the information I provided, but I would really appreciate any guidance, since this is now totally blocking and as someone outside of Google it's completely unclear how to proceed with debugging.
Oof, sorry you're blocked! Let's see if we can figure it out...
Would you be able to dump the HLO of the crashing computation and attach it here or email it to me? To dump HLO, set the env var XLA_FLAGS=--xla_dump_to=/path/to/some/dir. This may dump multiple computations. If you're not sure which one is crashing, just attach all of them. There are also multiple files dumped for each computation; please include all of them if you can. Please also include the corresponding tpu_driver.INFO if possible from after generating the dumped HLO, just to make sure all the info we have is lined up properly. Let me know if you need help with any of this.
Also, have you checked if you get this error with earlier versions of jax? It'd be helpful to know if this is a recent regression.
I've also filed a Google-internal issue with the XLA:TPU team to see if they have any ideas (b/318437743 for Googlers' reference).
Thanks @skye! I did try with jax/jaxlib 0.4.20 at some point and the same thing happened, as far as I could tell.
In collecting the HLO dump, I discovered that it's not actually the train step that is crashing now, but the very next function call, which is apply_ema_decay. This is shocking to me, since apply_ema_decay is a simple elementwise operation on the parameters. I was even checking the persistent compilation cache to deduce where the crash happened, and I saw that apply_ema_decay was not in there, so I assumed that it did not make it past the train step. However, since apply_ema_decay is in the HLO dump, it means that it must have crashed while compiling (but after lowering) apply_ema_decay (right?).
Anyway, here is the information requested, from both worker 0 and worker 6 (which is mentioned in the worker 0 logs, and appears to have some additional log entries). I left out a single file, which is the .after_optimization-buffer-assignment.txt for the train step, since that file alone is 4.7Gb. Let me know if you need it.
Hello, sorry for the delay! Thanks for collecting all the info, this is super useful.
Is the train step jit__unnamed_wrapped_function_? I think this is the crashing computation although I'm not 100% sure. The log indicates the TPU runtime is crashing executing some computation. Since TPU executions run asynchronously, I'm guessing that apply_ema_decay is being simultaneously compiled while the train step is running, and then the whole process crashes before it finishes (but unrelated to) that compilation. Note that worker 0 has the beginnings of the apply_ema_decay dump, but worker 6 doesn't.
Anyway, I've shared this with the XLA:TPU team. Hopefully I'll have some more updates for you soon!
Actually, one thing I've noticed: module_0521.jit__unnamed_wrapped_function_.before_optimizations.txt appears to be different between workers 0 and 6. If different processes in your TPU pod somehow end up compiling different versions of the same function (e.g. if they're fed dynamically-shaped data that varies across processes), this can cause mysterious runtime crashes.
Can you try setting the env var LIBTPU_INIT_ARGS=--xla_tpu_use_enhanced_launch_barrier=true on all hosts and rerunning? If the problem is indeed mismatched XLA programs, that flag should cause the TPU runtime to fail with a different error. It'll still be pretty mysterious, but if you can share a tpu_driver.INFO log or two after trying the env var, I can help check for the new error message.
(We're working on enabling this extra checking by default, and eventually would like to make it fail in a clear way, instead of arcane messages in the tpu_driver logs.)
Oops, maybe not actually. After cleaning up spurious differences in the HLO, I don't think they're significantly different. It doesn't hurt to run with the extra flag just to make sure, but we'll likely have to wait for feedback from the XLA:TPU team.
I also had this mysterious crash, no errors. adding additional partition rules to see if it helps. examining tpu_driver.ERROR.txt gives "You probably want to enrich the sharding annotations to prevent this from happening." you guys should probably see how to make this error more verbose