AbstractDifferentiation.jl icon indicating copy to clipboard operation
AbstractDifferentiation.jl copied to clipboard

API for user code to detect if it's being differentiated

Open ToucheSir opened this issue 2 years ago • 11 comments

The most recent iteration of this discussion was https://github.com/FluxML/NNlib.jl/pull/434, but it goes back to https://github.com/JuliaDiff/ChainRulesCore.jl/issues/547 and much further. Given that not all ADs are ChainRules compatible but this package seeks to support (/eventually be supported by) all ADs, it feels like a much better home for such functionality than a domain-specific package like NNlib.

ToucheSir avatar Nov 24 '22 05:11 ToucheSir

As we briefly discussed in FluxML/NNlib.jl#434, there's another way to think about the problem. If I understand the motivation correctly, there are functions that behave differently during training (i.e. being differentiated) and normal execution, e.g. dropout or batch normalization. Let's use this generic example for simplicity:

function func(x, training=false)
    y = nothing
    if training
        y = do_this(x)
    else
        y = do_that(x)
    end
    return y
end

Understanding branches is a hard problem for AD, so many systems assume code execution always takes the same path. For example, Yota would first trace the code with the provided inputs and assume it never changes. E.g. after tracing it will see equivalent of either this:

# when training == true
function func(x, not_used_training_flag)
    return do_this(x)
end

or this:

# when training == false
function func(x, not_used_training_flag)
    return do_that(x)
end

This leads to a lot of funny things. For example, if we hardcode training=true into rrule() (which sounds reasonable, because rrule is part of differentiation, right?), then the branch taken will depend on whether we first trace and then transform rrule() or first transform and then trace. Ironically, this is exactly the implementation detail I changed in Yota in the most recent major release, and didn't even know it breaks the behavior!

within_gradient() attempts to solve this problem by letting the function author to take the appropriate branch depending on the context. But in a multistage execution (e.g. tracing -> differentiation -> some other transformations) analyzing corner cases becomes quite hard.

Another approach is to provide a simple AD-friendly way to include branches, a kind of cond function. Just to illustrate the idea:

cond(flag, fn1, fn2) = ...

function func_with_cond(x, flag)
    y = cond(flag, () -> do_this()(x), () -> do_that(x))
    return y
end

function rrule(::typeof(cond), flag, fn1, fn2)
    res, pb = flag ? rrule_via_ad(fn1) : rrule_via_ad(fn2)  # a kind of...
    function cond_pullback(dy)
        dxs = pb(dy)
        return NoTangent(), NoTangent(), dxs...
    end
    return res, cond_pullback
end

The code is just an illustration, but I hope you get the idea. An AD system can then overload cond() (e.g. Tracker.jl), rewrite to rrule(::typeof(cond), ...) or do something more sophisticated. But there's no more need to treat differentiation context in a specific way - it's a general-purpose if, just written in an unusual way.

Another advantage of having a cond-like function is that it's compatible with conditionals in model exchanging formats like IF in ONNX, which should help to bring more pre-trained models to Julia.

dfdx avatar Jan 08 '23 18:01 dfdx

I think my worries with a cond-like construct are twofold:

  1. To avoid the within_gradient problem, it seems like you'd need specialized conds for every kind of branch that might change under AD. That or duplicate code for any additional conditionals, i.e. within_gradient(...) && othercondition now needs to become othercondition ? cond(...) : cond(...) or cond(flag, () -> othercondition ? ..., () -> othercondition ? ...).
  2. Chains of ... rrule -> rrule_via_ad -> ... rrule -> rrule_via_ad -> ... tend to trip up the recursion limit heuristic. This leads to additional dynamic dispatch, compilation and other overhead. It's not clear to me whether we can rely on work like https://github.com/JuliaLang/julia/pull/48059 landing any time soon to alleviate this.

Otherwise I don't have any objections.

ToucheSir avatar Jan 08 '23 19:01 ToucheSir

  1. Wouldn't it be just cond(somecondition && othercondition, () -> ..., () -> ...)? Also, if you worry about compilation time due to multiple specializations, I believe @nospecialize should help a lot here.
  2. It's a valid point and actually one of the reasons I currently experiment with a very restrictive approach to AD in Remix.jl. Sometimes, It looks like we spend too much time fighting the compiler instead of focusing on more mundane things, so I'm trying to find a less compile-intensive approach. In particular, my target vision for cond, as for any other allowed operation, is to be traced just once, transformed in runtime and finally compiled to the configured backend only at the very end. But it's also a huge experiment, and right now I have neither prototype, nor exact design to share.

