Zygote.jl
Zygote.jl copied to clipboard
Some limited definitions for nested AD
This adds a bunch of definitions to make nested AD work and adds tests for second-order AD. Unfortunately by the time we get to third-order AD, the types get so large that base decides to go on vacation while it thinks about whether or not it might be willing to compile a function with a type of such complexity. Additionally, Zygote introduces some unnecessary stacks, which then prevent higher order AD. I plan to work on both of those issues, but in the meantime, here are the changes to Zygote required to make this work.
The Zygote parts that break nested AD are addressed by #78. Still takes Inf time to compile on current Julia, master so working on that.
In general it'd be cleaner if we could avoid defining second order gradients in favour of e.g. having dgetindex use setindex (out of place) and adding a first-order adjoint for that. Seems like getfield should also just work since it doesn't use mutation, is there some other issue with that (performance?)
Also, what's the motivation for nobacksies? We need that in Tracker to avoid silently dropping gradients, but since there's no Zygote equivalent of data that shouldn't be an issue here; at worst if something's not differentiable then that part will just error.
the types get so large that base decides to go on vacation while it thinks about whether or not it might be willing to compile a function with a type of such complexity
Somewhat surprisingly, the huge types appear not to be a problem (or at least not to make things significantly worse here). I suspect the main issue is simply that we're generating, differentiating and infering the equivalent of several thousand lines of code. If so we may not be able to solve this without either running AD on typed IR with aggressive DCE, or (perhaps in the short term) switching to tracing where DCE is more valuable than other optimisations.
In general it'd be cleaner if we could avoid defining second order gradients in favour of e.g. having dgetindex use
setindex(out of place) and adding a first-order adjoint for that. Seems likegetfieldshould also just work since it doesn't use mutation, is there some other issue with that (performance?)
Yes, agreed, but I just wanted to make things work for now.
Also, what's the motivation for
nobacksies? We need that in Tracker to avoid silently dropping gradients, but since there's no Zygote equivalent ofdatathat shouldn't be an issue here; at worst if something's not differentiable then that part will just error.
Just as a debugging aid for now. While we don't have data, it's not super uncommon to run into a gradient that accidentally drops it, so I wanted to have a way to mark gradients I hadn't yet manually inspected.
Finding third order derivatives is very slow. Is this the related PR to fix it?
BTW: this branch is quite outdated, a lot of conflicts with mater, hoping someone can fix it :smiley: