AbstractDifferentiation.jl
AbstractDifferentiation.jl copied to clipboard
API for user code to detect if it's being differentiated
The most recent iteration of this discussion was https://github.com/FluxML/NNlib.jl/pull/434, but it goes back to https://github.com/JuliaDiff/ChainRulesCore.jl/issues/547 and much further. Given that not all ADs are ChainRules compatible but this package seeks to support (/eventually be supported by) all ADs, it feels like a much better home for such functionality than a domain-specific package like NNlib.
As we briefly discussed in FluxML/NNlib.jl#434, there's another way to think about the problem. If I understand the motivation correctly, there are functions that behave differently during training (i.e. being differentiated) and normal execution, e.g. dropout or batch normalization. Let's use this generic example for simplicity:
function func(x, training=false)
y = nothing
if training
y = do_this(x)
else
y = do_that(x)
end
return y
end
Understanding branches is a hard problem for AD, so many systems assume code execution always takes the same path. For example, Yota would first trace the code with the provided inputs and assume it never changes. E.g. after tracing it will see equivalent of either this:
# when training == true
function func(x, not_used_training_flag)
return do_this(x)
end
or this:
# when training == false
function func(x, not_used_training_flag)
return do_that(x)
end
This leads to a lot of funny things. For example, if we hardcode training=true
into rrule()
(which sounds reasonable, because rrule
is part of differentiation, right?), then the branch taken will depend on whether we first trace and then transform rrule()
or first transform and then trace. Ironically, this is exactly the implementation detail I changed in Yota in the most recent major release, and didn't even know it breaks the behavior!
within_gradient()
attempts to solve this problem by letting the function author to take the appropriate branch depending on the context. But in a multistage execution (e.g. tracing -> differentiation -> some other transformations) analyzing corner cases becomes quite hard.
Another approach is to provide a simple AD-friendly way to include branches, a kind of cond
function. Just to illustrate the idea:
cond(flag, fn1, fn2) = ...
function func_with_cond(x, flag)
y = cond(flag, () -> do_this()(x), () -> do_that(x))
return y
end
function rrule(::typeof(cond), flag, fn1, fn2)
res, pb = flag ? rrule_via_ad(fn1) : rrule_via_ad(fn2) # a kind of...
function cond_pullback(dy)
dxs = pb(dy)
return NoTangent(), NoTangent(), dxs...
end
return res, cond_pullback
end
The code is just an illustration, but I hope you get the idea. An AD system can then overload cond()
(e.g. Tracker.jl), rewrite to rrule(::typeof(cond), ...)
or do something more sophisticated. But there's no more need to treat differentiation context in a specific way - it's a general-purpose if
, just written in an unusual way.
Another advantage of having a cond
-like function is that it's compatible with conditionals in model exchanging formats like IF in ONNX, which should help to bring more pre-trained models to Julia.
I think my worries with a cond
-like construct are twofold:
- To avoid the
within_gradient
problem, it seems like you'd need specializedcond
s for every kind of branch that might change under AD. That or duplicate code for any additional conditionals, i.e.within_gradient(...) && othercondition
now needs to becomeothercondition ? cond(...) : cond(...)
orcond(flag, () -> othercondition ? ..., () -> othercondition ? ...)
. - Chains of
... rrule -> rrule_via_ad -> ... rrule -> rrule_via_ad -> ...
tend to trip up the recursion limit heuristic. This leads to additional dynamic dispatch, compilation and other overhead. It's not clear to me whether we can rely on work like https://github.com/JuliaLang/julia/pull/48059 landing any time soon to alleviate this.
Otherwise I don't have any objections.
- Wouldn't it be just
cond(somecondition && othercondition, () -> ..., () -> ...)
? Also, if you worry about compilation time due to multiple specializations, I believe@nospecialize
should help a lot here. - It's a valid point and actually one of the reasons I currently experiment with a very restrictive approach to AD in Remix.jl. Sometimes, It looks like we spend too much time fighting the compiler instead of focusing on more mundane things, so I'm trying to find a less compile-intensive approach. In particular, my target vision for
cond
, as for any other allowed operation, is to be traced just once, transformed in runtime and finally compiled to the configured backend only at the very end. But it's also a huge experiment, and right now I have neither prototype, nor exact design to share.
- Wouldn't it be just
cond(somecondition && othercondition, () -> ..., () -> ...)
? Also, if you worry about compilation time due to multiple specializations, I believe@nospecialize
should help a lot here.
Well that works for my simple case above, but then you have ||
, elseif
, etc. @nospecialize
might help with cond
itself, but my bigger worry would be redundancy across the branch functions. Maybe cond
could pass flag
to each branch so that one could feasibly use a single callback with if statement for both? Not sure if that's compatible with tracing.
2. Sometimes, It looks like we spend too much time fighting the compiler instead of focusing on more mundane things...
Agreed. Fundamentally it feels like what we want is some kind of staged programming mechanism where one can partially evaluate conditionals like this before the rest of the pipeline runs. Given such a mechanism does not exist at present, this seems like the pragmatic solution.
Well that works for my simple case above, but then you have
||
,elseif
, etc.
Hm, perhaps we put different meaning into somecondition
and othercondition
. I assume both are just ordinary booleans, so their combination is an ordinary boolean too. cond()
behaves just like ifelse()
, but maybe with slightly more complex implementation or special methods. Do you assume some other setup?
elseif
is similar to nested ifselse(..., ifelse(...))
and is harder to analyze, but that should happen too often in neural networks I guess. In fact, I know only two major functions with conditionals - dropout and batchnorm - and both have exactly one conditional branching in them.
Fundamentally it feels like what we want is some kind of staged programming mechanism where one can partially evaluate conditionals like this before the rest of the pipeline runs
If you mean an operator that can have 2 branches - one actually taken and another unevaluated - then I have a design for such feature in Umlaut. Something like this:
struct IfElse
cond::Union{Variable, Any}
true_branch::Union{Tape, Unevaluated}
false_branch::Union{Tape, Unevaluated}
end
where Unevaluated
is a structure holding the whole execution context - instance of IRCode
, execution frame, mappings between variables, etc. This way we analyze one branch during the initial tracing and postpone analyzing the other one until it's actually taken.
But of course it will only work for Umlaut, maybe for Tracker, but unlikely in Zygote or Diffractor since they don't have a user-level execution graph.
Do you assume some other setup?
My understanding of cond(flag, true_fn, false_fn)
is that it obeys the following truth table:
differentiating? | flag | branch taken |
---|---|---|
T | T | true_fn |
T | F | false_fn |
F | T | false_fn |
F | F | false_fn |
In other words, there is an implicit "are we differentiating?" variable which is and-ed with flag
. This variable needs to be implicit and not passed into cond
, because otherwise we run into the same issue within_gradient
has with tracing. Assuming true_fn
and false_fn
are zero-arg functions, notice how this doesn't provide fine-grained control over 1/2 the possible cases. In fact, I don't think taking flag
as an argument actually adds much here over just capturing it in one or both branches as desired. Passing it through as an arg to each branch (and not implicitly and-ing it) might be more optimization friendly, but I imagine that breaks the mental symmetry with ifelse
.
This variable needs to be implicit and not passed into
cond
, because otherwise we run into the same issuewithin_gradient
has with tracing.
I think this is where disagree - the point of cond
is exactly to avoid the problem with tracing in any context. Let's take it step by step. Imagine a function like this (roughly equivalent to dropout implementation):
function dropout(x, active::Bool)
if active
mask = create_mask(x)
return x .* mask
else
return x
end
end
What we want to do with this code is to:
- Trace it into a computational graph.
- Transform the graph for differentiating.
The problem arises during the first stage when tracer has to choose only one branch and ignore the other. Depending on the value of active
the graph will be either:
%1 = create_mask(x)
%2 = x .* %1
ret %2
or
ret x
The other part of the information is just lost. All further transformations thus have to make assumptions about the tracer behavior or pass some implicit flags.
The idea behind cond
is to preserve the information about both branches:
%1 = true_fn
%2 = false_fn
%3 = cond(active, true_fn, false_fn, x)
ret %3
(here I replicate API from JAX'x cond, which seems to be a better option than what I posted previously)
Further transformations now have all the information to behave correctly in all cases. The ChainRules-based AD transformation, for example, can produce graph like this:
%1 = true_fn
%2 = false_fn
%3, %4 = rrule(cond, active, true_fn, false_fn, x) # %3 == val, %4 == pullback
%5 = 1 # seed, i.e. initial dy
%6 = %4(%5) # dxs = pullback(seed)
rrule(cond, ...)
has access to the flag active
, functions of both branches and their arguments, so something like this should work:
function rrule(::typeof(cond), flag::Bool, true_fn, false_fn, args...)
return flag ? rrule_via_ad(true_fn, args...) : rrule_via_ad(false_fn, args...)
end
A more sophisticated AD system can also choose to treat cond
in a special way, and instead of
rrule(cond, active, true_fn, false_fn, x)
record
cond(active, rrule_via_ad(true_fn, x), rrule_via_ad(false_fn, y))
or even something more complicated and efficient.
So there's no need in a special differentiating
flag. There still may be several conditions, but I don't see any issues with them too. E.g. for ||
:
%1 = flag1 || flag2
%2 = cond(%1, true_fn(x), false_fn(x))
for elseif
:
# main graph
%1 = cond(flag1, true_fn(x), false_fn(x))
# false_fn subgraph
% 1 = cond(flag2, true_fn2(x), false_fn2(x))
elseif
after ChainRules transformation:
# main graph
%1, %2 = rrule(cond, flag1, true_fn(x), false_fn(x))
# false_fn subgraph - invoked inside of `rrule_via_ad(false_fn, x)`
%1, %2 = rrule(cond, flag2, true_fn2(x), false_fn2(x))
...
I'm not sure I understand how this tracing works then. To get a value for active
, you ultimately need to call something which returns true
when not differentiating and false
when differentiating. Doing any branching on the value of active
(including &&
and ||
, which lower to control flow/branches) will lead to the within_grad
issue. cond
works because it's special-cased as an opaque function to the tracer, but being limited to only non-short-circuiting bool operations between getting the value of active
and passing it to cond
seems quite limiting (though still workable for Flux).
In fact, I don't even know if cond
belongs in a diff-related library—it seems general-purpose enough to warrant inclusion in some hypothetical TracingPrimitives.jl or OpaqueConditionals.jl. Alternatively, could there be a way to mark active
as an opaque value for the tracer such that an IfElse
is always generated for conditionals branching on it?
To get a value for active, you ultimately need to call something which returns true when not differentiating and false when differentiating.
That's the point - active
and differentiation are independent. All four combinations are valid:
active | differentiating | example |
---|---|---|
F | F | dropout(x, false) |
T | F | dropout(x, true) |
F | T | rrule(dropout, x, false) |
T | T | rrule(dropout, x, true) |
Usually, people set active = true
while training and differentiate code while training, but strictly speaking nobody forbids you to set active = true
during inference. active
is a flag passed from the top-level (e.g. via trainmode!()
). Differentiating is a transformation that can work on any valid primitive. cond
is one of such primitives. All three can be mixed or used independently.
In fact, I don't even know if cond belongs in a diff-related library—it seems general-purpose enough to warrant inclusion in some hypothetical TracingPrimitives.jl or OpaqueConditionals.jl.
Absolutely! At the moment, cond
itself is a hypothetical op :) But once it gets shaped, something like TracingPrimitives
sounds reasonable.
That's the point -
active
and differentiation are independent.
Usually, people set
active = true
while training and differentiate code while training, but strictly speaking nobody forbids you to setactive = true
during inference.active
is a flag passed from the top-level (e.g. viatrainmode!()
).
It doesn't affect the outcome of this issue discussion so I think this can be continued elsewhere, but the impetus for this whole thread was the default case of when people don't set active
at all. Then active
and differentiation become tightly coupled, and this of course upsets tracing because it expects them not to be. The proximal solution seems to be creating a primitive like cond
that tracing can't pierce and being very, very careful about not getting troublesome values like active
caught in conditionals, but it would also be nice to step back a bit and brainstorm if there are more generic solutions. For example, TorchDynamo has a concept of "graph breaks", which allows it to avoid the pitfall of only capturing the traced branch.
Yeah, it looks like we went quite off the topic :) The whole discussion also makes me think how hard it is to design common API for very different AD systems. I think in terms of computational graphs and transformations on them, much like JAX, and you provide examples from PyTorch-like systems. Maybe such discussions are easier when we have a working prototype in at least one system.