tch-rs
tch-rs copied to clipboard
Supply custom gradient in the backward function
Recently I've been working on a project where I need to supply custom gradient into the back propagation process. Therefore I need an advanced backward function where I can pass in gradient, retain_graph and create_graph as parameters. That's why I introduced this new function backward_with_grad in this pull request.
I know that the documentation as well as the function name is not finalized yet in this pull request, I merely want to submit and enter the discussion of whether we can manage to add this feature to the code base. It would of great help to many other projects involving custom gradient!
Thank you for your feedback; I'll work on this and get back to you in a few days.