bambi icon indicating copy to clipboard operation
bambi copied to clipboard

`plot_predictions` breaks with HSGP

Open AlexAndorra opened this issue 1 year ago • 16 comments

The attached dataset will help reproduce the issue, but basically:

formula = "p(Campus_Living, N) ~ 1 + (1 | Age_Group) + (1 | Major) + hsgp(Fall_Year, by=Age_Group, m=10, c=2, scale=True, centered=True) + hsgp(Fall_Year, by=Major, m=10, c=2, scale=True, centered=True)"

m_bambi = bmb.Model(
    formula=formula,
    data=data,
    family="binomial",
    categorical=["Age_Group", "Major"],
)
m_bambi.build()
idata_bambi = m_bambi.fit()

## this line fails
bmb.interpret.plot_predictions(
    m_bambi, 
    idata_bambi,
    conditional={"Fall_Year": np.linspace(2012, 2025)},
);

yields (full traceback at end of post):

IndexError: index -1 is out of bounds for axis 1 with size 0

If you do formula = "p(Campus_Living, N) ~ 1 + hsgp(Fall_Year, by=Age_Group, m=10, c=2, scale=True, centered=True)", the error becomesKeyError: 'Age_Group'.

Interestingly, if you change formula to "p(Campus_Living, N) ~ 1 + (1 | Age_Group) + (1 | Major) + hsgp(Fall_Year, by=Age_Group, m=10, c=2, scale=True, centered=True)" or "p(Campus_Living, N) ~ 1 + (1 | Age_Group) + hsgp(Fall_Year, by=Age_Group, m=10, c=2, scale=True, centered=True)", the plot will display (although I'm not sure it's right)!

It looks like some indexing is happening in by_values in bambi/model_components.py:262, which yields the error.

Thanks a lot for your help, and LMK if anything is unclear 🙏

data.csv

Full error:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[225], line 1
----> 1 bmb.interpret.plot_predictions(
      2     m_bambi, 
      3     idata_bambi,
      4     conditional={"Fall_Year": np.linspace(2012, 2025)},
      5 );

File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/bambi/interpret/plotting.py:181, in plot_predictions(model, idata, conditional, average_by, target, sample_new_groups, pps, use_hdi, prob, transforms, legend, ax, fig_kwargs, subplot_kwargs)
    174 if average_by is True:
    175     raise ValueError(
    176         "Plotting when 'average_by = True' is not possible as 'True' marginalizes "
    177         "over all covariates resulting in a single prediction estimate. "
    178         "Please pass a covariate(s) to 'average_by'."
    179     )
--> 181 cap_data = predictions(
    182     model=model,
    183     idata=idata,
    184     conditional=conditional,
    185     average_by=average_by,
    186     target=target,
    187     pps=pps,
    188     use_hdi=use_hdi,
    189     prob=prob,
    190     transforms=transforms,
    191     sample_new_groups=sample_new_groups,
    192 )
    194 conditional_info = ConditionalInfo(model, conditional)
    195 transforms = transforms if transforms is not None else {}

File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/bambi/interpret/effects.py:532, in predictions(model, idata, conditional, average_by, target, pps, use_hdi, prob, transforms, sample_new_groups)
    530     y_hat_mean = y_hat.mean(("chain", "draw"))
    531 else:
--> 532     idata = model.predict(
    533         idata, data=cap_data, sample_new_groups=sample_new_groups, inplace=False
    534     )
    535     y_hat = response_transform(idata["posterior"][response.name_target])
    536     y_hat_mean = y_hat.mean(("chain", "draw"))

File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/bambi/models.py:832, in Model.predict(self, idata, kind, data, inplace, include_group_specific, sample_new_groups)
    829     else:
    830         var_name = f"{response_aliased_name}_{name}"
--> 832 means_dict[var_name] = component.predict(
    833     idata, data, include_group_specific, hsgp_dict, sample_new_groups
    834 )
    836 # Drop var/dim if already present. Needed for out-of-sample predictions.
    837 if var_name in idata.posterior.data_vars:

