bambi
bambi copied to clipboard
Use `bayeux` to access a wide range of samplers
I have been following @ColCarroll bayeux library and thought it would be interesting to see how Bambi could incorporate it to offer the users a wide range of samplers (more than nuts_blackjax
and nuts_numpyro
).
Edit: Now I access the samplers programmatically using the inference_method
arg. This removes previously needed code for nuts_blackjax
and nuts_numpyro
. If a user passes an MCMC inference method other than the PyMC MCMC sampler mcmc
, Bambi will use bayeux to call that sampler.
data = bmb.load_data("ANES")
clinton_data = data.loc[data["vote"].isin(["clinton", "trump"]), :]
model = bmb.Model("vote['clinton'] ~ party_id + party_id:age", clinton_data, family="bernoulli")
model.build()
idata = model.fit(inference_method="blackjax_hmc")
However, when cleaning the InferenceData, I am getting an xarray error
"name": "ValueError",
"message": "('chain', 'draw') must be a permuted list of FrozenMappingWarningOnValuesAccess({'chain': 8, 'draw': 500, 'party_id_dim_0': 2, 'party_id:age_dim_0': 3}), unless `...` is included"
It seems xarray is not liking something that bayeux is doing with the InferenceData.
Another thought is that using bayeux with Bambi is so easy
model = bmb.Model()
model.build()
bx_model = bx.Model.from_pymc(model.backend.model)
bx_model.<some sampling func>
that maybe we just add documentation explaining how to use bayeux with Bambi to avoid overhead from Bambi's side?
To Do:
- [ ] add additional tests in
test_alternative_samplers.py
- [ ] update notebooks using JAX based samplers
Really cool! A few suggestions --
-
bayeux
could be invisible here, and you could access all the methods programatically. That's done here, but i can factor that out into a function that gives methods instead of strings -- currently it only adds a method if the underlying library, e.g.optax
, is installed. i'm not sure how to avoid using a string at some point. You could have an api likemodel.fit.bx
that initializes thebayeux.Model
? - I'm happy to add a
from_bambi
constructor on the bayeux side to make your second option even easier.
What kind of API bayex has? Could we enable support for external samplers if we define specific API we support (need)? (Users could create class for external samplers if needed?)
Of course that does not mean we could not have a text based support on certain libraries?
bayeux
is inspired by arviz
, in that it just provides a representation of a model that is general enough for most samplers, but it does make the decision that it is specialized to JAX-based models (most of the algorithms use autodiff, vectorization is baked in, and automatic function inverses are also used). If you've got a sampler that accepts a JAX-based log density, you could use bayeux with it (or contribute it to bayeux!)
Really cool! A few suggestions --
bayeux
could be invisible here, and you could access all the methods programatically. That's done here, but i can factor that out into a function that gives methods instead of strings -- currently it only adds a method if the underlying library, e.g.optax
, is installed. i'm not sure how to avoid using a string at some point. You could have an api likemodel.fit.bx
that initializes thebayeux.Model
?- I'm happy to add a
from_bambi
constructor on the bayeux side to make your second option even easier.
Thanks for the suggestions! That makes sense. I am liking the second option, but I will run some ideas past the others first before asking for the feature. Thanks!
Check out this pull request on
See visual diffs & provide feedback on Jupyter Notebooks.
Powered by ReviewNB
In this example, the error is because bayeux is appending _0
to party_id_dim
of the posterior dims. This results in Bambi discarding all posterior dims because the dims in the PyMC model are inconsistent with the dims of the InferenceData returned by bayeux.
For example:
print(bayeux_idata.posterior.dims, pymc_idata.posterior.dims)
(FrozenMappingWarningOnValuesAccess({'chain': 8, 'draw': 500, 'party_id_dim_0': 2, 'party_id:age_dim_0': 3}),
FrozenMappingWarningOnValuesAccess({'chain': 4, 'draw': 1000, 'party_id_dim': 2, 'party_id:age_dim': 3}))
In this example, the error is because bayeux is appending
_0
toparty_id_dim
of the posterior dims. This results in Bambi discarding all posterior dims because the dims in the PyMC model are inconsistent with the dims of the InferenceData returned by bayeux.For example:
print(bayeux_idata.posterior.dims, pymc_idata.posterior.dims)
(FrozenMappingWarningOnValuesAccess({'chain': 8, 'draw': 500, 'party_id_dim_0': 2, 'party_id:age_dim_0': 3}), FrozenMappingWarningOnValuesAccess({'chain': 4, 'draw': 1000, 'party_id_dim': 2, 'party_id:age_dim': 3}))
Update: I have added logic in the cleaning of idata to: (1) identify bayeux idata and to remove the trailing numeric suffix from the _dims, and (2) to rename the posterior dims to be consistent with the PyMC model coords.
Although this works for simple models, I haven't tried this logic with more complex models in Bambi such as HSGP or with models that have a large number of dims and or factors. Since the idata contains very "important data", I also think it could be worthwhile to not clean idata when the user calls samplers from bayeux at the moment in order to avoid unknown effects appearing in the inference data.
Would it be possible to allow access to the optimization methods from Bayeux as well via Bambi?
On Sat, Feb 10, 2024 at 4:25 AM Gabriel Stechschulte < @.***> wrote:
In this example, the error is because bayeux is appending _0 to party_id_dim of the posterior dims. This results in Bambi discarding https://github.com/GStechschulte/bambi/blob/9f1d9d179071abbb4cc6255242132829aae80faf/bambi/backend/pymc.py#L261C4-L261C86 all posterior dims because the dims in the PyMC model are inconsistent with the dims of the InferenceData returned by bayeux.
For example:
print(bayeux_idata.posterior.dims, pymc_idata.posterior.dims)
(FrozenMappingWarningOnValuesAccess({'chain': 8, 'draw': 500, 'party_id_dim_0': 2, 'party_id:age_dim_0': 3}), FrozenMappingWarningOnValuesAccess({'chain': 4, 'draw': 1000, 'party_id_dim': 2, 'party_id:age_dim': 3}))
Update: I have added logic in the cleaning of idata to: (1) identify bayeux idata and to remove the trailing numeric suffix from the _dims, and (2) to rename the posterior dims to be consistent with the PyMC model coords.
— Reply to this email directly, view it on GitHub https://github.com/bambinos/bambi/pull/775#issuecomment-1936951148, or unsubscribe https://github.com/notifications/unsubscribe-auth/AH3QQV3RMDQKAJ34J6FVNWLYS44HDAVCNFSM6AAAAABCY5BXBGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSMZWHE2TCMJUHA . You are receiving this because you are subscribed to this thread.Message ID: @.***>
I think you could -- @GStechschulte has a good outline here. If @tomicapretto thinks this is a reasonable idea in principle, I'd be happy to either collaborate on this to (programatically) get the bayeux
inference methods in, or send a follow-up that generalizes it a bit, and allows optimization and VI.
I think this is really cool, thanks @GStechschulte and thanks @ColCarroll for bayeux
. I'm not sure I am aware of all the details, but what is the reason why bayeux
is appending _dim_0
to dimension names? As far as I remember that was an xarray
thing. Or is it that bayeux
is not receiving dimension names from the PyMC model and thus it appends _dim_0
?
Another thing, I see we're replacing blackjax
, jax
, jaxlib
, and numpyro
with bayeux
. However, as far as I know bayeux
does not install these dependencies, so doing pip install bambi[jax]
won't give users access to JAX based samplers, right? (I'm not very familiar with bayeux
so I may be wrong).
bayeux will pull those in, but i agree it is better to be explicit and require dependencies (in case bayeux makes weird decisions).
I'll double check on the naming conventions!
Oh right, yes: bayeux
has no concept of the dimensions from pymc
. That would have to be implemented as a post-processing step to rename the arviz
dimensions.
Oh right, yes:
bayeux
has no concept of the dimensions frompymc
. That would have to be implemented as a post-processing step to rename thearviz
dimensions.
Thanks for the answer, it makes much more sense now!
@ColCarroll thanks for the information.
@tomicapretto I can apply this post processing step on Bambi's side.
@ColCarroll thanks for the information.
@tomicapretto I can apply this post processing step on Bambi's side.
Sounds great, just let me know if you need help or a second opinion :)
Two updates:
- I added a processing step for when Bambi cleans the idata, it renames the idata dims and coordinates to match those of the underlying PyMC model.
- I explicitly added JAX based sampler dependencies.
Regarding
I'd be happy to either collaborate on this to (programatically) get the bayeux inference methods in, or send a follow-up that generalizes it a bit, and allows optimization and VI.
@ColCarroll I'd be happy to collaborate and see how you would do this 👍🏼
@ColCarroll thanks a lot for the review! I will incorporate these in the coming days.
Thanks for the reviews and suggestions @ColCarroll 😄 I have taken your suggestions and implemented them in slightly different ways.
- Bambi has separate inference method calls for VI, MCMC, and Laplace approx. within the PyMCModel class. I think we should maintain this encapsulation and not be able to call all bayeux methods within a single method? This does result in duplicate code though, e.g.
import bayeux as bx
import jax
bx_model = bx.Model.from_pymc(self.model)
bx_method = operator.attrgetter(inference_method)(bx_model.<method>)
is written for MCMC, VI, Laplace approx., and optimize.
- When passing bayeux methods to
inference_method
, I have kept the naming convention to betfp_hmc
,blackjax_nuts
since this is also how the dict values are returned inbx_model.methods
.
@tomicapretto should we allow optimization methods? I think it is really cool but model.predict
will not work out of the box as optimization methods return OptimizerResults
and not an idata object. I added it, but we can also delete it.
Update: In this PR I will leave optimization methods out. Then, in a separate PR, add optimization algorithms and get the OptimizerResults
object to work with model.predict
.
It would be cool if it could. It would allow non-Bayesians an entry point into Bayesian software/methods, with a simple change of sampler. On Mar 1, 2024 at 02:24 -0500, Gabriel Stechschulte @.***>, wrote:
@tomicapretto should we allow optimization methods? I think it is really cool but model.predict will not work out of the box as optimization methods return OptimizerResults and not an idata object. — Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you commented.Message ID: @.***>
It would be cool if it could. It would allow non-Bayesians an entry point into Bayesian software/methods, with a simple change of sampler. … On Mar 1, 2024 at 02:24 -0500, Gabriel Stechschulte @.>, wrote: @tomicapretto should we allow optimization methods? I think it is really cool but model.predict will not work out of the box as optimization methods return OptimizerResults and not an idata object. — Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you commented.Message ID: @.>
That's true. Maybe in this PR I will leave optimization methods out. Then, in a separate PR, add optimization algorithms and get the OptimizerResults
object to work with model.predict
. What do you think?
Keeping the encapsulation sounds good! Re: optimization, I often run with a bunch of particles, which ends up looking like draws from a posterior. They definitely aren't! But they ducktype like they are, and many methods might still mostly work -- in particular, you can store the results in arviz
. This is a little dangerous because if a user gets an arviz.InferenceData
, they might reasonable expect it contains (possibly correlated) draws from a posterior distribution.
So I guess I'm recommending not to do this, and do what you're suggesting instead!
Just for interest's sake, this colab will probably go into the bayeux docs in a few days, and shows how to work with a bayesian neural network. I run adam optimization with 128 particles. Most of the particles find different minima.
@tomicapretto should we allow optimization methods? I think it is really cool but
model.predict
will not work out of the box as optimization methods returnOptimizerResults
and not an idata object. I added it, but we can also delete it.Update: In this PR I will leave optimization methods out. Then, in a separate PR, add optimization algorithms and get the
OptimizerResults
object to work withmodel.predict
.
I agree it's good to have them but also good to leave them out for now.
I don't know all the details of the optimization methods. I guess we would get a single number for every parameter in the posterior, right? If that is the case, how do they differ from MAP? Is it possible to use those optimization results to get distributions for the parameters? I mention this because all the functions in the interpret
submodule show both point estimates (i.e. a point or a line) plus uncertainty (i.e. a band or a range). If we can't get any measure of uncertainty, then the plots will look different and we'll need to account for that in the implementation
I don't know all the details of the optimization methods. I guess we would get a single number for every parameter in the posterior, right? If that is the case, how do they differ from MAP? Is it possible to use those optimization results to get distributions for the parameters? I mention this because all the functions in the
interpret
submodule show both point estimates (i.e. a point or a line) plus uncertainty (i.e. a band or a range). If we can't get any measure of uncertainty, then the plots will look different and we'll need to account for that in the implementation.
Yeah, it really depends on the parameters of the optimization method called. If you use a lot of particles, then you would get more than one parameter estimate. So it "looks" like a posterior. Nonetheless, this is definitely one aspect we need to consider.
@GStechschulte I think this is a great addition and the PR is in great shape. Just want to know your opinion on my suggestion about handing, for now, the old argument values as well
@tomicapretto @GStechschulte will the samplers work with Bambi's mixture modes like the zero-inflated and hurdle models?
Z.
LGTM in terms of working with
bayeux
. It looks like CI needs to be fixed and @tomicapretto give final approval for it being in thebambi
style.Thanks for taking this on!
Thanks for the reviews, patience, and bayeux 😄
@GStechschulte I think this is a great addition and the PR is in great shape. Just want to know your opinion on my suggestion about handing, for now, the old argument values as well
Thanks a lot! I replied above and made the change 👍🏼
@tomicapretto @GStechschulte will the samplers work with Bambi's mixture modes like the zero-inflated and hurdle models?
Z.
Yup! Though, some samplers are better suited for different problems, etc. 😄
Ugh. pylint is making the CI fail. It says it cannot import bayeux. However, when I check the logs of the step "Install Bambi and all its dependencies", I can see that bayeux was installed.