add a sampling class to support prior and posterior predictive sampling with HGF CustomDist
Hello @LegrandNico, I am working on applying HGFs to behavioural data and wanted to use the pyhgf-package for sampling some rather complex HGFs. It generally works really great, however I am not sure how/whether it is possible to apply the classes/functions like HGFDistribution for such models.
When I am not mistaken, currently the HGFDistribution class (and, analogously, the Gradient an Pointwise classes) only supports 2- and 3-level ‘standard’ HGFs (since the parameters for 2- and 3-level are hardcoded)? In particular, ‘general/arbitrary’ HGFs (e.g. containing more nodes or multiple inputs) cannot be directly used with HGFDistribution and CustomDist?
Was the purpose of this issue to extend the functionality?
I tried to achieve this by modifying the classes from HDFDistribution and adding functions which transform the HGF attributes in data structures compatible with jax and vice versa so that the classes accept the parameters to be sampled wrapped into tuples etc., getting argument signatures like:
class HGFDistributionDyn(Op):
def __init__(
self,
new_attributes,
input_data: ArrayLike = jnp.nan,
time_steps: Optional[ArrayLike] = None,
make_HGF = None,
logp_function = dyn_logp,
hgf_logp_function = dyn_hgf_logp,
gradient_function = HGFLogpGradOp_Dyn,
response_function: Optional[Callable] = None,
response_function_inputs: Optional[ArrayLike] = None,
observed_input = None
):
or:
def dyn_logp(
new_attributes,
response_function_parameters: ArrayLike,
input_data: ArrayLike,
time_steps: ArrayLike,
make_HGF: Callable,
response_function_inputs: Optional[Union[ArrayLike, tuple]],
response_function: Callable = lambda: None,
paths_attributes = None
)
Within the modified logp-functions, the HGF with the variables to be sampled is ‘reconstructed’ from tuples (which also allows to use multiple cores, because the HGF object is not shared anymore, and arbitrary inputs), e.g.:
#create HGF
hgf = make_HGF()
paths_attributes_new = split_path_to_tuple(paths_attributes)
#modify
for i, path_tuple in enumerate(paths_attributes_new):
new_value_to_set = new_attributes[i]
set_nested_attr(hgf.attributes, path_tuple, new_value_to_set)
Finally, I rewrote the class functions without any hardcoded variables, but only passing the values of the tuples containing attributes to be sampled, e.g.:
def grad(
self, inputs: List[TensorVariable], output_gradients: List[TensorVariable]
) -> List[TensorVariable]:
"""Gradient of the function."""
grads = self.hgf_logp_grad_op(inputs)
output_gradient = output_gradients[0]
return [output_gradient * g for g in grads]
Now I would like to ask whether this is an appropriate/working way to modify the Distribution-classes to accept arbitrary networks and use CustomDist or am I missing something?
I checked it for the 2- and 3-level HGFs – as well as for the 2-armed bandit from the tutorials - and it seemingly gives the same results, but perhaps I should not trust my results for more general cases for reasons unknown to me...
Hi @jalhackl , Thank you for pointing this out.
You are right that the distributions currently implemented in the toolbox only support standard HGF. The point was to make it easy (as much as possible) for people to sample these models.
However, there is no limit on the kind of models we can sample, and it is on my to-do list to write a tutorial on that at some point, showing how to sample any network/parameter.
The idea is to express your whole model as a Jax function returning the log probability, and to sample this function using a PYMC node, BlackJAX, or Numpyro. So my plan was to create a Custom Pymc node that can sample any function by automatically scaling to the number of parameters.
I have also tried an approach that was similar to the one you suggested, but for some reason, I couldn't make it work. Can you give more details on what set_nested_attr is doing?
Set_nested_attr itself is very simple, all it does is setting the value in the innermost level in the nested attributes-dict (in which the outer dict usually indicates the node, and the inner one parameters like tonic_volatility):
def set_nested_attr(d, keys, value):
for key in keys[:-1]:
d = d[key]
d[keys[-1]] = value
The more complex changes concern the preprocessing (and the distribution-classes mentioned above). I start creating dicts containing the values to sample, e.g.
vals_to_optimize = ["tonic_volatility_1", "tonic_volatility_2", "tonic_volatility_3", …]
param_dict = dict()
param_dict["tonic_volatility_1"] = {"name": "tonic_volatility_1", "mean": -5, "std": 2, "distro": "Normal"}
param_dict["tonic_volatility_2"] = {"name": "tonic_volatility_2", "mean": -5, "std": 2, "distro": "Normal"}
param_dict["mean_0"] = {"name": "mean_0", "min_value": 0, "max_value": 1, "distro": "Uniform"}
param_dict["mean_1"] = {"name": "mean_1", "min_value":0, "max_value": 1, "distro": "Uniform"}
response_param_dict = dict()
response_vals_to_optimize = ["inverse_temperature_0"]
response_param_dict["inverse_temperature_0"] = {"name": "inverse_temperature_0", "min_value": 2, "max_value": 10, "distro": "Uniform"}
and later on I initialize the modified classes
hgf_logp_op_dyn = HGFDistributionDynCustomDist(
new_attributes=new_attributes_list_all,
input_data=input_data,
make_HGF=make_HGF,
response_function=response_function,
response_function_inputs=observed,
logp_function=logp_function,
observed_input=observed_rewards,
response_function_params_length=response_function_parameters_length,
masking_value=masking_value,
time_steps=time_steps,
)
hgf_logp_op_dyn_pointwise = HGFDistributionDynCustomDistPointwiseDyn(
new_attributes=new_attributes_list_all,
input_data=input_data,
make_HGF=make_HGF,
response_function=response_function,
response_function_inputs=observed,
logp_function=logp_function,
observed_input=observed_rewards,
response_function_params_length=response_function_parameters_length,
masking_value=masking_value,
time_steps=time_steps,
)
def logp(value, new_attributes_list_all, response_function_parameters):
return hgf_logp_op_dyn(
new_attributes=new_attributes_list_all,
response_function_parameters=response_function_parameters,
)
def logp_pointwise(new_attributes_list_all, response_function_parameters):
return hgf_logp_op_dyn_pointwise(
new_attributes=new_attributes_list_all,
response_function_parameters=response_function_parameters,
)
and use the dicts in a function for the sampling, i.e. all variables in vals_to_optimize are sampled:
with pm.Model() as a_hgf_model:
# response parameters
if response_param_dict:
response_function_parameters = []
for dict_entry in response_param_dict:
if response_param_dict[dict_entry]["distro"] == "Normal":
new_response_sample_param = pm.Uniform(
response_param_dict[dict_entry]["name"],
response_param_dict[dict_entry]["min_value"],
response_param_dict[dict_entry]["max_value"],
)
elif response_param_dict[dict_entry]["distro"] == "Uniform":
[…]
# sample parameters
sample_params = []
sample_params_list = []
for dict_entry in param_dict.values():
tuple_index = find_tuple_index(flat_tuples, dict_entry["name"], sep=".")
if tuple_index > -1:
while len(sample_params_list) <= tuple_index:
sample_params_list.append(None)
sample_params_list[tuple_index] = dict_entry
sample_params_list = [v for v in sample_params_list if v is not None]
[…] sample_params.append([dict_entry["name"], new_sample_param])
for entry in sample_params:
tuple_index = find_tuple_index(flat_tuples, entry[0], sep=".")
if tuple_index > -1:
key, _ = new_attributes_list_all[tuple_index]
new_attributes_list_all[tuple_index] = (key, entry[1])
param_values[tuple_index] = entry[1]
# sampling step
log_likelihood = pm.CustomDist(
"log_likelihood", param_values, response_function_parameters, logp=logp, observed=input_data
)
pm.Deterministic("pointwise_loglikelihood", logp_pointwise(new_attributes_list_all=param_values, response_function_parameters=response_function_parameters))
levels_idata = pm.sample(chains=chains, cores=cores)
levels_idata.add_groups(
log_likelihood=levels_idata.posterior["pointwise_loglikelihood"])
As far as I can tell, this set-upt works with arbitrary HGFs and variables to be sampled/optimized. If you want, I can provide you access to my github-repository in which I use the modified functions, this makes it probably easier than just the code snippets.