File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/bambi/model_components.py:185, in DistributionalComponent.predict(self, idata, data, include_group_specific, hsgp_dict, sample_new_groups)
    182     linear_predictor_dims = linear_predictor_dims + (response_levels_dim,)
    184 if self.design.common:
--> 185     linear_predictor += self.predict_common(
    186         posterior, data, in_sample, to_stack_dims, design_matrix_dims, hsgp_dict
    187     )
    189 if self.design.group and include_group_specific:
    190     linear_predictor += self.predict_group_specific(
    191         posterior, data, in_sample, to_stack_dims, design_matrix_dims, sample_new_groups
    192     )

File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/bambi/model_components.py:262, in DistributionalComponent.predict_common(self, posterior, data, in_sample, to_stack_dims, design_matrix_dims, hsgp_dict)
    256 # NOTE:
    257 # The approach here differs from the one in the PyMC implementation.
    258 # Here we have a single dot product with many zeros, while there we have many
    259 # smaller dot products.
    260 # It is subject to change here, but I don't want to mess up dims and coords.
    261 if term.by_levels is not None:
--> 262     by_values = x_slice[:, -1].astype(int)
    263     x_slice = x_slice[:, :-1]
    264     x_slice_centered = (x_slice.data - term.mean[by_values]) / maximum_distance

IndexError: index -1 is out of bounds for axis 1 with size 0

AlexAndorra avatar Feb 05 '24 22:02 AlexAndorra

Thanks for raising the issue and the detailed information!

GStechschulte avatar Feb 06 '24 06:02 GStechschulte

Note that this error also appears when sampling posterior predictions -- m_bambi.predict(idata_bambi, kind="pps")

AlexAndorra avatar Feb 06 '24 19:02 AlexAndorra

I've found where the bug is!

https://github.com/bambinos/bambi/blob/b5b9f093c623636ae3a8e2f0765b2f01ca26f4b0/bambi/model_components.py#L242-L246

In line 246 we're deleting columns of the design matrix based on the slice term_slice. This only happens for HSGP. The problem is that the slice for the second (and third, fourth, etc.) HSGP term does not reflect the shape of X at the time it's going to be used to slice X.

I think the fix is to delete the HSGP term data from the design matrix X only after having iterated over all the HSGP terms. I'll test it soon.

tomicapretto avatar Feb 06 '24 23:02 tomicapretto

@AlexAndorra could you install Bambi from https://github.com/tomicapretto/bambi/tree/fix_hsgp_prediction and try it again?

tomicapretto avatar Feb 06 '24 23:02 tomicapretto

In line 246 we're deleting columns of the design matrix based on the slice term_slice. This only happens for HSGP. The problem is that the slice for the second (and third, fourth, etc.) HSGP term does not reflect the shape of X at the time it's going to be used to slice X.

Makes sense indeed.

@AlexAndorra could you install Bambi from https://github.com/tomicapretto/bambi/tree/fix_hsgp_prediction and try it again?

Will do ASAP and of course keep you posted. Thanks for the lightning fast fix @tomicapretto 🙏

AlexAndorra avatar Feb 07 '24 00:02 AlexAndorra

This works great for "p(Campus_Living, N) ~ 1 + (1 | Age_Group) + (1 | Major) + hsgp(Fall_Year, by=Age_Group, m=10, c=2, scale=True, centered=True) + hsgp(Fall_Year, by=Major, m=10, c=2, scale=True, centered=True)" @tomicapretto 🍾

However, the error is still here when removing the group-specific effects: "p(Campus_Living, N) ~ 1 + hsgp(Fall_Year, by=Age_Group, m=10, c=2, scale=True, centered=True)", The posterior predictions sample now, but the plot_predictions still fails: KeyError: 'Age_Group'. This is triggered by:

