bambi icon indicating copy to clipboard operation
bambi copied to clipboard

Using custom distributions in `bmb.Prior` for group-specific terms

Open digicosmos86 opened this issue 1 year ago • 10 comments

Hi,

I know that bambi supports arbitrary distributions as priors through the dist argument. However, when I wanted to use a custom distribution as the prior for a group-specific term, an error is thrown at model creation time, because bambi checks Prior.args for at least one hyperprior. However, for bmb.Prior objects created with dist argument, args property needs to be empty (I remember that keeping a non-empty args will result in errors). Any idea how I can get around the check?

digicosmos86 avatar Nov 28 '23 18:11 digicosmos86

Do you have an example I can follow? It's OK if you don't have a full-working example, but I would like how you're passing the priors when creating the model.

tomicapretto avatar Nov 28 '23 19:11 tomicapretto

Sure! Here's what I understand what the issue is: suppose I have a regression rt ~ 0 + (1|subject_idx). For 1|subj_idx, I want to assign a custom bounded prior with hyperpriors:

def CustomPrior(name, mu, sigma, lower, upper):
    dist = pm.Normal.dist(mu=mu, sigma=sigma)  # assume mu and/or sigma here is a hyperprior
    return pm.Truncated(name=name, dist=dist, lower=lower, upper=upper)

Then I create a bmb.Prior with something like prior=bmb.Prior(dist=CustomPrior) and use it for the prior of 1|subj_idx. My understanding is that this will fail because bambi checks the prior of each group_specific term (in this case 1|subj_idx). It iterates the values of args of the Prior object, and checks if the type of any of those is Prior. However, for a Prior created this way, there is nothing in args, so this check won't pass.

Here's the code that raises the error in Bambi: https://github.com/bambinos/bambi/blob/2d4b260a3f1fc7691a3f7623971679efc253005e/bambi/terms/group_specific.py#L89-L91

digicosmos86 avatar Nov 28 '23 20:11 digicosmos86

Is this what you're trying to do?

import bambi as bmb
import numpy as np
import pandas as pd
import pymc as pm

n = 20
g = 8
rng = np.random.default_rng(1234)
outcomes = rng.normal(rng.normal(size=g), scale=0.5, size=(n, g))

data = pd.DataFrame(dict(y = outcomes.T.flatten(), g = np.repeat(list("abcdefgh"), 20)))
data

def custom_prior(name, mu, sigma, lower, upper, dims=None):
    print("is this called?")
    dist = pm.Normal.dist(mu=mu, sigma=sigma)
    return pm.Truncated(name=name, dist=dist, lower=lower, upper=upper, dims=dims)


priors = {
    "1|g": bmb.Prior(
        "CustomPrior", 
        mu=0, 
        sigma=bmb.Prior("Exponential", lam=1), 
        lower=-2,
        upper=2,
        dist=custom_prior
    )
}
model = bmb.Model("y ~ 1 + (1|g)", data, priors=priors)
model.build()

It's still problematic because it's not using the customprior so I think there's a bug with custom priors and group-specific effects.

tomicapretto avatar Nov 29 '23 13:11 tomicapretto

Yeah the latter is what I was trying. There does seem to be a bug. Is there a way to make bmb.Prior more general in that it can accept bounds? I have an implementation here with some hacks but I am sure there is a better way to do it in Bambi

digicosmos86 avatar Nov 29 '23 14:11 digicosmos86

@digicosmos86 I found the problem is here:

https://github.com/bambinos/bambi/blob/2d4b260a3f1fc7691a3f7623971679efc253005e/bambi/backend/terms.py#L145-L148

When we use a noncentered parametrization, Bambi is assuming you have a normal prior. If you pass noncentered=False it will use your custom prior.

Edit I'm not saying Bambi should allow other distributions for a noncentered pamatrization as it would be non-trivial to determine how to use a noncentered parametrization for any distribution. But I do think there should be at least a warning when you pass a non-normal prior and noncentered=True. What do you think?

tomicapretto avatar Nov 29 '23 16:11 tomicapretto

@tomicapretto Thank you for looking into this! I think this is a slightly different issue. Sorry I should have included the error message so it's a bit more clear. Here's the bami code and the error message when I run a test:

    @prior.setter
    def prior(self, value):
        # This does not check which argument has hyperprior (must be dispersion?)
        assert isinstance(value, VALID_PRIORS), f"Prior must be one of {VALID_PRIORS}"
        if isinstance(value, Prior):
            any_hyperprior = any(isinstance(x, Prior) for x in value.args.values())
            if not any_hyperprior:
>               raise ValueError("Prior for group-specific terms must have hyperpriors")
E               ValueError: Prior for group-specific terms must have hyperpriors

I think what happens is when the prior of a group-specific terms is specified like this:

def custom_prior(name, mu, sigma, lower, upper, dims=None):
    print("is this called?")
    dist = pm.Normal.dist(mu=mu, sigma=sigma)
    return pm.Truncated(name=name, dist=dist, lower=lower, upper=upper, dims=dims)

priors = {
    "1|g": bmb.Prior(
        "CustomPrior", 
        mu=0, 
        sigma=bmb.Prior("Exponential", lam=1), 
        lower=-2,
        upper=2,
        dist=custom_prior
    )
}
model = bmb.Model("y ~ 1 + (1|g)", data, priors=priors)
model.build()

priors.args is empty when dist is specified. So when prior.args is checked for hyperpriors, an error will be thrown there.

digicosmos86 avatar Nov 29 '23 19:11 digicosmos86

I'm sorry but I don't see the difference between the final block of code you shared

def custom_prior(name, mu, sigma, lower, upper, dims=None):
    print("is this called?")
    dist = pm.Normal.dist(mu=mu, sigma=sigma)
    return pm.Truncated(name=name, dist=dist, lower=lower, upper=upper, dims=dims)

priors = {
    "1|g": bmb.Prior(
        "CustomPrior", 
        mu=0, 
        sigma=bmb.Prior("Exponential", lam=1), 
        lower=-2,
        upper=2,
        dist=custom_prior
    )
}
model = bmb.Model("y ~ 1 + (1|g)", data, priors=priors)
model.build()

and the code that I shared in my previous comment. Is there anything I'm missing or you pasted the wrong thing?

tomicapretto avatar Nov 30 '23 01:11 tomicapretto

No I am just using your code to show that a different error can be caused. Please see the first block of code for the error.

digicosmos86 avatar Dec 01 '23 15:12 digicosmos86

Hi @digicosmos86 I'm reviewing issues and realized I never replied here. I see there could be a different error raised (i.e. ValueError("Prior for group-specific terms must have hyperpriors")). But that doesn't happen with the block of code that I shared in https://github.com/bambinos/bambi/issues/763#issuecomment-1831897363, so I'm not sure why you mentioned that error. If I understand correctly, this error shouldn't occur as long as we pass the hyperprior to the custom distribution. Nevertheless, the implementation isn't still free of flaws because of the comment about the noncentered parametrization.

tomicapretto avatar Jan 05 '24 13:01 tomicapretto

Thanks @tomicapretto! Do you have a moment for a quick call? Please feel free to pm me on slack :)

digicosmos86 avatar Jan 08 '24 14:01 digicosmos86