AdvancedHMC.jl
AdvancedHMC.jl copied to clipboard
Save gradients
Hi @sethaxen!
This should do what you were suggesting if you use the internal sampling method of AHMC and set drop_warmup=false and keep_gradient=true.
let me know if it works!
Generally speaking, it's not great to have the return-value change based on a keyword argument , as it leads to type-instabilities :confused: In this scenario, it's probably the not the worst (sample is generally going to be the outermost caller anyways), but it's still not great style.
IMO what we should do here is to
- Move AdvancedHMC.jl completely over to the AbstractMCMC.jl interface, i.e. we have
stepfunction that takes in astatecontaining all the necessary information. - Then the gradients can easily be extracted through a
callbackwhich just extracts this information from thestate.
(1) will also just be a huge gain in general, as it will make things much more modular:)