File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/bambi/model_components.py:233, in DistributionalComponent.predict_common(self, posterior, data, in_sample, to_stack_dims, design_matrix_dims, hsgp_dict)
    231     X = self.design.common.design_matrix
    232 else:
--> 233     X = self.design.common.evaluate_new_data(data).design_matrix
    235 # Add offset columns to their own design matrix and remove then from common matrix

So it still looks like a problem when updating the design matrix. LMK if that's unclear 🙏

Full traceback:

KeyError                                  Traceback (most recent call last)
Cell In[28], line 2
      1 ## this line fails
----> 2 bmb.interpret.plot_predictions(
      3     m_bambi, 
      4     idata_bambi,
      5     conditional={"Fall_Year": np.linspace(2012, 2025)},
      6 );

File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/bambi/interpret/plotting.py:181, in plot_predictions(model, idata, conditional, average_by, target, sample_new_groups, pps, use_hdi, prob, transforms, legend, ax, fig_kwargs, subplot_kwargs)
    174 if average_by is True:
    175     raise ValueError(
    176         "Plotting when 'average_by = True' is not possible as 'True' marginalizes "
    177         "over all covariates resulting in a single prediction estimate. "
    178         "Please pass a covariate(s) to 'average_by'."
    179     )
--> 181 cap_data = predictions(
    182     model=model,
    183     idata=idata,
    184     conditional=conditional,
    185     average_by=average_by,
    186     target=target,
    187     pps=pps,
    188     use_hdi=use_hdi,
    189     prob=prob,
    190     transforms=transforms,
    191     sample_new_groups=sample_new_groups,
    192 )
    194 conditional_info = ConditionalInfo(model, conditional)
    195 transforms = transforms if transforms is not None else {}

File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/bambi/interpret/effects.py:532, in predictions(model, idata, conditional, average_by, target, pps, use_hdi, prob, transforms, sample_new_groups)
    530     y_hat_mean = y_hat.mean(("chain", "draw"))
    531 else:
--> 532     idata = model.predict(
    533         idata, data=cap_data, sample_new_groups=sample_new_groups, inplace=False
    534     )
    535     y_hat = response_transform(idata["posterior"][response.name_target])
    536     y_hat_mean = y_hat.mean(("chain", "draw"))

File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/bambi/models.py:834, in Model.predict(self, idata, kind, data, inplace, include_group_specific, sample_new_groups)
    831     else:
    832         var_name = f"{response_aliased_name}_{name}"
--> 834 means_dict[var_name] = component.predict(
    835     idata, data, include_group_specific, hsgp_dict, sample_new_groups
    836 )
    838 # Drop var/dim if already present. Needed for out-of-sample predictions.
    839 if var_name in idata.posterior.data_vars:

File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/bambi/model_components.py:185, in DistributionalComponent.predict(self, idata, data, include_group_specific, hsgp_dict, sample_new_groups)
    182     linear_predictor_dims = linear_predictor_dims + (response_levels_dim,)
    184 if self.design.common:
--> 185     linear_predictor += self.predict_common(
    186         posterior, data, in_sample, to_stack_dims, design_matrix_dims, hsgp_dict
    187     )
    189 if self.design.group and include_group_specific:
    190     linear_predictor += self.predict_group_specific(
    191         posterior, data, in_sample, to_stack_dims, design_matrix_dims, sample_new_groups
    192     )

File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/bambi/model_components.py:233, in DistributionalComponent.predict_common(self, posterior, data, in_sample, to_stack_dims, design_matrix_dims, hsgp_dict)
    231     X = self.design.common.design_matrix
    232 else:
--> 233     X = self.design.common.evaluate_new_data(data).design_matrix
    235 # Add offset columns to their own design matrix and remove then from common matrix
    236 for term in self.offset_terms:

File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/formulae/matrices.py:259, in CommonEffectsMatrix.evaluate_new_data(self, data)
    256 new_instance.data = data
    257 new_instance.env = self.env
    258 new_instance.design_matrix = np.column_stack(
--> 259     [t.eval_new_data(data) for t in self.terms.values()]
    260 )
    261 new_instance.slices = self.slices
    262 new_instance.evaluated = True

