Tracker.jl icon indicating copy to clipboard operation
Tracker.jl copied to clipboard

Nested derivatives do not work for functions with arrays as argument

Open MariusDrulea opened this issue 9 months ago • 19 comments

See the following code. Pseudocode + explanations:

f = x1 * x2
grad_f = (x2, x1)
fg = sum(grad_f) = x2+x1
grad_fg = [1, 1]

Actual code:

using Tracker

x = param([1, 2])  # Tracked 3-element Vector{Float64}
f(x) = prod(x)
g(x) = gradient(f, x, nest=true)
fg(x) = sum(g(x)[1])
gg = gradient(fg, x)

The value of gg is [-2.0, -0.5] instead of [1, 1]

MariusDrulea avatar Sep 06 '23 13:09 MariusDrulea

Another related example, using sum instead of prod: Pseudocode + explanations:

f = x1 + x2
grad_f = (1, 1)
fg = sum(grad_f) = 2
grad_fg = [0, 0]

Actual code:

using Tracker

x = param([1, 2])  # Tracked 3-element Vector{Float64}
f(x) = sum(x)
g(x) = gradient(f, x, nest=true)
fg(x) = sum(g(x)[1])
gg = gradient(fg, x)

This code gives the following: ERROR: MethodError: no method matching back!(::Float64)

MariusDrulea avatar Sep 06 '23 14:09 MariusDrulea

As a side info, the hessian of the prod function works in AutoGrad.jl:

using AutoGrad

x = Param([1,2,3])		# user declares parameters
p(x) = prod(x)

hess(f,i=1) = grad((x...)->grad(f)(x...)[i])
hess(p, 1)(x)
hess(p, 2)(x)
hess(p, 3)(x)

Returns the correct result:

3-element Vector{Float64}:
 0.0
 3.0
 2.0

3-element Vector{Float64}:
 3.0
 0.0
 1.0

3-element Vector{Float64}:
 2.0
 1.0
 0.0

MariusDrulea avatar Sep 13 '23 08:09 MariusDrulea

Some really good news. If I use 2 separate params the hessian works. Next is to adjust this to work for arrays.

using Tracker

x1 = param(1)
x2 = param(2)

f(x1, x2) = x1*x2
∇f(x1, x2) = gradient(f, x1, x2, nest=true)

H11 = gradient(x1->∇f(x1, x2)[1], x1)[1]
H12 = gradient(x2->∇f(x1, x2)[1], x2)[1]
H21 = gradient(x1->∇f(x1, x2)[2], x1)[1]
H22 = gradient(x2->∇f(x1, x2)[2], x2)[1]

H = [H11 H12; H21 H22]

Result:

2×2 Matrix{Tracker.TrackedReal{Float64}}:
 0.0  1.0
 1.0  0.0

MariusDrulea avatar Sep 13 '23 19:09 MariusDrulea

Higher order derivatives of functions with a single argument, work correctly:

using Tracker

f(x) = sin(cos(x^2))
df(x) = gradient(f, x, nest=true)[1]
d2f(x) = gradient(u->df(u), x, nest=true)[1]
d3f(x) = gradient(u->d2f(u), x, nest=true)[1]

x0 = param(1)
df(x0) # -1.4432122981268867 (tracked)
d2f(x0) # -4.7534826540186135 (tracked)
d3f(x0) # -5.683220233612525 (tracked)

MariusDrulea avatar Sep 14 '23 17:09 MariusDrulea

Some really good news. If I use 2 separate params the hessian works. Next is to adjust this to work for arrays. f(x1, x2) = x1*x2 ∇f(x1, x2) = gradient(f, x1, x2, nest=true) H11 = gradient(x1->∇f(x1, x2)[1], x1)[1]

I tried to replicate this approach with the following code, but yields incorrect results:

∇f = x->gradient(f, x, nest=true)[1]
hess1 = x->gradient(u->∇f(u)[1], x)
hess1([1, 2]) # ([-2.0, 0.0] (tracked),)

