feat: custom modifiers for pure scalar functions with jax and sympy
Description
This is a PR to add a "pure function" modifier as an external / optional contribution under the contrib directory.
This should support modifiers or the form
{'name': 'interpscale1', 'type': 'purefunc', 'data': {'formula': 'theta_0 + theta_1'}},
{'name': 'interpscale1', 'type': 'purefunc', 'data': {'formula': 'sqrt(theta_0)'}},
{'name': 'interpscale1', 'type': 'purefunc', 'data': {'formula': 'mu + theta_0'}},
etc
This is build on top of sympy and jax and thus requires pyhf.set_backend('jax') as well as possibly a sympy extra (or we assume it's externally installed for the parsing of the expressions
The supported formulas involve either the pre-existing parameters (e.g. a mu coming from a normfactor or new parameters that can be added.
Basic Usage
spec = {
'channels': [
{
'name': 'channel1',
'samples': [
{
'name': 'signal1',
'data': [10,10,10],
'modifiers': [
{'name': 'interpscale1', 'type': 'purefunc', 'data': {'formula': 'mu**2'}},
],
},
],
}
]
}
import jax.numpy as jnp
import numpy as np
import pyhf
from pyhf.contrib.extended_modifiers import purefunc
pyhf.set_backend('jax')
modifier_set = purefunc.enable()
m = pyhf.Model(
spec,
modifier_set=modifier_set,
poi_name='mu',
validate=False,
)
pars = np.array(m.config.suggested_init())
pars[m.config.par_slice('mu')] = 2.0
pars = jnp.array(pars)
m.expected_actualdata(pars)
Note that this will heavily rely on jax.jit so care should be taking (by us/the user) that the model is properly jitted before fitting
Tagging @mswiatlo @alexander-held @matthewfeickert @kratsg @nhartman94 @malin-horstmann
This shouldn't really only depend on jax to make it work sympy allows you to parse out to different "backends" as needed although I think for the numpy case, one needs to swap to numexpr just for that part (which is really not able to symbolically parse, just symbolically evaluate). For example, if you convert to pytensor, then it's a quick step over to switch to pytorch/jax/tensorflow: https://github.com/scipp-atlas/pyhs3/blob/207fffae43b69528c27b26684955787e29206120/src/pyhs3/parse.py#L42-L88
Looks good, but to address @alexander-held's point about the nested dependencies requires building a tree and resolving dependencies in order (which takes a bit more logic). That might need to be held off into a separate PR. In the meantime, you can always just expand out the function calls yourself, or do a "pre-parsing" that expands out all nested parameter definitions first, to normalize/flatten things.
In #1991 we had a conversation about pyhf.experimental vs pyhf.contrib, perhaps this one should go into the former to clearly signal its current status?