jax icon indicating copy to clipboard operation
jax copied to clipboard

`jax.numpy.tan` call adds `chlo.tan` call to exported StableHLO; `stablehlo.tan` should be preferred

Open joaospinto opened this issue 1 year ago • 4 comments
trafficstars

Description

Is this intended behavior? I would have expected stablehlo calls to be preferred to external ones when possible. Note that stablehlo.tan does exist.

>>> import jax.numpy as np
>>> f = lambda x: np.tan(x)
>>> import jax
>>> jit_f = jax.jit(f)
>>> lowered = jit_f.lower(1.0)
>>> print(lowered.compiler_ir("stablehlo"))
module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<f32> {mhlo.layout_mode = "default"}) -> (tensor<f32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = chlo.tan %arg0 : tensor<f32> -> tensor<f32>
    return %0 : tensor<f32>
  }
}

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.31
jaxlib: 0.4.30
numpy:  2.0.0
python: 3.12.4 (main, Jun  6 2024, 18:26:44) [Clang 15.0.0 (clang-1500.3.9.4)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='LUSM493FC90MM', release='23.6.0', version='Darwin Kernel Version 23.6.0: Mon Jul 29 21:14:30 PDT 2024; root:xnu-10063.141.2~1/RELEASE_ARM64_T6030', machine='arm64')

joaospinto avatar Aug 26 '24 23:08 joaospinto

I suspect that stablehlo.tan was added after that code was written. Do you want to send a PR switching it to stablehlo? It should be a tiny change.

hawkinsp avatar Aug 26 '24 23:08 hawkinsp

Sure, I'll give it a go.

joaospinto avatar Aug 26 '24 23:08 joaospinto

Created a pull request: https://github.com/google/jax/pull/23261/files Feel free to review @hawkinsp

joaospinto avatar Aug 27 '24 00:08 joaospinto

FWIW,

(.jax) ➜  jax git:(stablehlo.tan) python
Python 3.12.4 (main, Jun  6 2024, 18:26:44) [Clang 15.0.0 (clang-1500.3.9.4)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax.numpy as np
>>> f = lambda x: np.tan(x)
>>> import jax
>>> jit_f = jax.jit(f)
>>> lowered = jit_f.lower(1.0)
>>> print(lowered.compiler_ir("stablehlo"))
module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<f32> {mhlo.layout_mode = "default"}) -> (tensor<f32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.tan %arg0 : tensor<f32>
    return %0 : tensor<f32>
  }
}

joaospinto avatar Aug 27 '24 00:08 joaospinto