pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Use Numba Generators for random graphs and deprecate shared RandomState variables

Open ricardoV94 opened this issue 2 years ago • 15 comments

Description

RandomState are legacy in numpy and we can save some complexity by letting go of them in PyTensor.

We were "obliged" to keep them because that's the only kind of RNG that Numba supported until now

ricardoV94 avatar May 23 '23 08:05 ricardoV94

I'm trying to understand this issue to start a PR. What actually needs to be done? Numba functions can take numpy random generators without any hassle now. For example, this works:

import numba as nb
import numpy as np

@nb.njit
def draw_nb(rng, loc, scale, size):
    return rng.normal(loc=loc, scale=scale, size=size)

rng = np.random.default_rng()
draw_nb(rng, 0.0, 1.0, (10,))

So would it be enough to just chop out all of the extra machinery from numba\dispatch\random.py related to the random states here? Plus the check in make_numba_random_fn? Or is there something deeper going on.

Numba still doesn't support broadcasting from parameters, so all the rest of the machinery seems like it needs to say (though I personally find it quite difficult to follow, it would be nice to refactor it to be more clear).

jessegrabowski avatar Jan 07 '24 18:01 jessegrabowski

Numba still doesn't support broadcasting from parameters, so all the rest of the machinery seems like it needs to say (though I personally find it quite difficult to follow, it would be nice to refactor it to be more clear).

This is the big hurdle. Generators without broadcasting is pretty useless so we have to support it if we want to really support RVs in numba backend (and phase out the RandomState)

ricardoV94 avatar Jan 07 '24 21:01 ricardoV94

This is the PR where Generator support was added to Aesara. The main complexity is writing broadcasting logic with python strings which is the usual numba backend PITA: https://github.com/aesara-devs/aesara/pull/1245

ricardoV94 avatar Jan 07 '24 21:01 ricardoV94

But I'm saying I'm pretty sure you can directly plug a generator into what they already have? There's even a numba datatype for numpy random generators (nb.types.NumPyRandomGenerator)

jessegrabowski avatar Jan 07 '24 21:01 jessegrabowski

I think the only challenge is the broadcasting logic, which I think can't be written as a Python function without writing it for every Op?

I don't remember exactly where did things break.

Also the strict API for RVs requires copying the RNG if the Op is not inplace. Not sure if this is relevant.

ricardoV94 avatar Jan 07 '24 21:01 ricardoV94

What did you mean, what they already have?

ricardoV94 avatar Jan 07 '24 21:01 ricardoV94

Feel free to open a PR if you it seems like some minimal changes do the job (or even if they don't).

Unfortunately, I've lost the context to this issue to be able to help just from thinking

ricardoV94 avatar Jan 07 '24 21:01 ricardoV94

The broadcasting is done with a loop, so it's actually not too bad. Here is a basic sketch:

import numba as nb
import numpy as np

@nb.njit
def draw_nb(rng, loc, scale, size):
    loc_bcast = np.broadcast_to(loc, size)
    scale_bcast = np.broadcast_to(scale, size)

    bcast_samples = np.empty(size, dtype=np.float64)
    
    for idx in np.ndindex(size):
        bcast_samples[idx] = rng.normal(loc_bcast[idx], scale_bcast[idx])
    return bcast_samples

rng = np.random.default_rng(1)
loc = np.zeros((5, 5))
scale = 1.0
size = (10, 5, 5)
samples_np = rng.normal(loc=loc, scale=scale, size=size)

rng = np.random.default_rng(1)
samples_nb = draw_nb(rng, loc, scale, size)

np.allclose(samples_np, samples_nb) #True

I guess it just seems like all the special packing/unpacking of the random state that is done in the numba linker can just go. But is that the only thing causing problems?

Sure I'll open a PR. I'm also curious why this all causes problems only with certain models (like scan) but not in others.

jessegrabowski avatar Jan 07 '24 21:01 jessegrabowski

Can you write that function in a way that works for any RV Op regardless of the number of parameters and ndim_supp/ndims_params?

I imagine that's the abstraction that complicates things.

ricardoV94 avatar Jan 07 '24 21:01 ricardoV94

The boxing and unboxing of RandomState is supposed to go out completely yes

ricardoV94 avatar Jan 07 '24 21:01 ricardoV94

No, you have to write overloads on a case-by-case basis. But that should be a hassle, not a blocker. Also it already exists.

The current code jumps through a lot of hoops to make a function that generates generalized boilerplate code. I guess it's elegant, but it makes it really hard to read. A bunch of individual functions would be easier to maintain IMO.

jessegrabowski avatar Jan 07 '24 21:01 jessegrabowski

It's also much more error prone though. For instance this is what the full function for NormalRV may look like (functionality wise):

def normal_rv_dispatch(op, node):
  
  def normal(rng, size, dtype_code, loc, scale):
    if size is not None:
      loc = np.broadcast_to(loc, size)
      sigma = np.broadcast_to(sigma, size)
    else:
      loc, scale = np.broadcast_arrays(loc, scale)
      size = loc.shape
    
    if not op.inplace
      rng = copy(rng)

    out = np.empty(size, dtype=node.outputs[1].dtype
    
    # populate out

    return rng, out

  return normal

Can we use any helpers to avoid repeating all this boilerplate? Also it gets slightly more tricky with:

  1. multivariate RVs where the output shape has an additional core shape and
  2. non scalar core parameters, where we broadcast the batch dims but not the core dims (there's a faulty python/pytensor impl of this in the random module)

ricardoV94 avatar Jan 07 '24 22:01 ricardoV94

So I guess there are two issues being discussed here:

  1. Can we eliminate the RandomStream objects and just use RandomGenerators (I think yes)
  2. Does there need to be any changes to the numba overloads for RVs to make this happen (i think no, but it might be nice to think about anyway)

As I understand it, only issue (1) is causing problems. My motivation is to be able to use nutpie on any PyMC model -- currently I get an error about how random generators are not supported for some (I think scan based? But I don't have an example off the top of my head).

Anyway I'll open a PR that dumps the random streams and see what it breaks when I have some time.

jessegrabowski avatar Jan 07 '24 22:01 jessegrabowski

I also bet numba will complain about the use of size or something and perhaps require we define two functions depending on whether size is provided.

And probably size has to be converted to a fixed size tuple. Small things that add to more boilerplate. Then multiply that by all the RV ops we have between pytensor and PyMC.

For me unreadable string manipulation doesn't seem too bad in comparison (but I can be persuaded).

Maybe we can make it more readable?

Regardless, hopefully numba will just do the work and implement overloads properly so we don't have to-.-

ricardoV94 avatar Jan 07 '24 22:01 ricardoV94

Anyway I'll open a PR that dumps the random streams and see what it breaks when I have some time.

Thanks!

ricardoV94 avatar Jan 07 '24 22:01 ricardoV94