ForwardDiff2.jl
ForwardDiff2.jl copied to clipboard
Unexplained allocations
(in #3)
A rather scathing benchmark:
julia> using ForwardDiff2: dualrun
julia> @btime rand()
3.199 ns (0 allocations: 0 bytes)
0.8383875026881109
julia> @btime dualrun(()->rand())
930.333 ns (5 allocations: 112 bytes)
0.7136361983030461
So I tried to track allocations for
julia> foo() = for i=1:10^6; dualrun(()->rand()); end
Which gave me:
- @inline function overdub(ctx::TaggedCtx{T}, f, args...) where {T}
287281248 if length(args) > 4
- return Cassette.recurse(ctx, f, args...)
- end
- # find the position of the dual number with the highest
- # precedence (dominant) tag
23520864 idx = find_dual(Tag{T}, args...)
0 if idx === 0
- # none of the arguments are dual
70562592 Cassette.recurse(ctx, f, args...)
- else
The file containing find_dual itself shows no allocations. There are in fact no allocations anywhere else within ForwardDiff2.
@oxinabox any help?
I am on holidays. Poke me in a 3+ days
3+ days @oxinabox 😆
@inline function overdub(ctx::TaggedCtx{T}, f, arg1, arg2, arg3, arg4, arg5, args...) where {T}
return Cassette.recurse(ctx, args)
end
@inline function overdub(ctx::TaggedCtx{T}, f, args) where {T}
# find the position of the dual number with the highest
# precedence (dominant) tag
idx = find_dual(Tag{T}, args...)
if idx === 0
# none of the arguments are dual
Cassette.recurse(ctx, f, args...)
else
...
end
end
Or (since recursing is the default):
for nargs in 1:4
arg_names = ntuple(n->Symbol(:arg, n), nargs)
@eval @inline function overdub(ctx::TaggedCtx{T}, f, $(arg_names...)) where {T}
# find the position of the dual number with the highest
# precedence (dominant) tag
idx = find_dual(Tag{T}, $(arg_names...))
if idx === 0
# none of the arguments are dual
Cassette.recurse(ctx, f, $(arg_names...))
else
...
end
end
end
Which will remove the runtime splatting entirely
@YingboMa was this solved?
It hasn't been solved yet, but there is an improvement.
julia> using ForwardDiff2: dualrun
julia> using BenchmarkTools
julia> @btime dualrun(()->rand())
453.505 ns (1 allocation: 32 bytes)
0.6022948398960193