jax icon indicating copy to clipboard operation
jax copied to clipboard

`jax.profiler.trace` repeatedly fails to display entire trace

Open jon-chuang opened this issue 1 year ago • 0 comments

Description

On various platforms, versions and backends jax.profiler.trace emits a trace that is truncated.

System info (python version, jaxlib version, accelerator, etc.)

Here is one such example

>>> import jax; jax.print_environment_info()
jax:    0.4.16.dev20240518
jaxlib: 0.4.14
numpy:  1.26.0
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [gpu(id=0)]
process_count: 1

$ nvidia-smi
Sat May 18 01:58:54 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.86.10              Driver Version: 535.86.10    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 4070 ...    On  | 00000000:01:00.0 Off |                  N/A |
| N/A   51C    P3               8W /  80W |   2434MiB /  8188MiB |      7%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

jon-chuang avatar May 18 '24 05:05 jon-chuang