dni-pytorch
dni-pytorch copied to clipboard
How to train the synthetic NN without apply its outputs to the base NN?
As the issue title.
My base module is a 3-layers GRU, and the synthetic module is another RNN.
I want to training base module in BPTT
mode without synthetic gradients, for a few starting epochs.
From my understand, the make_trigger()
function makes a trigger, so the synthetic module can be trained from true gradients. And backward_interface.backward()
function apply the synthetic gradients to final hidden states and base module, at same time.
So I tried to forward
the base module with make_trigger()
and without interface.backward()
, but the gradients of base module is still different from the model without make_trigger()
.
Any help?
Thanks.