TaylorDiff.jl icon indicating copy to clipboard operation
TaylorDiff.jl copied to clipboard

Error with nested calls to TaylorDiff

Open landreman opened this issue 5 months ago • 3 comments

Hi, I have the following nested AD problem that I'd like to get to work with TaylorDiff. Starting with a vector-valued function f(params, x) for scalar x, take a high-order derivative with respect to x and evaluate for a specific value of x. Then apply some reduction function to the result to obtain a scalar-valued function g(params). Finally I want to evaluate the gradient d g / d params. Example:

# Arbitrary function:
f(params, x) = [params[1] * x^3 + params[2], params[2] * sin(x - params[1]), sqrt(x + params[2])]

function g(params)
    closure(x) = f(params, x)
    some_x = 0.7
    d3f_dx3 = TaylorDiff.derivative(closure, some_x, Val(3))
    return sum(d3f_dx3)
end

some_params = [1.3, 2.1]

@show g(some_params)  # Fine, gives 6.095380076578732
TaylorDiff.derivative(g, some_params, [1.0, 0.0], Val(1))  # First element of the gradient

Results, using julia 1.11.5 and TaylorDiff v0.3.3:

ERROR: MethodError: *(::TaylorScalar{Float64, 1}, ::TaylorScalar{Float64, 3}) is ambiguous.

Candidates:
  *(a::TaylorScalar, b::Number)
    @ TaylorDiff ~/.julia/packages/TaylorDiff/qw5aY/src/primitive.jl:119
  *(a::Number, b::TaylorScalar)
    @ TaylorDiff ~/.julia/packages/TaylorDiff/qw5aY/src/primitive.jl:114

Possible fix, define
  *(::TaylorScalar, ::TaylorScalar)

Stacktrace:
 [1] f(params::Vector{TaylorScalar{Float64, 1}}, x::TaylorScalar{Float64, 3})
   @ Main ./REPL[4]:1
 [2] (::var"#closure#1"{Vector{TaylorScalar{Float64, 1}}})(x::TaylorScalar{Float64, 3})
   @ Main ./REPL[5]:2
 [3] derivatives
   @ ~/.julia/packages/TaylorDiff/qw5aY/src/derivative.jl:41 [inlined]
 [4] derivative
   @ ~/.julia/packages/TaylorDiff/qw5aY/src/derivative.jl:16 [inlined]
 [5] g(params_in::Vector{TaylorScalar{Float64, 1}})
   @ Main ./REPL[5]:4
 [6] derivatives
   @ ~/.julia/packages/TaylorDiff/qw5aY/src/derivative.jl:41 [inlined]
 [7] derivative(f::Function, x::Vector{Float64}, l::Vector{Float64}, p::Val{1})
   @ TaylorDiff ~/.julia/packages/TaylorDiff/qw5aY/src/derivative.jl:17
 [8] top-level scope
   @ REPL[8]:1

Any idea how this could be made to work?

While Zygote-over-TaylorDiff does work for this problem, @btime shows it is much faster to use ForwardDiff-over-ForwardDiff (probably due to the overhead of reverse mode), so I imagine TaylorDiff-over-TaylorDiff (or ForwardDiff-over-TaylorDiff) might be even faster due to the high-order inner derivative. Thanks.

landreman avatar Jun 15 '25 11:06 landreman

The purpose of TaylorDiff is that you wouldn't want to nest calls to TaylorDiff, you instead just use a higher order derivative. If you do the TaylorDiff derivative with Val(4) that's the same as the extra nesting here, but with less overhead. So that's what's recommended.

Maybe there can be an overload to TaylorDiff.derivative(closure, some_x, Val(N)) where if some_x is a Taylor dual number then it calls internally Val(N+1) and munges the result around. I presume that wouldn't be hard to add, but it doesn't exist right now.

ChrisRackauckas avatar Jun 15 '25 12:06 ChrisRackauckas

Image

I agree this is a reasonable use of TaylorDiff, although the previous work is more focused on Zygote-over-TaylorDiff and/or Enzyme-over-TaylorDiff, since it's usually faster to use reverse-mode when there are a lot of parameters to compute gradient.

TaylorDiff could in principle support nesting, but to rigorously handle that one need to implement a tag system to avoid perturbation confusion like in ForwardDiff, which I haven't got enough time to implement.

Before I finally get to that, maybe you could try out Enzyme and see if Enzyme-over-TaylorDiff is fast enough for your use case?

tansongchen avatar Jun 16 '25 19:06 tansongchen

Thanks for this info. If it's going to take substantial effort to implement a tag system to get around this error, I'll look for other approaches in the meantime. Enzyme-over-TaylorDiff does work well for the little example above, though I haven't had success with it in my real larger-scale application. (After several minutes of trying to evaluate, there is StackOverflowError - probably suboptimal code structure on my part).

Regarding Chris's comment: in the real application, the map from ∂^n f / ∂x^n to g is a more complicated nonlinear function (not just sum()) that we'll want to be able to change easily. Just taking higher derivatives of f wouldn't capture the derivative of that additional map, while the nested-AD approach would be convenient.

landreman avatar Jun 17 '25 21:06 landreman