Tracker.jl
Tracker.jl copied to clipboard
Nested derivatives do not work for functions with arrays as argument
See the following code. Pseudocode + explanations:
f = x1 * x2
grad_f = (x2, x1)
fg = sum(grad_f) = x2+x1
grad_fg = [1, 1]
Actual code:
using Tracker
x = param([1, 2]) # Tracked 3-element Vector{Float64}
f(x) = prod(x)
g(x) = gradient(f, x, nest=true)
fg(x) = sum(g(x)[1])
gg = gradient(fg, x)
The value of gg
is [-2.0, -0.5]
instead of [1, 1]
Another related example, using sum
instead of prod
:
Pseudocode + explanations:
f = x1 + x2
grad_f = (1, 1)
fg = sum(grad_f) = 2
grad_fg = [0, 0]
Actual code:
using Tracker
x = param([1, 2]) # Tracked 3-element Vector{Float64}
f(x) = sum(x)
g(x) = gradient(f, x, nest=true)
fg(x) = sum(g(x)[1])
gg = gradient(fg, x)
This code gives the following: ERROR: MethodError: no method matching back!(::Float64)
As a side info, the hessian of the prod
function works in AutoGrad.jl:
using AutoGrad
x = Param([1,2,3]) # user declares parameters
p(x) = prod(x)
hess(f,i=1) = grad((x...)->grad(f)(x...)[i])
hess(p, 1)(x)
hess(p, 2)(x)
hess(p, 3)(x)
Returns the correct result:
3-element Vector{Float64}:
0.0
3.0
2.0
3-element Vector{Float64}:
3.0
0.0
1.0
3-element Vector{Float64}:
2.0
1.0
0.0
Some really good news. If I use 2 separate params the hessian works. Next is to adjust this to work for arrays.
using Tracker
x1 = param(1)
x2 = param(2)
f(x1, x2) = x1*x2
∇f(x1, x2) = gradient(f, x1, x2, nest=true)
H11 = gradient(x1->∇f(x1, x2)[1], x1)[1]
H12 = gradient(x2->∇f(x1, x2)[1], x2)[1]
H21 = gradient(x1->∇f(x1, x2)[2], x1)[1]
H22 = gradient(x2->∇f(x1, x2)[2], x2)[1]
H = [H11 H12; H21 H22]
Result:
2×2 Matrix{Tracker.TrackedReal{Float64}}:
0.0 1.0
1.0 0.0
Higher order derivatives of functions with a single argument, work correctly:
using Tracker
f(x) = sin(cos(x^2))
df(x) = gradient(f, x, nest=true)[1]
d2f(x) = gradient(u->df(u), x, nest=true)[1]
d3f(x) = gradient(u->d2f(u), x, nest=true)[1]
x0 = param(1)
df(x0) # -1.4432122981268867 (tracked)
d2f(x0) # -4.7534826540186135 (tracked)
d3f(x0) # -5.683220233612525 (tracked)
Some really good news. If I use 2 separate params the hessian works. Next is to adjust this to work for arrays. f(x1, x2) = x1*x2 ∇f(x1, x2) = gradient(f, x1, x2, nest=true) H11 = gradient(x1->∇f(x1, x2)[1], x1)[1]
I tried to replicate this approach with the following code, but yields incorrect results:
∇f = x->gradient(f, x, nest=true)[1]
hess1 = x->gradient(u->∇f(u)[1], x)
hess1([1, 2]) # ([-2.0, 0.0] (tracked),)
Getting closer. I have manually created the corresponding graph for the gradient and the hessian. The rrules used are correct. There is something wrong with the tracker algo in the "nested" gradients. The extra graph to record the derivatives is perhaps not created correctly.
using Tracker
p=x->prod(x)
q=x->p(x)./x # the hard-coded gradient
jacobian(q, x) # jacobian of the gradient = hessian
Gives the correct result:
Tracked 2×2 Matrix{Float64}:
0.0 1.0
1.0 0.0
I have created a function to print the graph under a specific node: https://github.com/MariusDrulea/Tracker.jl/blob/master/src/Tracker.jl#L140
One can notice the Tracker does not record the methods performed, but the pullback of these methods. This is because we only need the pullbacks when we do back-propagation in the graph.
using Tracker
f(x) = prod(x)
∇f = x->gradient(f, x, nest=true)[1]
hess1 = x->gradient(u->∇f(u)[1], x, nest=true)
h1 = hess1([3, 4])
Tracker.print_graph(stdout, h1[1])
# Prints the following
TrackedData
-data=[-1.3333333333333333, 0.0]
-Tracker=
--isleaf=false
--grad=UndefInitializer()
--Call=
---f
----back
-----func.f=+
-----func.args=([0.0, 0.0], [-1.3333333333333333, -0.0] (tracked))
---args
----nothing
----Tracker=
-----isleaf=false
-----grad=UndefInitializer()
-----Call=
------f
-------back
--------func.f=partial
--------func.args=(Base.RefValue{Base.Broadcast.var"#12#14"{Base.Broadcast.var"#18#20"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#11#13"}, Base.Broadcast.var"#15#16"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}}, Base.Broadcast.var"#25#26"{Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}}, Base.Broadcast.var"#21#22"{Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}}, typeof(/)}, typeof(*)}}(Base.Broadcast.var"#12#14"{Base.Broadcast.var"#18#20"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#11#13"}, Base.Broadcast.var"#15#16"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}}, Base.Broadcast.var"#25#26"{Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}}, Base.Broadcast.var"#21#22"{Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}}, typeof(/)}, typeof(*)}(Base.Broadcast.var"#18#20"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#11#13"}, Base.Broadcast.var"#15#16"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}}, Base.Broadcast.var"#25#26"{Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}}, Base.Broadcast.var"#21#22"{Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}}, typeof(/)}(Base.Broadcast.var"#15#16"{Base.Broadcast.var"#11#13"}(Base.Broadcast.var"#11#13"()), Base.Broadcast.var"#15#16"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}}(Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}(Base.Broadcast.var"#17#19"())), Base.Broadcast.var"#25#26"{Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}}(Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}(Base.Broadcast.var"#27#28"())), Base.Broadcast.var"#21#22"{Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}}(Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}(Base.Broadcast.var"#23#24"())), /), *)), [1.0, 0.0] (tracked), 2, 12.0, [3.0, 4.0] (tracked), 1)
------args
-------nothing
-------Tracker=
--------isleaf=false
--------grad=UndefInitializer()
--------Call=
---------f
----------back
-----------func.f=partial
-----------func.args=(Base.RefValue{typeof(+)}(+), [1.0, 0.0] (tracked), 2, [0.0, 0.0], [4.0, 3.0] (tracked))
---------args
----------nothing
----------Tracker=
-----------isleaf=false
-----------grad=UndefInitializer()
-----------Call=
------------f
-------------#714
--------------func.f=getindex
------------args
-------------nothing
-------------nothing
----------nothing
----------nothing
----------Tracker=
-----------isleaf=false
-----------grad=UndefInitializer()
-----------Call=
------------f
-------------back
--------------func.f=#12
--------------func.args=(12.0, [3.0, 4.0] (tracked), 1)
------------args
-------------nothing
-------------Tracker=
--------------isleaf=false
--------------grad=UndefInitializer()
--------------Call=
---------------f
----------------#718
---------------args
----------------Tracker=
-----------------isleaf=true
-----------------grad=[0.0, 0.0]
-----------------Call=
------------------f
------------------args
-------------nothing
-------nothing
-------nothing
-------Tracker=
--------isleaf=false
--------grad=UndefInitializer()
--------Call=
---------f
----------#718
---------args
----------Tracker=
-----------isleaf=true
-----------grad=[0.0, 0.0]
-----------Call=
------------f
------------args
-------nothing
Thanks for digging! As you can see this package doesn't get a lot of attention, but fixes are very welcome.
I do not know what's going wrong in these cases, I never thought much about how this package handles second derivatives. There is a way to mark rules as only suitable for first derivatives, which the prod
rule does not use.
I do suspect some broadcasting rule does not work correctly. If you look in the graph above there is a long chain of broadcasts, which looks weird.
This is the graph of the example which works correctly. Looks clean.
using Tracker
x1 = param(1)
x2 = param(2)
f_(x1, x2) = x1*x2
∇f_(x1, x2) = gradient(f_, x1, x2, nest=true)
H12 = gradient(x2->∇f_(x1, x2)[1], x2, nest=true)[1]
##Prints the following
Tracker.print_graph(stdout, H12)
TrackedData
-data=1.0
-Tracker=
--isleaf=false
--grad=0.0
--Call=
---f
----back
-----func.f=+
-----func.args=(0.0, 1.0 (tracked))
---args
----nothing
----Tracker=
-----isleaf=false
-----grad=0.0
-----Call=
------f
-------#229
--------func.b=1
------args
-------Tracker=
--------isleaf=false
--------grad=0.0
--------Call=
---------f
----------back
-----------func.f=partial
-----------func.args=(Base.RefValue{typeof(+)}(+), 1, 2, 0.0, 2.0 (tracked))
---------args
----------nothing
----------nothing
----------nothing
----------nothing
----------Tracker=
-----------isleaf=false
-----------grad=0.0
-----------Call=
------------f
-------------#230
--------------func.a=1
------------args
-------------nothing
-------------Tracker=
--------------isleaf=false
--------------grad=0.0
--------------Call=
---------------f
----------------#718
---------------args
----------------Tracker=
-----------------isleaf=false
-----------------grad=0.0
-----------------Call=
------------------f
-------------------#718
------------------args
-------------------Tracker=
--------------------isleaf=true
--------------------grad=1.0
--------------------Call=
---------------------f
---------------------args
-------nothing
Thanks for digging! As you can see this package doesn't get a lot of attention, but fixes are very welcome.
I have reviewed the back-propagation algo for both simple and nested gradients and everything seem just correct. So probably is a good idea now to switch to ChainRules as these rules are more robust. Another internal design decision is whether to record the original function or the pullback function in the graph. The algo will be mostly the same in both of these options. Currently the pullback is stored, but this makes the printing and visualization of the graph a bit more difficult (so is debugging). I also don't know if storing the pullback has any advantage over storing the original function, so I would be tempted to store the original function instead and invoke the rrule of this original function only during back-propagation in the graph.
To summarize, 3 actions are needed to improve the robustness and ease of development of this package:
- switch to ChainRules
- in the graph, store the original function instead of the pullback
- document the code and the public functions, generate docs
Calling the rrule
only during the backwards pass would probably not work great, because it'd end up with us recomputing the primal. Did you perhaps mean calling the pullback function rrule
returns? I know Zygote and Tracker have not been great at giving those names, but ChainRules is pretty good about doing so. If I call rrule(f, args...)
, I'll probably get a (primal, pullback::typeof(f_pullback))
back.
@ToucheSir you are right, we have to store the pullback, not the original function. This is also clear now to me: https://discourse.julialang.org/t/second-order-derivatives-with-chainrules/103606
Integrating ChainRules is pretty easy: https://github.com/MariusDrulea/Tracker.jl/blob/master/src/Tracker.jl#L91 Second derivative still do not work, because not all operations are tracked and sent to ChainRules, I just added +(x, y) for instance: https://github.com/MariusDrulea/Tracker.jl/blob/master/src/lib/array.jl#L455.
Next is to:
- continue integration of ChainRules: track some missing operations (like
x+y=track(+, x, y)
, wherex
andy
are TrackedArrays etc - make sure all first order derivatives work
- fix second order derivatives
- remove obsolete rules defined by Tracker
- document the code & generate documentation
The following first order derivatives works via ChainRules:
using Tracker
f2(x, y) = prod(x+y)
dx = gradient(f2, [3, 4], [7, 9], nest=true)
Prints:
[ Info: Chainrules for +
[ Info: Chainrules for prod
([13.0, 10.0], [13.0, 10.0])
Integrating ChainRules is pretty easy: https://github.com/MariusDrulea/Tracker.jl/blob/master/src/Tracker.jl#L91
Not that easy, son. It is easy only for the first order derivatives. I do have to further track operations performed by these first order derivatives, such that the second order derivatives can be called.
I'm currently stuck as I don't know how to deal with the ChainRules.rrule(s) here.
The logic in Tracker is to define an untracked forward pass and a tracked pullback and the AD engine will do the rest.
-
f(x::TrackedReal) = f(data(x)) # untracked forward pass, data(x) is Float64
-
back(d)=d*∂f(x) # tracked pullback, x is a TrackedReal
What ChainRules offers:
y, back = rrule(sin, x::TrackedReal)
This would yield sin(x)
(tracked) and back
also tracked.
What we want to get is sin(data(x))
(not tracked) and back
tracked. So basically I want to be able to modify the forward pass.
With diffrules this is easily achievable as we can define our forward and pullback separately, e.g.
sin(x) = sin(data(x)) # untracked
DiffRules.diffrule(sin, x) = cos(x) # tracked, implicitly provided by DiffRules
@ToucheSir, @mcabbott Any idea? Or ping somebody who can help?
This would yield
sin(x)
(tracked) andback
also tracked. What we want to get issin(data(x))
(not tracked) andback
tracked.
I'm not sure I understand, why shouldn't it be sin(x)
(tracked)? If the primal output is not tracked, AD will just stop working for subsequent operations.
The tracking of sin(x)
is done by the Tracker.jl logic, not by rrule.
For sin
this is what happens:
sin(xs::TrackedReal) = track(sin, xs::TrackedReal)
The definition of track
for sin
performs like this:
function track(sin, xs::TrackedReal)
# y is not tracked as we only perform sin over data(xs)
# the pullback, cos(xs) is tracked as we perform cos on a TrackedReal object - essentially a track function will be called for it, just like for sin
y, back = sin(data(xs)), cos(xs) # cos(xs) is the pullback provided by DiffRules
# here we create another TrackedReal object whose data is y and we also record the pullback and the previous node in the graph (tracker.(xs))
track(Call(back, tracker.(xs)), y) # the tracking of primal y happens here
end
I have checked Autograd.jl engine and also HIPS/autograd engine. These engines calls the differentiation (rrules) in the backward pass. So does Tracker.jl.
The right thing to do here is to thunk
the primal computation for all rrules in ChainRules, or any other implementation avoiding the primal computation. Yes, the primal is needed most of the time with rrules, but not always!
I think we can currently achieve the following:
- If only first order derivatives are needed, we can simply use rrules during the forward pass and register the pullback for the backward pass. This way the primal will be computed only once.
- If higher order derivatives are needed, we will have to do a primal on the data, while calling rrules on TrackedReal etc on the backward pass. This mean the primal function will be called twice.
Sample code for item 2:
using ChainRules
using ChainRules: rrule
using ChainRulesCore
import Base: +, *
struct _Tracked <: Real
data::Float64
f::Any
_Tracked(data, f) = new(data, f)
_Tracked(data) = new(data, nothing)
end
function track(f, a, b)
data, delayed = _forward(f, a, b)
return _Tracked(data, delayed)
end
function _forward(f, a, b)
data = f(a.data, b.data) # primal no tracking
delayed = ()->rrule(f, a, b) # delayed pullback with tracking
return data, delayed
end
a::_Tracked + b::_Tracked = track(+, a, b)
a::_Tracked * b::_Tracked = track(*, a, b)
##
a = _Tracked(10)
b = _Tracked(20)
tr = a*b
back2 = tr.f()[2]
da, db = back2(_Tracked(1.0))[2:end]
# result: da =_Tracked(20.0, ...), db =_Tracked(10.0, ...)
Status:
Test Summary: | Pass Fail Error Broken Total Time
Tracker | 357 1 15 3 376 22.0s
gradtests 1 | 16 16 0.6s
gradtests 1.1 | 10 10 0.4s
indexing & slicing | 1 1 0.1s
concat | 181 181 2.1s
getindex (Nabla.jl - #139) | 1 1 0.0s
gradtests 2 | 22 2 24 0.9s
mean | 5 5 0.1s
maximum | 5 5 0.1s
minimum | 5 5 0.1s
gradtests 3 | 9 4 13 0.5s
transpose | 48 48 0.0s
conv, 1d | 2 1 3 2.3s
conv, 2d | 2 1 3 1.1s
conv, 3d | 2 1 3 8.0s
pooling | 4 4 2.0s
equality & order | 16 16 0.0s
reshape | 8 8 0.0s
Intermediates | 2 2 0.0s
Fallbacks | 1 1 0.0s
collect | 1 1 0.1s
Hooks | 1 1 0.1s
Checkpointing | 2 1 1 4 0.2s
Updates | 2 2 0.0s
Params | 1 1 0.0s
Forward | 2 2 0.2s
Custom Sensitivities | 1 1 0.1s
PDMats | 1 1 0.1s
broadcast | 2 1 3 0.0s
logabsgamma | 2 2 0.1s
Jacobian | 1 1 0.8s
withgradient | 3 3 1.3s
NNlib.within_gradient | 2 2 0.0s
ERROR: Some tests did not pass: 357 passed, 1 failed, 15 errored, 3 broken