pyhf
pyhf copied to clipboard
Support vanilla lists when the default backend is non-numpy
(Make builders diffable)
#1646 will pull out a bug that we somewhat have in our code about assumptions of the default backend... we'll need to fix this.
This is just revealing an underlying feature we never actually supported
>>> import json
>>> ws = pyhf.Workspace(json.load(open('mysigfit_brZ_100_brH_0_brW_0_bre_33_brm_33_brt_34_mass_100.json')))
>>> pdf = ws.model()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/kratsg/.pyenv/versions/pyhf-dev/lib/python3.8/site-packages/pyhf/workspace.py", line 425, in model
return Model(modelspec, **config_kwargs)
File "/Users/kratsg/.pyenv/versions/pyhf-dev/lib/python3.8/site-packages/pyhf/pdf.py", line 632, in __init__
modifiers, _nominal_rates = _nominal_and_modifiers_from_spec(
File "/Users/kratsg/.pyenv/versions/pyhf-dev/lib/python3.8/site-packages/pyhf/pdf.py", line 129, in _nominal_and_modifiers_from_spec
nominal_rates = nominal.finalize()
File "/Users/kratsg/.pyenv/versions/pyhf-dev/lib/python3.8/site-packages/pyhf/pdf.py", line 69, in finalize
[
File "/Users/kratsg/.pyenv/versions/pyhf-dev/lib/python3.8/site-packages/pyhf/pdf.py", line 70, in <listcomp>
pyhf.default_backend.concatenate(self.mega_samples[sample]['nom'])
File "/Users/kratsg/.pyenv/versions/pyhf-dev/lib/python3.8/site-packages/pyhf/tensor/jax_backend.py", line 298, in concatenate
return jnp.concatenate(sequence, axis=axis)
File "/Users/kratsg/.pyenv/versions/pyhf-dev/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 3382, in concatenate
_check_arraylike("concatenate", *arrays)
File "/Users/kratsg/.pyenv/versions/pyhf-dev/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 560, in _check_arraylike
raise TypeError(msg.format(fun_name, type(arg), pos))
TypeError: concatenate requires ndarray or scalar arguments, got <class 'list'> at position 0.
Originally posted by @kratsg in https://github.com/scikit-hep/pyhf/pull/1646#issuecomment-947110609
@kratsg @phinate can you comment a bit more on the specifics of what is causing this problem and/or create a public workspace that will fail for this?
If I take the workspaces from Discussion #1695
import json
import pyhf
if __name__ == "__main__":
with open("NormalMeasurement_combined.json") as read_file:
workspace_json = json.load(read_file)
backends = ["numpy", "jax", "pytorch", "tensorflow"]
for backend in backends:
print(f"\n{backend}")
pyhf.set_backend(backend)
workspace = pyhf.Workspace(workspace_json)
model = workspace.model()
assert model is not None
those are fine, so can you summarize in the Issue what the problems are that you're encountering in Issue #1646 and PR #1655.
@kratsg @phinate can you comment a bit more on the specifics of what is causing this problem and/or create a public workspace that will fail for this?
If I take the workspaces from Discussion #1695
import json import pyhf if __name__ == "__main__": with open("NormalMeasurement_combined.json") as read_file: workspace_json = json.load(read_file) backends = ["numpy", "jax", "pytorch", "tensorflow"] for backend in backends: print(f"\n{backend}") pyhf.set_backend(backend) workspace = pyhf.Workspace(workspace_json) model = workspace.model() assert model is not None
those are fine, so can you summarize in the Issue what the problems are that you're encountering in Issue #1646 and PR #1655.
No problem -- the issue arises when using the new pyhf.set_backend(..., default=True)
arg. @kratsg enabled this functionality in #1646, but this exposed issues where using e.g. jax
as the default backend had unforseen consequences that result from differing behaviour between jax
and numpy
.
The specific problem I found was to do with tensorlib.concatenate
-- details are also in #1655, but the bottom line is that jax.numpy.concatenate
:
- does not support lists within the iterable of arrays
- does not support jagged concatenation
Code examples:
import numpy as np
np.concatenate([[4,5], [3]])
> array([4, 5, 3])
import jax.numpy as jnp
jnp.concatenate(jnp.array([[4,5], [3]]))
> /home/jovyan/.local/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:476: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.
arr = np.array(obj, dtype=dtype, **kwargs)
... (leaving out long stack trace)
TypeError: JAX only supports number and bool dtypes, got dtype object in array
This happens in pyhf
possibly in multiple places, but the main occurence I found was when using tensorlib.stitch
.
It's only possible to replicate the numpy
behaviour here through iteratively casting to arrays within the iterable in jax_backend.py
, but I also don't remember if this fixed the problem or maybe introduced some other pathology when I tried it...
Hope this makes sense! @lukasheinrich was also encountering these things while working on #1676, so he may also have comments or a resolution :)