numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

Equinox models integration

Open juanitorduz opened this issue 1 year ago • 4 comments

It would be nice to have equinox_module and random_equinox_module model functions in https://github.com/pyro-ppl/numpyro/blob/master/numpyro/contrib/module.py as Equinox seems to be in quite active development.

Would this be a good addition?

I could give it a shot in the upcoming months but I will need some guidance :) Still, I am also happy if a more experienced dev wants to give it a go. XD.

juanitorduz avatar Dec 27 '23 08:12 juanitorduz

Hi @juanitorduz, if you need this feature, please feel free to put it in contrib.module. I guess you can mimic random_flax_module for an implementation. If you need to clarify something, please leave a comment in this issue thread.

fehiepsi avatar Dec 28 '23 21:12 fehiepsi

Great! Makes sense. Thank you @fehiepsi ! I'll give it a try in the upcoming months!

juanitorduz avatar Dec 28 '23 21:12 juanitorduz

I've been using this in my package flowjax for registering parameters for equinox modules.


def register_params(
    name: str,
    model: PyTree,
    filter_spec: Callable | PyTree = eqx.is_inexact_array,
):
    """Register numpyro params for an arbitrary pytree.

    This partitions the parameters and static components, registers the parameters using
    numpyro.param, then recombines them. This should be called from within an inference
    context to have an effect, e.g. within a numpyro model or guide function.

    Args:
        name: Name for the parameter set.
        model: The pytree (e.g. an equinox module, flowjax distribution/bijection).
        filter_spec: Equinox `filter_spec` for specifying trainable parameters. Either a
            callable `leaf -> bool`, or a PyTree with prefix structure matching `dist`
            with True/False values. Defaults to `eqx.is_inexact_array`.

    """
    params, static = eqx.partition(model, filter_spec)
    if callable(params):
        # Wrap to avoid special handling of callables by numpyro. Numpyro expects a
        # callable to be used for lazy initialization, whereas in our case it is likely
        # a callable module we wish to train.
        params = numpyro.param(name, lambda _: params)
    else:
        params = numpyro.param(name, params)
    return eqx.combine(params, static)

It's not particularly well tested, and I'm not familiar with the implementations for other frameworks, but maybe it's another useful reference. After training I just use eqx.combine(trained_params, model) to retrieve the trained module.

danielward27 avatar Jan 26 '24 17:01 danielward27

Thank you @danielward27 ! This will be a great entry point! (I am planning to tackle this sometime in February)

juanitorduz avatar Jan 26 '24 17:01 juanitorduz