nnabla
nnabla copied to clipboard
[WIP] Implement backward_all as a counterpart of forward_all
Hi, @TE-AkioHayakawa san, @TE-TakuyaNarihira san!
I've implemented nn.backward_all as a counterpart of nn.forward_all.
It seems to work correctly. But I have some concerns about this.
- implementation design looks a bit ugly because there are two backward_all functions including CgVariable and computation_graph.cpp.
- unit tests have passed except a test of rewire_on
This is just a reminder.
- [ ] move visit_function_backward outside from CgVariable class (nbla::visitfunction_backward)
- [ ] add "grads" argument to backward_all
- In python API, check the lengths of grads and variables and make these lengths same (by padding or slicing grads).
- [ ] check the test code for rewire_on again (I will also check it out)
Anyway, almost all parts of your PR look good for me! Thank you for your contribution!