pytensor
pytensor copied to clipboard
No JAX dispatch for `mul_without_zeros`
Describe the issue:
The ProdWithoutZeros
Op
arises in the gradients of pt.prod
. This currently cannot be compiled to gradient mode unless we specifically pass no_zeros_in_input=True
. I guess we would just need a JAX dispatch for this function? Or maybe a mapping to the correct jax.lax function?
Reproducable code example:
import pytensor
import pytensor.tensor as pt
x = pt.dvector('x')
z = pt.prod(x, no_zeros_in_input=False)
gz = pytensor.grad(z, x)
f_gz = pytensor.function([x], gz, mode='JAX')
f_gz([1, 2, 3, 4])
Error message:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/utils.py:198, in streamline.<locals>.streamline_default_f()
195 for thunk, node, old_storage in zip(
196 thunks, order, post_thunk_old_storage
197 ):
--> 198 thunk()
199 for old_s in old_storage:
File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/basic.py:660, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
654 def thunk(
655 fgraph=self.fgraph,
656 fgraph_jit=fgraph_jit,
657 thunk_inputs=thunk_inputs,
658 thunk_outputs=thunk_outputs,
659 ):
--> 660 outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
662 for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
[... skipping hidden 12 frame]
File /var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/tmpdtqmle3i:13, in jax_funcified_fgraph(x)
12 # ProdWithoutZeros{axes=None}(Mul.0)
---> 13 tensor_variable_5 = careduce_1(tensor_variable_4)
14 # ExpandDims{axis=0}(ProdWithoutZeros{axes=None}.0)
File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/jax/dispatch/elemwise.py:57, in jax_funcify_CAReduce.<locals>.careduce(x)
54 if to_reduce:
55 # In this case, we need to use the `jax.lax` function (if there
56 # is one), and not the `jnp` version.
---> 57 jax_op = getattr(jax.lax, scalar_fn_name)
58 init_value = jnp.array(scalar_op_identity, dtype=acc_dtype)
File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/jax/_src/deprecations.py:53, in deprecation_getattr.<locals>.getattr(name)
52 return fn
---> 53 raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.lax' has no attribute 'mul_without_zeros'
During handling of the above exception, another exception occurred:
AttributeError Traceback (most recent call last)
Cell In[61], line 1
----> 1 f_z([1, 2, 3, 4])
File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/compile/function/types.py:970, in Function.__call__(self, *args, **kwargs)
967 t0_fn = time.perf_counter()
968 try:
969 outputs = (
--> 970 self.vm()
971 if output_subset is None
972 else self.vm(output_subset=output_subset)
973 )
974 except Exception:
975 restore_defaults()
File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/utils.py:202, in streamline.<locals>.streamline_default_f()
200 old_s[0] = None
201 except Exception:
--> 202 raise_with_op(fgraph, node, thunk)
File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/utils.py:531, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
526 warnings.warn(
527 f"{exc_type} error does not allow us to add an extra error message"
528 )
529 # Some exception need extra parameter in inputs. So forget the
530 # extra long error message in that case.
--> 531 raise exc_value.with_traceback(exc_trace)
File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/utils.py:198, in streamline.<locals>.streamline_default_f()
194 try:
195 for thunk, node, old_storage in zip(
196 thunks, order, post_thunk_old_storage
197 ):
--> 198 thunk()
199 for old_s in old_storage:
200 old_s[0] = None
File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/basic.py:660, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
654 def thunk(
655 fgraph=self.fgraph,
656 fgraph_jit=fgraph_jit,
657 thunk_inputs=thunk_inputs,
658 thunk_outputs=thunk_outputs,
659 ):
--> 660 outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
662 for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
663 compute_map[o_var][0] = True
[... skipping hidden 12 frame]
File /var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/tmpdtqmle3i:13, in jax_funcified_fgraph(x)
11 tensor_variable_4 = elemwise_fn_2(tensor_variable_3, x)
12 # ProdWithoutZeros{axes=None}(Mul.0)
---> 13 tensor_variable_5 = careduce_1(tensor_variable_4)
14 # ExpandDims{axis=0}(ProdWithoutZeros{axes=None}.0)
15 tensor_variable_6 = dimshuffle_1(tensor_variable_5)
File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/jax/dispatch/elemwise.py:57, in jax_funcify_CAReduce.<locals>.careduce(x)
52 to_reduce = sorted(axis, reverse=True)
54 if to_reduce:
55 # In this case, we need to use the `jax.lax` function (if there
56 # is one), and not the `jnp` version.
---> 57 jax_op = getattr(jax.lax, scalar_fn_name)
58 init_value = jnp.array(scalar_op_identity, dtype=acc_dtype)
59 return jax.lax.reduce(x, init_value, jax_op, to_reduce).astype(acc_dtype)
File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/jax/_src/deprecations.py:53, in deprecation_getattr.<locals>.getattr(name)
51 warnings.warn(message, DeprecationWarning, stacklevel=2)
52 return fn
---> 53 raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.lax' has no attribute 'mul_without_zeros'
Apply node that caused the error: Switch(Eq.0, True_div.0, Switch.0)
Toposort index: 13
Inputs types: [TensorType(bool, shape=(1,)), TensorType(float64, shape=(None,)), TensorType(float64, shape=(None,))]
Inputs shapes: [(4,)]
Inputs strides: [(8,)]
Inputs values: [array([1., 2., 3., 4.])]
Outputs clients: [['output']]
Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3488, in run_ast_nodes
if await self.run_code(code, result, async_=asy):
File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3548, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "/var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/ipykernel_27218/3109327815.py", line 5, in <module>
gz = pytensor.grad(z, x)
File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/gradient.py", line 607, in grad
_rval: Sequence[Variable] = _populate_grad_dict(
File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/gradient.py", line 1407, in _populate_grad_dict
rval = [access_grad_cache(elem) for elem in wrt]
File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/gradient.py", line 1407, in <listcomp>
rval = [access_grad_cache(elem) for elem in wrt]
File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/gradient.py", line 1362, in access_grad_cache
term = access_term_cache(node)[idx]
File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/gradient.py", line 1192, in access_term_cache
input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
PyTensor version information:
Pytensor 2.17.4
Context for the issue:
I want the gradient of a product in JAX mode
I took out the bug label since it's more like a missing feature.
Is ProdWithoutZeros just pt.prod(x[pt.neq(x, 0)])
? In that case I would suggest dispatching to something like than in JAX. jax.numpy.prod
even accepts a where
argument already: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.prod.html
Although I am not sure it's gonna like that dynamic boolean array. Maybe It needs to be implemented as a Scan?
The graph JAX produces for grad of Prod is absurd? They just enumerate all the cases where a zero might be (with some bisect logic)?
import jax
import jax.numpy as jnp
def prod(x):
return jnp.prod(x)
# @jax.jit
def foo(x):
return jax.grad(prod)(x)
jax.make_jaxpr(foo)(jnp.arange(800, dtype="float32"))
Maybe rewrite to something like this?
import jax
import jax.numpy as jnp
def prod(x):
return jnp.exp(jnp.sum(jnp.log(x)))
# @jax.jit
def foo(x):
return jax.grad(prod)(x)
jax.make_jaxpr(foo)(jnp.arange(800, dtype="float32"))
Amusingly, this was suggested as a workaround here before the actual thing we have now was implemented.
That would fail with negative inputs, no?
Would prod(eq(x, 0), 1, x)
work?
Isn't this exactly the jax graph but without the bisect logic to avoid checking every single value in x
elemwise?
I don't see anywhere were jax is checking for zeros
The PR that implemented the grad is here: https://github.com/google/jax/pull/675/files
It's a bit odd, is that really better than just a switch statement?
In any case you are asking to rewrite a gradient, and unfortunately because grads are eager you don't know if this MulWithZeros is due to a grad of a prod or whatever other reason. So to be safe you would need to implement that exact Op dispatch.