xla icon indicating copy to clipboard operation
xla copied to clipboard

Torchax regressing basic TRN tests (ValueError: Invalid value "TRACE" for JAX flag jax_logging_level)

Open rpsilva-aws opened this issue 7 months ago • 6 comments

🐛 Bug

Test:

def test_sharded_matmul(tensor_a_shape, tensor_b_shape, mesh_shape, sharding_spec_a, sharding_spec_b):
    cpu_device = torch.device("cpu")
    neuron_device = xm.xla_device()

    device_ids = np.array(range(NUM_DEVICES))

    mesh = Mesh(device_ids, mesh_shape, ("tp1", "tp2"))
    tensor_a_cpu = torch.rand(tensor_a_shape, dtype=torch.float32, device=cpu_device)
    tensor_b_cpu = torch.rand(tensor_b_shape, dtype=torch.float32, device=cpu_device)

    tensor_a_neuron = tensor_a_cpu.to(neuron_device)
    tensor_b_neuron = tensor_b_cpu.to(neuron_device)
    xs.mark_sharding(tensor_a_neuron, mesh, sharding_spec_a)
    xs.mark_sharding(tensor_b_neuron, mesh, sharding_spec_b)
    tensor_c_neuron = torch.matmul(tensor_a_neuron, tensor_b_neuron)
    xs.mark_sharding(tensor_c_neuron, mesh, (None, None))
    result_cpu = torch.matmul(tensor_a_cpu, tensor_b_cpu)

    assert compare_npy_diff(result_cpu, tensor_c_neuron.to(cpu_device))
    xm.mark_step()

Stacktrace:

aws_neuron_venv/lib/python3.10/site-packages/torch_xla/distributed/spmd/xla_sharding.py:610: in mark_sharding
    tx = maybe_get_torchax()
aws_neuron_venv/lib/python3.10/site-packages/torch_xla/_internal/jax_workarounds.py:51: in maybe_get_torchax
    import torchax
aws_neuron_venv/lib/python3.10/site-packages/torchax/__init__.py:4: in <module>
    import jax
aws_neuron_venv/lib/python3.10/site-packages/jax/__init__.py:25: in <module>
    from jax._src.cloud_tpu_init import cloud_tpu_init as _cloud_tpu_init
aws_neuron_venv/lib/python3.10/site-packages/jax/_src/cloud_tpu_init.py:20: in <module>
    from jax._src import config
aws_neuron_venv/lib/python3.10/site-packages/jax/_src/config.py:1721: in <module>
    optional_enum_state(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

name = 'jax_logging_level'
enum_values = ['NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']
default = 'TRACE'
help = 'Set the corresponding logging level on all jax loggers. Only string values from ["NOTSET", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] are accepted. If None, the logging level will not be set. Includes C++ logging.'

    def optional_enum_state(
        name: str,
        enum_values: Sequence[str],
        default: str | None,
        help: str,
        *,
        update_global_hook: Callable[[str | None], None] | None = None,
        update_thread_local_hook: Callable[[str | None], None] | None = None,
        include_in_jit_key: bool = False,
    ) -> State[str | None]:
      """Set up thread-local state and return a contextmanager for managing it.

      See docstring for ``bool_state``.

      Args:
        name: string, converted to lowercase to define the name of the config
          option (and absl flag). It is converted to uppercase to define the
          corresponding shell environment variable.
        enum_values: list of strings representing the possible values for the
          option.
        default: optional string, default value.
        help: string, used to populate the flag help information as well as the
          docstring of the returned context manager.

      Returns:
        A contextmanager to control the thread-local state value.
      """
      if default is not None and not isinstance(default, str):
        raise TypeError(f"Default value must be of type str or None, got {default} "
                        f"of type {getattr(type(default), '__name__', type(default))}")
      name = name.lower()
      default = os.getenv(name.upper(), default)
      if default is not None and default not in enum_values:
>       raise ValueError(f"Invalid value \"{default}\" for JAX flag {name}")
E       ValueError: Invalid value "TRACE" for JAX flag jax_logging_level

Environment

  • torch-2.8.0.dev20250522+cpu
  • jax-0.6.1.dev20250424
  • jaxlib-0.6.1.dev20250424
Looking in indexes: https://pypi.org/simple, file:///home/ubuntu/pip, https://download.pytorch.org/whl/nightly/cpu, https://download.pytorch.org/whl/test/cpu
...
...
Collecting jax@ https://storage.googleapis.com/jax-releases/nightly/jax/jax-0.6.1.dev20250424-py3-none-any.whl (from torch-xla==2.8.*->-r /home/ubuntu/pip/requirements.txt (line 8))
  Downloading https://storage.googleapis.com/jax-releases/nightly/jax/jax-0.6.1.dev20250424-py3-none-any.whl (2.5 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.5/2.5 MB 14.7 MB/s eta 0:00:00
Collecting jaxlib@ https://storage.googleapis.com/jax-releases/nightly/nocuda/jaxlib-0.6.1.dev20250424-cp310-cp310-manylinux2014_x86_64.whl (from torch-xla==2.8.*->-r /home/ubuntu/pip/requirements.txt (line 8))
  Downloading https://storage.googleapis.com/jax-releases/nightly/nocuda/jaxlib-0.6.1.dev20250424-cp310-cp310-manylinux2014_x86_64.whl (88.1 MB)
  • Reproducible on XLA backend [CPU/TPU/CUDA]: TRN
  • torch_xla version: 2.8

rpsilva-aws avatar May 23 '25 00:05 rpsilva-aws

cc: @jeffhataws

rpsilva-aws avatar May 23 '25 00:05 rpsilva-aws

FYI @qihqi

rpsilva-aws avatar May 23 '25 00:05 rpsilva-aws

"TRACE" is not in the enum list. Has the list changed? Or the "TRACE" usage is committed before the enum list change?

jeffhataws avatar May 29 '25 19:05 jeffhataws

hmm would export JAX_LOGGING_LEVEL="NOTSET" help?

qihqi avatar Jun 05 '25 23:06 qihqi

It does, but this will be a major blocker for us with 2.8. Do we not see this in any other case? This doesn't seem like a Neuron specific problem, but with the JAX version (from the stacktrace snippet: default = 'TRACE').

rpsilva-aws avatar Jun 06 '25 05:06 rpsilva-aws

@qihqi do you have updates for this issue? Looks like now we see this issue when we use xp.Trace.

jeffhataws avatar Jun 12 '25 21:06 jeffhataws

We recently updated Jax to 0.6.2; does it help there?

qihqi avatar Jul 09 '25 02:07 qihqi

Closing this as Jax dependency is now made optional.

qihqi avatar Aug 05 '25 00:08 qihqi