pyhgf icon indicating copy to clipboard operation
pyhgf copied to clipboard

add a sampling class to support prior and posterior predictive sampling with HGF CustomDist

Open LegrandNico opened this issue 1 year ago • 3 comments

LegrandNico avatar Jul 04 '24 11:07 LegrandNico

Hello @LegrandNico, I am working on applying HGFs to behavioural data and wanted to use the pyhgf-package for sampling some rather complex HGFs. It generally works really great, however I am not sure how/whether it is possible to apply the classes/functions like HGFDistribution for such models.

When I am not mistaken, currently the HGFDistribution class (and, analogously, the Gradient an Pointwise classes) only supports 2- and 3-level ‘standard’ HGFs (since the parameters for 2- and 3-level are hardcoded)? In particular, ‘general/arbitrary’ HGFs (e.g. containing more nodes or multiple inputs) cannot be directly used with HGFDistribution and CustomDist?

Was the purpose of this issue to extend the functionality?

I tried to achieve this by modifying the classes from HDFDistribution and adding functions which transform the HGF attributes in data structures compatible with jax and vice versa so that the classes accept the parameters to be sampled wrapped into tuples etc., getting argument signatures like:

class HGFDistributionDyn(Op):
    def __init__(
        self,
        new_attributes,
        input_data: ArrayLike = jnp.nan,
        time_steps: Optional[ArrayLike] = None,
        make_HGF = None,
        logp_function = dyn_logp,
        hgf_logp_function = dyn_hgf_logp,
        gradient_function = HGFLogpGradOp_Dyn,
        response_function: Optional[Callable] = None,
        response_function_inputs: Optional[ArrayLike] = None,
        observed_input = None
    ):

or:


def dyn_logp(
    new_attributes,
    response_function_parameters: ArrayLike,
    input_data: ArrayLike,
    time_steps: ArrayLike,
    make_HGF: Callable,
    response_function_inputs: Optional[Union[ArrayLike, tuple]],
    response_function: Callable = lambda: None,
    paths_attributes = None
)

Within the modified logp-functions, the HGF with the variables to be sampled is ‘reconstructed’ from tuples (which also allows to use multiple cores, because the HGF object is not shared anymore, and arbitrary inputs), e.g.:


    #create HGF
    hgf = make_HGF()
    paths_attributes_new = split_path_to_tuple(paths_attributes)
    #modify
    for i, path_tuple in enumerate(paths_attributes_new):
        new_value_to_set = new_attributes[i]
        set_nested_attr(hgf.attributes, path_tuple, new_value_to_set)


Finally, I rewrote the class functions without any hardcoded variables, but only passing the values of the tuples containing attributes to be sampled, e.g.:

    def grad(
        self, inputs: List[TensorVariable], output_gradients: List[TensorVariable]
    ) -> List[TensorVariable]:
        """Gradient of the function."""
        grads = self.hgf_logp_grad_op(inputs)

        output_gradient = output_gradients[0]

        return [output_gradient * g for g in grads]

Now I would like to ask whether this is an appropriate/working way to modify the Distribution-classes to accept arbitrary networks and use CustomDist or am I missing something?
I checked it for the 2- and 3-level HGFs – as well as for the 2-armed bandit from the tutorials - and it seemingly gives the same results, but perhaps I should not trust my results for more general cases for reasons unknown to me...

jalhackl avatar Sep 04 '25 12:09 jalhackl

Hi @jalhackl , Thank you for pointing this out.

You are right that the distributions currently implemented in the toolbox only support standard HGF. The point was to make it easy (as much as possible) for people to sample these models.

However, there is no limit on the kind of models we can sample, and it is on my to-do list to write a tutorial on that at some point, showing how to sample any network/parameter.

The idea is to express your whole model as a Jax function returning the log probability, and to sample this function using a PYMC node, BlackJAX, or Numpyro. So my plan was to create a Custom Pymc node that can sample any function by automatically scaling to the number of parameters.

I have also tried an approach that was similar to the one you suggested, but for some reason, I couldn't make it work. Can you give more details on what set_nested_attr is doing?

LegrandNico avatar Sep 11 '25 12:09 LegrandNico

Set_nested_attr itself is very simple, all it does is setting the value in the innermost level in the nested attributes-dict (in which the outer dict usually indicates the node, and the inner one parameters like tonic_volatility):


def set_nested_attr(d, keys, value):
    for key in keys[:-1]:
        d = d[key]
    d[keys[-1]] = value

