bambi icon indicating copy to clipboard operation
bambi copied to clipboard

Use `bayeux` to access a wide range of samplers

Open GStechschulte opened this issue 1 year ago • 17 comments

I have been following @ColCarroll bayeux library and thought it would be interesting to see how Bambi could incorporate it to offer the users a wide range of samplers (more than nuts_blackjax and nuts_numpyro).

Edit: Now I access the samplers programmatically using the inference_method arg. This removes previously needed code for nuts_blackjax and nuts_numpyro. If a user passes an MCMC inference method other than the PyMC MCMC sampler mcmc, Bambi will use bayeux to call that sampler.

data = bmb.load_data("ANES")
clinton_data = data.loc[data["vote"].isin(["clinton", "trump"]), :]

model = bmb.Model("vote['clinton'] ~ party_id + party_id:age", clinton_data, family="bernoulli")
model.build()

idata = model.fit(inference_method="blackjax_hmc")

However, when cleaning the InferenceData, I am getting an xarray error

"name": "ValueError",
	"message": "('chain', 'draw') must be a permuted list of FrozenMappingWarningOnValuesAccess({'chain': 8, 'draw': 500, 'party_id_dim_0': 2, 'party_id:age_dim_0': 3}), unless `...` is included"

It seems xarray is not liking something that bayeux is doing with the InferenceData.

Another thought is that using bayeux with Bambi is so easy

model = bmb.Model()
model.build()

bx_model = bx.Model.from_pymc(model.backend.model)
bx_model.<some sampling func>

that maybe we just add documentation explaining how to use bayeux with Bambi to avoid overhead from Bambi's side?

To Do:

  • [ ] add additional tests in test_alternative_samplers.py
  • [ ] update notebooks using JAX based samplers

GStechschulte avatar Feb 04 '24 14:02 GStechschulte

Really cool! A few suggestions --

  • bayeux could be invisible here, and you could access all the methods programatically. That's done here, but i can factor that out into a function that gives methods instead of strings -- currently it only adds a method if the underlying library, e.g. optax, is installed. i'm not sure how to avoid using a string at some point. You could have an api like model.fit.bx that initializes the bayeux.Model?
  • I'm happy to add a from_bambi constructor on the bayeux side to make your second option even easier.

ColCarroll avatar Feb 04 '24 14:02 ColCarroll

What kind of API bayex has? Could we enable support for external samplers if we define specific API we support (need)? (Users could create class for external samplers if needed?)

Of course that does not mean we could not have a text based support on certain libraries?

ahartikainen avatar Feb 04 '24 15:02 ahartikainen

bayeux is inspired by arviz, in that it just provides a representation of a model that is general enough for most samplers, but it does make the decision that it is specialized to JAX-based models (most of the algorithms use autodiff, vectorization is baked in, and automatic function inverses are also used). If you've got a sampler that accepts a JAX-based log density, you could use bayeux with it (or contribute it to bayeux!)

ColCarroll avatar Feb 04 '24 15:02 ColCarroll

Really cool! A few suggestions --

  • bayeux could be invisible here, and you could access all the methods programatically. That's done here, but i can factor that out into a function that gives methods instead of strings -- currently it only adds a method if the underlying library, e.g. optax, is installed. i'm not sure how to avoid using a string at some point. You could have an api like model.fit.bx that initializes the bayeux.Model?
  • I'm happy to add a from_bambi constructor on the bayeux side to make your second option even easier.

Thanks for the suggestions! That makes sense. I am liking the second option, but I will run some ideas past the others first before asking for the feature. Thanks!

GStechschulte avatar Feb 05 '24 14:02 GStechschulte

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

In this example, the error is because bayeux is appending _0 to party_id_dim of the posterior dims. This results in Bambi discarding all posterior dims because the dims in the PyMC model are inconsistent with the dims of the InferenceData returned by bayeux.

For example:

