bambi
bambi copied to clipboard
Using custom distributions in `bmb.Prior` for group-specific terms
Hi,
I know that bambi
supports arbitrary distributions as priors through the dist
argument. However, when I wanted to use a custom distribution as the prior for a group-specific term, an error is thrown at model creation time, because bambi checks Prior.args
for at least one hyperprior. However, for bmb.Prior
objects created with dist
argument, args
property needs to be empty (I remember that keeping a non-empty args
will result in errors). Any idea how I can get around the check?
Do you have an example I can follow? It's OK if you don't have a full-working example, but I would like how you're passing the priors when creating the model.
Sure! Here's what I understand what the issue is: suppose I have a regression rt ~ 0 + (1|subject_idx)
. For 1|subj_idx
, I want to assign a custom bounded prior with hyperpriors:
def CustomPrior(name, mu, sigma, lower, upper):
dist = pm.Normal.dist(mu=mu, sigma=sigma) # assume mu and/or sigma here is a hyperprior
return pm.Truncated(name=name, dist=dist, lower=lower, upper=upper)
Then I create a bmb.Prior
with something like prior=bmb.Prior(dist=CustomPrior)
and use it for the prior of 1|subj_idx
. My understanding is that this will fail because bambi
checks the prior of each group_specific term (in this case 1|subj_idx
). It iterates the values of args
of the Prior
object, and checks if the type of any of those is Prior
. However, for a Prior
created this way, there is nothing in args
, so this check won't pass.
Here's the code that raises the error in Bambi: https://github.com/bambinos/bambi/blob/2d4b260a3f1fc7691a3f7623971679efc253005e/bambi/terms/group_specific.py#L89-L91
Is this what you're trying to do?
import bambi as bmb
import numpy as np
import pandas as pd
import pymc as pm
n = 20
g = 8
rng = np.random.default_rng(1234)
outcomes = rng.normal(rng.normal(size=g), scale=0.5, size=(n, g))
data = pd.DataFrame(dict(y = outcomes.T.flatten(), g = np.repeat(list("abcdefgh"), 20)))
data
def custom_prior(name, mu, sigma, lower, upper, dims=None):
print("is this called?")
dist = pm.Normal.dist(mu=mu, sigma=sigma)
return pm.Truncated(name=name, dist=dist, lower=lower, upper=upper, dims=dims)
priors = {
"1|g": bmb.Prior(
"CustomPrior",
mu=0,
sigma=bmb.Prior("Exponential", lam=1),
lower=-2,
upper=2,
dist=custom_prior
)
}
model = bmb.Model("y ~ 1 + (1|g)", data, priors=priors)
model.build()
It's still problematic because it's not using the customprior so I think there's a bug with custom priors and group-specific effects.
Yeah the latter is what I was trying. There does seem to be a bug. Is there a way to make bmb.Prior
more general in that it can accept bounds? I have an implementation here with some hacks but I am sure there is a better way to do it in Bambi
@digicosmos86 I found the problem is here:
https://github.com/bambinos/bambi/blob/2d4b260a3f1fc7691a3f7623971679efc253005e/bambi/backend/terms.py#L145-L148
When we use a noncentered parametrization, Bambi is assuming you have a normal prior. If you pass noncentered=False
it will use your custom prior.
Edit I'm not saying Bambi should allow other distributions for a noncentered pamatrization as it would be non-trivial to determine how to use a noncentered parametrization for any distribution. But I do think there should be at least a warning when you pass a non-normal prior and noncentered=True
. What do you think?
@tomicapretto Thank you for looking into this! I think this is a slightly different issue. Sorry I should have included the error message so it's a bit more clear. Here's the bami code and the error message when I run a test:
@prior.setter
def prior(self, value):
# This does not check which argument has hyperprior (must be dispersion?)
assert isinstance(value, VALID_PRIORS), f"Prior must be one of {VALID_PRIORS}"
if isinstance(value, Prior):
any_hyperprior = any(isinstance(x, Prior) for x in value.args.values())
if not any_hyperprior:
> raise ValueError("Prior for group-specific terms must have hyperpriors")
E ValueError: Prior for group-specific terms must have hyperpriors
I think what happens is when the prior of a group-specific terms is specified like this:
def custom_prior(name, mu, sigma, lower, upper, dims=None):
print("is this called?")
dist = pm.Normal.dist(mu=mu, sigma=sigma)
return pm.Truncated(name=name, dist=dist, lower=lower, upper=upper, dims=dims)
priors = {
"1|g": bmb.Prior(
"CustomPrior",
mu=0,
sigma=bmb.Prior("Exponential", lam=1),
lower=-2,
upper=2,
dist=custom_prior
)
}
model = bmb.Model("y ~ 1 + (1|g)", data, priors=priors)
model.build()
priors.args
is empty when dist
is specified. So when prior.args
is checked for hyperpriors, an error will be thrown there.
I'm sorry but I don't see the difference between the final block of code you shared
def custom_prior(name, mu, sigma, lower, upper, dims=None):
print("is this called?")
dist = pm.Normal.dist(mu=mu, sigma=sigma)
return pm.Truncated(name=name, dist=dist, lower=lower, upper=upper, dims=dims)
priors = {
"1|g": bmb.Prior(
"CustomPrior",
mu=0,
sigma=bmb.Prior("Exponential", lam=1),
lower=-2,
upper=2,
dist=custom_prior
)
}
model = bmb.Model("y ~ 1 + (1|g)", data, priors=priors)
model.build()
and the code that I shared in my previous comment. Is there anything I'm missing or you pasted the wrong thing?
No I am just using your code to show that a different error can be caused. Please see the first block of code for the error.
Hi @digicosmos86 I'm reviewing issues and realized I never replied here. I see there could be a different error raised (i.e. ValueError("Prior for group-specific terms must have hyperpriors")
). But that doesn't happen with the block of code that I shared in https://github.com/bambinos/bambi/issues/763#issuecomment-1831897363, so I'm not sure why you mentioned that error. If I understand correctly, this error shouldn't occur as long as we pass the hyperprior to the custom distribution. Nevertheless, the implementation isn't still free of flaws because of the comment about the noncentered parametrization.
Thanks @tomicapretto! Do you have a moment for a quick call? Please feel free to pm me on slack :)