horqrux
horqrux copied to clipboard
[Feature] Allow parameter shifting on all parametric gates
The previous PR (https://github.com/pasqal-io/horqrux/pull/27) implements the parameter shift rule (PSR) for parameters defined in the values
argument of expectations. However, it suffered from some limitations:
- It did not allow for the PSR for every parameter: parameters had to be defined via
values
, and a user couldn't pass a parameter directly into a gate. - It has a bug (https://github.com/pasqal-io/horqrux/issues/29) for repeated values.
This MR addresses the two above points. It also adds tests that the above can be jit-compiled and give the correct answers.
Some noteworthy points:
- When jitting functions containing
checkify.check
points, the output type of the original function is changed to(error, output_of_original_function)
. This is not ideal for end users, so this has been removed. It would be great to have such checks in the code, so an issue investigating a promising alternative has been raised (https://github.com/pasqal-io/horqrux/issues/30). - Previously, the
param
attribute of aParametric
gate could be of typestr | float
. This was problematic when implementing custom JVP rules, since afloat
is a valid jax type, but a string is not (e.g. https://github.com/google/jax/issues/3045). Consequently,param
has been explicitly split intoparam_name: str
andparam_val: float
, so thatparam_val
is always a valid jax type.
Closes #29