print(bayeux_idata.posterior.dims, pymc_idata.posterior.dims)
(FrozenMappingWarningOnValuesAccess({'chain': 8, 'draw': 500, 'party_id_dim_0': 2, 'party_id:age_dim_0': 3}),
 FrozenMappingWarningOnValuesAccess({'chain': 4, 'draw': 1000, 'party_id_dim': 2, 'party_id:age_dim': 3}))

GStechschulte avatar Feb 09 '24 05:02 GStechschulte

In this example, the error is because bayeux is appending _0 to party_id_dim of the posterior dims. This results in Bambi discarding all posterior dims because the dims in the PyMC model are inconsistent with the dims of the InferenceData returned by bayeux.

For example:

print(bayeux_idata.posterior.dims, pymc_idata.posterior.dims)
(FrozenMappingWarningOnValuesAccess({'chain': 8, 'draw': 500, 'party_id_dim_0': 2, 'party_id:age_dim_0': 3}),
FrozenMappingWarningOnValuesAccess({'chain': 4, 'draw': 1000, 'party_id_dim': 2, 'party_id:age_dim': 3}))

Update: I have added logic in the cleaning of idata to: (1) identify bayeux idata and to remove the trailing numeric suffix from the _dims, and (2) to rename the posterior dims to be consistent with the PyMC model coords.

Although this works for simple models, I haven't tried this logic with more complex models in Bambi such as HSGP or with models that have a large number of dims and or factors. Since the idata contains very "important data", I also think it could be worthwhile to not clean idata when the user calls samplers from bayeux at the moment in order to avoid unknown effects appearing in the inference data.

GStechschulte avatar Feb 10 '24 09:02 GStechschulte

Would it be possible to allow access to the optimization methods from Bayeux as well via Bambi?

On Sat, Feb 10, 2024 at 4:25 AM Gabriel Stechschulte < @.***> wrote:

In this example, the error is because bayeux is appending _0 to party_id_dim of the posterior dims. This results in Bambi discarding https://github.com/GStechschulte/bambi/blob/9f1d9d179071abbb4cc6255242132829aae80faf/bambi/backend/pymc.py#L261C4-L261C86 all posterior dims because the dims in the PyMC model are inconsistent with the dims of the InferenceData returned by bayeux.

For example:

print(bayeux_idata.posterior.dims, pymc_idata.posterior.dims)

(FrozenMappingWarningOnValuesAccess({'chain': 8, 'draw': 500, 'party_id_dim_0': 2, 'party_id:age_dim_0': 3}), FrozenMappingWarningOnValuesAccess({'chain': 4, 'draw': 1000, 'party_id_dim': 2, 'party_id:age_dim': 3}))

Update: I have added logic in the cleaning of idata to: (1) identify bayeux idata and to remove the trailing numeric suffix from the _dims, and (2) to rename the posterior dims to be consistent with the PyMC model coords.

— Reply to this email directly, view it on GitHub https://github.com/bambinos/bambi/pull/775#issuecomment-1936951148, or unsubscribe https://github.com/notifications/unsubscribe-auth/AH3QQV3RMDQKAJ34J6FVNWLYS44HDAVCNFSM6AAAAABCY5BXBGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSMZWHE2TCMJUHA . You are receiving this because you are subscribed to this thread.Message ID: @.***>

zwelitunyiswa avatar Feb 10 '24 16:02 zwelitunyiswa

I think you could -- @GStechschulte has a good outline here. If @tomicapretto thinks this is a reasonable idea in principle, I'd be happy to either collaborate on this to (programatically) get the bayeux inference methods in, or send a follow-up that generalizes it a bit, and allows optimization and VI.

ColCarroll avatar Feb 11 '24 02:02 ColCarroll

I think this is really cool, thanks @GStechschulte and thanks @ColCarroll for bayeux. I'm not sure I am aware of all the details, but what is the reason why bayeux is appending _dim_0 to dimension names? As far as I remember that was an xarray thing. Or is it that bayeux is not receiving dimension names from the PyMC model and thus it appends _dim_0?

