cabinetry
cabinetry copied to clipboard
Allowing gradients to pass through cabinetry to enable a differentiable analysis workflow
Outline
As I understand, cabinetry
has within its scope to enable and drive an analysis that is fully differentiable. In order for this to work in practice, one should be able to backpropagate through the calculations in cabinetry
that influence the derivative of some loss function with respect to some parameter specified upstream.
There have been little discussions about which autodiff frameworks to support; given cabinetry
's heavy use of numpy
, and since machinery to convert between gradients of frameworks is in the works, I would think that using jax
is the choice of least resistance, assuming that there are no plans to synchronise with the pyhf
tensor backend.
There are, however, obstacles to consider:
-
cabinetry
would either have to only use thejax
backend ofpyhf
for interoperability, or specifically handle the gradients of other frameworks somehow through custom gradient ops -
boost-histogram
is used but not differentiable (althoughawkward
could soon be) - Care would need to be taken to process things in a
jax
-friendly way (e.g. ensuring "pure" functions, see thejax
docs for more) - At runtime, the
cabinetry
workspace would need to be built dynamically, as the model specification changes during training (which is luckily supported it seems, just something to be mindful of)
I think that using jax
for array operations may even lead to benefits irrespective of autodiff -- jax.jit
and jax.vmap
are powerful idioms that can give you free speedups :)
N.B. these are just some initial thoughts, I need to get much more practical before addressing these things!
(tagging @lukasheinrich for interest)