File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/formulae/matrices.py:259, in <listcomp>(.0)
    256 new_instance.data = data
    257 new_instance.env = self.env
    258 new_instance.design_matrix = np.column_stack(
--> 259     [t.eval_new_data(data) for t in self.terms.values()]
    260 )
    261 new_instance.slices = self.slices
    262 new_instance.evaluated = True

File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/formulae/terms/terms.py:496, in Term.eval_new_data(self, data)
    492     result = reduce(
    493         get_interaction_matrix, [c.eval_new_data(data) for c in self.components]
    494     )
    495 else:
--> 496     result = self.components[0].eval_new_data(data)
    497 return result

File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/formulae/terms/call.py:274, in Call.eval_new_data(self, data_mask)
    255 """Evaluates the function call with new data.
    256 
    257 This method evaluates the function call within a new data mask. If the transformation
   (...)
    271     categoric ones.
    272 """
    273 if self.kind in ["numeric", "categoric"]:
--> 274     x = self.call.eval(data_mask, self.env)
    275     if self.kind == "numeric":
    276         result = self.eval_new_data_numeric(x)

File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/formulae/terms/call_resolver.py:270, in LazyCall.eval(self, data_mask, env)
    267     callee = self.stateful_transform
    269 args = [arg.eval(data_mask, env) for arg in self.args]
--> 270 kwargs = {name: arg.eval(data_mask, env) for name, arg in self.kwargs.items()}
    272 return callee(*args, **kwargs)

File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/formulae/terms/call_resolver.py:270, in <dictcomp>(.0)
    267     callee = self.stateful_transform
    269 args = [arg.eval(data_mask, env) for arg in self.args]
--> 270 kwargs = {name: arg.eval(data_mask, env) for name, arg in self.kwargs.items()}
    272 return callee(*args, **kwargs)

File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/formulae/terms/call_resolver.py:141, in LazyVariable.eval(self, data_mask, env)
    139             result = env.namespace[self.name]
    140         except KeyError as e:
--> 141             raise e
    142 return result

File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/formulae/terms/call_resolver.py:139, in LazyVariable.eval(self, data_mask, env)
    137 except KeyError:
    138     try:
--> 139         result = env.namespace[self.name]
    140     except KeyError as e:
    141         raise e

File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/formulae/environment.py:17, in VarLookupDict.__getitem__(self, key)
     15     except KeyError:
     16         pass
---> 17 raise KeyError(key)

KeyError: 'Age_Group'

AlexAndorra avatar Feb 10 '24 20:02 AlexAndorra

Thanks a lot @AlexAndorra. formulae is raising the correct error.

The error lies in interpret and is happening because when creating the data for plot_predictions, we first return all covariates specified in the model. However, we have a bug when an arg. is passed to by of the HSGP term. Only Fall_Year is being returned. Both Fall_Year and Age_Group should be returned. This isn't a problem with your first model formula because Age_Group is a group-specific term.

As a workaround until I fix the bug I suggest the following.

bmb.interpret.plot_predictions(
    m_bambi_2, 
    idata_bambi_2,
    conditional={
        "Fall_Year": np.linspace(2012, 2025),
        "Age_Group": data.Age_Group.unique()
    },
    subplot_kwargs={
        "main": "Fall_Year", 
        "group": "Age_Group", 
        "panel": "Age_Group"
    },
    fig_kwargs={"figsize": (18, 4), "sharey": True},
    legend=False
);

image

We are able to pass Age_Group to conditional because we know a priori that it should exist in the data grid. However, if you attempt to pass a variable to conditional that was not specified in the formula, then you will get a KeyError.

GStechschulte avatar Feb 10 '24 21:02 GStechschulte

