Zygote.jl
Zygote.jl copied to clipboard
LoadError: Non-differentiable function Core.Intrinsics.lshr_int
Hi,
here are some approaches that perform a sorted search.
Each row in x should be searched for an the corresponding column of s.
All get_index functions give the correct answer.
However, only the first two work with gradient, get_indices_3 throws an error:
LoadError: Non-differentiable function Core.Intrinsics.lshr_int
If I run get_indices_3 on the GPU, the whole Julia session crashes, no idea if that is an issue with Zygote.jl or CUDA.jl.
using BenchmarkTools
using Flux
using Zygote
todevice = cpu
# todevice = gpu
x = [
1 2 3 4 5
6 7 8 9 10
0.1 0.2 0.3 0.4 0.5
] |> todevice
s = sort(rand(10, 3) .* 10, dims=1) |> todevice
# %%
function get_indices_1(s, x)
buffer = Zygote.Buffer(x, typeof(firstindex(x)))
for var_index in axes(x, 1)
buffer[var_index, :] = searchsortedlast.(Ref(view(s, :, var_index)), view(x, var_index, :))
end
copy(buffer)
end
@btime get_indices_1(s, x)
@btime gradient((s, x) -> sum(get_indices_1(s, x)), s, x)
# %%
function get_indices_2(s, x)
buffer = Zygote.Buffer(x, typeof(firstindex(x)))
for var_index in axes(x, 1)
reducer(x) = searchsortedlast(view(s, :, var_index), x)
buffer[var_index, :] = reducer.(view(x, var_index, :))
end
copy(buffer)
end
@btime get_indices_2(s, x)
@btime gradient((s, x) -> sum(get_indices_1(s, x)), s, x)
# %%
function get_indices_3(s, x)
searchsortedlast.(eachcol(s), x)
end
@btime get_indices_3(s, x)
@btime gradient((s, x) -> sum(get_indices_3(s, x)), s, x)
Versions: julia version 1.7.1 [052768ef] CUDA v3.8.0 [587475ba] Flux v0.12.9 [e88e6eb3] Zygote v0.6.34
You can fix the error with ChainRulesCore.@non_differentiable searchsortedlast(x, y), which ideally should be added here: https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/nondiff.jl#L380 . The immediate error is in fact [5] (::typeof(∂(>>>)))(Δ::Int64) and bit-shift functions like >>> should probably also be non-differentiable.
Since your function returns only indices not continuous values, it could also be marked non-differentiable, which would be even faster.
Thank you!
I used the following code and on CPU is does work without errors. Unfortunately, if I try to run it on the GPU, the whole Julia session still crashes.
And, strangely enough, but possibly unrelated: My whole model (with a workaround to not make it crash) does converge on CPU and diverges on GPU, reliably.
using BenchmarkTools
using Flux
using Zygote
using ChainRulesCore
# todevice = cpu
todevice = gpu
x = [
1 2 3 4 5
6 7 8 9 10
0.1 0.2 0.3 0.4 0.5
] |> todevice
s = sort(rand(10, 3) .* 10, dims=1) |> todevice
function get_indices_3(s, x)
searchsortedlast.(eachcol(s), x)
end
@non_differentiable get_indices_3(s, x)
@btime get_indices_3(s, x)
@btime gradient((s, x) -> sum(get_indices_3(s, x)), s, x)