NNlib.jl
NNlib.jl copied to clipboard
Add `within_gradient`
Motivated by https://github.com/FluxML/Tracker.jl/issues/133
Possibly it could have a better name, but what?
is_differentiating
was suggested at some point, IIRC?
I may have suggested that. But aren't we (in english) differentiating the function, not the array x
?
Maybe worth trying to line up with whatever no_gradient function we need too.
Is that argument x
needed? It seems we can just define within_gradient() = false
since we didn't (or can't?) detect whether the function is differentiated wrt x
.
I think x is needed for Tracker to work, as it will never notice a function call which doesn't get a TrackedArray.
I may have suggested that. But aren't we (in english) differentiating the function, not the array
x
?Maybe worth trying to line up with whatever no_gradient function we need too.
I was thinking of the correspondence with CRC.non_differentiable, but agree that the name gets weird when you introduce that array arg.
FWIW, PyTorch calls it GradMode::is_enabled()
(ex).
Edit: I just realized NNlib already has within_gradient()
after reading https://github.com/JuliaDiff/ChainRulesCore.jl/issues/547. To make the purpose of one-arg method clearer, x
could be rebranded as test_val
or some such and specifically documented.
Hah I completely forgot that I already added that... and it's still there:
https://github.com/FluxML/NNlib.jl/blob/0c8396e2f2707d4c223fb45348897eae28b62e2e/src/softmax.jl#L90-L91
It could just grow to within_grad(x...)
or something? Maybe Tracker would like the use of it in ∇softmax
to check the array.
Doesn't is_differentiating
make sense from the user's perspective? Like
function myf(x, y)
if is_differentiating(myf, x, y)
# do something special
else
# normal
end
end
I think we want the convention to pass everything that would get passed to rrule
. e.g. x
could be a TrackedArray
but not y
. And in english, "are we differentiating myf
w.r.t. x
or y
?"
Can you explain more what you see this version which takes a function doing? Does it call the function, or infer something about it (like whether it returns a non-diff type or a TrackedArray or something)? Or would you overload this for some functions, to have non-standard behaviour? Or would is_differentiating(myf, x, y) == is_differentiating(identity, x, y)
always?
Presumably, no, it does not make sense to overload specific functions. If you want that behavior, then why even have is_X
as a check at all in the definition of f
? But I was thinking from the Tracker perspective, you want to test if any of the inputs are tracked. I only included the function for consistency.
I guess I don't see what's gained by passing it more arguments. You can check any(within_gradient, (x, y, z))
, or it could be spelled within_grad(x...)
. You may only care about whether a particular argument is tracked, or know that it's sufficient to test one.
If I want the function to work regardless of AD system, then I need to check all (assuming I care about all of them). You could certainly write any(within_gradient, (x, y, z))
.
Either way, my point was to bring up two things:
-
is_deriving(x)
/is_differentiating(x)
makes more sense in english to me, because you care about whether you are currently taking a derivative with respect to the arguments passed.within_gradient
seems more appropriate for something that never takes any arguments. - Given (1), I would then calling
any(is_deriving, (x, y, z))
seems like the most common use case for the function (versus caring about only a specific argument). For some AD systems, these are equivalent, but others not. So shortening this common case tois_deriving(x, y, z)
seemed worth it. Only a small comment, no need to follow it if we don't think it is worth it.
One reason not to write the splat version is that dispatch doesn't work very well for that. You need a lot of methods to allow for the 3rd being the only TrackedArray, without ambiguity.
One reason to dislike any
is that Zygote knows it returns Bool, and hence stays out:
julia> gradient(x -> within_gradient(x) * x, 2.0) # using this PR, works fine
(1.0,)
julia> gradient(x -> any(within_gradient, (x,x)) * x, 2.0)
(0.0,)
julia> gradient(x -> (within_gradient(x) || within_gradient(x)) * x, 2.0) # could change
(1.0,)
Here ∇softmax_data
does only care about one argument.
In something like https://github.com/FluxML/Flux.jl/blob/master/src/layers/normalise.jl#L109-L110 ... "training" probably means the layer's parameters are tracked, ideally, but not whether the data is.
I would say is_deriving(x)
/is_differentiating(x)
is kind of weird for non-tracker AD. It sounds like you are checking whether the pullback get a NoTangent
as being non-differentiable. Actually, that means this function would have different semantic for different AD.
Yes. The idea seems very clear for Tracker. But for non-tracker AD it seems a bit fragile really.
What a Flux layer wants to know is whether you are training. But what it checks is whether the forward pass calls the rrule
. These seems to be no reason a smarter AD couldn't notice that the condition in if is_training()
can never get a gradient, thus it need not run the rrule
.
In fact it looks like Yota is smart enough to do that:
julia> Zygote.withgradient(x -> within_gradient(x) ? x^2 : x, 2.0)
(val = 4.0, grad = (4.0,))
julia> Diffractor.gradient(x -> within_gradient(x) ? x^2 : x, 2.0)
(4.0,)
julia> Yota.grad(x -> within_gradient(x) ? x^2 : x, 2.0)
(2.0, (ZeroTangent(), 1))
In fact it looks like Yota is smart enough to do that:
julia> Yota.grad(x -> within_gradient(x) ? x^2 : x, 2.0)
(2.0, (ZeroTangent(), 1))
I am a little confused. For me, it seems we are actually differentiate wrt x
so I would expect within_gradient
returning true
?
Paging @dfdx for a real answer, but I think it's assuming that the forward pass of rrule
should always agree with the original function. Then in an expression cond ? a : b
there is never a need to trace into cond
and replace functions with rrule
s, because their pullbacks will never get any nonzero input.
but I think it's assuming that the forward pass of rrule should always agree with the original function
Correct. In Yota, rrule(f, args...)[1]
giving a different result than f(args...)
is kind of undefined behavior - in the latest version it works as in your example, but in the previous it would most likely take the different branch. And since Yota currently doesn't support control flow, the other branch will never show up.
One way to hide these details from Yota (and, perhaps, from all CR-based AD systems) is to add the training
/differentiating
flag to rrule
s themselves. For example:
batch_norm(...; training) = ...
rrule(batch_norm, ...; training) = ...
Since AD usually doesn't analyze the contents of the rrule
itself, it should be safe to put any branching logic into it. Though, I'm not sure it won't invalidate any of the compiler optimizations in Zygote and Diffractor.
Thanks!
The hope with this rule is to magically infer whether a given evaluation of BatchNorm
is during training or inference, even though the code for it is deep inside some model. I don't think there's an obvious way to pass the rrule
a flag.
My attempt to invent a better magic rule still seems to be defeated, I guess it traces it and remembers which branch was taken, before replacing the function with the rrule.
istracked_and_val(xs...) = (any(istracked, xs), xs...)
istracked(x) = false # overload this in Tracker
ChainRulesCore.rrule(::typeof(istracked_and_val), xs...) = (true, xs...), Tuple
julia> function dropout(x1::AbstractArray, p::Real = 0.5)
flag, x2 = istracked_and_val(x1)
if flag # then we are inside a gradient call, or x isa TrackedArray
@. ifelse(rand()>p, x2, zero(x2))
else
x2
end
end;
julia> sum(dropout(ones(10)))
10.0
julia> grad(x -> sum(dropout(x)), ones(10))
(10.0, (ZeroTangent(), [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]))
So what's left for Flux? train!
could just call testmode!(m, false)
on the model first. Or maybe CRC should own some special token which AD has to know to look for & replace?
Re branches in Yota. Yes, in the latest version tracing and rrule-transformation are two different steps, which turned to be a much, much more robust implementation than a unified single-step pass. Thus, tracer sees only the primal definition of istracked_and_val
and follows corresponding branch.
But you don't have to trick Yota - if the consensus is to use istracked()
, I can add it directly to the library.
As a general thought though, if we wish to inject some special calls into functions to account for control flow, then to me the most logical step would be to add control flow constructs themselves, e.g. like JAX's cond
. I guess it would add flexibility not only to the training/testing issue, but to a number of other use cases too.
Even with something like cond
, presumably it would have to be hosted in some common package so all ADs and downstream libraries can use it without taking on unwanted deps? Or do you see a way for ADs to pick up on it without explicitly checking for/defining rules for it?
Yes, definitely, there's no way to avoid a common library. My point is not to avoid the dependency, but to minimize the API that AD systems must support. Having conditionals as a part of the API is quite common and seems to help in this case too (hopefully not only for Yota :D). istracked()
seems to be a new concept in the ML space, so its harder for me to analyze its corner cases. (Which doesn't imply we shouldn't explore it!)
It seems quite tricky. I think @non_differentiable any
at present stops Zygote from looking inside any(within_gradient, (x,x))
at all. But if there's a magic token it has to obey, then it has to keep looking... when does it stop, only when it hits ccall
?
Another option would be to demand that every AD system set some task_local_storage
, a bit like how @allowscalar
works:
https://github.com/JuliaGPU/GPUArrays.jl/blob/master/lib/GPUArraysCore/src/GPUArraysCore.jl#L40-L112
I believe that's a global namespace, so different packages could actually all set task_local_storage(:InsideReverseModeGrad, true)
without depending on a common package.
RE task local state, my understanding is that it doesn't propagate to child tasks and would thus could be a problem for differentiating async code. That and the risk of the AD not cleaning task-local state up correctly on errors, leading to false positives for code outside of AD.
istracked()
seems to be a new concept in the ML space, so its harder for me to analyze its corner cases. (Which doesn't imply we shouldn't explore it!)
I was under the impression that something like cond
was newer? When digging for https://github.com/FluxML/NNlib.jl/pull/434#issuecomment-1232053472, it seems like PyTorch has had a "are we in AD" function for quite some time to help with choosing derivative routines.
This issue of branches within AD may also deserve some though in relation to https://github.com/JuliaDiff/ForwardDiff.jl/issues/480 . That's at last fixed on ForwardDiff, where it seemed to cause the most problems as e.g. you push numbers through det
rather than having a rule for it. But if you do have a branch based on values inside AD, the other systems show the same bad behaviour:
julia> gradient(prod1, x)[1] # correct
1×5 Matrix{Float32}:
0.0 0.0 24.0 0.0 0.0
julia> Diffractor.gradient(prod2, x)[1] # wrong
1×5 Matrix{Float32}:
0.0 0.0 2.0 0.0 0.0
# Same for Tracker, Zygote, Enzyme, but no longer for ForwardDiff.
julia> ChainRulesCore.@non_differentiable eltype(::Any)
julia> Yota.grad(prod2, x) # worse?
(0.0f0, (ZeroTangent(), NoTangent()))
I was under the impression that something like cond was newer? When digging for https://github.com/FluxML/NNlib.jl/pull/434#issuecomment-1232053472, it seems like PyTorch has had a "are we in AD" function for quite some time to help with choosing derivative routines.
Symbolic condition operator existed in TensorFlow 1 (created in 2015), Theano (created in 2008) and perhaps even earlier. I think my cognitive resistance to accept equivalence of istracked()
and PyTorch's GradMode::is_enabled()
comes from the difference between symbolic (Theano, TensorFlow, JAX) and overload-based (PyTorch) families of ADs. As far as can say, Zygote, Diffractor and Yota analyze code as a symbolic graph, so using tools from such systems (e.g. cond
) instead of the differentiation flag looks more natural to me. But maybe that's just my background - I'm pretty curious to see how istracked()
works in this context anyway.
It seems quite tricky. I think @non_differentiable any at present stops Zygote from looking inside any(within_gradient, (x,x)) at all.
This particular example doesn't sound lie a problem. Given a function:
function foo()
if any(within_gradient, (x,x))
return f(...)
else
return g(...)
end
end
Zygote should be able to analyze it's IR and replace all calls to within_gradient(...)
with just true
(or equivalent). Of course, it's possible to nest within_gradient()
into another auxiliary function, e.g.:
bar(x) = any(within_gradient, (x,x))
function foo()
if bar(x)
return f(...)
else
return g(...)
end
end
But for me it looks fair enough to just forbid such usage.
julia> Yota.grad(prod2, x) # worse? (0.0f0, (ZeroTangent(), NoTangent()))
Thanks for the report! I've created an issue to track it.
Thanks for the correction, I misremembered those operations as behaving more like ifelse
than actually being able to lazily evaluate branch subgraphs.
This particular example doesn't sound lie a problem...
Currently that example will hit https://github.com/JuliaDiff/ChainRules.jl/blob/v1.44.5/src/rulesets/Base/nondiff.jl#L99, so AD never gets to generate a pullback for within_grad
. Zygote could in theory do enough constant prop + folding to realize that any(within_gradient, ...) == true
, but in practice I don't think anyone is interested in writing one (using IRTools, because we're still stuck with that) that matches Julia's built-in behaviour from scratch!
So a cond
-like construct could be the path of least resistance. It does feel somewhat unidiomatic to write out each branch as a function instead of using a conditional, but at least for Zygote there is value in hiding said conditional where AD can't see it. My main concern would be whether adding an extra 2+ functions to compile for each branch like this negatively impacts latency, but that can be easily tested once a prototype exists.
Shall we touch this up & merge it? It's not perfect, but nor is what we have right now.
Let's do this.
@dfdx, do you mind fleshing that cond
idea out a bit more and linking it back here once done? I wouldn't want to have it lost after this PR is done. Location of the writeup doesn't matter, but if you're looking for one I opened https://github.com/JuliaDiff/AbstractDifferentiation.jl/issues/66 a while back.
We can add this here, but ChainRulesCore seems a better place.
One reason not ChainRules is that we'd like this to work for ForwardDiff (https://github.com/FluxML/Flux.jl/issues/2122) and Tracker (https://github.com/FluxML/Tracker.jl/issues/133), and that's really the point of moving it from Flux, and neither depend on CR.
Today's 744309a adds ForwardDiff via Requires.
Another is that this is a bit of a dodgy mechanism, it's what Flux uses but really needs tests for each use to make sure AD hasn't out-smarted us.