Aaaaah, that makes sense, thanks a lot @GStechschulte ! Is that gonna work if I have more GPs though? Let's say I have 3 different GPs, on 3 different categorical variables. So now, in conditional I'm gonna have 4 keys (Fall_Year and the 3 covariates), so I have to specify average_by, which I think is gonna have the same issue. Correct?

AlexAndorra avatar Feb 10 '24 23:02 AlexAndorra

Yeah I can confirm. If you do:

grades = np.random.choice(["A", "B", "C"], size=len(fake_data))
fake_data["Grades"] = grades

formula = "p(Campus_Living, N) ~ 0 + hsgp(Fall_Year, by=Age_Group, m=10, c=2) + hsgp(Fall_Year, by=Major, m=10, c=2) + hsgp(Fall_Year, by=Grades, m=10, c=2)"

m_bambi = bmb.Model(
    formula=formula,
    data=fake_data,
    family="binomial",
    categorical=["Age_Group", "Major", "Grades"],
)
m_bambi.build()
idata_bambi = m_bambi.fit()

# This fails
bmb.interpret.plot_predictions(
    m_bambi, 
    idata_bambi,
    conditional={
        "Fall_Year": np.linspace(2012, 2025),
        "Age_Group": fake_data.Age_Group.unique(),
        "Major": "Science",
        "Grades": "A",
    },
    average_by="Age_Group",
);

Now you get:

KeyError: 'Fall_Year'

The above exception was the direct cause of the following exception:

KeyError                                  Traceback (most recent call last)
Cell In[57], line 1
----> 1 bmb.interpret.plot_predictions(
      2     m_bambi, 
      3     idata_bambi,
      4     conditional={
      5         "Fall_Year": np.linspace(2012, 2025),
      6         "Age_Group": fake_data.Age_Group.unique(),
      7         "Major": fake_data.Major.unique(),
      8         "Grades": fake_data.Grades.unique(),
      9     },
     10     average_by="Age_Group",
     11     subplot_kwargs={
     12         "main": "Fall_Year", 
     13         "group": "Age_Group", 
     14         "panel": "Age_Group"
     15     },
     16     fig_kwargs={"figsize": (18, 4), "sharey": True},
     17     legend=False
     18 );

File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/bambi/interpret/plotting.py:225, in plot_predictions(model, idata, conditional, average_by, target, sample_new_groups, pps, use_hdi, prob, transforms, legend, ax, fig_kwargs, subplot_kwargs)
    222     else:
    223         fig = axes[0].get_figure()
--> 225 if is_numeric_dtype(cap_data[covariates.main]):
    226     axes = plot_numeric(covariates, cap_data, transforms, legend, axes)
    227 elif is_categorical_dtype(cap_data[covariates.main]) or is_string_dtype(
    228     cap_data[covariates.main]
    229 ):

File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/pandas/core/frame.py:3761, in DataFrame.__getitem__(self, key)
   3759 if self.columns.nlevels > 1:
   3760     return self._getitem_multilevel(key)
-> 3761 indexer = self.columns.get_loc(key)
   3762 if is_integer(indexer):
   3763     indexer = [indexer]

File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/pandas/core/indexes/base.py:3655, in Index.get_loc(self, key)
   3653     return self._engine.get_loc(casted_key)
   3654 except KeyError as err:
-> 3655     raise KeyError(key) from err
   3656 except TypeError:
   3657     # If we have a listlike key, _check_indexing_error will raise
   3658     #  InvalidIndexError. Otherwise we fall through and re-raise
   3659     #  the TypeError.
   3660     self._check_indexing_error(key)

Hope that helps!

AlexAndorra avatar Feb 10 '24 23:02 AlexAndorra

@AlexAndorra Mmmm. For me, the three HSGP terms works with

grades = np.random.choice(["A", "B", "C"], size=len(fake_data))
fake_data["Grades"] = grades

formula = "p(Campus_Living, N) ~ 0 + hsgp(Fall_Year, by=Age_Group, m=10, c=2) + hsgp(Fall_Year, by=Major, m=10, c=2) + hsgp(Fall_Year, by=Grades, m=10, c=2)"

