Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

LoadError: Non-differentiable function Core.Intrinsics.lshr_int

Open drrmmng opened this issue 3 years ago • 2 comments

Hi,

here are some approaches that perform a sorted search. Each row in x should be searched for an the corresponding column of s.

All get_index functions give the correct answer. However, only the first two work with gradient, get_indices_3 throws an error:

LoadError: Non-differentiable function Core.Intrinsics.lshr_int

If I run get_indices_3 on the GPU, the whole Julia session crashes, no idea if that is an issue with Zygote.jl or CUDA.jl.

using BenchmarkTools
using Flux
using Zygote

todevice = cpu
# todevice = gpu

x = [
    1 2 3 4 5
    6 7 8 9 10
    0.1 0.2 0.3 0.4 0.5
] |> todevice

s = sort(rand(10, 3) .* 10, dims=1) |> todevice

# %%
function get_indices_1(s, x)
    buffer = Zygote.Buffer(x, typeof(firstindex(x)))

    for var_index in axes(x, 1)
        buffer[var_index, :] = searchsortedlast.(Ref(view(s, :, var_index)), view(x, var_index, :))
    end

    copy(buffer)
end

@btime get_indices_1(s, x)
@btime gradient((s, x) -> sum(get_indices_1(s, x)), s, x)

# %%
function get_indices_2(s, x)
    buffer = Zygote.Buffer(x, typeof(firstindex(x)))

    for var_index in axes(x, 1)
        reducer(x) = searchsortedlast(view(s, :, var_index), x)
        buffer[var_index, :] = reducer.(view(x, var_index, :))
    end

    copy(buffer)
end

@btime get_indices_2(s, x)
@btime gradient((s, x) -> sum(get_indices_1(s, x)), s, x)

# %%
function get_indices_3(s, x)
    searchsortedlast.(eachcol(s), x)
end

@btime get_indices_3(s, x)
@btime gradient((s, x) -> sum(get_indices_3(s, x)), s, x)

Versions: julia version 1.7.1 [052768ef] CUDA v3.8.0 [587475ba] Flux v0.12.9 [e88e6eb3] Zygote v0.6.34

drrmmng avatar Feb 06 '22 01:02 drrmmng

You can fix the error with ChainRulesCore.@non_differentiable searchsortedlast(x, y), which ideally should be added here: https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/nondiff.jl#L380 . The immediate error is in fact [5] (::typeof(∂(>>>)))(Δ::Int64) and bit-shift functions like >>> should probably also be non-differentiable.

Since your function returns only indices not continuous values, it could also be marked non-differentiable, which would be even faster.

mcabbott avatar Feb 11 '22 03:02 mcabbott

Thank you!

I used the following code and on CPU is does work without errors. Unfortunately, if I try to run it on the GPU, the whole Julia session still crashes.

And, strangely enough, but possibly unrelated: My whole model (with a workaround to not make it crash) does converge on CPU and diverges on GPU, reliably.

using BenchmarkTools
using Flux
using Zygote
using ChainRulesCore

# todevice = cpu
todevice = gpu

x = [
    1 2 3 4 5
    6 7 8 9 10
    0.1 0.2 0.3 0.4 0.5
] |> todevice

s = sort(rand(10, 3) .* 10, dims=1) |> todevice


function get_indices_3(s, x)
    searchsortedlast.(eachcol(s), x)
end

@non_differentiable get_indices_3(s, x)

@btime get_indices_3(s, x)
@btime gradient((s, x) -> sum(get_indices_3(s, x)), s, x)

drrmmng avatar Feb 11 '22 09:02 drrmmng