ForwardDiff.jl
ForwardDiff.jl copied to clipboard
Thread Safety of ForwardDiff.GradientConfig
So I'm not sure if it was intended to be, but I'm finding that if I generate a configuration as
cfg = ForwardDiff.GradientConfig(...)
and then use this in defining a gradient with ForwardDiff.gradient!, the resultant object is not thread safe. If I omit the cfg (and take an associated performance hit), everything works as expected.
EDIT: This shows the problem. On my laptop with 4 threads, I get a disagreement in the results:
using ForwardDiff
using Random
using Printf
using LinearAlgebra
n = 10^6;
function f(x)
return x[1]^2
end
cfg = ForwardDiff.GradientConfig(f, zeros(Float64, 1));
∇f1 = x -> ForwardDiff.gradient(f, x, cfg);
∇f2 = x -> ForwardDiff.gradient(f, x);
Random.seed!(100);
x = randn(n);
results_single = zeros(n);
results_thread_v1 = zeros(n);
results_thread_v2 = zeros(n);
for j in 1:n
results_single[j] = ∇f1([x[j]])[1]
end
Threads.@threads for j in 1:n
results_thread_v1[j] = ∇f1([x[j]])[1]
end
Threads.@threads for j in 1:n
results_thread_v2[j] = ∇f2([x[j]])[1]
end
@show norm(results_single .- results_thread_v1);
@show norm(results_single .- results_thread_v2);
I'm seeing this as well. The results of the last 2 lines in the above example are on my machine:
julia> @show norm(results_single .- results_thread_v1);
norm(results_single .- results_thread_v1) = 66.82892270098449
julia> @show norm(results_single .- results_thread_v2);
norm(results_single .- results_thread_v2) = 0.0
The cfg object is used as a cache / work buffer so sharing it between threads is not possible.
The
cfgobject is used as a cache / work buffer so sharing it between threads is not possible.
It seems then that cfg is only safe to use interactively. If a package internally uses ForwardDiff, then to maintain thread-safety, they need to avoid it.
Would it make sense for GradientConfig to allocate n buffers if n threads are present?
It seems then that cfg is only safe to use interactively. If a package internally uses ForwardDiff, then to maintain thread-safety, they need to avoid it.
I wouldn't say so, you must just ensure that the same cfg object is not used concurrently.
The problem using n buffers for n threads is it will only work with static scheduling. It needs a different pattern, or it needs to be managed at the user end. I'd be interested if there is a generic solution to this since I've run into this and similar problems several times.
My solution for a long time has been to use temporary variables Lux-style in nested Named tuples (In Lux it would be the st state). But I would prefer temporaries to be managed more locally.
I have started experimenting with thread-safe objectpools but this feels a bit like a hack to me.