futhark
futhark copied to clipboard
Reverse AD: more nuanced handling of consumed arrays.
In each scope in which arrays are consumed, the reverse AD pass makes a copy of each consumed array just before entering the scope. In the forward pass, the arrays are substituted with their respective copies so that the original arrays (with their original values) remain available to the reverse pass. [1]
While this works, it's inefficient. For example, consider
let xs' = if it_is_raining
then
let z = xs[0] * xs[0]
let res = xs with [0] = z
in res
else
...
After applying AD, we have
-- forward sweep
let xs_copy = copy xs
let xs' = if it_is_raining
then
let z = xs_copy[0] * xs_copy[0]
let res = xs_copy with [0] = z
in res
else
...
-- reverse sweep
let xs_adj = if it_is_raining
then
let z = xs[0] * xs[0]
let res = xs with [0] = z
let z_adj += res[0]
let xs_adj = xs_adj with [0] = 2 * xs[0] * z_adj
in xs_adj
else
...
Instead, we can just save the individual updated element(s) and avoid copying an entire new array. The saved element(s) are then used to restore the array so that all intermediate variables can be reproduced.
-- forward sweep
let (xs', xs_0) = if it_is_raining
then
let z = xs[0] * xs[0]
let xs_0 = xs[0] -- save the overwritten element
let res = xs with [0] = z
in (res, xs_0)
else
...
-- reverse sweep
let xs_adj = if it_is_raining
then
let xs_restore = xs' with [0] = xs_0 -- restore the overwritten element
let z = xs_restore[0] * xs_restore[0]
let res = xs_restore with [0] = z
let z_adj += res[0]
let xs_adj = xs_adj with [0] = 2 * xs[0] * z_adj
in xs_adj
else
...
This is very similar to the saving/restoring we already do for scatter
, just instrumented to work across scopes instead of only within scopes; the main distinction is that the forward re-execution sweep in each new scope must be modified to appropriately restore values--a restore must be placed preceding any statements which preceded the corresponding save in the forward sweep.
[1] returnSweepCode
is responsible for substituting the names of the copies back to the originals in the reverse pass. There was a technical reason for choosing the originals in the return pass (instead of the copies), and unfortunately I forget it.