Yota.jl icon indicating copy to clipboard operation
Yota.jl copied to clipboard

Loop operator

Open dfdx opened this issue 3 years ago • 1 comments

Current approach to dynamic graphs is to trace function on each execution and select matching tape (with diff operations already added) from cache. This approach has several disadvantages:

  1. Tracing takes time.
  2. Cached tapes occupy memory.

One way to avoid them is to add support for Loop and If operations. Loop looks like the harder one, so this issue is about it.

Representation

Loop operator can be represented with something like this:

mutable struct Loop
    id::Int
    subtape::Tape
    exec_count::Int
end

When executing, this operator would not only repeat its code, but also record the number of repetitions. This number can then be passed to the derivative of the operator for the reverse pass.

Tracing

Currently we trace code by rewriting every call to f(args...) with a call to record_or_recurse!(..., fargs). Every time a primitive function is called, this call is recorded to the tape, leading to a fully unrolled trace.

If we want to move loops into a separate operator, we need to treat them in some other way. Assuming we can detect start and end of the loop (which doesn't seem hard in IRTools's code representation), I see 2 possible approaches so far:

  1. Trace every execution into subtape as usually and then simplify this subtape.
  2. Stop tracing after the first execution.

Switching between the outer and inner (sub)tapes can be implemented similar to switching between execution frames.

Reverse pass

Should be similar to the generated forward pass code, but with updating the derivatives.

As an option, we may generate simple gradient function for each loop write during differentiating the outer tape. This would mix up differentiation and compilation stages though - so Julian, but pretty dangerous strategy.

Loop operator outputs

Loop operator can return a tuple of all changed variables. This tuple can then be destructured in a usual way.

Note that it shouldn't break CSE or any other optimizations.

dfdx avatar Jan 26 '21 22:01 dfdx

Loop tracing is now supported via Ghost.jl. Differentiating is on the roadmap.

dfdx avatar Jul 03 '21 20:07 dfdx