qadence
qadence copied to clipboard
[Differentiability, Refactoring] Rethink parameter dictionaries in backends / Introduce hybrid differentiation modes
Issue:
Right now, when we do:
quantum_backend = SomeBackend()
conv = quantum_backend.convert(circuit, obs)
conv_circ, conv_obs, embedding_fn, params = conv
we store all of the following in the initial params
dict:
(a) all variational user-facing parameters of the circuit AND observable (b) all fixed parameters in both circuit AND observable
Issue 1: when using torch, the (torch-based) backend then knows for which params to compute gradients via the requires_grad
flag. however, this doesnt work for JAX.
issue 2: both diff_modes ADJOINT and GPSR do not support parametric observables
Possible Solution:
Introduce separate parameter dicts for initial fixed and vparams in both circuit and observable:
initial_params = {'circuit_vparams': ..., 'circuit_fixedparams': ..., 'obs_vparams': ..., 'obs_fixedparams': ..., }
- This way, we can easily differentiate between v and fixed params in JAX
- We can start thinking about introducing hybrid diff_modes, circuit_diffmode = GPSR / ADJOINT, observable_diffmode= AD and use a certain diff routine on subsets of the parameters
If I understand correctly, then the problem is that conv.params
returns a dict that contains both fixed and variational parameters, right? would it be easier/more elegant to introduce a conv.vparams
and conv.circuit.vparams
(+same for observable) that returns only the variational parameters? then we don't have to change all the code that assumes conv.params
to be one non-nested dict.
yes great idea, but i would try to avoid changing the low-level interface so i would be inclined to keep the conv.params
and let it just return the composition of conv.circuit.vparams,...
next question: do we want to give the user the option to choose which diff_mode to use for a particular part of the model?
If I remember correctly, I originally designed this with @awennersteen. Indeed we didn't consider fixing some parameters, and having different grad backends. There's two options, both without changing the API, I think:
- Have the parameters always the returned dict always be trainable, and let the
embedding_fn
take care of adding in the non-trainable params. In my opinion not a great option for various reasons, but possible. - Instead of the params being a basic
dict
, make it a slightly more involved object with for exampletrainable_params
,fixed_params
, and the corresponding AD rules for each group. This is very Jax style (have a look at optax) I think this might cover everything we need. This is my preferred option, and I think in line with @dominikandreasseitz idea, if I understand it correctly?
next question: do we want to give the user the option to choose which diff_mode to use for a particular part of the model? IMHO, this sounds dangerous. But I guess it makes a lot of sense if consider for example a hybrid model where we have a classical NN composed with a QNN. I think that it should be strictly defined.
I, like @GJBoth, have no recollection of why we did this and what we may have considered or not :p The one thing I do remember was that after the initial design over the next month or so there where many hacks and patches to make it actually work...
Since @nmheim was asking about namedtuples
the other day, maybe this is another place to use them?
so that we have a more solid object, we keep all the different data in there (Gert-Jax' option number 2), and then go for it?
My only concern is how this might behave together with the idea of different diff-modes for different parts? But maybe this is the best way of achieving that too? Suppose we end up saying that in order to use different Diff modes you would achieve this by composing multiple QuantumModels (or maybe DifferentiableBackends or whatever is the current name). Then in this namedtuple keeping track of parameters we could also keep track of which model they belong to. So then by using Gert-Jax' idea of "looking up the AD rules for each group" I guess that could be achieved arbitrarly. This is quickly overengineering though and we should think about that perhaps before implementing that part.