cabinetry icon indicating copy to clipboard operation
cabinetry copied to clipboard

Allowing gradients to pass through cabinetry to enable a differentiable analysis workflow

Open phinate opened this issue 3 years ago • 0 comments

Outline

As I understand, cabinetry has within its scope to enable and drive an analysis that is fully differentiable. In order for this to work in practice, one should be able to backpropagate through the calculations in cabinetry that influence the derivative of some loss function with respect to some parameter specified upstream.

There have been little discussions about which autodiff frameworks to support; given cabinetry's heavy use of numpy, and since machinery to convert between gradients of frameworks is in the works, I would think that using jax is the choice of least resistance, assuming that there are no plans to synchronise with the pyhf tensor backend.

There are, however, obstacles to consider:

  • cabinetry would either have to only use the jax backend of pyhf for interoperability, or specifically handle the gradients of other frameworks somehow through custom gradient ops
  • boost-histogram is used but not differentiable (although awkward could soon be)
  • Care would need to be taken to process things in a jax-friendly way (e.g. ensuring "pure" functions, see the jax docs for more)
  • At runtime, the cabinetry workspace would need to be built dynamically, as the model specification changes during training (which is luckily supported it seems, just something to be mindful of)

I think that using jax for array operations may even lead to benefits irrespective of autodiff -- jax.jit and jax.vmap are powerful idioms that can give you free speedups :)

N.B. these are just some initial thoughts, I need to get much more practical before addressing these things!

(tagging @lukasheinrich for interest)

phinate avatar Jun 09 '21 20:06 phinate