jaxpr-viz
jaxpr-viz copied to clipboard
Issue with jax.grad
Hey,
Thanks for a lot for this nice package! I was trying to find the graph for the gradient operation and it does not do anything
from jax import make_jaxpr
import jax
import jpviz
def func(x):
return x**2
print(make_jaxpr(jax.grad(func))(1.0)) # This works and shows a series of ops
dot_graph = jpviz.draw(jax.grad(func), collapse_primitives=False)(1.0)
jpviz.view_pydot(dot_graph) # This shows only a single box!
Thanks for a lot for this nice package!
Thanks, glad you are making use of it!
I was trying to find the graph for the gradient operation and it does not do anything
Thanks for the report and example, I'll grab a look at this.
I think the issue here is not jit compiling the function, so
import jax
import jpviz
def func(x):
return x**2
dot_graph = jpviz.draw(jax.jit(jax.grad(func)), collapse_primitives=False)(1.0)
jpviz.view_pydot(dot_graph)
should work?
I've also just released an update that fixed some old node labelling issues.