Another thing, I see we're replacing blackjax, jax, jaxlib, and numpyro with bayeux. However, as far as I know bayeux does not install these dependencies, so doing pip install bambi[jax] won't give users access to JAX based samplers, right? (I'm not very familiar with bayeux so I may be wrong).

tomicapretto avatar Feb 18 '24 18:02 tomicapretto

bayeux will pull those in, but i agree it is better to be explicit and require dependencies (in case bayeux makes weird decisions).

I'll double check on the naming conventions!

ColCarroll avatar Feb 18 '24 21:02 ColCarroll

Oh right, yes: bayeux has no concept of the dimensions from pymc. That would have to be implemented as a post-processing step to rename the arviz dimensions.

ColCarroll avatar Feb 19 '24 15:02 ColCarroll

Oh right, yes: bayeux has no concept of the dimensions from pymc. That would have to be implemented as a post-processing step to rename the arviz dimensions.

Thanks for the answer, it makes much more sense now!

tomicapretto avatar Feb 19 '24 16:02 tomicapretto

@ColCarroll thanks for the information.

@tomicapretto I can apply this post processing step on Bambi's side.

GStechschulte avatar Feb 19 '24 17:02 GStechschulte

@ColCarroll thanks for the information.

@tomicapretto I can apply this post processing step on Bambi's side.

Sounds great, just let me know if you need help or a second opinion :)

tomicapretto avatar Feb 19 '24 17:02 tomicapretto

Two updates:

  1. I added a processing step for when Bambi cleans the idata, it renames the idata dims and coordinates to match those of the underlying PyMC model.
  2. I explicitly added JAX based sampler dependencies.

Regarding

I'd be happy to either collaborate on this to (programatically) get the bayeux inference methods in, or send a follow-up that generalizes it a bit, and allows optimization and VI.

@ColCarroll I'd be happy to collaborate and see how you would do this 👍🏼

GStechschulte avatar Feb 19 '24 19:02 GStechschulte

@ColCarroll thanks a lot for the review! I will incorporate these in the coming days.

GStechschulte avatar Feb 20 '24 21:02 GStechschulte

Thanks for the reviews and suggestions @ColCarroll 😄 I have taken your suggestions and implemented them in slightly different ways.

  • Bambi has separate inference method calls for VI, MCMC, and Laplace approx. within the PyMCModel class. I think we should maintain this encapsulation and not be able to call all bayeux methods within a single method? This does result in duplicate code though, e.g.
import bayeux as bx
import jax

bx_model = bx.Model.from_pymc(self.model)
bx_method = operator.attrgetter(inference_method)(bx_model.<method>)

is written for MCMC, VI, Laplace approx., and optimize.

  • When passing bayeux methods to inference_method, I have kept the naming convention to be tfp_hmc, blackjax_nuts since this is also how the dict values are returned in bx_model.methods.

GStechschulte avatar Mar 01 '24 07:03 GStechschulte

@tomicapretto should we allow optimization methods? I think it is really cool but model.predict will not work out of the box as optimization methods return OptimizerResults and not an idata object. I added it, but we can also delete it.

Update: In this PR I will leave optimization methods out. Then, in a separate PR, add optimization algorithms and get the OptimizerResults object to work with model.predict.

GStechschulte avatar Mar 01 '24 07:03 GStechschulte

It would be cool if it could. It would allow non-Bayesians an entry point into Bayesian software/methods, with a simple change of sampler. On Mar 1, 2024 at 02:24 -0500, Gabriel Stechschulte @.***>, wrote:

@tomicapretto should we allow optimization methods? I think it is really cool but model.predict will not work out of the box as optimization methods return OptimizerResults and not an idata object. — Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you commented.Message ID: @.***>

zwelitunyiswa avatar Mar 01 '24 10:03 zwelitunyiswa

It would be cool if it could. It would allow non-Bayesians an entry point into Bayesian software/methods, with a simple change of sampler. On Mar 1, 2024 at 02:24 -0500, Gabriel Stechschulte @.>, wrote: @tomicapretto should we allow optimization methods? I think it is really cool but model.predict will not work out of the box as optimization methods return OptimizerResults and not an idata object. — Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you commented.Message ID: @.>

