pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

No JAX dispatch for `mul_without_zeros`

Open jessegrabowski opened this issue 1 year ago • 9 comments

Describe the issue:

The ProdWithoutZeros Op arises in the gradients of pt.prod. This currently cannot be compiled to gradient mode unless we specifically pass no_zeros_in_input=True. I guess we would just need a JAX dispatch for this function? Or maybe a mapping to the correct jax.lax function?

Reproducable code example:

import pytensor
import pytensor.tensor as pt
x = pt.dvector('x')
z = pt.prod(x, no_zeros_in_input=False)
gz = pytensor.grad(z, x)

f_gz = pytensor.function([x], gz, mode='JAX')
f_gz([1, 2, 3, 4])

Error message:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/utils.py:198, in streamline.<locals>.streamline_default_f()
    195 for thunk, node, old_storage in zip(
    196     thunks, order, post_thunk_old_storage
    197 ):
--> 198     thunk()
    199     for old_s in old_storage:

File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/basic.py:660, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
    654 def thunk(
    655     fgraph=self.fgraph,
    656     fgraph_jit=fgraph_jit,
    657     thunk_inputs=thunk_inputs,
    658     thunk_outputs=thunk_outputs,
    659 ):
--> 660     outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
    662     for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):

    [... skipping hidden 12 frame]

File /var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/tmpdtqmle3i:13, in jax_funcified_fgraph(x)
     12 # ProdWithoutZeros{axes=None}(Mul.0)
---> 13 tensor_variable_5 = careduce_1(tensor_variable_4)
     14 # ExpandDims{axis=0}(ProdWithoutZeros{axes=None}.0)

File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/jax/dispatch/elemwise.py:57, in jax_funcify_CAReduce.<locals>.careduce(x)
     54 if to_reduce:
     55     # In this case, we need to use the `jax.lax` function (if there
     56     # is one), and not the `jnp` version.
---> 57     jax_op = getattr(jax.lax, scalar_fn_name)
     58     init_value = jnp.array(scalar_op_identity, dtype=acc_dtype)

File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/jax/_src/deprecations.py:53, in deprecation_getattr.<locals>.getattr(name)
     52   return fn
---> 53 raise AttributeError(f"module {module!r} has no attribute {name!r}")

AttributeError: module 'jax.lax' has no attribute 'mul_without_zeros'

During handling of the above exception, another exception occurred:

AttributeError                            Traceback (most recent call last)
Cell In[61], line 1
----> 1 f_z([1, 2, 3, 4])

File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/compile/function/types.py:970, in Function.__call__(self, *args, **kwargs)
    967 t0_fn = time.perf_counter()
    968 try:
    969     outputs = (
--> 970         self.vm()
    971         if output_subset is None
    972         else self.vm(output_subset=output_subset)
    973     )
    974 except Exception:
    975     restore_defaults()

File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/utils.py:202, in streamline.<locals>.streamline_default_f()
    200             old_s[0] = None
    201 except Exception:
--> 202     raise_with_op(fgraph, node, thunk)

File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/utils.py:531, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    526     warnings.warn(
    527         f"{exc_type} error does not allow us to add an extra error message"
    528     )
    529     # Some exception need extra parameter in inputs. So forget the
    530     # extra long error message in that case.
--> 531 raise exc_value.with_traceback(exc_trace)

File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/utils.py:198, in streamline.<locals>.streamline_default_f()
    194 try:
    195     for thunk, node, old_storage in zip(
    196         thunks, order, post_thunk_old_storage
    197     ):
--> 198         thunk()
    199         for old_s in old_storage:
    200             old_s[0] = None

File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/basic.py:660, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
    654 def thunk(
    655     fgraph=self.fgraph,
    656     fgraph_jit=fgraph_jit,
    657     thunk_inputs=thunk_inputs,
    658     thunk_outputs=thunk_outputs,
    659 ):
--> 660     outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
    662     for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
    663         compute_map[o_var][0] = True

    [... skipping hidden 12 frame]

File /var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/tmpdtqmle3i:13, in jax_funcified_fgraph(x)
     11 tensor_variable_4 = elemwise_fn_2(tensor_variable_3, x)
     12 # ProdWithoutZeros{axes=None}(Mul.0)
---> 13 tensor_variable_5 = careduce_1(tensor_variable_4)
     14 # ExpandDims{axis=0}(ProdWithoutZeros{axes=None}.0)
     15 tensor_variable_6 = dimshuffle_1(tensor_variable_5)

