jax
jax copied to clipboard
print literal types in jaxprs
import jax
import jax.numpy as jnp
print(jax.make_jaxpr(lambda x: jnp.sin(jnp.sin(x)) + 2)(3))
Before:
{ lambda ; a:i32[]. let
b:f32[] = convert_element_type[new_dtype=float32 weak_type=True] a
c:f32[] = sin b
d:f32[] = sin c
e:f32[] = add d 2.0
in (e,) }
After:
{ lambda ; a:i32[]. let
b:f32[] = convert_element_type[new_dtype=float32 weak_type=True] a
c:f32[] = sin b
d:f32[] = sin c
e:f32[] = add d 2.0:f32[]
in (e,) }
Not sure about this one...