wavetorch icon indicating copy to clipboard operation
wavetorch copied to clipboard

Implementing an adjoint calculation for backprop-ing through time

Open ianwilliamson opened this issue 5 years ago • 8 comments

Should consider the performance benefit of implementing an adjoint calculation for the backward pass through the forward() method in WaveCell. This would potentially save us on memory during gradient computation because pytorch doesn't need to construct as large of a graph.

The approach is described here: https://pytorch.org/docs/stable/notes/extending.html

ianwilliamson avatar Apr 24 '19 16:04 ianwilliamson

Sorry to pop in, but on the off and maybe small chance you folks haven’t seen this lib/paper:

https://github.com/rtqichen/torchdiffeq https://arxiv.org/pdf/1806.07366.pdf

Implements a ODE solver and uses adjoint methods for the backward pass. This is what you need?

I was already thinking about porting WaveCell to it for my own use. Collaborate?

parenthetical-e avatar Aug 25 '19 17:08 parenthetical-e

Thanks for your interest in this! We are aware of that paper, but unfortunately we can't apply the scheme they propose here because the wave equation with loss (from the absorbing layer) is not reversible.

The "adjoint calculation" I'm referring to here is basically just hard coding the gradient for the time step using the pytorch API documented here: https://pytorch.org/docs/stable/notes/extending.html The motivation for this is that we can potentially save a bunch of memory because pytorch doesn't need to store the fields at every sub-operation of each time step. However, it still needs to store the fields at each time step (there's no getting around this when the differential equation isn't reversible.) In contrast, the neural ODE paper reconstructs these fields by reversing the forward equation during backpropagation, thus, avoiding the need to store the fields from the forward pass.

We actually have this adjoint approach implemented, I just need to push the commits to this repository.

ianwilliamson avatar Aug 26 '19 16:08 ianwilliamson

I'm definitely interested to learn about your project and what you hope to do. We would certainly be open to collaboration if there's an opportunity.

ianwilliamson avatar Aug 26 '19 16:08 ianwilliamson

Thanks for your interest in this! We are aware of that paper, but unfortunately we can't apply the scheme they propose here because the wave equation with loss (from the absorbing layer) is not reversible.

The "adjoint calculation" I'm referring to here is basically just hard coding the gradient for the time step using the pytorch API documented here: https://pytorch.org/docs/stable/notes/extending.html The motivation for this is that we can potentially save a bunch of memory because pytorch doesn't need to store the fields at every sub-operation of each time step. However, it still needs to store the fields at each time step (there's no getting around this when the differential equation isn't reversible.) In contrast, the neural ODE paper reconstructs these fields by reversing the forward equation during backpropagation, thus, avoiding the need to store the fields from the forward pass.

Ah. I understand. Thanks for the explanation.

parenthetical-e avatar Aug 26 '19 17:08 parenthetical-e

I sent you an email about the project I'm pondering. :)

parenthetical-e avatar Aug 26 '19 17:08 parenthetical-e

Hey Eric could you forward that email to me as well please? Im interested in what you have planned. Thanks!

On Tue, Aug 27, 2019, 2:23 AM Erik [email protected] wrote:

I sent you an email about the project I'm pondering. :)

— You are receiving this because you were assigned. Reply to this email directly, view it on GitHub https://github.com/fancompute/wavetorch/issues/1?email_source=notifications&email_token=ABLIFNMGO4P43WSW5JW5JJ3QGQGPVA5CNFSM4HIF33BKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOD5FBLUA#issuecomment-524948944, or mute the thread https://github.com/notifications/unsubscribe-auth/ABLIFNLAPRTLAKHHUS3IY4LQGQGPVANCNFSM4HIF33BA .

twhughes avatar Aug 26 '19 19:08 twhughes

Done, @twhughes

parenthetical-e avatar Aug 27 '19 18:08 parenthetical-e

This is now partially implemented. Currently, the individual time step is a primitive. This seems to help with memory utilization during training, especially with nonlinearity. Perhaps we could investigate if there would be significant performance benefits from adjoint-ing the time loop as well.

ianwilliamson avatar Sep 20 '19 21:09 ianwilliamson