jax icon indicating copy to clipboard operation
jax copied to clipboard

print literal types in jaxprs

Open mattjj opened this issue 9 months ago • 0 comments

import jax
import jax.numpy as jnp
print(jax.make_jaxpr(lambda x: jnp.sin(jnp.sin(x)) + 2)(3))

Before:

{ lambda ; a:i32[]. let
    b:f32[] = convert_element_type[new_dtype=float32 weak_type=True] a
    c:f32[] = sin b
    d:f32[] = sin c
    e:f32[] = add d 2.0
  in (e,) }

After:

{ lambda ; a:i32[]. let
    b:f32[] = convert_element_type[new_dtype=float32 weak_type=True] a
    c:f32[] = sin b
    d:f32[] = sin c
    e:f32[] = add d 2.0:f32[]
  in (e,) }

Not sure about this one...

mattjj avatar May 24 '24 04:05 mattjj