horqrux icon indicating copy to clipboard operation
horqrux copied to clipboard

[Feature] Allow parameter shifting on all parametric gates

Open atiyo opened this issue 5 months ago • 0 comments

The previous PR (https://github.com/pasqal-io/horqrux/pull/27) implements the parameter shift rule (PSR) for parameters defined in the values argument of expectations. However, it suffered from some limitations:

  • It did not allow for the PSR for every parameter: parameters had to be defined via values, and a user couldn't pass a parameter directly into a gate.
  • It has a bug (https://github.com/pasqal-io/horqrux/issues/29) for repeated values.

This MR addresses the two above points. It also adds tests that the above can be jit-compiled and give the correct answers.

Some noteworthy points:

  • When jitting functions containing checkify.check points, the output type of the original function is changed to (error, output_of_original_function). This is not ideal for end users, so this has been removed. It would be great to have such checks in the code, so an issue investigating a promising alternative has been raised (https://github.com/pasqal-io/horqrux/issues/30).
  • Previously, the param attribute of a Parametric gate could be of type str | float. This was problematic when implementing custom JVP rules, since a float is a valid jax type, but a string is not (e.g. https://github.com/google/jax/issues/3045). Consequently, param has been explicitly split into param_name: str and param_val: float, so that param_val is always a valid jax type.

Closes #29

atiyo avatar Sep 19 '24 10:09 atiyo