Torchax regressing basic TRN tests (ValueError: Invalid value "TRACE" for JAX flag jax_logging_level)
🐛 Bug
Test:
def test_sharded_matmul(tensor_a_shape, tensor_b_shape, mesh_shape, sharding_spec_a, sharding_spec_b):
cpu_device = torch.device("cpu")
neuron_device = xm.xla_device()
device_ids = np.array(range(NUM_DEVICES))
mesh = Mesh(device_ids, mesh_shape, ("tp1", "tp2"))
tensor_a_cpu = torch.rand(tensor_a_shape, dtype=torch.float32, device=cpu_device)
tensor_b_cpu = torch.rand(tensor_b_shape, dtype=torch.float32, device=cpu_device)
tensor_a_neuron = tensor_a_cpu.to(neuron_device)
tensor_b_neuron = tensor_b_cpu.to(neuron_device)
xs.mark_sharding(tensor_a_neuron, mesh, sharding_spec_a)
xs.mark_sharding(tensor_b_neuron, mesh, sharding_spec_b)
tensor_c_neuron = torch.matmul(tensor_a_neuron, tensor_b_neuron)
xs.mark_sharding(tensor_c_neuron, mesh, (None, None))
result_cpu = torch.matmul(tensor_a_cpu, tensor_b_cpu)
assert compare_npy_diff(result_cpu, tensor_c_neuron.to(cpu_device))
xm.mark_step()
Stacktrace:
aws_neuron_venv/lib/python3.10/site-packages/torch_xla/distributed/spmd/xla_sharding.py:610: in mark_sharding
tx = maybe_get_torchax()
aws_neuron_venv/lib/python3.10/site-packages/torch_xla/_internal/jax_workarounds.py:51: in maybe_get_torchax
import torchax
aws_neuron_venv/lib/python3.10/site-packages/torchax/__init__.py:4: in <module>
import jax
aws_neuron_venv/lib/python3.10/site-packages/jax/__init__.py:25: in <module>
from jax._src.cloud_tpu_init import cloud_tpu_init as _cloud_tpu_init
aws_neuron_venv/lib/python3.10/site-packages/jax/_src/cloud_tpu_init.py:20: in <module>
from jax._src import config
aws_neuron_venv/lib/python3.10/site-packages/jax/_src/config.py:1721: in <module>
optional_enum_state(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
name = 'jax_logging_level'
enum_values = ['NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']
default = 'TRACE'
help = 'Set the corresponding logging level on all jax loggers. Only string values from ["NOTSET", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] are accepted. If None, the logging level will not be set. Includes C++ logging.'
def optional_enum_state(
name: str,
enum_values: Sequence[str],
default: str | None,
help: str,
*,
update_global_hook: Callable[[str | None], None] | None = None,
update_thread_local_hook: Callable[[str | None], None] | None = None,
include_in_jit_key: bool = False,
) -> State[str | None]:
"""Set up thread-local state and return a contextmanager for managing it.
See docstring for ``bool_state``.
Args:
name: string, converted to lowercase to define the name of the config
option (and absl flag). It is converted to uppercase to define the
corresponding shell environment variable.
enum_values: list of strings representing the possible values for the
option.
default: optional string, default value.
help: string, used to populate the flag help information as well as the
docstring of the returned context manager.
Returns:
A contextmanager to control the thread-local state value.
"""
if default is not None and not isinstance(default, str):
raise TypeError(f"Default value must be of type str or None, got {default} "
f"of type {getattr(type(default), '__name__', type(default))}")
name = name.lower()
default = os.getenv(name.upper(), default)
if default is not None and default not in enum_values:
> raise ValueError(f"Invalid value \"{default}\" for JAX flag {name}")
E ValueError: Invalid value "TRACE" for JAX flag jax_logging_level
Environment
torch-2.8.0.dev20250522+cpujax-0.6.1.dev20250424jaxlib-0.6.1.dev20250424
Looking in indexes: https://pypi.org/simple, file:///home/ubuntu/pip, https://download.pytorch.org/whl/nightly/cpu, https://download.pytorch.org/whl/test/cpu
...
...
Collecting jax@ https://storage.googleapis.com/jax-releases/nightly/jax/jax-0.6.1.dev20250424-py3-none-any.whl (from torch-xla==2.8.*->-r /home/ubuntu/pip/requirements.txt (line 8))
Downloading https://storage.googleapis.com/jax-releases/nightly/jax/jax-0.6.1.dev20250424-py3-none-any.whl (2.5 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.5/2.5 MB 14.7 MB/s eta 0:00:00
Collecting jaxlib@ https://storage.googleapis.com/jax-releases/nightly/nocuda/jaxlib-0.6.1.dev20250424-cp310-cp310-manylinux2014_x86_64.whl (from torch-xla==2.8.*->-r /home/ubuntu/pip/requirements.txt (line 8))
Downloading https://storage.googleapis.com/jax-releases/nightly/nocuda/jaxlib-0.6.1.dev20250424-cp310-cp310-manylinux2014_x86_64.whl (88.1 MB)
- Reproducible on XLA backend [CPU/TPU/CUDA]: TRN
- torch_xla version: 2.8
cc: @jeffhataws
FYI @qihqi
"TRACE" is not in the enum list. Has the list changed? Or the "TRACE" usage is committed before the enum list change?
hmm would export JAX_LOGGING_LEVEL="NOTSET" help?
It does, but this will be a major blocker for us with 2.8. Do we not see this in any other case? This doesn't seem like a Neuron specific problem, but with the JAX version (from the stacktrace snippet: default = 'TRACE').
@qihqi do you have updates for this issue? Looks like now we see this issue when we use xp.Trace.
We recently updated Jax to 0.6.2; does it help there?
Closing this as Jax dependency is now made optional.