bambi
bambi copied to clipboard
`plot_predictions` breaks with HSGP
The attached dataset will help reproduce the issue, but basically:
formula = "p(Campus_Living, N) ~ 1 + (1 | Age_Group) + (1 | Major) + hsgp(Fall_Year, by=Age_Group, m=10, c=2, scale=True, centered=True) + hsgp(Fall_Year, by=Major, m=10, c=2, scale=True, centered=True)"
m_bambi = bmb.Model(
formula=formula,
data=data,
family="binomial",
categorical=["Age_Group", "Major"],
)
m_bambi.build()
idata_bambi = m_bambi.fit()
## this line fails
bmb.interpret.plot_predictions(
m_bambi,
idata_bambi,
conditional={"Fall_Year": np.linspace(2012, 2025)},
);
yields (full traceback at end of post):
IndexError: index -1 is out of bounds for axis 1 with size 0
If you do formula = "p(Campus_Living, N) ~ 1 + hsgp(Fall_Year, by=Age_Group, m=10, c=2, scale=True, centered=True)"
, the error becomesKeyError: 'Age_Group'
.
Interestingly, if you change formula
to "p(Campus_Living, N) ~ 1 + (1 | Age_Group) + (1 | Major) + hsgp(Fall_Year, by=Age_Group, m=10, c=2, scale=True, centered=True)"
or "p(Campus_Living, N) ~ 1 + (1 | Age_Group) + hsgp(Fall_Year, by=Age_Group, m=10, c=2, scale=True, centered=True)"
, the plot will display (although I'm not sure it's right)!
It looks like some indexing is happening in by_values
in bambi/model_components.py:262
, which yields the error.
Thanks a lot for your help, and LMK if anything is unclear 🙏
Full error:
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[225], line 1
----> 1 bmb.interpret.plot_predictions(
2 m_bambi,
3 idata_bambi,
4 conditional={"Fall_Year": np.linspace(2012, 2025)},
5 );
File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/bambi/interpret/plotting.py:181, in plot_predictions(model, idata, conditional, average_by, target, sample_new_groups, pps, use_hdi, prob, transforms, legend, ax, fig_kwargs, subplot_kwargs)
174 if average_by is True:
175 raise ValueError(
176 "Plotting when 'average_by = True' is not possible as 'True' marginalizes "
177 "over all covariates resulting in a single prediction estimate. "
178 "Please pass a covariate(s) to 'average_by'."
179 )
--> 181 cap_data = predictions(
182 model=model,
183 idata=idata,
184 conditional=conditional,
185 average_by=average_by,
186 target=target,
187 pps=pps,
188 use_hdi=use_hdi,
189 prob=prob,
190 transforms=transforms,
191 sample_new_groups=sample_new_groups,
192 )
194 conditional_info = ConditionalInfo(model, conditional)
195 transforms = transforms if transforms is not None else {}
File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/bambi/interpret/effects.py:532, in predictions(model, idata, conditional, average_by, target, pps, use_hdi, prob, transforms, sample_new_groups)
530 y_hat_mean = y_hat.mean(("chain", "draw"))
531 else:
--> 532 idata = model.predict(
533 idata, data=cap_data, sample_new_groups=sample_new_groups, inplace=False
534 )
535 y_hat = response_transform(idata["posterior"][response.name_target])
536 y_hat_mean = y_hat.mean(("chain", "draw"))
File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/bambi/models.py:832, in Model.predict(self, idata, kind, data, inplace, include_group_specific, sample_new_groups)
829 else:
830 var_name = f"{response_aliased_name}_{name}"
--> 832 means_dict[var_name] = component.predict(
833 idata, data, include_group_specific, hsgp_dict, sample_new_groups
834 )
836 # Drop var/dim if already present. Needed for out-of-sample predictions.
837 if var_name in idata.posterior.data_vars:
File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/bambi/model_components.py:185, in DistributionalComponent.predict(self, idata, data, include_group_specific, hsgp_dict, sample_new_groups)
182 linear_predictor_dims = linear_predictor_dims + (response_levels_dim,)
184 if self.design.common:
--> 185 linear_predictor += self.predict_common(
186 posterior, data, in_sample, to_stack_dims, design_matrix_dims, hsgp_dict
187 )
189 if self.design.group and include_group_specific:
190 linear_predictor += self.predict_group_specific(
191 posterior, data, in_sample, to_stack_dims, design_matrix_dims, sample_new_groups
192 )
File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/bambi/model_components.py:262, in DistributionalComponent.predict_common(self, posterior, data, in_sample, to_stack_dims, design_matrix_dims, hsgp_dict)
256 # NOTE:
257 # The approach here differs from the one in the PyMC implementation.
258 # Here we have a single dot product with many zeros, while there we have many
259 # smaller dot products.
260 # It is subject to change here, but I don't want to mess up dims and coords.
261 if term.by_levels is not None:
--> 262 by_values = x_slice[:, -1].astype(int)
263 x_slice = x_slice[:, :-1]
264 x_slice_centered = (x_slice.data - term.mean[by_values]) / maximum_distance
IndexError: index -1 is out of bounds for axis 1 with size 0
Thanks for raising the issue and the detailed information!
Note that this error also appears when sampling posterior predictions -- m_bambi.predict(idata_bambi, kind="pps")
I've found where the bug is!
https://github.com/bambinos/bambi/blob/b5b9f093c623636ae3a8e2f0765b2f01ca26f4b0/bambi/model_components.py#L242-L246
In line 246 we're deleting columns of the design matrix based on the slice term_slice
. This only happens for HSGP. The problem is that the slice for the second (and third, fourth, etc.) HSGP term does not reflect the shape of X at the time it's going to be used to slice X.
I think the fix is to delete the HSGP term data from the design matrix X only after having iterated over all the HSGP terms. I'll test it soon.
@AlexAndorra could you install Bambi from https://github.com/tomicapretto/bambi/tree/fix_hsgp_prediction and try it again?
In line 246 we're deleting columns of the design matrix based on the slice term_slice. This only happens for HSGP. The problem is that the slice for the second (and third, fourth, etc.) HSGP term does not reflect the shape of X at the time it's going to be used to slice X.
Makes sense indeed.
@AlexAndorra could you install Bambi from https://github.com/tomicapretto/bambi/tree/fix_hsgp_prediction and try it again?
Will do ASAP and of course keep you posted. Thanks for the lightning fast fix @tomicapretto 🙏
This works great for "p(Campus_Living, N) ~ 1 + (1 | Age_Group) + (1 | Major) + hsgp(Fall_Year, by=Age_Group, m=10, c=2, scale=True, centered=True) + hsgp(Fall_Year, by=Major, m=10, c=2, scale=True, centered=True)"
@tomicapretto 🍾
However, the error is still here when removing the group-specific effects: "p(Campus_Living, N) ~ 1 + hsgp(Fall_Year, by=Age_Group, m=10, c=2, scale=True, centered=True)"
,
The posterior predictions sample now, but the plot_predictions
still fails: KeyError: 'Age_Group'
.
This is triggered by:
File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/bambi/model_components.py:233, in DistributionalComponent.predict_common(self, posterior, data, in_sample, to_stack_dims, design_matrix_dims, hsgp_dict)
231 X = self.design.common.design_matrix
232 else:
--> 233 X = self.design.common.evaluate_new_data(data).design_matrix
235 # Add offset columns to their own design matrix and remove then from common matrix
So it still looks like a problem when updating the design matrix. LMK if that's unclear 🙏
Full traceback:
KeyError Traceback (most recent call last)
Cell In[28], line 2
1 ## this line fails
----> 2 bmb.interpret.plot_predictions(
3 m_bambi,
4 idata_bambi,
5 conditional={"Fall_Year": np.linspace(2012, 2025)},
6 );
File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/bambi/interpret/plotting.py:181, in plot_predictions(model, idata, conditional, average_by, target, sample_new_groups, pps, use_hdi, prob, transforms, legend, ax, fig_kwargs, subplot_kwargs)
174 if average_by is True:
175 raise ValueError(
176 "Plotting when 'average_by = True' is not possible as 'True' marginalizes "
177 "over all covariates resulting in a single prediction estimate. "
178 "Please pass a covariate(s) to 'average_by'."
179 )
--> 181 cap_data = predictions(
182 model=model,
183 idata=idata,
184 conditional=conditional,
185 average_by=average_by,
186 target=target,
187 pps=pps,
188 use_hdi=use_hdi,
189 prob=prob,
190 transforms=transforms,
191 sample_new_groups=sample_new_groups,
192 )
194 conditional_info = ConditionalInfo(model, conditional)
195 transforms = transforms if transforms is not None else {}
File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/bambi/interpret/effects.py:532, in predictions(model, idata, conditional, average_by, target, pps, use_hdi, prob, transforms, sample_new_groups)
530 y_hat_mean = y_hat.mean(("chain", "draw"))
531 else:
--> 532 idata = model.predict(
533 idata, data=cap_data, sample_new_groups=sample_new_groups, inplace=False
534 )
535 y_hat = response_transform(idata["posterior"][response.name_target])
536 y_hat_mean = y_hat.mean(("chain", "draw"))
File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/bambi/models.py:834, in Model.predict(self, idata, kind, data, inplace, include_group_specific, sample_new_groups)
831 else:
832 var_name = f"{response_aliased_name}_{name}"
--> 834 means_dict[var_name] = component.predict(
835 idata, data, include_group_specific, hsgp_dict, sample_new_groups
836 )
838 # Drop var/dim if already present. Needed for out-of-sample predictions.
839 if var_name in idata.posterior.data_vars:
File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/bambi/model_components.py:185, in DistributionalComponent.predict(self, idata, data, include_group_specific, hsgp_dict, sample_new_groups)
182 linear_predictor_dims = linear_predictor_dims + (response_levels_dim,)
184 if self.design.common:
--> 185 linear_predictor += self.predict_common(
186 posterior, data, in_sample, to_stack_dims, design_matrix_dims, hsgp_dict
187 )
189 if self.design.group and include_group_specific:
190 linear_predictor += self.predict_group_specific(
191 posterior, data, in_sample, to_stack_dims, design_matrix_dims, sample_new_groups
192 )
File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/bambi/model_components.py:233, in DistributionalComponent.predict_common(self, posterior, data, in_sample, to_stack_dims, design_matrix_dims, hsgp_dict)
231 X = self.design.common.design_matrix
232 else:
--> 233 X = self.design.common.evaluate_new_data(data).design_matrix
235 # Add offset columns to their own design matrix and remove then from common matrix
236 for term in self.offset_terms:
File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/formulae/matrices.py:259, in CommonEffectsMatrix.evaluate_new_data(self, data)
256 new_instance.data = data
257 new_instance.env = self.env
258 new_instance.design_matrix = np.column_stack(
--> 259 [t.eval_new_data(data) for t in self.terms.values()]
260 )
261 new_instance.slices = self.slices
262 new_instance.evaluated = True
File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/formulae/matrices.py:259, in <listcomp>(.0)
256 new_instance.data = data
257 new_instance.env = self.env
258 new_instance.design_matrix = np.column_stack(
--> 259 [t.eval_new_data(data) for t in self.terms.values()]
260 )
261 new_instance.slices = self.slices
262 new_instance.evaluated = True
File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/formulae/terms/terms.py:496, in Term.eval_new_data(self, data)
492 result = reduce(
493 get_interaction_matrix, [c.eval_new_data(data) for c in self.components]
494 )
495 else:
--> 496 result = self.components[0].eval_new_data(data)
497 return result
File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/formulae/terms/call.py:274, in Call.eval_new_data(self, data_mask)
255 """Evaluates the function call with new data.
256
257 This method evaluates the function call within a new data mask. If the transformation
(...)
271 categoric ones.
272 """
273 if self.kind in ["numeric", "categoric"]:
--> 274 x = self.call.eval(data_mask, self.env)
275 if self.kind == "numeric":
276 result = self.eval_new_data_numeric(x)
File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/formulae/terms/call_resolver.py:270, in LazyCall.eval(self, data_mask, env)
267 callee = self.stateful_transform
269 args = [arg.eval(data_mask, env) for arg in self.args]
--> 270 kwargs = {name: arg.eval(data_mask, env) for name, arg in self.kwargs.items()}
272 return callee(*args, **kwargs)
File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/formulae/terms/call_resolver.py:270, in <dictcomp>(.0)
267 callee = self.stateful_transform
269 args = [arg.eval(data_mask, env) for arg in self.args]
--> 270 kwargs = {name: arg.eval(data_mask, env) for name, arg in self.kwargs.items()}
272 return callee(*args, **kwargs)
File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/formulae/terms/call_resolver.py:141, in LazyVariable.eval(self, data_mask, env)
139 result = env.namespace[self.name]
140 except KeyError as e:
--> 141 raise e
142 return result
File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/formulae/terms/call_resolver.py:139, in LazyVariable.eval(self, data_mask, env)
137 except KeyError:
138 try:
--> 139 result = env.namespace[self.name]
140 except KeyError as e:
141 raise e
File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/formulae/environment.py:17, in VarLookupDict.__getitem__(self, key)
15 except KeyError:
16 pass
---> 17 raise KeyError(key)
KeyError: 'Age_Group'
Thanks a lot @AlexAndorra. formulae
is raising the correct error.
The error lies in interpret
and is happening because when creating the data for plot_predictions
, we first return all covariates specified in the model. However, we have a bug when an arg. is passed to by
of the HSGP term. Only Fall_Year
is being returned. Both Fall_Year
and Age_Group
should be returned. This isn't a problem with your first model formula because Age_Group
is a group-specific term.
As a workaround until I fix the bug I suggest the following.
bmb.interpret.plot_predictions(
m_bambi_2,
idata_bambi_2,
conditional={
"Fall_Year": np.linspace(2012, 2025),
"Age_Group": data.Age_Group.unique()
},
subplot_kwargs={
"main": "Fall_Year",
"group": "Age_Group",
"panel": "Age_Group"
},
fig_kwargs={"figsize": (18, 4), "sharey": True},
legend=False
);
We are able to pass Age_Group
to conditional
because we know a priori that it should exist in the data grid. However, if you attempt to pass a variable to conditional
that was not specified in the formula, then you will get a KeyError
.
Aaaaah, that makes sense, thanks a lot @GStechschulte !
Is that gonna work if I have more GPs though? Let's say I have 3 different GPs, on 3 different categorical variables. So now, in conditional
I'm gonna have 4 keys (Fall_Year
and the 3 covariates), so I have to specify average_by
, which I think is gonna have the same issue. Correct?
Yeah I can confirm. If you do:
grades = np.random.choice(["A", "B", "C"], size=len(fake_data))
fake_data["Grades"] = grades
formula = "p(Campus_Living, N) ~ 0 + hsgp(Fall_Year, by=Age_Group, m=10, c=2) + hsgp(Fall_Year, by=Major, m=10, c=2) + hsgp(Fall_Year, by=Grades, m=10, c=2)"
m_bambi = bmb.Model(
formula=formula,
data=fake_data,
family="binomial",
categorical=["Age_Group", "Major", "Grades"],
)
m_bambi.build()
idata_bambi = m_bambi.fit()
# This fails
bmb.interpret.plot_predictions(
m_bambi,
idata_bambi,
conditional={
"Fall_Year": np.linspace(2012, 2025),
"Age_Group": fake_data.Age_Group.unique(),
"Major": "Science",
"Grades": "A",
},
average_by="Age_Group",
);
Now you get:
KeyError: 'Fall_Year'
The above exception was the direct cause of the following exception:
KeyError Traceback (most recent call last)
Cell In[57], line 1
----> 1 bmb.interpret.plot_predictions(
2 m_bambi,
3 idata_bambi,
4 conditional={
5 "Fall_Year": np.linspace(2012, 2025),
6 "Age_Group": fake_data.Age_Group.unique(),
7 "Major": fake_data.Major.unique(),
8 "Grades": fake_data.Grades.unique(),
9 },
10 average_by="Age_Group",
11 subplot_kwargs={
12 "main": "Fall_Year",
13 "group": "Age_Group",
14 "panel": "Age_Group"
15 },
16 fig_kwargs={"figsize": (18, 4), "sharey": True},
17 legend=False
18 );
File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/bambi/interpret/plotting.py:225, in plot_predictions(model, idata, conditional, average_by, target, sample_new_groups, pps, use_hdi, prob, transforms, legend, ax, fig_kwargs, subplot_kwargs)
222 else:
223 fig = axes[0].get_figure()
--> 225 if is_numeric_dtype(cap_data[covariates.main]):
226 axes = plot_numeric(covariates, cap_data, transforms, legend, axes)
227 elif is_categorical_dtype(cap_data[covariates.main]) or is_string_dtype(
228 cap_data[covariates.main]
229 ):
File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/pandas/core/frame.py:3761, in DataFrame.__getitem__(self, key)
3759 if self.columns.nlevels > 1:
3760 return self._getitem_multilevel(key)
-> 3761 indexer = self.columns.get_loc(key)
3762 if is_integer(indexer):
3763 indexer = [indexer]
File ~/mambaforge/envs/bayes-workshop/lib/python3.11/site-packages/pandas/core/indexes/base.py:3655, in Index.get_loc(self, key)
3653 return self._engine.get_loc(casted_key)
3654 except KeyError as err:
-> 3655 raise KeyError(key) from err
3656 except TypeError:
3657 # If we have a listlike key, _check_indexing_error will raise
3658 # InvalidIndexError. Otherwise we fall through and re-raise
3659 # the TypeError.
3660 self._check_indexing_error(key)
Hope that helps!
@AlexAndorra Mmmm. For me, the three HSGP terms works with
grades = np.random.choice(["A", "B", "C"], size=len(fake_data))
fake_data["Grades"] = grades
formula = "p(Campus_Living, N) ~ 0 + hsgp(Fall_Year, by=Age_Group, m=10, c=2) + hsgp(Fall_Year, by=Major, m=10, c=2) + hsgp(Fall_Year, by=Grades, m=10, c=2)"
m_bambi = bmb.Model(
formula=formula,
data=fake_data,
family="binomial",
categorical=["Age_Group", "Major", "Grades"],
)
m_bambi.build()
idata_bambi = m_bambi.fit()
# This works
bmb.interpret.plot_predictions(
m_bambi,
idata_bambi,
conditional={
"Fall_Year": np.linspace(2012, 2025),
"Age_Group": fake_data.Age_Group.unique(),
"Major": "Science",
"Grades": "A",
},
average_by="Age_Group",
);
In the error output you provided, it looks like you were calling
Cell In[57], line 1
----> 1 bmb.interpret.plot_predictions(
2 m_bambi,
3 idata_bambi,
4 conditional={
5 "Fall_Year": np.linspace(2012, 2025),
6 "Age_Group": fake_data.Age_Group.unique(),
7 "Major": fake_data.Major.unique(),
8 "Grades": fake_data.Grades.unique(),
9 },
10 average_by="Age_Group",
11 subplot_kwargs={
12 "main": "Fall_Year",
13 "group": "Age_Group",
14 "panel": "Age_Group"
15 },
16 fig_kwargs={"figsize": (18, 4), "sharey": True},
17 legend=False
18 );
This is failing because you are only averaging by Age_Group
. This is applying an aggregation function (average) for each level of Age_Group
and then combining the result; effectively marginalizing over fall year, major, and grades to get the average for each age group level.
For example:
bmb.interpret.predictions(
m_bambi,
idata_bambi,
conditional={
"Fall_Year": np.linspace(2012, 2025),
"Age_Group": fake_data.Age_Group.unique(),
"Major": fake_data.Major.unique(),
"Grades": fake_data.Grades.unique(),
},
average_by="Age_Group"
)
Age_Group | estimate | lower_3.0% | upper_97.0% |
---|---|---|---|
23-26 | 0.370167 | 0.089879 | 0.650803 |
<23 | 0.451753 | 0.149715 | 0.781734 |
>26 | 0.439085 | 0.138099 | 0.753850 |
Unknown | 0.361037 | 0.078200 | 0.637989 |
If you want to retain levels / values for some combination of the other covariates, you can pass more covariates to average_by
. This is particular useful for plotting:
bmb.interpret.plot_predictions(
m_bambi,
idata_bambi,
conditional={
"Fall_Year": np.linspace(2012, 2025),
"Age_Group": fake_data.Age_Group.unique(),
"Major": fake_data.Major.unique(),
"Grades": fake_data.Grades.unique(),
},
average_by=["Fall_Year", "Age_Group", "Major"],
subplot_kwargs={
"main": "Fall_Year",
"group": "Age_Group",
"panel": "Major"
},
fig_kwargs={"figsize": (18, 4), "sharey": True},
legend=True
);
@tomicapretto the HSGP term component is a function call, so we can access the argument name to get Fall_Year
. However, this does not return the argument to by
of the HSGP term.
The HSGP component also has a levels
attribute
# Putting this print inside of `get_model_covariates` of bambi/interpret/utils.py
print(vars(component))
{'call': <formulae.terms.call_resolver.LazyCall at 0x288cf53c0>,
'is_response': False,
'name': 'hsgp(Fall_Year, by=Age_Group, m=10, c=2, scale=True, centered=True)',
'contrast_matrix': None,
'env': <formulae.environment.Environment at 0x288cf7b20>,
'kind': 'numeric',
'levels': None,
'spans_intercept': None,
'value': array([[2011, 2],
[2012, 0],
[2012, 1],
[2012, 2],
[2014, 2],
[2014, 2],
[2015, 3],
[2016, 0],
[2016, 3],
[2017, 0],
[2018, 1],
[2018, 3],
[2019, 3],
[2020, 0],
[2020, 0],
[2021, 0],
[2021, 1],
[2021, 2],
[2021, 3],
[2022, 0]]),
'_intermediate_data': array([[2011, 2],
[2012, 0],
[2012, 1],
[2012, 2],
[2014, 2],
[2014, 2],
[2015, 3],
[2016, 0],
[2016, 3],
[2017, 0],
[2018, 1],
[2018, 3],
[2019, 3],
[2020, 0],
[2020, 0],
[2021, 0],
[2021, 1],
[2021, 2],
[2021, 3],
[2022, 0]])}
which I had originally thought would contain the by
argument. But here it says None
. Is this a bug? Or do you know how we can access the arguments to by
?
If you want to retain levels / values for some combination of the other covariates, you can pass more covariates to average_by. This is particular useful for plotting
Aaaaah ok @GStechschulte , thanks for the detailed answer and help! That's my mistake then -- just wasn't using the function properly.
You know why the years are plotted as floats though? Even though they are entered as int64
@AlexAndorra anytime! Thanks for raising the questions and issues. No worries, the function signature is a bit overwhelming.
Yeah, I am looking into that. That should not be happening.
@GStechschulte
which I had originally thought would contain the by argument. But here it says None. Is this a bug? Or do you know how we can access the arguments to by?
The LazyCall
instance has several attributes, one is args
, which is what we are already using, and the other is kwargs
, which we are not using. See for example
vars(component.call)
{'callee': 'hsgp',
'args': [<formulae.terms.call_resolver.LazyVariable at 0x7f056c63d840>],
'kwargs': {'by': <formulae.terms.call_resolver.LazyVariable at 0x7f056c63d8a0>,
'm': <formulae.terms.call_resolver.LazyValue at 0x7f056c63d960>,
'c': <formulae.terms.call_resolver.LazyValue at 0x7f056c63d9c0>,
'scale': <formulae.terms.call_resolver.LazyVariable at 0x7f056c63da20>,
'centered': <formulae.terms.call_resolver.LazyVariable at 0x7f056c63dae0>},
'stateful_transform': <bambi.transformations.HSGP at 0x7f056c63ded0>}
And to access the name of the by
variable we can do component.call.kwargs["by"].name
.
So, we need to check kwargs
in addition to args
. However, I'm realizing this won't be that simple (for now) because:
- Ideally, we only need to get the name of the kwargs that are
LazyVariable
s. That's not hard, it's simplykwarg.name
. However... - Formulae is using
LazyVariable
for things that are not variable names, but values, such asTrue
. This is not a problem for Formulae because that "name" is in the namespace where formulas are evaluated. But... - If we do
kwarg.name
on theLazyVariable
of, for example,centered
, the program will think there's a variable called"True"
, and that's not the intended behavior.
I think we can do the following:
- Update Formulae so it uses
LazyValue
s for things likeFalse
,True
, andNone
(I need to check if there are more) - Update Bambi so it also uses the names passed to
kwargs
when these are instances ofLazyVariable
. - In parallel, open a PR and merge the change that fixes the bug with the HSGP components
I'll work on these things now
@AlexAndorra I think it should be fixed now.
For that:
- Upgrade formulae. I released version 0.5.2
- Re-install Bambi from https://github.com/tomicapretto/bambi/tree/fix_hsgp_prediction.
I haven't tested it on my end, but I think the changes I implemented were the required ones :)
@tomicapretto ahh that makes sense. Thanks for the information and for fixing this in formulae
and bambi
already!