xla icon indicating copy to clipboard operation
xla copied to clipboard

Remove excessive warn message in maybe_get_jax as it creates too many log lines during training

Open rajkthakur opened this issue 4 months ago • 10 comments

🐛 Bug

The maybe_get_jax() function in torch_xla/_internal/jax_workarounds.py merged in #9521 currently emits a warning message when JAX is not installed. While informative, this warning results in an excessive number of log lines during training workloads, cluttering the logs and making it difficult to spot genuinely important debug messages.

To Reproduce

Steps to reproduce the behavior:

  1. Create Python Virtual Environment (python3 -m venv ptxla_28) on Ubuntu 22.04
  2. pip install torch==2.8.0 torchvision; pip install torch_xla==2.8.0
  3. Create small python script(let's call it trigger_warning.py)
import sys
sys.path.insert(0, 'ptxla_28/lib/python3.10/site-packages')
from torch_xla._internal.jax_workarounds import maybe_get_jax
maybe_get_jax() 
  1. execute the script bash -c "source ptxla_28/bin/activate && python trigger_warning.py"
  2. You should be able to see the warning message like below
WARNING:root:Defaulting to PJRT_DEVICE=CPU
WARNING:root:You are trying to use a feature that requires jax/pallas.You can install Jax/Pallas via pip install torch_xla[pallas]

Expected behavior

Remove or suppress this warning message, or limit it to display only once per process/session instead of for every invocation.

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: CPU
  • torch_xla version: 2.8.0
  • Relevant Code: https://github.com/pytorch/xla/blob/0f56dec9a33a993d4c14cb755bdd25490cabba21/torch_xla/_internal/jax_workarounds.py#L61

Additional context

The current behavior results in thousands of lines of repeated warnings when running workloads that do not require JAX, negatively impacting developer experience. Reducing or removing this warning will significantly clean up logs for users running long or large-scale training jobs, improving usability without sacrificing relevant error reporting.

rajkthakur avatar Aug 19 '25 20:08 rajkthakur

+1

subhashgahlot161278 avatar Aug 19 '25 20:08 subhashgahlot161278

Mark this as performance issue since affects lazy mode training performance for multiple models.

jeffhataws avatar Aug 21 '25 15:08 jeffhataws

cc @qihqi

ysiraichi avatar Aug 21 '25 15:08 ysiraichi

Hi,

Sure, we can have an option to disable this warning via env variables and what not.

However, usually a function calls maybe_get_jax is because it actually need to use jax (say pallas, shard_as, call_jax, assume_pure); if you can figure out which function called maybe_get_jax and is unneeded then that will be better, we can remove call to maybe_get_jax from that.

If you actually used a fucntion that needs jax, but you dont have jax, then that function likely is not doing what it is suppose to do (say shard_as becoming no-op) and that is no good.

qihqi avatar Aug 21 '25 18:08 qihqi

Since JAX is an optional dependency, we should introduce a new env variable that applies to maybe_get_jax so that the function return None if the env is not set.

Additional Context: We have observed upto 20% performance degradation just because this check is in hot path of some of our test cases. So if we can return reliably using the env variable it helps restore the perf drop.

rajkthakur avatar Aug 22 '25 20:08 rajkthakur

@rajkthakur sounds good. Feel free to submit a PR

qihqi avatar Aug 22 '25 21:08 qihqi

im using torch

still spammed by:

[WARNING][root] You are trying to use a feature that requires jax/pallas.You can install Jax/Pallas via pip install torch_xla[pallas]
[WARNING][root] You are trying to use a feature that requires jax/pallas.You can install Jax/Pallas via pip install torch_xla[pallas]
[WARNING][root] You are trying to use a feature that requires jax/pallas.You can install Jax/Pallas via pip install torch_xla[pallas]
[WARNING][root] You are trying to use a feature that requires jax/pallas.You can install Jax/Pallas via pip install torch_xla[pallas]
[WARNING][root] You are trying to use a feature that requires jax/pallas.You can install Jax/Pallas via pip install torch_xla[pallas]
[WARNING][root] You are trying to use a feature that requires jax/pallas.You can install Jax/Pallas via pip install torch_xla[pallas]
[WARNING][root] You are trying to use a feature that requires jax/pallas.You can install Jax/Pallas via pip install torch_xla[pallas]
[WARNING][root] You are trying to use a feature that requires jax/pallas.You can install Jax/Pallas via pip install torch_xla[pallas]
[WARNING][root] You are trying to use a feature that requires jax/pallas.You can install Jax/Pallas via pip install torch_xla[pallas]
...

But the spam is not main problem, i concern that is this telling about a problem behind the scene?

steveepreston avatar Oct 09 '25 14:10 steveepreston

@rajkthakur You telling this is decreasing performance? can you tell where is it? maybe we can monkey-patch it

steveepreston avatar Oct 09 '25 15:10 steveepreston

@steveepreston The issue is resolved in PT/XLA 2.8.1 There were multiple reasons contributing to perf degradation, This being one., but the major issue was compile flag change specially in 2.8.0

rajkthakur avatar Oct 10 '25 02:10 rajkthakur

@rajkthakur Sadly after upgrading to torch_xla 2.8.0 I faced too much train time increase. from ~8min to ~24min Upgrading to 2.8.1 even didn't fixed this.

steveepreston avatar Oct 11 '25 02:10 steveepreston