Enzyme.jl
Enzyme.jl copied to clipboard
Active integer differentiation
I think NamedTuple
s are allowed as Active
s, but currently, we get (may be a known problem):
julia> autodiff(x -> x.a * x.b, Active((a = 2, b = 3)))
((a = 0, b = 0),)
I think the issue here is that the type is integer rather than floating point, what happens if you do 2.0 and 3.0?
The justification for integers being that integers are essentially rounded values with derivative zero everywhere except at the round point (which can be thought of as a delta function).
Ah, indeed:
julia> autodiff(x -> x.a * x.b, Active((a = 2.0, b = 3.0)))
((a = 0, b = 0),)
Ah, I didn't realize that Active
converts scalar integers to floating point (but doesn't do this recursively for NamedTuple
s, of course). Kind of a trap for the unaware user, though. :-)
I guess Enzyme can't do gradients for integers due to the way it works?
May be non-trivial to implement, but what if Active
could throw an exception when passed something with integers in it?
Can easily happen that a user supplies an integer as part of a tuple of fit parameters or so, and generic code between that and Enzyme may think this perfectly fine and not complain about it. And the silent zero-gradient may stay undetected in a larger application ... and when noticed, the reason for it may be hard to track down.
Added some information regarding integer values to #72 .
Also wait @oschulz, your example with floats isn't what I'd expect (see my use below). What version are you using?
julia> autodiff(x -> x.a * x.b, Active((a = 2.0, b = 3.0)))
((a = 3.0, b = 2.0),)
The other reason why integers are inactive by default, for example, is that most of the time you really don't want to differentiate with respect to them. E.g. they are the size of an array, etc. That said using an integer as part of an active computation (e.g. multiplying by an integer) should work perfectly fine (and propagate that) so I'm curious what sorts of larger cases you're thinking about?
your example with floats isn't what I'd expect (see my use below). What version are you using?
That was with Enzyme.jl v0.6.0.
integers are inactive by default, for example, is that most of the time you really don't want to differentiate with respect to them. E.g. they are the size of an array, etc.
It's kinda different to what users are used to in ForwardDiff, Zygote, etc., though. Since Enzyme has a very nice mechanism to control what's active and what's not, the user can keep things like array sizes out of differentiation, right?
so I'm curious what sorts of larger cases you're thinking about
Let's say a user want to run a larger optimization problem, and part of the parameter set is a NamedTuple
with some values. The user writes down some starting value (e.g. just (a = 0, b = 0, ...)
) - if the code is written in a generic fashion, those integers may propagate through to Enzyme. And if things are complex enough (involved calculations, many parameters, etc.), the reason may be hard to miss when debugging. And ForwardDiff and Zygote (if compatible with the code) would return the correct gradient.
Then there's the section "One must be allowed to take derivatives of integer arguments" in the ChainRules docs (CC @oxinabox), so Julia users are used to diffing in respect to integers. And since Enzme.Active
does silently convert plain integers to floats, Enzyme seems (from a user point of view) to do just that in simple cases, even though it really doesn't.
I don't mean to criticize (Enzyme is awesome!) and I do understand that Enzyme isn't Julia only. I'm just a bit worried that this may become "trap" for Julia users in some cases.
Then there's the section "One must be allowed to take derivatives of integer arguments" in the ChainRules docs (CC @oxinabox), so Julia users are used to diffing in respect to integers. And since Enzme.Active does silently convert plain integers to floats, Enzyme seems (from a user point of view) to do just that in simple cases, even though it really doesn't.
Finally someone agrees with me :P
At some point I need to write a long post about embedded subspaces, types that represent the space, vs "computation conveniences".
Diffing with respect to integers is nonsense except:
When it is a computational convenience.
i.e. this value could have been a float but we converted it to an integer because it happend to be an integer right now and our algorithm is much faster if input is integer.
his happens much for for things like structured matrixes, but I can imagine doing for exp2
julia> fast_exp2(x) = exp2(isinteger(x) ? Int(x) : x)
fast_exp2 (generic function with 1 method)
julia> @benchmark exp2(k) setup=(k=rand((1.0, 2.0, 3.0, 4.0)))
BenchmarkTools.Trial:
memory estimate: 0 bytes
allocs estimate: 0
--------------
minimum time: 3.638 ns (0.00% GC)
median time: 3.870 ns (0.00% GC)
mean time: 4.075 ns (0.00% GC)
maximum time: 14.902 ns (0.00% GC)
--------------
samples: 10000
evals/sample: 1000
julia> @benchmark fast_exp2(k) setup=(k=rand((1.0, 2.0, 3.0, 4.0)))
BenchmarkTools.Trial:
memory estimate: 0 bytes
allocs estimate: 0
--------------
minimum time: 2.157 ns (0.00% GC)
median time: 2.723 ns (0.00% GC)
mean time: 2.708 ns (0.00% GC)
maximum time: 19.106 ns (0.00% GC)
--------------
samples: 10000
evals/sample: 1000
When it actually represents a continuous quantity
E.g. rather than a floating point representation, one can have some fixed-point representation. Which is going to be backed by a Integer representing some minimum increment. It's just as valid as floating point, and has some advantages.
We have fixed point libraries, FixedPointDecimals.jl.
Example in stdlibs is that DateTime
s are backed by a count of tick's since unix epoch.
Adding a period to a date time results in calling + on that backing count.
I am note sure where Decimal Floating point (like Decimals.jl) and Rationals fit in there; but both are backed be intergers. Can one actually AD them by the operations then call on their consitutent parts? I suspsect you can. (excluding projecting back to the fixed point quantity)
In case in my tired "create rebuttals for paper reviews haze" caused the message to come across wrong, I do agree Enzyme should handle differentiation of integer quantities -- though doesn't do so right now.
Work for differentiating integers is in progress (and essentially was stalled for a bit as began tackling some more pressing issues like GPU, garbage collection, etc: https://github.com/wsmoses/Enzyme/issues/158).
Currently the way I see the design is to explicitly distinguish between non-differentiable integers (perhaps just called integer types in type analysis) and differentiable integers (perhaps called fixed point in type analysis, though still including regular integers). The reason we've (historically) assumed integers to be non differentiable is to avoid potentially asking questions like what is d/di array[i] (like in the example from Chainrules). By explicitly not implementing a derivative of this and asserting when something like it happened has been incredibly useful for debugging and ensuring the correctness of Enzyme's passes.
Since we do want certain integer operations to be differentiated (and still potentially others which can not be) is the reason why I'm currently thinking about this split type approach.
Also this is not at all set in stone (just early design musings) so if you have suggestions/ideas (especially if you have time to try them out), feel free!
As a final comment, the fast_exp2 case actually worries me a lot. Suppose the integer version were implemented as something like:
res = 1
while (i > 1) {
if (i & 1)
res *= base
i /= 2
}
Even if i
were a differentiable quantity, there'd be a zero derivative result since the only dependence on i
is via control flow. Moreover, what should the semantic meaning of a differentiable i & 1
be. Activity analysis deduces the (integer) boolean result is only used is control flow and thus is inactive, so should we proceed with it being inactive -- or alternatively what is the adjoint of that instruction?
Thanks for all the explanations, Lyndon and Willian! Should I close this or leave it open?