m_bambi = bmb.Model(
    formula=formula,
    data=fake_data,
    family="binomial",
    categorical=["Age_Group", "Major", "Grades"],
)
m_bambi.build()
idata_bambi = m_bambi.fit()

# This works
bmb.interpret.plot_predictions(
    m_bambi, 
    idata_bambi,
    conditional={
        "Fall_Year": np.linspace(2012, 2025),
        "Age_Group": fake_data.Age_Group.unique(),
        "Major": "Science",
        "Grades": "A",
    },
    average_by="Age_Group",
);

image

In the error output you provided, it looks like you were calling

Cell In[57], line 1
----> 1 bmb.interpret.plot_predictions(
      2     m_bambi, 
      3     idata_bambi,
      4     conditional={
      5         "Fall_Year": np.linspace(2012, 2025),
      6         "Age_Group": fake_data.Age_Group.unique(),
      7         "Major": fake_data.Major.unique(),
      8         "Grades": fake_data.Grades.unique(),
      9     },
     10     average_by="Age_Group",
     11     subplot_kwargs={
     12         "main": "Fall_Year", 
     13         "group": "Age_Group", 
     14         "panel": "Age_Group"
     15     },
     16     fig_kwargs={"figsize": (18, 4), "sharey": True},
     17     legend=False
     18 );

This is failing because you are only averaging by Age_Group. This is applying an aggregation function (average) for each level of Age_Group and then combining the result; effectively marginalizing over fall year, major, and grades to get the average for each age group level.

For example:

bmb.interpret.predictions(
    m_bambi, 
    idata_bambi,
    conditional={
        "Fall_Year": np.linspace(2012, 2025),
        "Age_Group": fake_data.Age_Group.unique(),
        "Major": fake_data.Major.unique(),
        "Grades": fake_data.Grades.unique(),
    },
    average_by="Age_Group"
)
Age_Group estimate lower_3.0% upper_97.0%
23-26 0.370167 0.089879 0.650803
<23 0.451753 0.149715 0.781734
>26 0.439085 0.138099 0.753850
Unknown 0.361037 0.078200 0.637989

If you want to retain levels / values for some combination of the other covariates, you can pass more covariates to average_by. This is particular useful for plotting:

bmb.interpret.plot_predictions(
    m_bambi, 
    idata_bambi,
    conditional={
        "Fall_Year": np.linspace(2012, 2025),
        "Age_Group": fake_data.Age_Group.unique(),
        "Major": fake_data.Major.unique(),
        "Grades": fake_data.Grades.unique(),
    },
    average_by=["Fall_Year", "Age_Group", "Major"],
    subplot_kwargs={
        "main": "Fall_Year", 
        "group": "Age_Group", 
        "panel": "Major"
    },
    fig_kwargs={"figsize": (18, 4), "sharey": True},
    legend=True
);

image

GStechschulte avatar Feb 11 '24 06:02 GStechschulte

@tomicapretto the HSGP term component is a function call, so we can access the argument name to get Fall_Year. However, this does not return the argument to by of the HSGP term.

The HSGP component also has a levels attribute

# Putting this print inside of `get_model_covariates` of bambi/interpret/utils.py
print(vars(component))
{'call': <formulae.terms.call_resolver.LazyCall at 0x288cf53c0>,
 'is_response': False,
 'name': 'hsgp(Fall_Year, by=Age_Group, m=10, c=2, scale=True, centered=True)',
 'contrast_matrix': None,
 'env': <formulae.environment.Environment at 0x288cf7b20>,
 'kind': 'numeric',
 'levels': None,
 'spans_intercept': None,
 'value': array([[2011,    2],
        [2012,    0],
        [2012,    1],
        [2012,    2],
        [2014,    2],
        [2014,    2],
        [2015,    3],
        [2016,    0],
        [2016,    3],
        [2017,    0],
        [2018,    1],
        [2018,    3],
        [2019,    3],
        [2020,    0],
        [2020,    0],
        [2021,    0],
        [2021,    1],
        [2021,    2],
        [2021,    3],
        [2022,    0]]),
 '_intermediate_data': array([[2011,    2],
        [2012,    0],
        [2012,    1],
        [2012,    2],
        [2014,    2],
        [2014,    2],
        [2015,    3],
        [2016,    0],
        [2016,    3],
        [2017,    0],
        [2018,    1],
        [2018,    3],
        [2019,    3],
        [2020,    0],
        [2020,    0],
        [2021,    0],
        [2021,    1],
        [2021,    2],
        [2021,    3],
        [2022,    0]])}