File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/jax/dispatch/elemwise.py:57, in jax_funcify_CAReduce.<locals>.careduce(x)
     52 to_reduce = sorted(axis, reverse=True)
     54 if to_reduce:
     55     # In this case, we need to use the `jax.lax` function (if there
     56     # is one), and not the `jnp` version.
---> 57     jax_op = getattr(jax.lax, scalar_fn_name)
     58     init_value = jnp.array(scalar_op_identity, dtype=acc_dtype)
     59     return jax.lax.reduce(x, init_value, jax_op, to_reduce).astype(acc_dtype)

File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/jax/_src/deprecations.py:53, in deprecation_getattr.<locals>.getattr(name)
     51   warnings.warn(message, DeprecationWarning, stacklevel=2)
     52   return fn
---> 53 raise AttributeError(f"module {module!r} has no attribute {name!r}")

AttributeError: module 'jax.lax' has no attribute 'mul_without_zeros'
Apply node that caused the error: Switch(Eq.0, True_div.0, Switch.0)
Toposort index: 13
Inputs types: [TensorType(bool, shape=(1,)), TensorType(float64, shape=(None,)), TensorType(float64, shape=(None,))]
Inputs shapes: [(4,)]
Inputs strides: [(8,)]
Inputs values: [array([1., 2., 3., 4.])]
Outputs clients: [['output']]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3488, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3548, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/ipykernel_27218/3109327815.py", line 5, in <module>
    gz = pytensor.grad(z, x)
  File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/gradient.py", line 607, in grad
    _rval: Sequence[Variable] = _populate_grad_dict(
  File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/gradient.py", line 1407, in _populate_grad_dict
    rval = [access_grad_cache(elem) for elem in wrt]
  File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/gradient.py", line 1407, in <listcomp>
    rval = [access_grad_cache(elem) for elem in wrt]
  File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/gradient.py", line 1362, in access_grad_cache
    term = access_term_cache(node)[idx]
  File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/gradient.py", line 1192, in access_term_cache
    input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

PyTensor version information:

Pytensor 2.17.4

Context for the issue:

I want the gradient of a product in JAX mode

jessegrabowski avatar Nov 30 '23 14:11 jessegrabowski

I took out the bug label since it's more like a missing feature.

Is ProdWithoutZeros just pt.prod(x[pt.neq(x, 0)])? In that case I would suggest dispatching to something like than in JAX. jax.numpy.prod even accepts a where argument already: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.prod.html

Although I am not sure it's gonna like that dynamic boolean array. Maybe It needs to be implemented as a Scan?

ricardoV94 avatar Dec 01 '23 18:12 ricardoV94

The graph JAX produces for grad of Prod is absurd? They just enumerate all the cases where a zero might be (with some bisect logic)?

import jax
import jax.numpy as jnp

def prod(x):
    return jnp.prod(x)

# @jax.jit
def foo(x):
    return jax.grad(prod)(x)

jax.make_jaxpr(foo)(jnp.arange(800, dtype="float32"))

ricardoV94 avatar Dec 01 '23 18:12 ricardoV94

Maybe rewrite to something like this?

import jax
import jax.numpy as jnp

def prod(x):
    return jnp.exp(jnp.sum(jnp.log(x)))

# @jax.jit
def foo(x):
    return jax.grad(prod)(x)

jax.make_jaxpr(foo)(jnp.arange(800, dtype="float32"))

Amusingly, this was suggested as a workaround here before the actual thing we have now was implemented.

jessegrabowski avatar Dec 01 '23 18:12 jessegrabowski

That would fail with negative inputs, no?

ricardoV94 avatar Dec 16 '23 16:12 ricardoV94

Would prod(eq(x, 0), 1, x) work?

ricardoV94 avatar Dec 16 '23 16:12 ricardoV94

Isn't this exactly the jax graph but without the bisect logic to avoid checking every single value in x elemwise?

jessegrabowski avatar Dec 16 '23 16:12 jessegrabowski

I don't see anywhere were jax is checking for zeros

ricardoV94 avatar Dec 16 '23 16:12 ricardoV94

The PR that implemented the grad is here: https://github.com/google/jax/pull/675/files

It's a bit odd, is that really better than just a switch statement?

ricardoV94 avatar Dec 16 '23 16:12 ricardoV94

In any case you are asking to rewrite a gradient, and unfortunately because grads are eager you don't know if this MulWithZeros is due to a grad of a prod or whatever other reason. So to be safe you would need to implement that exact Op dispatch.

ricardoV94 avatar Dec 16 '23 16:12 ricardoV94