Yota.jl
Yota.jl copied to clipboard
Loop operator
Current approach to dynamic graphs is to trace function on each execution and select matching tape (with diff operations already added) from cache. This approach has several disadvantages:
- Tracing takes time.
- Cached tapes occupy memory.
One way to avoid them is to add support for Loop
and If
operations. Loop
looks like the harder one, so this issue is about it.
Representation
Loop operator can be represented with something like this:
mutable struct Loop
id::Int
subtape::Tape
exec_count::Int
end
When executing, this operator would not only repeat its code, but also record the number of repetitions. This number can then be passed to the derivative of the operator for the reverse pass.
Tracing
Currently we trace code by rewriting every call to f(args...)
with a call to record_or_recurse!(..., fargs)
. Every time a primitive function is called, this call is recorded to the tape, leading to a fully unrolled trace.
If we want to move loops into a separate operator, we need to treat them in some other way. Assuming we can detect start and end of the loop (which doesn't seem hard in IRTools's code representation), I see 2 possible approaches so far:
- Trace every execution into subtape as usually and then simplify this subtape.
- Stop tracing after the first execution.
Switching between the outer and inner (sub)tapes can be implemented similar to switching between execution frames.
Reverse pass
Should be similar to the generated forward pass code, but with updating the derivatives.
As an option, we may generate simple gradient function for each loop write during differentiating the outer tape. This would mix up differentiation and compilation stages though - so Julian, but pretty dangerous strategy.
Loop operator outputs
Loop operator can return a tuple of all changed variables. This tuple can then be destructured in a usual way.
Note that it shouldn't break CSE or any other optimizations.
Loop tracing is now supported via Ghost.jl. Differentiating is on the roadmap.