jax
jax copied to clipboard
`jax.profiler.trace` repeatedly fails to display entire trace
Description
On various platforms, versions and backends jax.profiler.trace emits a trace that is truncated.
System info (python version, jaxlib version, accelerator, etc.)
Here is one such example
>>> import jax; jax.print_environment_info()
jax: 0.4.16.dev20240518
jaxlib: 0.4.14
numpy: 1.26.0
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [gpu(id=0)]
process_count: 1
$ nvidia-smi
Sat May 18 01:58:54 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.86.10 Driver Version: 535.86.10 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA GeForce RTX 4070 ... On | 00000000:01:00.0 Off | N/A |
| N/A 51C P3 8W / 80W | 2434MiB / 8188MiB | 7% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+