dfdx avatar Jan 09 '23 22:01 dfdx

  1. Wouldn't it be just cond(somecondition && othercondition, () -> ..., () -> ...)? Also, if you worry about compilation time due to multiple specializations, I believe @nospecialize should help a lot here.

Well that works for my simple case above, but then you have ||, elseif, etc. @nospecialize might help with cond itself, but my bigger worry would be redundancy across the branch functions. Maybe cond could pass flag to each branch so that one could feasibly use a single callback with if statement for both? Not sure if that's compatible with tracing.

2. Sometimes, It looks like we spend too much time fighting the compiler instead of focusing on more mundane things...

Agreed. Fundamentally it feels like what we want is some kind of staged programming mechanism where one can partially evaluate conditionals like this before the rest of the pipeline runs. Given such a mechanism does not exist at present, this seems like the pragmatic solution.

ToucheSir avatar Jan 10 '23 01:01 ToucheSir

Well that works for my simple case above, but then you have ||, elseif, etc.

Hm, perhaps we put different meaning into somecondition and othercondition. I assume both are just ordinary booleans, so their combination is an ordinary boolean too. cond() behaves just like ifelse(), but maybe with slightly more complex implementation or special methods. Do you assume some other setup?

elseif is similar to nested ifselse(..., ifelse(...)) and is harder to analyze, but that should happen too often in neural networks I guess. In fact, I know only two major functions with conditionals - dropout and batchnorm - and both have exactly one conditional branching in them.

Fundamentally it feels like what we want is some kind of staged programming mechanism where one can partially evaluate conditionals like this before the rest of the pipeline runs

If you mean an operator that can have 2 branches - one actually taken and another unevaluated - then I have a design for such feature in Umlaut. Something like this:

struct IfElse
    cond::Union{Variable, Any}
    true_branch::Union{Tape, Unevaluated}
    false_branch::Union{Tape, Unevaluated}
end

where Unevaluated is a structure holding the whole execution context - instance of IRCode, execution frame, mappings between variables, etc. This way we analyze one branch during the initial tracing and postpone analyzing the other one until it's actually taken.

But of course it will only work for Umlaut, maybe for Tracker, but unlikely in Zygote or Diffractor since they don't have a user-level execution graph.

dfdx avatar Jan 10 '23 22:01 dfdx

Do you assume some other setup?

My understanding of cond(flag, true_fn, false_fn) is that it obeys the following truth table:

differentiating? flag branch taken
T T true_fn
T F false_fn
F T false_fn
F F false_fn

In other words, there is an implicit "are we differentiating?" variable which is and-ed with flag. This variable needs to be implicit and not passed into cond, because otherwise we run into the same issue within_gradient has with tracing. Assuming true_fn and false_fn are zero-arg functions, notice how this doesn't provide fine-grained control over 1/2 the possible cases. In fact, I don't think taking flag as an argument actually adds much here over just capturing it in one or both branches as desired. Passing it through as an arg to each branch (and not implicitly and-ing it) might be more optimization friendly, but I imagine that breaks the mental symmetry with ifelse.

ToucheSir avatar Jan 11 '23 01:01 ToucheSir

This variable needs to be implicit and not passed into cond, because otherwise we run into the same issue within_gradient has with tracing.

I think this is where disagree - the point of cond is exactly to avoid the problem with tracing in any context. Let's take it step by step. Imagine a function like this (roughly equivalent to dropout implementation):

function dropout(x, active::Bool)
    if active
        mask = create_mask(x)
        return x .* mask
    else
        return x
    end
end

What we want to do with this code is to:

  1. Trace it into a computational graph.
  2. Transform the graph for differentiating.

The problem arises during the first stage when tracer has to choose only one branch and ignore the other. Depending on the value of active the graph will be either:

%1 = create_mask(x)
%2 = x .* %1
ret %2

or

ret x

The other part of the information is just lost. All further transformations thus have to make assumptions about the tracer behavior or pass some implicit flags.

The idea behind cond is to preserve the information about both branches:

%1 = true_fn
%2 = false_fn
%3 = cond(active, true_fn, false_fn, x)
ret %3

