Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Many basic CUDA functions do not work with Zygote

Open clintonTE opened this issue 4 years ago • 23 comments

MWE (Not specific to ones function, just using that as an example)

using Zygote, CUDA
function usecuda(x::CuVector)
  some1s = CUDA.ones(length(x))
  return CUDA.sum(some1s .+ x)
end

function nocuda(x::Vector)
  some1s = ones(length(x))
  return sum(some1s .+ x)
end

function mwe()
  x = rand(5)
  dx = x |> cu

  #verify functions work
  @assert nocuda(x) + usecuda(dx) > 0

  ∇nocuda(x) = gradient((x)->nocuda(x), x)
  ∇usecuda(xcu) = gradient((xcu)->usecuda(xcu), xcu)

  @info "deriv cpu: $(∇nocuda(x))"
  @info "deriv gpu: $(∇usecuda(dx))"
end

mwe()

Output:

[ Info: deriv cpu: ([1.0, 1.0, 1.0, 1.0, 1.0],)
┌ Error: Exception while generating log record in module Main at C:\Users\Clinton\Dropbox\Projects\Capacity\src\Experimentation\mwe.jl:1322
│   exception =
│    this intrinsic must be compiled to be called
│    Stacktrace:
│     [1] macro expansion at C:\Users\Clinton\.julia\packages\Zygote\EjVY4\src\compiler\interface2.jl:0 [inlined]
│     [2] _pullback(::Zygote.Context, ::Core.IntrinsicFunction, ::String, ::Type{Int64}, ::Type{Tuple{Ptr{Int64}}}, ::Ptr{Int64}) at C:\Users\Clinton\.julia\packages\Zygote\EjVY4\src\compiler\interface2.jl:13    
│     [3] getindex at .\atomics.jl:347 [inlined]
│     [4] _pullback(::Zygote.Context, ::typeof(getindex), ::Base.Threads.Atomic{Int64}) at C:\Users\Clinton\.julia\packages\Zygote\EjVY4\src\compiler\interface2.jl:0
│     [5] macro expansion at C:\Users\Clinton\.julia\packages\CUDA\sjcZt\lib\utils\call.jl:37 [inlined]
│     [6] macro expansion at C:\Users\Clinton\.julia\packages\CUDA\sjcZt\lib\cudadrv\libcuda.jl:669 [inlined]
│     [7] macro expansion at C:\Users\Clinton\.julia\packages\CUDA\sjcZt\lib\cudadrv\error.jl:108 [inlined]
│     [8] cuMemsetD32_v2 at C:\Users\Clinton\.julia\packages\CUDA\sjcZt\lib\utils\call.jl:93 [inlined]
│     [9] _pullback(::Zygote.Context, ::typeof(CUDA.cuMemsetD32_v2), ::CuPtr{UInt32}, ::UInt32, ::Int64) at C:\Users\Clinton\.julia\packages\Zygote\EjVY4\src\compiler\interface2.jl:0
│     [10] #set!#5 at C:\Users\Clinton\.julia\packages\CUDA\sjcZt\lib\cudadrv\memory.jl:372 [inlined]
│     [11] set! at C:\Users\Clinton\.julia\packages\CUDA\sjcZt\lib\cudadrv\memory.jl:365 [inlined]
│     [12] _pullback(::Zygote.Context, ::typeof(CUDA.Mem.set!), ::CuPtr{UInt32}, ::UInt32, ::Int64) at C:\Users\Clinton\.julia\packages\Zygote\EjVY4\src\compiler\interface2.jl:0
│     [13] fill! at C:\Users\Clinton\.julia\packages\CUDA\sjcZt\src\array.jl:364 [inlined]
│     [14] _pullback(::Zygote.Context, ::typeof(fill!), ::CuArray{Float32,1,Nothing}, ::Int64) at C:\Users\Clinton\.julia\packages\Zygote\EjVY4\src\compiler\interface2.jl:0
│     [15] ones at C:\Users\Clinton\.julia\packages\CUDA\sjcZt\src\array.jl:350 [inlined]
│     [16] _pullback(::Zygote.Context, ::typeof(CUDA.ones), ::Type{Float32}, ::Int64) at C:\Users\Clinton\.julia\packages\Zygote\EjVY4\src\compiler\interface2.jl:0
│     [17] adjoint at C:\Users\Clinton\.julia\packages\Zygote\EjVY4\src\lib\lib.jl:175 [inlined]
│     [18] _pullback at C:\Users\Clinton\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:47 [inlined]
│     [19] ones at C:\Users\Clinton\.julia\packages\CUDA\sjcZt\src\array.jl:352 [inlined]
│     [20] _pullback(::Zygote.Context, ::typeof(CUDA.ones), ::Int64) at C:\Users\Clinton\.julia\packages\Zygote\EjVY4\src\compiler\interface2.jl:0
│     [21] usecuda at C:\Users\Clinton\Dropbox\Projects\Capacity\src\Experimentation\mwe.jl:1302 [inlined]
│     [22] _pullback(::Zygote.Context, ::typeof(usecuda), ::CuArray{Float32,1,Nothing}) at C:\Users\Clinton\.julia\packages\Zygote\EjVY4\src\compiler\interface2.jl:0
│     [23] #16 at C:\Users\Clinton\Dropbox\Projects\Capacity\src\Experimentation\mwe.jl:1319 [inlined]
│     [24] _pullback(::Zygote.Context, ::var"#16#20", ::CuArray{Float32,1,Nothing}) at C:\Users\Clinton\.julia\packages\Zygote\EjVY4\src\compiler\interface2.jl:0
│     [25] _pullback(::Function, ::CuArray{Float32,1,Nothing}) at C:\Users\Clinton\.julia\packages\Zygote\EjVY4\src\compiler\interface.jl:38
│     [26] pullback(::Function, ::CuArray{Float32,1,Nothing}) at C:\Users\Clinton\.julia\packages\Zygote\EjVY4\src\compiler\interface.jl:44
│     [27] gradient(::Function, ::CuArray{Float32,1,Nothing}) at C:\Users\Clinton\.julia\packages\Zygote\EjVY4\src\compiler\interface.jl:53
│     [28] ∇usecuda at C:\Users\Clinton\Dropbox\Projects\Capacity\src\Experimentation\mwe.jl:1319 [inlined]
│     [29] macro expansion at .\logging.jl:322 [inlined]
│     [30] mwe() at C:\Users\Clinton\Dropbox\Projects\Capacity\src\Experimentation\mwe.jl:1322
│     [31] top-level scope at C:\Users\Clinton\Dropbox\Projects\Capacity\src\Experimentation\mwe.jl:1325
│     [32] include_string(::Module, ::String, ::String) at .\loading.jl:1080
│     [33] (::Atom.var"#220#225"{String,String})() at C:\Users\Clinton\.julia\packages\Atom\isnka\src\eval.jl:174
│     [34] withpath(::Atom.var"#220#225"{String,String}, ::String) at C:\Users\Clinton\.julia\packages\CodeTools\VsjEq\src\utils.jl:30
│     [35] withpath(::Function, ::String) at C:\Users\Clinton\.julia\packages\Atom\isnka\src\eval.jl:9
│     [36] #219 at C:\Users\Clinton\.julia\packages\Atom\isnka\src\eval.jl:171 [inlined]
│     [37] with_logstate(::Atom.var"#219#224"{String,String}, ::Base.CoreLogging.LogState) at .\logging.jl:398
│     [38] with_logger at .\logging.jl:505 [inlined]
│     [39] #218 at C:\Users\Clinton\.julia\packages\Atom\isnka\src\eval.jl:170 [inlined]
│     [40] hideprompt(::Atom.var"#218#223"{String,String}) at C:\Users\Clinton\.julia\packages\Atom\isnka\src\repl.jl:127
│     [41] macro expansion at C:\Users\Clinton\.julia\packages\Media\ItEPc\src\dynamic.jl:24 [inlined]
│     [42] evalall(::String, ::Nothing, ::String) at C:\Users\Clinton\.julia\packages\Atom\isnka\src\eval.jl:160
│     [43] invokelatest(::Any, ::Any, ::Vararg{Any,N} where N; kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at .\essentials.jl:712
│     [44] invokelatest(::Any, ::Any, ::Vararg{Any,N} where N) at .\essentials.jl:711
│     [45] macro expansion at C:\Users\Clinton\.julia\packages\Atom\isnka\src\eval.jl:41 [inlined]
│     [46] (::Atom.var"#188#189")() at .\task.jl:358
└ @ Main C:\Users\Clinton\Dropbox\Projects\Capacity\src\Experimentation\mwe.jl:1322

Edit: Just a note that the reduction operation (the sum) does seem to work, but other functions (I tried zeros, ones, and cumprod) all break Zygote.

clintonTE avatar Jul 09 '20 18:07 clintonTE

@clintonTE to turn this issue into something actionable, can you list all the CUDA functions (together with a corresponding MRE) that, to your knowledge, do not seem to work with Zygote.jl?

AzamatB avatar Jul 10 '20 00:07 AzamatB

Sure, hope this helps-

  • [ ] zeros
  • [ ] ones
  • [ ] fill
  • [ ] cumprod
  • [ ] device2host
  • [ ] reduce (two examples)
using Zygote, CUDA

function mwe(N=5)
  x = CUDA.rand(N)

  M=CUDA.rand(N,N^2)

  brokenfunctions = (
    shouldwork=()->CUDA.sum(M),
    ones=()->CUDA.ones(N),
    zeros=()->CUDA.zeros(N),
    fill=()->CUDA.fill(2f0, N),
    cumprod=()->CUDA.cumprod(M, dims=2),
    device2host=()-> CUDA.rand(N,N) |> Matrix |>sum,
    reduce=()->CUDA.reduce(+,M),
    reduce2=()->CUDA.reduce(hcat,[M,M])
    )

  testcuda(f,x) = CUDA.sum(f() .+ x)
  for n ∈ propertynames(brokenfunctions)
    print("testing $n: ")
    f = brokenfunctions[n]
    #check function works
    @assert testcuda(f,x) > 0
    try
      ∇testcuda(x) = gradient((x)->testcuda(f,x), x)
      grad = ∇testcuda(x)
      @info "passed with val $grad"
    catch err
      @warn " function $n is broken ($err)"
    end
  end
end

mwe()

Output:

testing shouldwork: [ Info: passed with val (Float32[1.0, 1.0, 1.0, 1.0, 1.0],)
testing ones: ┌ Warning:  function ones is broken (ErrorException("this intrinsic must be compiled to be called"))
└ @ Main C:\Users\Clinton\Dropbox\Projects\Capacity\src\Experimentation\mwe.jl:1334
testing zeros: ┌ Warning:  function zeros is broken (ErrorException("this intrinsic must be compiled to be called"))
└ @ Main C:\Users\Clinton\Dropbox\Projects\Capacity\src\Experimentation\mwe.jl:1334
testing fill: ┌ Warning:  function fill is broken (ErrorException("this intrinsic must be compiled to be called"))
└ @ Main C:\Users\Clinton\Dropbox\Projects\Capacity\src\Experimentation\mwe.jl:1334
testing cumprod: ┌ Warning:  function cumprod is broken (ErrorException("this intrinsic must be compiled to be called"))
└ @ Main C:\Users\Clinton\Dropbox\Projects\Capacity\src\Experimentation\mwe.jl:1334
testing device2host: ┌ Warning:  function device2host is broken (ErrorException("Mutating arrays is not supported"))
└ @ Main C:\Users\Clinton\Dropbox\Projects\Capacity\src\Experimentation\mwe.jl:1334
testing reduce: ┌ Warning:  function reduce is broken (ErrorException("type CuArray has no field f"))
└ @ Main C:\Users\Clinton\Dropbox\Projects\Capacity\src\Experimentation\mwe.jl:1334
testing reduce2: ┌ Warning:  function reduce2 is broken (ErrorException("Mutating arrays is not supported"))
└ @ Main C:\Users\Clinton\Dropbox\Projects\Capacity\src\Experimentation\mwe.jl:1334

clintonTE avatar Jul 10 '20 01:07 clintonTE

For reduce(hcat, ..., xref #501. Does CUDA.reduce differ from reduce here?

For cumprod, xref #282

mcabbott avatar Jul 10 '20 07:07 mcabbott

Updated checklist:

  • [ ] zeros
  • [ ] ones
  • [ ] fill
  • [ ] cumprod
  • [ ] device2host
  • [ ] reduce +
  • [X] reduce hcat

Current output:

[ Info: passed with val (Float32[1.0, 1.0, 1.0, 1.0, 1.0],)
testing ones: ┌ Warning:  function ones is broken (ErrorException("this intrinsic must be compiled to be called"))
└ @ Main C:\Users\Clinton\Dropbox\Projects\Capacity\src\Experimentation\mwe.jl:210
testing zeros: ┌ Warning:  function zeros is broken (ErrorException("this intrinsic must be compiled to be called"))
└ @ Main C:\Users\Clinton\Dropbox\Projects\Capacity\src\Experimentation\mwe.jl:210
testing fill: ┌ Warning:  function fill is broken (ErrorException("this intrinsic must be compiled to be called"))
└ @ Main C:\Users\Clinton\Dropbox\Projects\Capacity\src\Experimentation\mwe.jl:210
testing cumprod: ┌ Warning:  function cumprod is broken (ErrorException("this intrinsic must be compiled to be called"))
└ @ Main C:\Users\Clinton\Dropbox\Projects\Capacity\src\Experimentation\mwe.jl:210
testing device2host: ┌ Warning:  function device2host is broken (ErrorException("Mutating arrays is not supported"))
└ @ Main C:\Users\Clinton\Dropbox\Projects\Capacity\src\Experimentation\mwe.jl:210
testing reduce: ┌ Warning:  function reduce is broken (ErrorException("this intrinsic must be compiled to be called"))
└ @ Main C:\Users\Clinton\Dropbox\Projects\Capacity\src\Experimentation\mwe.jl:210
testing reduce2: [ Info: passed with val (Float32[50.0, 50.0, 50.0, 50.0, 50.0],)

clintonTE avatar Oct 12 '20 05:10 clintonTE

CUDA.rand_logn is another function which Zygote is not happy with. However exp(CUDA.randn) is fine.

using Zygote

function mweNotWorking(x)
    return sum(x * CUDA.rand_logn(2, mean=0.0, stddev=1.0))
end

function mweWorking(x)
    return sum(x * exp.(CUDA.randn(2, mean=0.0, stddev=1.0)))
end

mweNotWorking'(1.0)
mweWorking'(1.0)

mweNotWorking'(1.0) gives:

ERROR: LoadError: this intrinsic must be compiled to be called

sheevy avatar Dec 21 '20 22:12 sheevy

mweWorking'(1.0) is not working on [email protected] and [email protected], same error!

ngphuoc avatar Feb 03 '21 03:02 ngphuoc

Another example on a fresh installation (CUDA v3.12.0, Zygote v0.6.43):

using CUDA
using Zygote

function increment!(x)
    i = threadIdx().x
    x[i] += 1
    return nothing
end

function call_kernel(x)
    @cuda threads=4 increment!(x)
    return sum(x)
end

x = CuArray([1., 2., 3., 4.])

gradient(x -> call_kernel(x), x)

Error:

this intrinsic must be compiled to be called

Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:0 [inlined]
  [2] _pullback(::Zygote.Context, ::Core.IntrinsicFunction, ::String, ::Type{Int64}, ::Type{Tuple{Ptr{Int64}}}, ::Ptr{Int64})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:9
  [3] _pullback
    @ ./atomics.jl:358 [inlined]
  [4] _pullback(ctx::Zygote.Context, f::typeof(getindex), args::Base.Threads.Atomic{Int64})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:0
  [5] _pullback (repeats 2 times)
    @ ~/.julia/packages/CUDA/DfvRa/lib/utils/threading.jl:25 [inlined]
  [6] _pullback
    @ ~/.julia/packages/CUDA/DfvRa/src/compiler/gpucompiler.jl:7 [inlined]
  [7] _pullback(ctx::Zygote.Context, f::typeof(CUDA.device_properties), args::CuDevice)
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:0
  [8] _pullback
    @ ~/.julia/packages/CUDA/DfvRa/src/compiler/gpucompiler.jl:51 [inlined]
  [9] _pullback(::Zygote.Context, ::CUDA.var"##CUDACompilerTarget#206", ::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(CUDA.CUDACompilerTarget), ::CuDevice)
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:0
 [10] _pullback
    @ ~/.julia/packages/CUDA/DfvRa/src/compiler/gpucompiler.jl:51 [inlined]
 [11] _pullback
    @ ~/.julia/packages/CUDA/DfvRa/src/compiler/execution.jl:296 [inlined]
 [12] _pullback(::Zygote.Context, ::CUDA.var"##cufunction#221", ::Nothing, ::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(cufunction), ::typeof(increment!), ::Type{Tuple{CuDeviceVector{Float64, 1}}})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:0
 [13] _pullback
    @ ~/.julia/packages/CUDA/DfvRa/src/compiler/execution.jl:293 [inlined]
 [14] _pullback(::Zygote.Context, ::typeof(cufunction), ::typeof(increment!), ::Type{Tuple{CuDeviceVector{Float64, 1}}})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:0
 [15] macro expansion
    @ ~/.julia/packages/CUDA/DfvRa/src/compiler/execution.jl:102 [inlined]
 [16] _pullback
    @ ./In[1]:11 [inlined]
 [17] _pullback(ctx::Zygote.Context, f::typeof(call_kernel), args::CuArray{Float64, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:0
 [18] _pullback
    @ ./In[1]:17 [inlined]
 [19] _pullback(ctx::Zygote.Context, f::var"#1#2", args::CuArray{Float64, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:0
 [20] _pullback(f::Function, args::CuArray{Float64, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface.jl:34
 [21] pullback(f::Function, args::CuArray{Float64, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface.jl:40
 [22] gradient(f::Function, args::CuArray{Float64, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface.jl:75
 [23] top-level scope
    @ In[1]:17
 [24] eval
    @ ./boot.jl:373 [inlined]
 [25] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
    @ Base ./loading.jl:1196

renatobellotti avatar Aug 03 '22 07:08 renatobellotti

Running CUDA kernels directly is not expected to work. Using regular array operations would work though. ie CUDA.zeros -> gpu(zeros()), CUDA.randn -> cu(randn), and a kernel launched as @cuda requires an rrule

DhairyaLGandhi avatar Aug 03 '22 07:08 DhairyaLGandhi

Thanks for your answer. I'm fairly new to Julia and just experimenting whether I could use it for a project.

and a kernel launched as @cuda requires an rrule

What does this mean?

renatobellotti avatar Aug 03 '22 07:08 renatobellotti

When you take the gradient of a function call, Zygote will recursively enter each function in the call stack until it encounters a primitive function that it knows how to differentiate or an expression it cannot recurse (like the internals of a CUDA kernel). Every primitive function has a "rule" that tells Zygote how to differentiate it, and Zygote will put these rules together via the chain rule to differentiate the full call stack. An rrule is one of these rules for reverse-mode AD like Zygote.

Defining an rrule for your kernel means that you will provide Zygote with the gradient definition for your kernel so that it does not try to recurse into it and throw the error you posted above. You can learn more about how to define an rrule in the ChainRulesCore.jl docs.

darsnack avatar Aug 03 '22 15:08 darsnack

When you take the gradient of a function call, Zygote will recursively enter each function in the call stack until it encounters a primitive function that it knows how to differentiate or an expression it cannot recurse (like the internals of a CUDA kernel). Every primitive function has a "rule" that tells Zygote how to differentiate it, and Zygote will put these rules together via the chain rule to differentiate the full call stack. An rrule is one of these rules for reverse-mode AD like Zygote.

Defining an rrule for your kernel means that you will provide Zygote with the gradient definition for your kernel so that it does not try to recurse into it and throw the error you posted above. You can learn more about how to define an rrule in the ChainRulesCore.jl docs.

I see, thank you very much for the clarifications. It seems to me like Zygote is not what I need for my project, then.

renatobellotti avatar Aug 04 '22 09:08 renatobellotti

Just in case it wasn't clear before, you can use Zygote with CUDA without manually defining rrules. But you need to use the array programming interface instead of the kernel programming interface.

darsnack avatar Aug 04 '22 13:08 darsnack

Just in case it wasn't clear before, you can use Zygote with CUDA without manually defining rrules. But you need to use the array programming interface instead of the kernel programming interface.

Thank you for highlighting this. Unfortunately, I do need not to apply the same operations to every element in the array. Therefore I think I really need custom kernels.

renatobellotti avatar Aug 08 '22 14:08 renatobellotti

You may want to describe the kinds of kernels you are working with then, because the array interface is not limited to what you describe. It may be that an alternative AD like ForwardDiff or Enzyme is a better fit for your kernel, so unless you need features Zygote provides it may be fruitful to look at those.

ToucheSir avatar Aug 08 '22 18:08 ToucheSir

Interesting, perhaps I have some misconceptions about the array interface. Thanks for pointing this out.

One of the things I'm interested is Siddon's algorithm. Simplifying, the algorithm gives you the path travelled by a ray entering a voxel grid until a target voxel is reached. For that, you basically do the following steps:

  • Calculate intersection points of the ray with the grid points.
  • Sort the intersection points.
  • Sum up the path lengths between the points until you reach the target voxel.

All the bullet points should be run for many target voxels in parallel.

I see how I could implement the bullet points in a parallel way, but parallelising over the target voxels seems tricky to me without writing custom kernels. I'm more than happy to hear your thoughts about doing this with the array interface. :)

renatobellotti avatar Aug 09 '22 11:08 renatobellotti

Interesting algorithm! I agree this does not seem possible (or at least not easy) to do with the array interface. Enzyme seems like the best option for differentiating GPU kernels these days, just be prepared for it to be a little rough around the edges. If you're able to express the algorithm as a Tullio expression like https://discourse.julialang.org/t/using-zygote-with-custom-cuda-kernels/65688/3, that's another possible avenue to explore.

ToucheSir avatar Aug 09 '22 14:08 ToucheSir

Thanks for the hint, I'll keep Tullio in mind. However, I don't think it's flexible enough for my needs, especially because Siddon's algorithm is only one step in my project.

renatobellotti avatar Aug 11 '22 12:08 renatobellotti

I think I found another example of non-functioning CUDA calls:

using Zygote

function my_loss(v)
    # This works:
    # l = sum(v)
    # This does not work:
    l = reduce(+, v)
    return l
end

v = cu([1., 2.])
Zygote.gradient(my_loss, v)

Calling the reduction outside of Zygote works, but does not return a CUDA array. I guess this makes sense and is not a problem because we just need to copy a single float.

Does somebody know how to execute the reduction correctly?

renatobellotti avatar Aug 19 '22 15:08 renatobellotti

Looking at the call chain, this would require a rule for either reduce(::typeof(+),::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})::Float32 or mapreduce(::typeof(identity),::typeof(+),::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})::Float32. ChainRules is probably the best place for that.

ToucheSir avatar Aug 19 '22 22:08 ToucheSir

Thanks for the hint, I've opened an issue in the ChainRules repo.

renatobellotti avatar Aug 20 '22 05:08 renatobellotti

Encountered the CUDA.ones() couldn't take derivatives. What's the current solution around?

jackisdesigning avatar Jan 11 '23 06:01 jackisdesigning

workaround for those operations having zero gradients is adding the following to your code

using ChainRulesCore
@ignore_derivatives CUDA.ones(::Any...)

CarloLucibello avatar Jan 11 '23 07:01 CarloLucibello

Documentation for the above: https://juliadiff.org/ChainRulesCore.jl/stable/api.html#Ignoring-gradients. I'd go so far as to say this is the "official" recommendation for methods which aren't marked non-diff in libraries Zygote can't directly depend on.

ToucheSir avatar Jan 11 '23 15:01 ToucheSir