pyhf icon indicating copy to clipboard operation
pyhf copied to clipboard

Support vanilla lists when the default backend is non-numpy

Open kratsg opened this issue 2 years ago • 2 comments

(Make builders diffable)

#1646 will pull out a bug that we somewhat have in our code about assumptions of the default backend... we'll need to fix this.

This is just revealing an underlying feature we never actually supported

>>> import json
>>> ws = pyhf.Workspace(json.load(open('mysigfit_brZ_100_brH_0_brW_0_bre_33_brm_33_brt_34_mass_100.json')))
>>> pdf = ws.model()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/kratsg/.pyenv/versions/pyhf-dev/lib/python3.8/site-packages/pyhf/workspace.py", line 425, in model
    return Model(modelspec, **config_kwargs)
  File "/Users/kratsg/.pyenv/versions/pyhf-dev/lib/python3.8/site-packages/pyhf/pdf.py", line 632, in __init__
    modifiers, _nominal_rates = _nominal_and_modifiers_from_spec(
  File "/Users/kratsg/.pyenv/versions/pyhf-dev/lib/python3.8/site-packages/pyhf/pdf.py", line 129, in _nominal_and_modifiers_from_spec
    nominal_rates = nominal.finalize()
  File "/Users/kratsg/.pyenv/versions/pyhf-dev/lib/python3.8/site-packages/pyhf/pdf.py", line 69, in finalize
    [
  File "/Users/kratsg/.pyenv/versions/pyhf-dev/lib/python3.8/site-packages/pyhf/pdf.py", line 70, in <listcomp>
    pyhf.default_backend.concatenate(self.mega_samples[sample]['nom'])
  File "/Users/kratsg/.pyenv/versions/pyhf-dev/lib/python3.8/site-packages/pyhf/tensor/jax_backend.py", line 298, in concatenate
    return jnp.concatenate(sequence, axis=axis)
  File "/Users/kratsg/.pyenv/versions/pyhf-dev/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 3382, in concatenate
    _check_arraylike("concatenate", *arrays)
  File "/Users/kratsg/.pyenv/versions/pyhf-dev/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 560, in _check_arraylike
    raise TypeError(msg.format(fun_name, type(arg), pos))
TypeError: concatenate requires ndarray or scalar arguments, got <class 'list'> at position 0.

Originally posted by @kratsg in https://github.com/scikit-hep/pyhf/pull/1646#issuecomment-947110609

kratsg avatar Oct 19 '21 21:10 kratsg

@kratsg @phinate can you comment a bit more on the specifics of what is causing this problem and/or create a public workspace that will fail for this?

If I take the workspaces from Discussion #1695

import json

import pyhf

if __name__ == "__main__":

    with open("NormalMeasurement_combined.json") as read_file:
        workspace_json = json.load(read_file)

    backends = ["numpy", "jax", "pytorch", "tensorflow"]
    for backend in backends:
        print(f"\n{backend}")
        pyhf.set_backend(backend)
        workspace = pyhf.Workspace(workspace_json)
        model = workspace.model()
        assert model is not None

those are fine, so can you summarize in the Issue what the problems are that you're encountering in Issue #1646 and PR #1655.

matthewfeickert avatar Nov 12 '21 07:11 matthewfeickert

@kratsg @phinate can you comment a bit more on the specifics of what is causing this problem and/or create a public workspace that will fail for this?

If I take the workspaces from Discussion #1695

import json

import pyhf

if __name__ == "__main__":

    with open("NormalMeasurement_combined.json") as read_file:
        workspace_json = json.load(read_file)

    backends = ["numpy", "jax", "pytorch", "tensorflow"]
    for backend in backends:
        print(f"\n{backend}")
        pyhf.set_backend(backend)
        workspace = pyhf.Workspace(workspace_json)
        model = workspace.model()
        assert model is not None

those are fine, so can you summarize in the Issue what the problems are that you're encountering in Issue #1646 and PR #1655.

No problem -- the issue arises when using the new pyhf.set_backend(..., default=True) arg. @kratsg enabled this functionality in #1646, but this exposed issues where using e.g. jax as the default backend had unforseen consequences that result from differing behaviour between jax and numpy.

The specific problem I found was to do with tensorlib.concatenate -- details are also in #1655, but the bottom line is that jax.numpy.concatenate:

  • does not support lists within the iterable of arrays
  • does not support jagged concatenation

Code examples:

import numpy as np
​
np.concatenate([[4,5], [3]])
> array([4, 5, 3])

import jax.numpy as jnp
​
jnp.concatenate(jnp.array([[4,5], [3]]))
> /home/jovyan/.local/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:476: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.
  arr = np.array(obj, dtype=dtype, **kwargs)

... (leaving out long stack trace)

TypeError: JAX only supports number and bool dtypes, got dtype object in array

This happens in pyhf possibly in multiple places, but the main occurence I found was when using tensorlib.stitch. It's only possible to replicate the numpy behaviour here through iteratively casting to arrays within the iterable in jax_backend.py, but I also don't remember if this fixed the problem or maybe introduced some other pathology when I tried it...

Hope this makes sense! @lukasheinrich was also encountering these things while working on #1676, so he may also have comments or a resolution :)

phinate avatar Nov 12 '21 14:11 phinate