CellListMap.jl
CellListMap.jl copied to clipboard
ReverseDiff gradients
Hi, I wanted to try getting some gradients from a function involving map_pairwise
as I saw on the docs that automatic differentiation was available. The issue is my input consists in hundreds of thousands of variables (a 3x~1e5 Matrix) and my output is a loss score, so ForwardDiff is quite inefficient. I tried just replacing it with ReverseDiff but I got this
ERROR: LoadError: ArgumentError: cannot reinterpret `ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}}` as
`SVector{3, ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}}}`, type `SVector{3, ReverseDiff.TrackedReal
{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}}}` is not a bits type
and with Zygote
ERROR: LoadError: Compiling Tuple{CellListMap.var"##set_number_of_batches!#38", Bool, typeof(CellListMap.set_number_of_batches!), CellList{3, Float64}, Tuple{Int64, Int64}}: try/c
atch is not supported.
My question would be first, if it is even possible to use reverse-mode differentiation with CellListMap? If so, is it possible to add some examples to the docs on how to do so? The type-conversion trick used for ForwardDiff does not work. Thanks in advance for your help.
I don't remember having tried reverse diff. I'll take a look at it asap.
Here is the situation. Currently, you cannot really reverse-differentiate easily through the whole construction of the cell lists, but you can bypass all that and differentiate the computation of the objective function, if the coordinates are provided (redundantly) as a closure.
For example, consider the following simple function, which sums the squared distance between particle coordinates:
sum_sqr(d2, s) = s += d2
Which could be mapped to all pairs of particles (here constructed as matrices of size (3,N)
, with:
coordinates = rand(3,1000)
box = Box([1,1,1], 0.05)
cl = CellList(coordinates,box)
map_pairwise( (_, _, _, _, d2, s) -> sum_sqr(d2, s), 0.0, box, cl)
(the (x,y,i,j...)
parameters are omitted because they are not used in the sum_sqr
function).
This can be forward-differentiated as shown in the manual, but reverse differentiation does not work, basically because the construction of the cell lists requires mutation of arrays, and the current infrastructures do not support that (maybe Enzyme could do it, but there is a simpler alternative).
The trick is to define a function that uses only the indexes of the particles, and compute the property of interest, from the particles, using the coordinates provided in a closure. That is, the above function would be defined as:
sum_sqr(i, j, s, coordinates) = s += sum(abs2, @views(coordinates[:,i] - coordinates[:,j]))
Note that x
here is the complete set of coordinates, which now will be closed over in the function call to map_parwise
, and not correspond to any of the inner input parameters. That is:
coordinates = rand(3,1000)
box = Box([1,1,1], 0.05)
cl = CellList(coordinates, box)
map_pairwise( (_, _, i, j, _, s) -> sum_sqr(i, j, s, coordinates), 0.0, box, cl)
Now we use the i
and j
internal parameters, but we close over the coordinates. The call to map_pairwise
can now be differentiated with respect to the coordinates, because the construction of the cell lists is not part of the equation. For that, we enclose the complete call into a function that receives box
and cl
as parameters:
julia> function sum_sqr(coordinates, box, cl)
sum_sqr = map_pairwise!(
(_, _, i, j, _, sum_sqr) -> sum_sqr += sum(abs2, @views(coordinates[:,i] - coordinates[:,j])),
zero(eltype(coordinates)), box, cl,
)
return sum_sqr
end
sum_sqr (generic function with 3 methods)
And this can be both forward and reverse- differentiated:
julia> using ForwardDiff, ReverseDiff
julia> coordinates = rand(3,1000);
julia> box = Box([1,1,1], 0.05);
julia> cl = CellList(coordinates, box);
julia> gr = ReverseDiff.gradient( (x) -> sum_sqr(x,box,cl), coordinates)
3×1000 Matrix{Float64}:
-0.0875518 0.0 0.0 -0.0848635 0.0 0.0 0.0 … 0.0 0.0 0.00620111 -0.0335964 0.0 0.0234671
-0.0442765 0.0 0.0 -0.0575914 0.0 0.0 0.0 0.0 0.0 -0.014573 0.0649681 0.0 0.0607623
0.000192838 0.0 0.0 -0.0623673 0.0 0.0 0.0 0.0 0.0 0.0316766 0.0451708 0.0 -0.00635728
julia> gr = ForwardDiff.gradient( (x) -> sum_sqr(x,box,cl), coordinates)
3×1000 Matrix{Float64}:
-0.0875518 0.0 0.0 -0.0848635 0.0 0.0 0.0 … 0.0 0.0 0.00620111 -0.0335964 0.0 0.0234671
-0.0442765 0.0 0.0 -0.0575914 0.0 0.0 0.0 0.0 0.0 -0.014573 0.0649681 0.0 0.0607623
0.000192838 0.0 0.0 -0.0623673 0.0 0.0 0.0 0.0 0.0 0.0316766 0.0451708 0.0 -0.00635728
As expected, reverse differentiation is much faster here:
julia> revg(coordiantex, box, cl) = ReverseDiff.gradient( (x) -> sum_sqr(x,box,cl), coordinates)
revg (generic function with 1 method)
julia> forg(coordinates, box, cl) = ForwardDiff.gradient( (x) -> sum_sqr(x,box,cl), coordinates)
forg (generic function with 1 method)
julia> @btime revg($coordinates, $box, $cl);
550.999 μs (11188 allocations: 567.33 KiB)
julia> @btime forg($coordinates, $box, $cl);
60.827 ms (73754 allocations: 24.44 MiB)
Hi, thanks a lot for your answer! I've been able to replicate your example for some similar data. But sometimes I get ERROR: LoadError: UndefRefError: access to undefined reference
and others it Segfaults and quits Julia.
...
main at julia (unknown line)
__libc_start_main at /lib64/libc.so.6 (unknown line)
unknown function (ip: 0x401098)
Allocations: 282841011 (Pool: 282818657; Big: 22354); GC: 103
Segmentation fault (core dumped)
My particular application is a histogram (or two-point function) that I would like to differentiate through but I keep getting the Segfault. Could it be because of the size of my data (600k coordinates)? Though that would still not explain why it segfaults with 1k coordinates.
Segfaults usually are related to some corrupted memory access, and not because of the size of the data. When the data is too big to fit in memory you get OutOfMemory
errors.
Without further details, I can't speculate on what may be going on there.
One thing that may be related is that when running CellLIstMap in parallel, there is a machinery to avoid concurrency among threads, which I don't know if the differentiation routines can handle that properly. One test is to run the calculations without parallelization.
I've made a small test here, and the results in that simple example are the same. But in your case you are probably updating a shared histogram
array as the output, and the example above is a scalar function, so not exactly the same thing:
julia> function sum_sqr(coordinates, box, cl; parallel=true)
sum_sqr = map_pairwise!(
(_, _, i, j, _, sum_sqr) -> sum_sqr += sum(abs2, @views(coordinates[:,i] - coordinates[:,j])),
zero(eltype(coordinates)), box, cl; parallel=parallel
)
return sum_sqr
end
sum_sqr (generic function with 1 method)
julia> coordinates = rand(3,5000);
julia> box = Box([1,1,1], 0.05);
julia> cl = CellList(coordinates, box);
julia> ReverseDiff.gradient(x -> sum_sqr(x, box, cl; parallel=true), coordinates)
3×5000 Matrix{Float64}:
-0.0344465 0.0 -0.00210207 -0.0551787 -0.0152113 0.0472379 … 0.0680135 -0.0428575 0.0 0.0715126 -2.16583
0.0527802 0.0 -0.0280015 -0.0101965 -0.0427886 -0.0660361 0.10064 0.00204247 0.0 -0.0447712 0.0216365
0.0416164 0.0 0.103765 -0.0231282 -0.0654189 -0.00256454 -0.0578568 -3.82477 0.0 0.10979 0.0366878
julia> ReverseDiff.gradient(x -> sum_sqr(x, box, cl; parallel=false), coordinates)
3×5000 Matrix{Float64}:
-0.0344465 0.0 -0.00210207 -0.0551787 -0.0152113 0.0472379 … 0.0680135 -0.0428575 0.0 0.0715126 -2.16583
0.0527802 0.0 -0.0280015 -0.0101965 -0.0427886 -0.0660361 0.10064 0.00204247 0.0 -0.0447712 0.0216365
0.0416164 0.0 0.103765 -0.0231282 -0.0654189 -0.00256454 -0.0578568 -3.82477 0.0 0.10979 0.0366878
julia> ReverseDiff.gradient(x -> sum_sqr(x, box, cl; parallel=false), coordinates) ≈
ReverseDiff.gradient(x -> sum_sqr(x, box, cl; parallel=true), coordinates)
true
In fact, in a simple Histogram-like function, ReverseDiff fails, even serially:
julia> function hist(coordinates, box, cl; parallel=true)
h = map_pairwise!(
(_, _, i, j, d2, h) -> begin
if sqrt(d2) < box.cutoff / 2
h[1] += sum(abs2, @views(coordinates[:,i] - coordinates[:,j]))
else
h[2] += sum(abs2, @views(coordinates[:,i] - coordinates[:,j]))
end
return h
end,
zeros(eltype(coordinates), 2), box, cl; parallel=parallel
)
return h
end
hist (generic function with 1 method)
julia> hist(coordinates, box, cl)
2-element Vector{Float64}:
406.8648388535888
3223.091284798098
julia> ReverseDiff.gradient(x -> hist(x, box, cl; parallel=false), coordinates)
ERROR: DimensionMismatch: new dimensions (2, 10000) must be consistent with array size 10000
If you change that to compute the histogram by passing the hist
variable in the closure, than you certainly can get corrupted memory accesses among threads. And I actually couldn't make it work.
Maybe one alternative is to compute each bin of the histogram independently, as that would provide scalar returns to the function. I'm not an specialist in autodiff to be more precise about what to suggest there.
I see, however it seems you get a different kind of error as I do (you get DimensionMismatch
and I get UndefRefError
). My complete function actually computes a scalar in the end since it computes the MSE to a reference 2pt function, but I guess the same undefined access applies.
Can you share something about your code, such to at least we can localize the issue?
Sure! here are my "core" functions. I believe the issue can be reproduced with randomly distributed particles in a box since my dataset is a bit large (though the "before" dataset I shared in a past issue may work too):
bin_edges = 10 .^range(-2, stop = log10(50), length=11)
positions = 2000. .* rand(3, 600000)
box_size = [2e3 for _ = 1:3]
box = Box(box_size, 5.)
cl = CellList(positions, box)
function coordinate_separation(a, b, box_size)
delta = abs(a - b)
return (delta > 0.5*box_size ? delta - box_size : delta)*sign(a-b)
end
function diff_build_histogram!(i, j ,hist, coordinates, bin_edges, box_size)
d2 = sum(abs2, coordinate_separation.(view(coordinates, : , i), view(coordinates, :, j), box_size))
ibin = searchsortedlast(bin_edges, sqrt(d2))
if (ibin > 0) && ibin <= length(bin_edges)
hist[ibin] += 1
end #if
return hist
end
function loss(positions, box, cl, bin_edges, box_size)
hist = zeros(Int,size(bin_edges,1)-1);
println("Counting pairs...")
# Run calculation
map_pairwise!(
(_, _, i, j, _, hist) -> diff_build_histogram!(i, j, hist, positions, bin_edges, box_size),
hist, box, cl; show_progress = true
)
println("Done")
N = size(positions,2)
hist = hist / (N * (N - 1))
norm = @. (4/3) * π * (bin_edges[2:end]^3 -bin_edges[1:end-1]^3) / (box_size[1] * box_size[2] * box_size[3])
hist ./= norm
mean(abs.(hist - xi_ref)) # I think for testing purposes xi_ref can be 0.
end #func
ReverseDiff.gradient((x) -> loss(x, box, cl, bin_edges, box_size), positions)
Sorry, actually in my last test it seems to work (I did deactivate parallelization). All gradients seem to be are 0 but that may be because the histogram is just not differentiable.
I would try to compute a single count (of one bin) in a regular scalar variable to see how that works.
Then, if that works, maybe it is possible to create the histogram with an immutable structure (a Svector, for example).
Just to add, if I compute a single bin of the histogram, the differentiation apparently works, but returns, all zeros, as you observed. I'm not sure if this is correct:
julia> using CellListMap, LinearAlgebra
julia> function hist(coordinates, box, cl; parallel=true)
h = map_pairwise!(
(_, _, i, j, _, h) -> begin
d = norm(@views(coordinates[:,i] - coordinates[:,j]))
if d < box.cutoff / 2
h += 1
#else
# h[2] += 1
end
return h
end,
0, box, cl; parallel=parallel
)
return h
end
hist (generic function with 1 method)
julia> coordinates = rand(3,1000);
julia> box = Box([1,1,1], 0.05);
julia> cl = CellList(coordinates, box);
julia> hist(coordinates, box, cl)
24
julia> all(==(0), ReverseDiff.gradient(x -> hist(x, box, cl; parallel=false), coordinates))
true
julia> all(==(0), ForwardDiff.gradient(x -> hist(x, box, cl; parallel=false), coordinates))
true
(example fixed @dforero0896)
Indeed it seems to work. The issue of the zeros is just that this "exact" way of histogramming is not differentiable. An approximate histogram could be built with something like
function diff_build_histogram!(i, j ,hist, coordinates, bin_widths, box_size, bin_centers)
d2 = sum(abs2, coordinate_separation.(view(coordinates, : , i), view(coordinates, :, j), box_size))
hist .+= exp(-((sqrt(d2) .- bin_centers) ./ bin_widths).^2)
return hist
end
So it is clear how the end product depends on the coordinates. Thanks for your help!
Yes, cool, I was thinking about that problem. Exactly, the histogram has a zero gradient because no infinitesimal move of of the particles will cause a particle to change from one bin to the other. No only the derivative is discontinuous, but mostly it is zero.
The problem of obtaining a differentiable distribution is indeed interesting. Thanks for posting.
I will update the docs with some examples that came out of this discussion, and will close the issue when I do that, thank you very much for the feedback. It will be useful for others to know to apply ReverseDiff here.
Glad my question was helpful. There are some other packages that have implemented differentiable histogramming in other ways. May be useful for someone looking into this too.