futhark icon indicating copy to clipboard operation
futhark copied to clipboard

Reverse AD: more nuanced handling of consumed arrays.

Open zfnmxt opened this issue 2 years ago • 0 comments

In each scope in which arrays are consumed, the reverse AD pass makes a copy of each consumed array just before entering the scope. In the forward pass, the arrays are substituted with their respective copies so that the original arrays (with their original values) remain available to the reverse pass. [1]

While this works, it's inefficient. For example, consider

let xs' = if it_is_raining
          then
            let z = xs[0] * xs[0]
            let res = xs with [0] = z
            in res
          else
            ...

After applying AD, we have

-- forward sweep
let xs_copy = copy xs
let xs' = if it_is_raining
          then
            let z = xs_copy[0] * xs_copy[0]
            let res = xs_copy with [0] = z
            in res
          else
            ...
     
-- reverse sweep
let xs_adj = if it_is_raining
             then
               let z = xs[0] * xs[0]
               let res = xs with [0] = z
               let z_adj += res[0]
               let xs_adj = xs_adj with [0] = 2 * xs[0] * z_adj
               in xs_adj
             else
               ...

Instead, we can just save the individual updated element(s) and avoid copying an entire new array. The saved element(s) are then used to restore the array so that all intermediate variables can be reproduced.

-- forward sweep
let (xs', xs_0) = if it_is_raining
                  then
                    let z = xs[0] * xs[0]
                    let xs_0 = xs[0] -- save the overwritten element
                    let res = xs with [0] = z
                    in (res, xs_0)
                  else
                    ...
     
-- reverse sweep
let xs_adj = if it_is_raining
             then
               let xs_restore = xs' with [0] = xs_0  -- restore the overwritten element
               let z = xs_restore[0] * xs_restore[0]
               let res = xs_restore with [0] = z
               let z_adj += res[0]
               let xs_adj = xs_adj with [0] = 2 * xs[0] * z_adj
               in xs_adj
             else
               ...

This is very similar to the saving/restoring we already do for scatter, just instrumented to work across scopes instead of only within scopes; the main distinction is that the forward re-execution sweep in each new scope must be modified to appropriately restore values--a restore must be placed preceding any statements which preceded the corresponding save in the forward sweep.

[1] returnSweepCode is responsible for substituting the names of the copies back to the originals in the reverse pass. There was a technical reason for choosing the originals in the return pass (instead of the copies), and unfortunately I forget it.

zfnmxt avatar Jul 13 '22 14:07 zfnmxt