That's true. Maybe in this PR I will leave optimization methods out. Then, in a separate PR, add optimization algorithms and get the OptimizerResults object to work with model.predict. What do you think?

GStechschulte avatar Mar 01 '24 11:03 GStechschulte

Keeping the encapsulation sounds good! Re: optimization, I often run with a bunch of particles, which ends up looking like draws from a posterior. They definitely aren't! But they ducktype like they are, and many methods might still mostly work -- in particular, you can store the results in arviz. This is a little dangerous because if a user gets an arviz.InferenceData, they might reasonable expect it contains (possibly correlated) draws from a posterior distribution.

So I guess I'm recommending not to do this, and do what you're suggesting instead!

Just for interest's sake, this colab will probably go into the bayeux docs in a few days, and shows how to work with a bayesian neural network. I run adam optimization with 128 particles. Most of the particles find different minima.

ColCarroll avatar Mar 01 '24 15:03 ColCarroll

@tomicapretto should we allow optimization methods? I think it is really cool but model.predict will not work out of the box as optimization methods return OptimizerResults and not an idata object. I added it, but we can also delete it.

Update: In this PR I will leave optimization methods out. Then, in a separate PR, add optimization algorithms and get the OptimizerResults object to work with model.predict.

I agree it's good to have them but also good to leave them out for now.

I don't know all the details of the optimization methods. I guess we would get a single number for every parameter in the posterior, right? If that is the case, how do they differ from MAP? Is it possible to use those optimization results to get distributions for the parameters? I mention this because all the functions in the interpret submodule show both point estimates (i.e. a point or a line) plus uncertainty (i.e. a band or a range). If we can't get any measure of uncertainty, then the plots will look different and we'll need to account for that in the implementation

tomicapretto avatar Mar 02 '24 19:03 tomicapretto

I don't know all the details of the optimization methods. I guess we would get a single number for every parameter in the posterior, right? If that is the case, how do they differ from MAP? Is it possible to use those optimization results to get distributions for the parameters? I mention this because all the functions in the interpret submodule show both point estimates (i.e. a point or a line) plus uncertainty (i.e. a band or a range). If we can't get any measure of uncertainty, then the plots will look different and we'll need to account for that in the implementation.

Yeah, it really depends on the parameters of the optimization method called. If you use a lot of particles, then you would get more than one parameter estimate. So it "looks" like a posterior. Nonetheless, this is definitely one aspect we need to consider.

GStechschulte avatar Mar 03 '24 07:03 GStechschulte

@GStechschulte I think this is a great addition and the PR is in great shape. Just want to know your opinion on my suggestion about handing, for now, the old argument values as well

tomicapretto avatar Mar 04 '24 14:03 tomicapretto

@tomicapretto @GStechschulte will the samplers work with Bambi's mixture modes like the zero-inflated and hurdle models?

Z.

zwelitunyiswa avatar Mar 04 '24 17:03 zwelitunyiswa

LGTM in terms of working with bayeux. It looks like CI needs to be fixed and @tomicapretto give final approval for it being in the bambi style.

Thanks for taking this on!

Thanks for the reviews, patience, and bayeux 😄

GStechschulte avatar Mar 04 '24 17:03 GStechschulte

@GStechschulte I think this is a great addition and the PR is in great shape. Just want to know your opinion on my suggestion about handing, for now, the old argument values as well

Thanks a lot! I replied above and made the change 👍🏼

GStechschulte avatar Mar 04 '24 17:03 GStechschulte

@tomicapretto @GStechschulte will the samplers work with Bambi's mixture modes like the zero-inflated and hurdle models?

Z.

Yup! Though, some samplers are better suited for different problems, etc. 😄

GStechschulte avatar Mar 04 '24 17:03 GStechschulte

Ugh. pylint is making the CI fail. It says it cannot import bayeux. However, when I check the logs of the step "Install Bambi and all its dependencies", I can see that bayeux was installed.

GStechschulte avatar Mar 04 '24 18:03 GStechschulte