(here I replicate API from JAX'x cond, which seems to be a better option than what I posted previously)

Further transformations now have all the information to behave correctly in all cases. The ChainRules-based AD transformation, for example, can produce graph like this:

%1 = true_fn
%2 = false_fn
%3, %4 = rrule(cond, active, true_fn, false_fn, x)   # %3 == val, %4 == pullback
%5 = 1                 # seed, i.e. initial dy
%6 = %4(%5)            # dxs = pullback(seed)

rrule(cond, ...) has access to the flag active, functions of both branches and their arguments, so something like this should work:

function rrule(::typeof(cond), flag::Bool, true_fn, false_fn, args...)
    return flag ? rrule_via_ad(true_fn, args...) : rrule_via_ad(false_fn, args...)
end

A more sophisticated AD system can also choose to treat cond in a special way, and instead of

rrule(cond, active, true_fn, false_fn, x)

record

cond(active, rrule_via_ad(true_fn, x), rrule_via_ad(false_fn, y))

or even something more complicated and efficient.


So there's no need in a special differentiating flag. There still may be several conditions, but I don't see any issues with them too. E.g. for ||:

%1 = flag1 || flag2
%2 = cond(%1, true_fn(x), false_fn(x))

for elseif:

# main graph
%1 = cond(flag1, true_fn(x), false_fn(x))

# false_fn subgraph
% 1 = cond(flag2, true_fn2(x), false_fn2(x))

elseif after ChainRules transformation:

# main graph
%1, %2 = rrule(cond, flag1, true_fn(x), false_fn(x))

# false_fn subgraph - invoked inside of `rrule_via_ad(false_fn, x)`
%1, %2 = rrule(cond, flag2, true_fn2(x), false_fn2(x))
...

dfdx avatar Jan 11 '23 09:01 dfdx

I'm not sure I understand how this tracing works then. To get a value for active, you ultimately need to call something which returns true when not differentiating and false when differentiating. Doing any branching on the value of active (including && and ||, which lower to control flow/branches) will lead to the within_grad issue. cond works because it's special-cased as an opaque function to the tracer, but being limited to only non-short-circuiting bool operations between getting the value of active and passing it to cond seems quite limiting (though still workable for Flux).

In fact, I don't even know if cond belongs in a diff-related library—it seems general-purpose enough to warrant inclusion in some hypothetical TracingPrimitives.jl or OpaqueConditionals.jl. Alternatively, could there be a way to mark active as an opaque value for the tracer such that an IfElse is always generated for conditionals branching on it?

ToucheSir avatar Jan 11 '23 15:01 ToucheSir

To get a value for active, you ultimately need to call something which returns true when not differentiating and false when differentiating.

That's the point - active and differentiation are independent. All four combinations are valid:

active differentiating example
F F dropout(x, false)
T F dropout(x, true)
F T rrule(dropout, x, false)
T T rrule(dropout, x, true)

Usually, people set active = true while training and differentiate code while training, but strictly speaking nobody forbids you to set active = true during inference. active is a flag passed from the top-level (e.g. via trainmode!()). Differentiating is a transformation that can work on any valid primitive. cond is one of such primitives. All three can be mixed or used independently.

In fact, I don't even know if cond belongs in a diff-related library—it seems general-purpose enough to warrant inclusion in some hypothetical TracingPrimitives.jl or OpaqueConditionals.jl.

Absolutely! At the moment, cond itself is a hypothetical op :) But once it gets shaped, something like TracingPrimitives sounds reasonable.

dfdx avatar Jan 11 '23 20:01 dfdx

That's the point - active and differentiation are independent.

Usually, people set active = true while training and differentiate code while training, but strictly speaking nobody forbids you to set active = true during inference. active is a flag passed from the top-level (e.g. via trainmode!()).

It doesn't affect the outcome of this issue discussion so I think this can be continued elsewhere, but the impetus for this whole thread was the default case of when people don't set active at all. Then active and differentiation become tightly coupled, and this of course upsets tracing because it expects them not to be. The proximal solution seems to be creating a primitive like cond that tracing can't pierce and being very, very careful about not getting troublesome values like active caught in conditionals, but it would also be nice to step back a bit and brainstorm if there are more generic solutions. For example, TorchDynamo has a concept of "graph breaks", which allows it to avoid the pitfall of only capturing the traced branch.

ToucheSir avatar Jan 11 '23 21:01 ToucheSir

Yeah, it looks like we went quite off the topic :) The whole discussion also makes me think how hard it is to design common API for very different AD systems. I think in terms of computational graphs and transformations on them, much like JAX, and you provide examples from PyTorch-like systems. Maybe such discussions are easier when we have a working prototype in at least one system.

dfdx avatar Jan 11 '23 22:01 dfdx