MariusDrulea avatar Sep 14 '23 17:09 MariusDrulea

Getting closer. I have manually created the corresponding graph for the gradient and the hessian. The rrules used are correct. There is something wrong with the tracker algo in the "nested" gradients. The extra graph to record the derivatives is perhaps not created correctly.

using Tracker

p=x->prod(x)
q=x->p(x)./x # the hard-coded gradient
jacobian(q, x) # jacobian of the gradient = hessian

Gives the correct result:

Tracked 2×2 Matrix{Float64}:
 0.0  1.0
 1.0  0.0

MariusDrulea avatar Sep 16 '23 09:09 MariusDrulea

I have created a function to print the graph under a specific node: https://github.com/MariusDrulea/Tracker.jl/blob/master/src/Tracker.jl#L140

One can notice the Tracker does not record the methods performed, but the pullback of these methods. This is because we only need the pullbacks when we do back-propagation in the graph.

using Tracker
f(x) = prod(x)
∇f = x->gradient(f, x, nest=true)[1]
hess1 = x->gradient(u->∇f(u)[1], x, nest=true)

h1 = hess1([3, 4])
Tracker.print_graph(stdout, h1[1])

# Prints the following
TrackedData
-data=[-1.3333333333333333, 0.0]
-Tracker=
--isleaf=false
--grad=UndefInitializer()
--Call=
---f
----back
-----func.f=+
-----func.args=([0.0, 0.0], [-1.3333333333333333, -0.0] (tracked))
---args
----nothing
----Tracker=
-----isleaf=false
-----grad=UndefInitializer()
-----Call=
------f
-------back
--------func.f=partial
--------func.args=(Base.RefValue{Base.Broadcast.var"#12#14"{Base.Broadcast.var"#18#20"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#11#13"}, Base.Broadcast.var"#15#16"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}}, Base.Broadcast.var"#25#26"{Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}}, Base.Broadcast.var"#21#22"{Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}}, typeof(/)}, typeof(*)}}(Base.Broadcast.var"#12#14"{Base.Broadcast.var"#18#20"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#11#13"}, Base.Broadcast.var"#15#16"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}}, Base.Broadcast.var"#25#26"{Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}}, Base.Broadcast.var"#21#22"{Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}}, typeof(/)}, typeof(*)}(Base.Broadcast.var"#18#20"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#11#13"}, Base.Broadcast.var"#15#16"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}}, Base.Broadcast.var"#25#26"{Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}}, Base.Broadcast.var"#21#22"{Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}}, typeof(/)}(Base.Broadcast.var"#15#16"{Base.Broadcast.var"#11#13"}(Base.Broadcast.var"#11#13"()), Base.Broadcast.var"#15#16"{Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}}(Base.Broadcast.var"#15#16"{Base.Broadcast.var"#17#19"}(Base.Broadcast.var"#17#19"())), Base.Broadcast.var"#25#26"{Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}}(Base.Broadcast.var"#25#26"{Base.Broadcast.var"#27#28"}(Base.Broadcast.var"#27#28"())), Base.Broadcast.var"#21#22"{Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}}(Base.Broadcast.var"#21#22"{Base.Broadcast.var"#23#24"}(Base.Broadcast.var"#23#24"())), /), *)), [1.0, 0.0] (tracked), 2, 12.0, [3.0, 4.0] (tracked), 1)
------args
-------nothing
-------Tracker=
--------isleaf=false
--------grad=UndefInitializer()
--------Call=
---------f
----------back
-----------func.f=partial
-----------func.args=(Base.RefValue{typeof(+)}(+), [1.0, 0.0] (tracked), 2, [0.0, 0.0], [4.0, 3.0] (tracked))
---------args
----------nothing
----------Tracker=
-----------isleaf=false
-----------grad=UndefInitializer()
-----------Call=
------------f
-------------#714
--------------func.f=getindex
------------args
-------------nothing
-------------nothing
----------nothing
----------nothing
----------Tracker=
-----------isleaf=false
-----------grad=UndefInitializer()
-----------Call=
------------f
-------------back
--------------func.f=#12
--------------func.args=(12.0, [3.0, 4.0] (tracked), 1)
------------args
-------------nothing
-------------Tracker=
--------------isleaf=false
--------------grad=UndefInitializer()
--------------Call=
---------------f
----------------#718
---------------args
----------------Tracker=
-----------------isleaf=true
-----------------grad=[0.0, 0.0]
-----------------Call=
------------------f
------------------args
-------------nothing
-------nothing
-------nothing
-------Tracker=
--------isleaf=false
--------grad=UndefInitializer()
--------Call=
---------f
----------#718
---------args
----------Tracker=
-----------isleaf=true
-----------grad=[0.0, 0.0]
-----------Call=
------------f
------------args
-------nothing

MariusDrulea avatar Sep 20 '23 14:09 MariusDrulea

Thanks for digging! As you can see this package doesn't get a lot of attention, but fixes are very welcome.

I do not know what's going wrong in these cases, I never thought much about how this package handles second derivatives. There is a way to mark rules as only suitable for first derivatives, which the prod rule does not use.

mcabbott avatar Sep 20 '23 15:09 mcabbott

I do suspect some broadcasting rule does not work correctly. If you look in the graph above there is a long chain of broadcasts, which looks weird. Screenshot from 2023-09-20 21-31-34

MariusDrulea avatar Sep 20 '23 18:09 MariusDrulea

This is the graph of the example which works correctly. Looks clean.

using Tracker
x1 = param(1)
x2 = param(2)

f_(x1, x2) = x1*x2
∇f_(x1, x2) = gradient(f_, x1, x2, nest=true)
H12 = gradient(x2->∇f_(x1, x2)[1], x2, nest=true)[1]

##Prints the following
Tracker.print_graph(stdout, H12)
TrackedData
-data=1.0
-Tracker=
--isleaf=false
--grad=0.0
--Call=
---f
----back
-----func.f=+
-----func.args=(0.0, 1.0 (tracked))
---args
----nothing
----Tracker=
-----isleaf=false
-----grad=0.0
-----Call=
------f
-------#229
--------func.b=1
------args
-------Tracker=
--------isleaf=false
--------grad=0.0
--------Call=
---------f
----------back
-----------func.f=partial
-----------func.args=(Base.RefValue{typeof(+)}(+), 1, 2, 0.0, 2.0 (tracked))
---------args
----------nothing
----------nothing
----------nothing
----------nothing
----------Tracker=
-----------isleaf=false
-----------grad=0.0
-----------Call=
------------f
-------------#230
--------------func.a=1
------------args
-------------nothing
-------------Tracker=
--------------isleaf=false
--------------grad=0.0
--------------Call=
---------------f
----------------#718
---------------args
----------------Tracker=
-----------------isleaf=false
-----------------grad=0.0
-----------------Call=
------------------f
-------------------#718
------------------args
-------------------Tracker=
--------------------isleaf=true
--------------------grad=1.0
--------------------Call=
---------------------f
---------------------args
-------nothing

MariusDrulea avatar Sep 20 '23 18:09 MariusDrulea

Thanks for digging! As you can see this package doesn't get a lot of attention, but fixes are very welcome.

I have reviewed the back-propagation algo for both simple and nested gradients and everything seem just correct. So probably is a good idea now to switch to ChainRules as these rules are more robust. Another internal design decision is whether to record the original function or the pullback function in the graph. The algo will be mostly the same in both of these options. Currently the pullback is stored, but this makes the printing and visualization of the graph a bit more difficult (so is debugging). I also don't know if storing the pullback has any advantage over storing the original function, so I would be tempted to store the original function instead and invoke the rrule of this original function only during back-propagation in the graph.

To summarize, 3 actions are needed to improve the robustness and ease of development of this package:

  • switch to ChainRules
  • in the graph, store the original function instead of the pullback
  • document the code and the public functions, generate docs

MariusDrulea avatar Sep 20 '23 19:09 MariusDrulea

Calling the rrule only during the backwards pass would probably not work great, because it'd end up with us recomputing the primal. Did you perhaps mean calling the pullback function rrule returns? I know Zygote and Tracker have not been great at giving those names, but ChainRules is pretty good about doing so. If I call rrule(f, args...), I'll probably get a (primal, pullback::typeof(f_pullback)) back.

ToucheSir avatar Sep 20 '23 20:09 ToucheSir

@ToucheSir you are right, we have to store the pullback, not the original function. This is also clear now to me: https://discourse.julialang.org/t/second-order-derivatives-with-chainrules/103606

Integrating ChainRules is pretty easy: https://github.com/MariusDrulea/Tracker.jl/blob/master/src/Tracker.jl#L91 Second derivative still do not work, because not all operations are tracked and sent to ChainRules, I just added +(x, y) for instance: https://github.com/MariusDrulea/Tracker.jl/blob/master/src/lib/array.jl#L455.

Next is to:

  • continue integration of ChainRules: track some missing operations (like x+y=track(+, x, y), where x and y are TrackedArrays etc
  • make sure all first order derivatives work
  • fix second order derivatives
  • remove obsolete rules defined by Tracker
  • document the code & generate documentation

The following first order derivatives works via ChainRules:

using Tracker
f2(x, y) = prod(x+y)
dx = gradient(f2, [3, 4], [7, 9], nest=true)

Prints:

[ Info: Chainrules for +
[ Info: Chainrules for prod
([13.0, 10.0], [13.0, 10.0])

MariusDrulea avatar Sep 21 '23 14:09 MariusDrulea

Integrating ChainRules is pretty easy: https://github.com/MariusDrulea/Tracker.jl/blob/master/src/Tracker.jl#L91

Not that easy, son. It is easy only for the first order derivatives. I do have to further track operations performed by these first order derivatives, such that the second order derivatives can be called.

MariusDrulea avatar Sep 22 '23 09:09 MariusDrulea

I'm currently stuck as I don't know how to deal with the ChainRules.rrule(s) here.

The logic in Tracker is to define an untracked forward pass and a tracked pullback and the AD engine will do the rest.

  • f(x::TrackedReal) = f(data(x)) # untracked forward pass, data(x) is Float64
  • back(d)=d*∂f(x) # tracked pullback, x is a TrackedReal

What ChainRules offers:

y, back = rrule(sin, x::TrackedReal)

This would yield sin(x) (tracked) and back also tracked. What we want to get is sin(data(x)) (not tracked) and back tracked. So basically I want to be able to modify the forward pass. With diffrules this is easily achievable as we can define our forward and pullback separately, e.g.

sin(x) = sin(data(x)) # untracked
DiffRules.diffrule(sin, x) = cos(x) # tracked, implicitly provided by DiffRules

@ToucheSir, @mcabbott Any idea? Or ping somebody who can help?

MariusDrulea avatar Sep 22 '23 19:09 MariusDrulea

This would yield sin(x) (tracked) and back also tracked. What we want to get is sin(data(x)) (not tracked) and back tracked.

I'm not sure I understand, why shouldn't it be sin(x) (tracked)? If the primal output is not tracked, AD will just stop working for subsequent operations.

ToucheSir avatar Sep 23 '23 17:09 ToucheSir

The tracking of sin(x) is done by the Tracker.jl logic, not by rrule. For sin this is what happens: sin(xs::TrackedReal) = track(sin, xs::TrackedReal)

The definition of track for sin performs like this:

function track(sin, xs::TrackedReal)
  # y is not tracked as we only perform sin over data(xs)
  # the pullback, cos(xs) is tracked as we perform cos on a TrackedReal object - essentially a track function will be called for it, just like for sin
  y, back = sin(data(xs)), cos(xs) # cos(xs) is the pullback provided by DiffRules

  # here we create another TrackedReal object whose data is y and we also record the pullback and the previous node in the graph (tracker.(xs))
  track(Call(back, tracker.(xs)), y) # the tracking of primal y happens here
end

MariusDrulea avatar Sep 23 '23 18:09 MariusDrulea

I have checked Autograd.jl engine and also HIPS/autograd engine. These engines calls the differentiation (rrules) in the backward pass. So does Tracker.jl.

The right thing to do here is to thunk the primal computation for all rrules in ChainRules, or any other implementation avoiding the primal computation. Yes, the primal is needed most of the time with rrules, but not always!

I think we can currently achieve the following:

  1. If only first order derivatives are needed, we can simply use rrules during the forward pass and register the pullback for the backward pass. This way the primal will be computed only once.
  2. If higher order derivatives are needed, we will have to do a primal on the data, while calling rrules on TrackedReal etc on the backward pass. This mean the primal function will be called twice.

Sample code for item 2:

using ChainRules
using ChainRules: rrule
using ChainRulesCore
import Base: +, *

struct _Tracked <: Real
    data::Float64
    f::Any
    _Tracked(data, f) = new(data, f)
    _Tracked(data) = new(data, nothing)
end

function track(f, a, b)
    data, delayed = _forward(f, a, b)
    return _Tracked(data, delayed)
end

function _forward(f, a, b)
    data = f(a.data, b.data) # primal no tracking
    delayed = ()->rrule(f, a, b) # delayed pullback with tracking
    return data, delayed
end

a::_Tracked + b::_Tracked = track(+, a, b)
a::_Tracked * b::_Tracked = track(*, a, b)

##
a = _Tracked(10)
b = _Tracked(20)
tr = a*b
back2 = tr.f()[2]
da, db = back2(_Tracked(1.0))[2:end] 
# result: da =_Tracked(20.0, ...), db =_Tracked(10.0, ...)

MariusDrulea avatar Sep 25 '23 07:09 MariusDrulea

Status:

Test Summary:                | Pass  Fail  Error  Broken  Total   Time
Tracker                      |  357     1     15       3    376  22.0s
  gradtests 1                |   16                          16   0.6s
  gradtests 1.1              |   10                          10   0.4s
  indexing & slicing         |    1                           1   0.1s
  concat                     |  181                         181   2.1s
  getindex (Nabla.jl - #139) |    1                           1   0.0s
  gradtests 2                |   22            2             24   0.9s
  mean                       |    5                           5   0.1s
  maximum                    |    5                           5   0.1s
  minimum                    |    5                           5   0.1s
  gradtests 3                |    9            4             13   0.5s
  transpose                  |   48                          48   0.0s
  conv, 1d                   |    2                    1      3   2.3s
  conv, 2d                   |    2                    1      3   1.1s
  conv, 3d                   |    2                    1      3   8.0s
  pooling                    |    4                           4   2.0s
  equality & order           |   16                          16   0.0s
  reshape                    |    8                           8   0.0s
  Intermediates              |    2                           2   0.0s
  Fallbacks                  |    1                           1   0.0s
  collect                    |                 1              1   0.1s
  Hooks                      |                 1              1   0.1s
  Checkpointing              |    2     1      1              4   0.2s
  Updates                    |    2                           2   0.0s
  Params                     |    1                           1   0.0s
  Forward                    |                 2              2   0.2s
  Custom Sensitivities       |                 1              1   0.1s
  PDMats                     |                 1              1   0.1s
  broadcast                  |    2            1              3   0.0s
  logabsgamma                |    2                           2   0.1s
  Jacobian                   |    1                           1   0.8s
  withgradient               |    3                           3   1.3s
  NNlib.within_gradient      |    2                           2   0.0s
ERROR: Some tests did not pass: 357 passed, 1 failed, 15 errored, 3 broken

MariusDrulea avatar Oct 01 '23 19:10 MariusDrulea