pyhf icon indicating copy to clipboard operation
pyhf copied to clipboard

feat: custom modifiers for pure scalar functions with jax and sympy

Open lukasheinrich opened this issue 9 months ago • 2 comments

Description

This is a PR to add a "pure function" modifier as an external / optional contribution under the contrib directory.

This should support modifiers or the form

{'name': 'interpscale1', 'type': 'purefunc', 'data': {'formula': 'theta_0 + theta_1'}},
{'name': 'interpscale1', 'type': 'purefunc', 'data': {'formula': 'sqrt(theta_0)'}},
{'name': 'interpscale1', 'type': 'purefunc', 'data': {'formula': 'mu + theta_0'}},

etc

This is build on top of sympy and jax and thus requires pyhf.set_backend('jax') as well as possibly a sympy extra (or we assume it's externally installed for the parsing of the expressions

The supported formulas involve either the pre-existing parameters (e.g. a mu coming from a normfactor or new parameters that can be added.

Basic Usage


spec = {
        'channels': [
            {
                'name': 'channel1',
                'samples': [
                    {
                        'name': 'signal1',
                        'data': [10,10,10],
                        'modifiers': [
                            {'name': 'interpscale1', 'type': 'purefunc', 'data': {'formula': 'mu**2'}},
                        ],
                    },
                ],
            }
        ]
}

import jax.numpy as jnp
import numpy as np
import pyhf
from pyhf.contrib.extended_modifiers import purefunc
pyhf.set_backend('jax')
modifier_set = purefunc.enable()
m = pyhf.Model(
    spec,
    modifier_set=modifier_set,
    poi_name='mu',
    validate=False,
)
pars = np.array(m.config.suggested_init())
pars[m.config.par_slice('mu')] = 2.0
pars = jnp.array(pars)
m.expected_actualdata(pars)

Note that this will heavily rely on jax.jit so care should be taking (by us/the user) that the model is properly jitted before fitting

Tagging @mswiatlo @alexander-held @matthewfeickert @kratsg @nhartman94 @malin-horstmann

lukasheinrich avatar Mar 17 '25 10:03 lukasheinrich

This shouldn't really only depend on jax to make it work sympy allows you to parse out to different "backends" as needed although I think for the numpy case, one needs to swap to numexpr just for that part (which is really not able to symbolically parse, just symbolically evaluate). For example, if you convert to pytensor, then it's a quick step over to switch to pytorch/jax/tensorflow: https://github.com/scipp-atlas/pyhs3/blob/207fffae43b69528c27b26684955787e29206120/src/pyhs3/parse.py#L42-L88

Looks good, but to address @alexander-held's point about the nested dependencies requires building a tree and resolving dependencies in order (which takes a bit more logic). That might need to be held off into a separate PR. In the meantime, you can always just expand out the function calls yourself, or do a "pre-parsing" that expands out all nested parameter definitions first, to normalize/flatten things.

kratsg avatar Mar 17 '25 16:03 kratsg

In #1991 we had a conversation about pyhf.experimental vs pyhf.contrib, perhaps this one should go into the former to clearly signal its current status?

alexander-held avatar Mar 17 '25 19:03 alexander-held