pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Adding conditionals for torch

Open Ch0ronomato opened this issue 1 year ago • 3 comments

Description

Add the branching ops

  • [x] Ifelse
  • [x] ScalarLoop
  • [ ] Scan
  • [x] OpFromGraph (WIP in #956 )
  • [ ] Blockwise

Ch0ronomato avatar Jul 17 '24 19:07 Ch0ronomato

Hey @ricardoV94 , could I get some clarity on scalar loop? I was under the impression that it might just work (I don't see any explicit tests for numba or jax) - what is the work needed for scalar loop? Here is an example test I wrote, that also maybe invalid


def test_ScalarOp():
    n_steps = int64("n_steps")
    x0 = float64("x0")
    const = float64("const")
    x = x0 + const

    op = ScalarLoop(init=[x0], constant=[const], update=[x])
    x = op(n_steps, x0, const)

    fn = function([n_steps, x0, const], x, mode=pytorch_mode)
    np.testing.assert_allclose(fn(5, 0, 1), 5)
    np.testing.assert_allclose(fn(5, 0, 2), 10)
    np.testing.assert_allclose(fn(4, 3, -1), -1)
op = ScalarLoop(), node = ScalarLoop(n_steps, x0, const)
kwargs = {'input_storage': [[None], [None], [None]], 'output_storage': [[None]], 'storage_map': {ScalarLoop.0: [None], const: [None], x0: [None], n_steps: [None]}}
nfunc_spec = None

    @pytorch_funcify.register(ScalarOp)
    def pytorch_funcify_ScalarOp(op, node, **kwargs):
        """Return pytorch function that implements the same computation as the Scalar Op.
    
        This dispatch is expected to return a pytorch function that works on Array inputs as Elemwise does,
        even though it's dispatched on the Scalar Op.
        """
    
        nfunc_spec = getattr(op, "nfunc_spec", None)
        if nfunc_spec is None:
>           raise NotImplementedError(f"Dispatch not implemented for Scalar Op {op}")
E           NotImplementedError: Dispatch not implemented for Scalar Op ScalarLoop

pytensor/link/pytorch/dispatch/scalar.py:19: NotImplementedError

Ch0ronomato avatar Jul 25 '24 18:07 Ch0ronomato

You haven't seen JAX/Numba code because scalar loop isn't yet supported in those backends either.

I suggest checking the perform method to have an idea of how the Operator works

ricardoV94 avatar Jul 26 '24 14:07 ricardoV94

For Blockwise you should be able to use vmap repeatedly for each batch dimension. If they would have an equivalent to np.vectorize that would be all we need.

ricardoV94 avatar Sep 01 '24 16:09 ricardoV94

I'm gonna close this out for now. We have the larger lists of ops and I'm not actively working on the scan op

Ch0ronomato avatar Jan 06 '25 16:01 Ch0ronomato