blackjax
blackjax copied to clipboard
Remove the `logprob_grad` keyword arguments
To use custom gradients users currently need to pass the function as a keyword arguments; this makes the API more complicated as we need to carry this argument downstream.
We can bypass this thanks to JAX allowing to register a custom gradient for functions. This is related to #318 and also #340 and probably #319 as well.