jaxpr-viz icon indicating copy to clipboard operation
jaxpr-viz copied to clipboard

Issue with jax.grad

Open SNMS95 opened this issue 1 year ago • 2 comments

Hey,

Thanks for a lot for this nice package! I was trying to find the graph for the gradient operation and it does not do anything

from jax import make_jaxpr
import jax
import jpviz

def func(x):
    return x**2

print(make_jaxpr(jax.grad(func))(1.0)) # This works and shows a series of ops
dot_graph = jpviz.draw(jax.grad(func), collapse_primitives=False)(1.0)
jpviz.view_pydot(dot_graph)  # This shows only a single box!

SNMS95 avatar Dec 19 '24 10:12 SNMS95

Thanks for a lot for this nice package!

Thanks, glad you are making use of it!

I was trying to find the graph for the gradient operation and it does not do anything

Thanks for the report and example, I'll grab a look at this.

zombie-einstein avatar Dec 19 '24 12:12 zombie-einstein

I think the issue here is not jit compiling the function, so

import jax
import jpviz

def func(x):
    return x**2

dot_graph = jpviz.draw(jax.jit(jax.grad(func)), collapse_primitives=False)(1.0)
jpviz.view_pydot(dot_graph)

should work?

I've also just released an update that fixed some old node labelling issues.

zombie-einstein avatar Dec 20 '24 00:12 zombie-einstein