which I had originally thought would contain the by argument. But here it says None. Is this a bug? Or do you know how we can access the arguments to by?

GStechschulte avatar Feb 11 '24 07:02 GStechschulte

If you want to retain levels / values for some combination of the other covariates, you can pass more covariates to average_by. This is particular useful for plotting

Aaaaah ok @GStechschulte , thanks for the detailed answer and help! That's my mistake then -- just wasn't using the function properly.

You know why the years are plotted as floats though? Even though they are entered as int64

AlexAndorra avatar Feb 11 '24 20:02 AlexAndorra

@AlexAndorra anytime! Thanks for raising the questions and issues. No worries, the function signature is a bit overwhelming.

Yeah, I am looking into that. That should not be happening.

GStechschulte avatar Feb 12 '24 10:02 GStechschulte

@GStechschulte

which I had originally thought would contain the by argument. But here it says None. Is this a bug? Or do you know how we can access the arguments to by?

The LazyCall instance has several attributes, one is args, which is what we are already using, and the other is kwargs, which we are not using. See for example

vars(component.call)
{'callee': 'hsgp',
 'args': [<formulae.terms.call_resolver.LazyVariable at 0x7f056c63d840>],
 'kwargs': {'by': <formulae.terms.call_resolver.LazyVariable at 0x7f056c63d8a0>,
  'm': <formulae.terms.call_resolver.LazyValue at 0x7f056c63d960>,
  'c': <formulae.terms.call_resolver.LazyValue at 0x7f056c63d9c0>,
  'scale': <formulae.terms.call_resolver.LazyVariable at 0x7f056c63da20>,
  'centered': <formulae.terms.call_resolver.LazyVariable at 0x7f056c63dae0>},
 'stateful_transform': <bambi.transformations.HSGP at 0x7f056c63ded0>}

And to access the name of the by variable we can do component.call.kwargs["by"].name.

So, we need to check kwargs in addition to args. However, I'm realizing this won't be that simple (for now) because:

  • Ideally, we only need to get the name of the kwargs that are LazyVariables. That's not hard, it's simply kwarg.name. However...
  • Formulae is using LazyVariable for things that are not variable names, but values, such as True. This is not a problem for Formulae because that "name" is in the namespace where formulas are evaluated. But...
  • If we do kwarg.name on the LazyVariable of, for example, centered, the program will think there's a variable called "True", and that's not the intended behavior.

I think we can do the following:

  • Update Formulae so it uses LazyValues for things like False, True, and None (I need to check if there are more)
  • Update Bambi so it also uses the names passed to kwargs when these are instances of LazyVariable.
  • In parallel, open a PR and merge the change that fixes the bug with the HSGP components

I'll work on these things now

tomicapretto avatar Feb 18 '24 18:02 tomicapretto

@AlexAndorra I think it should be fixed now.

For that:

  1. Upgrade formulae. I released version 0.5.2
  2. Re-install Bambi from https://github.com/tomicapretto/bambi/tree/fix_hsgp_prediction.

I haven't tested it on my end, but I think the changes I implemented were the required ones :)

tomicapretto avatar Feb 18 '24 19:02 tomicapretto

@tomicapretto ahh that makes sense. Thanks for the information and for fixing this in formulae and bambi already!

GStechschulte avatar Feb 19 '24 11:02 GStechschulte