The more complex changes concern the preprocessing (and the distribution-classes mentioned above). I start creating dicts containing the values to sample, e.g.

vals_to_optimize = ["tonic_volatility_1", "tonic_volatility_2", "tonic_volatility_3", …]

param_dict = dict()
param_dict["tonic_volatility_1"] = {"name": "tonic_volatility_1", "mean": -5, "std": 2, "distro": "Normal"}
param_dict["tonic_volatility_2"] = {"name": "tonic_volatility_2", "mean": -5, "std": 2, "distro": "Normal"}
param_dict["mean_0"] = {"name": "mean_0", "min_value": 0, "max_value": 1, "distro": "Uniform"}
param_dict["mean_1"] = {"name": "mean_1", "min_value":0, "max_value": 1, "distro": "Uniform"}    
response_param_dict = dict()
response_vals_to_optimize = ["inverse_temperature_0"]
response_param_dict["inverse_temperature_0"] = {"name": "inverse_temperature_0", "min_value": 2, "max_value": 10, "distro": "Uniform"}

and later on I initialize the modified classes


    hgf_logp_op_dyn = HGFDistributionDynCustomDist(
        new_attributes=new_attributes_list_all,
        input_data=input_data,
        make_HGF=make_HGF,
        response_function=response_function,
        response_function_inputs=observed,
        logp_function=logp_function,
        observed_input=observed_rewards,
        response_function_params_length=response_function_parameters_length,
        masking_value=masking_value,
        time_steps=time_steps,
    )

    hgf_logp_op_dyn_pointwise = HGFDistributionDynCustomDistPointwiseDyn(
        new_attributes=new_attributes_list_all,
        input_data=input_data,
        make_HGF=make_HGF,
        response_function=response_function,
        response_function_inputs=observed,
        logp_function=logp_function,
        observed_input=observed_rewards,
        response_function_params_length=response_function_parameters_length,
        masking_value=masking_value,
        time_steps=time_steps,
    )

    def logp(value, new_attributes_list_all, response_function_parameters):
        return hgf_logp_op_dyn(
            new_attributes=new_attributes_list_all,
            response_function_parameters=response_function_parameters,
        )

    def logp_pointwise(new_attributes_list_all, response_function_parameters):
        return hgf_logp_op_dyn_pointwise(
            new_attributes=new_attributes_list_all,
            response_function_parameters=response_function_parameters,
        )

and use the dicts in a function for the sampling, i.e. all variables in vals_to_optimize are sampled:


with pm.Model() as a_hgf_model:
        # response parameters
        if response_param_dict:
            response_function_parameters = []
            for dict_entry in response_param_dict:
                if response_param_dict[dict_entry]["distro"] == "Normal":
                    new_response_sample_param = pm.Uniform(
                        response_param_dict[dict_entry]["name"],
                        response_param_dict[dict_entry]["min_value"],
                        response_param_dict[dict_entry]["max_value"],
                    )
                elif response_param_dict[dict_entry]["distro"] == "Uniform":
[…]

        # sample parameters
        sample_params = []
        sample_params_list = []
        for dict_entry in param_dict.values():
            tuple_index = find_tuple_index(flat_tuples, dict_entry["name"], sep=".")
            if tuple_index > -1:
                while len(sample_params_list) <= tuple_index:
                    sample_params_list.append(None)
                sample_params_list[tuple_index] = dict_entry
        sample_params_list = [v for v in sample_params_list if v is not None]
[…]            sample_params.append([dict_entry["name"], new_sample_param])

        for entry in sample_params:
            tuple_index = find_tuple_index(flat_tuples, entry[0], sep=".")
            if tuple_index > -1:
                key, _ = new_attributes_list_all[tuple_index]
                new_attributes_list_all[tuple_index] = (key, entry[1])
                param_values[tuple_index] = entry[1]

        # sampling step
       log_likelihood = pm.CustomDist(
                "log_likelihood", param_values, response_function_parameters, logp=logp, observed=input_data
            )
       pm.Deterministic("pointwise_loglikelihood", logp_pointwise(new_attributes_list_all=param_values, response_function_parameters=response_function_parameters))
            levels_idata = pm.sample(chains=chains, cores=cores)

            
       levels_idata.add_groups(
                log_likelihood=levels_idata.posterior["pointwise_loglikelihood"])

As far as I can tell, this set-upt works with arbitrary HGFs and variables to be sampled/optimized. If you want, I can provide you access to my github-repository in which I use the modified functions, this makes it probably easier than just the code snippets.

jalhackl avatar Oct 03 '25 19:10 jalhackl