pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Add support for random Generators in Numba backend

Open ricardoV94 opened this issue 1 year ago • 2 comments

Description

TODO:

  • [x] Merge code duplication into elemwise_codegen
  • [x] Handle broadcast via size argument
  • [x] Handle more than normal
  • [x] Add expand_dims in make_node. This is actually something that ends up being useful in other places, as in vectorization and dim analysis. It also canonicalizes graphs that are equivalent, as the expand_dims is always implict
  • [x] Handle non-scalar inputs / outputs. This would also solve Blockwise for Numba backend!
  • [x] Find a way to copy RNGs if inplace=False. Not a huge priority, but a nice to have for consistency
  • [x] Remove RandomState from codebase. I vote for no deprecation as this is functionality that was broken in Numba anyway (no way to reseed), and non-default for other backends.
  • [x] Split sneaky bugfix of RandomVariable vectorize into separate ~~commit~~ PR (#738)
  • [ ] Split dtype removal into separate PR
  • [ ] Split new RV implementations into separate commits

Related Issue

  • [x] Closes #316
  • [x] Closes #701

Checklist

Type of change

  • [x] New feature / enhancement
  • [ ] Bug fix
  • [ ] Documentation
  • [ ] Maintenance
  • [ ] Other (please specify):

ricardoV94 avatar Apr 05 '24 20:04 ricardoV94

Size and rng copy are now supported

ricardoV94 avatar Apr 17 '24 09:04 ricardoV94

@aseyboldt Elemwise is working again in case you want to benchmark before/after

ricardoV94 avatar Apr 23 '24 05:04 ricardoV94

I've been peeling away some of the changes into separate PRs to make this more manageable to review. Good thing is everything is working in some versions of this PR :)

ricardoV94 avatar May 14 '24 10:05 ricardoV94

Looks like this currently breaks pymc, I'm getting this import error:

File [~/git/pymc-dev/pymc/pytensorf.py:46](http://localhost:7890/lab/tree/nuts-py/notebooks/pymc-dev/pymc/pytensorf.py#line=45)
     44 from pytensor.tensor.random.op import RandomVariable
     45 from pytensor.tensor.random.type import RandomType
---> 46 from pytensor.tensor.random.var import (
     47     RandomGeneratorSharedVariable,
     48     RandomStateSharedVariable,
     49 )
     50 from pytensor.tensor.rewriting.shape import ShapeFeature
     51 from pytensor.tensor.sharedvar import SharedVariable, TensorSharedVariable

ImportError: cannot import name 'RandomStateSharedVariable' from 'pytensor.tensor.random.var' ([/home/adr/git/pytensor/pytensor/tensor/random/var.py](http://localhost:7890/lab/tree/nuts-py/notebooks/pytensor/pytensor/tensor/random/var.py))

aseyboldt avatar May 14 '24 10:05 aseyboldt

I did some benchmarks with the changed elemwise. And I guess the good news is that it is about 400x faster, bad news is that now it doesn't seem to compute the right thing :-)

import pytensor.tensor as pt
import pytensor
import numpy as np

N = 1000
x = pt.vector(shape=(N,))
y = pt.vector(shape=(N,))
z = x[None, :] + 2*y[:, None]

rng = np.random.default_rng(42)
u = rng.normal(size=N)
v = rng.normal(size=N)
pt_func = pytensor.function([x, y], z, mode="NUMBA")
func = pt_func.vm.jit_fn

func(u, v)[0].shape
# (1, 1000)

Even though the graph wants to produce a (1000, 1000):

Add [id A] <Matrix(float64, shape=(1000, 1000))> 3
 ├─ ExpandDims{axis=0} [id B] <Matrix(float64, shape=(1, 1000))> 2
 │  └─ <Vector(float64, shape=(1000,))> [id C] <Vector(float64, shape=(1000,))>
 └─ Mul [id D] <Matrix(float64, shape=(1000, 1))> 1
    ├─ [[2.]] [id E] <Matrix(float64, shape=(1, 1))>
    └─ ExpandDims{axis=1} [id F] <Matrix(float64, shape=(1000, 1))> 0
       └─ <Vector(float64, shape=(1000,))> [id G] <Vector(float64, shape=(1000,))>

Also strange: There doesn't seem to be fusion?

aseyboldt avatar May 14 '24 11:05 aseyboldt

Looks like this currently breaks pymc, I'm getting this import error:

File [~/git/pymc-dev/pymc/pytensorf.py:46](http://localhost:7890/lab/tree/nuts-py/notebooks/pymc-dev/pymc/pytensorf.py#line=45)
     44 from pytensor.tensor.random.op import RandomVariable
     45 from pytensor.tensor.random.type import RandomType
---> 46 from pytensor.tensor.random.var import (
     47     RandomGeneratorSharedVariable,
     48     RandomStateSharedVariable,
     49 )
     50 from pytensor.tensor.rewriting.shape import ShapeFeature
     51 from pytensor.tensor.sharedvar import SharedVariable, TensorSharedVariable

ImportError: cannot import name 'RandomStateSharedVariable' from 'pytensor.tensor.random.var' ([/home/adr/git/pytensor/pytensor/tensor/random/var.py](http://localhost:7890/lab/tree/nuts-py/notebooks/pytensor/pytensor/tensor/random/var.py))

The label major indicates breaking changes, and in PyMC we always pin a specific range, so it won't pick this up when a new version is released. Several other things are breaking changes in this PR including removing the dtype fake input and using None instead of () for no size

ricardoV94 avatar May 14 '24 11:05 ricardoV94

Also strange: There doesn't seem to be fusion?

There is no fusion when inputs broadcast together.

ricardoV94 avatar May 14 '24 11:05 ricardoV94

Failing code is funny. No test we had picked it up xD? Or I never run the whole Numba suite locally

Edit: I like the speedups, maybe it's an okay trade-off?

ricardoV94 avatar May 14 '24 11:05 ricardoV94

There is no fusion when inputs broadcast together.

True, I didn't think that through...

Edit: I like the speedups, maybe it's an okay trade-off?

Yeah, you gotta make some sacrifices if you want good performance. And that way we can easily do all our computations on the CPU, and don't have buy those expensive nvidia thingies.

aseyboldt avatar May 14 '24 11:05 aseyboldt

Funny, enough RVs are broadcasting correctly:

import pytensor.tensor as pt
import pytensor

N = 1000
x = pt.random.normal(size=(1, N,))
y = pt.random.normal(size=(N, 1))
z = pt.random.normal(x, pt.exp(y))

pt_func = pytensor.function([], z, mode="NUMBA")
# pytensor.dprint(pt_func, print_type=True)
assert pt_func().shape == (1000, 1000)

ricardoV94 avatar May 14 '24 11:05 ricardoV94

@aseyboldt I think your Elemwise example is failing in main as well?

Edit: No it's not, must be some caching stuff that tricked me

ricardoV94 avatar May 14 '24 11:05 ricardoV94

Commit Refactor encoding helper introduces the Elemwise bug

ricardoV94 avatar May 14 '24 11:05 ricardoV94

Ah just something dumb: output_bc_patterns = tuple([out.type.broadcastable for out in node.inputs])

ricardoV94 avatar May 14 '24 11:05 ricardoV94

Ah just something dumb: output_bc_patterns = tuple([out.type.broadcastable for out in node.inputs])

Let me guess, AI autocomplete? :-)

aseyboldt avatar May 14 '24 11:05 aseyboldt

Ah just something dumb: output_bc_patterns = tuple([out.type.broadcastable for out in node.inputs])

Let me guess, AI autocomplete? :-)

Actually no, I think I changed that on purpose. Before it was setting the output bc to False for the outputs, which didn't make sense either? I think vanilla copy-pasta human error when I tried to fix it haha

I pushed the fixes

ricardoV94 avatar May 14 '24 11:05 ricardoV94

I'll need to rebase to get your new jit options in as well. Question, the overload needs to include them because it's part of the dispatch signature?

ricardoV94 avatar May 14 '24 11:05 ricardoV94

Yes, overload calls njit on the returned function, and uses jit_options for that njit call.

aseyboldt avatar May 14 '24 11:05 aseyboldt

I rebased, the no cython wrapper made the random test file quite faster!

ricardoV94 avatar May 14 '24 12:05 ricardoV94

I rebased, the no cython wrapper made the random test file quite faster!

Or maybe not. Anyway doesn't feel too bad

ricardoV94 avatar May 14 '24 12:05 ricardoV94

This is finally ready for review!

ricardoV94 avatar May 24 '24 10:05 ricardoV94

For the numba iteration I went with a wrap OFG that adds the core_shape as an explicit input. I felt this was the cleanest solution

ricardoV94 avatar May 24 '24 11:05 ricardoV94

Codecov Report

Attention: Patch coverage is 81.85596% with 131 lines in your changes missing coverage. Please review.

Project coverage is 80.68%. Comparing base (fc21336) to head (69111c8). Report is 206 commits behind head on main.

Files Patch % Lines
pytensor/link/numba/dispatch/random.py 56.56% 96 Missing :warning:
pytensor/link/numba/dispatch/vectorize_codegen.py 90.26% 11 Missing and 11 partials :warning:
pytensor/tensor/random/op.py 90.32% 3 Missing and 6 partials :warning:
pytensor/link/jax/dispatch/random.py 86.66% 2 Missing :warning:
pytensor/tensor/random/rewriting/jax.py 77.77% 0 Missing and 2 partials :warning:
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #691      +/-   ##
==========================================
- Coverage   80.85%   80.68%   -0.18%     
==========================================
  Files         162      163       +1     
  Lines       47043    47036       -7     
  Branches    11514    11530      +16     
==========================================
- Hits        38038    37951      -87     
- Misses       6750     6832      +82     
+ Partials     2255     2253       -2     
Files Coverage Δ
pytensor/compile/builders.py 77.50% <100.00%> (+0.05%) :arrow_up:
pytensor/compile/mode.py 84.40% <ø> (ø)
pytensor/link/numba/dispatch/basic.py 85.54% <100.00%> (-0.15%) :arrow_down:
pytensor/link/numba/dispatch/elemwise.py 91.91% <100.00%> (+3.19%) :arrow_up:
pytensor/link/numba/dispatch/scan.py 95.89% <100.00%> (ø)
pytensor/tensor/blockwise.py 83.50% <100.00%> (-0.42%) :arrow_down:
pytensor/tensor/random/basic.py 99.22% <100.00%> (-0.08%) :arrow_down:
pytensor/tensor/random/rewriting/basic.py 92.25% <100.00%> (-1.34%) :arrow_down:
pytensor/tensor/random/rewriting/numba.py 100.00% <100.00%> (ø)
pytensor/tensor/random/type.py 85.45% <ø> (-5.18%) :arrow_down:
... and 8 more

... and 4 files with indirect coverage changes

codecov[bot] avatar May 24 '24 13:05 codecov[bot]

Add numba implementations for the random variables. (By using existing numba impls or adapting the numpy implementation and writing numba code)

Yup.

Some minor changes to the numba rewrites: Adding rewrites for the inner graph of scans. (I thought we had a flag for ops that show if they have an inner graph and trigger them through that somehow? But I couldn't find that code)

We only trigger them at link time, has always been the case (except some rewrites that try to do stuff with existing Scans from the outside). Scan has an extra knob in that it allows users to specify it's inner compilation Mode (no other Ops with inner graphs allow this, I guess OFG indirectly, but it's not tested/documented). This is actually a source of trouble when we have a graph that may be compiled into distinct backends (happens in a typical PyMC workflow with external samplers).

We have to rethink the Scan having a user-customizable mode. Anyway this PR simply made sure we include the Numba required rewrite for RVs when dispatching the Scan impl.

Change random variables so that they specify a signature. Maybe this should be mentioned in the release notes because it affects for instance pymc implementations or custom distributions?

This was not a breaking change, it will issue a warning telling people/PyMC to use signatures instead. More important imo to announce the actual Op node signature change and dtype is no longe part of it, and size=[] is no longer the same as size=None. Will keep in mind to be verbose in the release notes

Introduce a numba rewrite to set the core shape. I think a little more docstring would be nice there, because otherwise it is hard to tell what that is for.

Sure, will add

ricardoV94 avatar May 24 '24 15:05 ricardoV94

Expanded the explanation of the core-shape numba rewrite

ricardoV94 avatar May 24 '24 15:05 ricardoV94

@aseyboldt any suggestion about not suppressing all Numba Warnings :) The message match stopped working with the new store_core_outputs for some reason I can't tell (I tried with a bunch of strings/regexes without success)

It's the last commit

ricardoV94 avatar May 24 '24 15:05 ricardoV94

Why not just message="Cannot cache compiled function"? filterwarnings accepts a regex for the message, so maybe something in the string is interpreted in a strange way?

Edit On second though, I think it would be better to have the function names in there...

aseyboldt avatar May 24 '24 16:05 aseyboldt

I tried that, it did not work for some reason. Either I did something dumb or numba magic is somehow sidestepping the usual warning functionality.

If you want to try the relevant test is https://github.com/pymc-devs/pytensor/blob/1e96b8942bc42fd48af50e138c5b2458083c6091/tests/link/numba/test_basic.py#L1071-L1078

ricardoV94 avatar May 24 '24 16:05 ricardoV94

Using this works for me:

warnings.filterwarnings(
    "ignore",
    message='Cannot cache compiled function "numba_funcified_fgraph"',
    category=NumbaWarning,
)

warnings.filterwarnings(
    "ignore",
    message='Cannot cache compiled function "store_core_outputs"',
    category=NumbaWarning,
)

aseyboldt avatar May 24 '24 16:05 aseyboldt

With that the test fails for me, maybe it's something odd in my testing environment? I'll push and see what happens on the CI

ricardoV94 avatar May 24 '24 16:05 ricardoV94

Also fails in the CI: https://github.com/pymc-devs/pytensor/actions/runs/9226914530/job/25387761898?pr=691#step:6:2878

ricardoV94 avatar May 24 '24 16:05 ricardoV94