AbstractFFTs.jl
AbstractFFTs.jl copied to clipboard
differentiating `rfft` on `CuArray` leads to error
Redirected from https://github.com/FluxML/Zygote.jl/issues/1406
gradient (based on rrule
) on rfft
(in FFTW.jl, doing Fast Fourier Transform for real-valued entities) leads to error. On the other hand, fft
on CuArray
, or rfft
on CPU array, both run fine.
Pointed out by @ToucheSir, this comes from the fact that the AbstractFFTs rrule
for rfft
is unconditionally creating a CPU array and using it in https://github.com/JuliaMath/AbstractFFTs.jl/blob/v1.3.1/ext/AbstractFFTsChainRulesCoreExt.jl#L33-L40.
(@v1.8) pkg> activate --temp
Activating new project at `/tmp/jl_yR7NT1`
(jl_yR7NT1) pkg> add CUDA, Flux, FFTW
Updating registry at `~/.julia/registries/General.toml`
Resolving package versions...
Installed Optimisers ─ v0.2.17
Installed CUDA ─────── v4.1.1
Updating `/tmp/jl_yR7NT1/Project.toml`
[052768ef] + CUDA v4.1.1
[7a1cc6ca] + FFTW v1.6.0
[587475ba] + Flux v0.13.14
Updating `/tmp/jl_yR7NT1/Manifest.toml`
[621f4979] + AbstractFFTs v1.3.1
[7d9f7c33] + Accessors v0.1.28
[79e6a3ab] + Adapt v3.6.1
[dce04be8] + ArgCheck v2.3.0
[a9b6321e] + Atomix v0.1.0
[ab4f0b2a] + BFloat16s v0.4.2
[198e06fe] + BangBang v0.3.37
[9718e550] + Baselet v0.1.1
[fa961155] + CEnum v0.4.2
[052768ef] + CUDA v4.1.1
[1af6417a] + CUDA_Runtime_Discovery v0.1.1
[082447d4] + ChainRules v1.48.0
[d360d2e6] + ChainRulesCore v1.15.7
[9e997f8a] + ChangesOfVariables v0.1.6
[bbf7d656] + CommonSubexpressions v0.3.0
[34da2185] + Compat v4.6.1
[a33af91c] + CompositionsBase v0.1.1
[187b0558] + ConstructionBase v1.5.1
[6add18c4] + ContextVariablesX v0.1.3
[9a962f9c] + DataAPI v1.14.0
[864edb3b] + DataStructures v0.18.13
[e2d170a0] + DataValueInterfaces v1.0.0
[244e2a9f] + DefineSingletons v0.1.2
[163ba53b] + DiffResults v1.1.0
[b552c78f] + DiffRules v1.13.0
[ffbed154] + DocStringExtensions v0.9.3
[e2ba6199] + ExprTools v0.1.9
[7a1cc6ca] + FFTW v1.6.0
[cc61a311] + FLoops v0.2.1
[b9860ae5] + FLoopsBase v0.1.1
[1a297f60] + FillArrays v0.13.10
[587475ba] + Flux v0.13.14
[9c68100b] + FoldsThreads v0.1.1
[f6369f11] + ForwardDiff v0.10.35
[069b7b12] + FunctionWrappers v1.1.3
[d9f16b24] + Functors v0.4.3
[0c68f7d7] + GPUArrays v8.6.5
[46192b85] + GPUArraysCore v0.1.4
[61eb1bfa] + GPUCompiler v0.18.0
[7869d1d1] + IRTools v0.4.9
[22cec73e] + InitialValues v0.3.1
[3587e190] + InverseFunctions v0.1.8
[92d709cd] + IrrationalConstants v0.2.2
[82899510] + IteratorInterfaceExtensions v1.0.0
[692b3bcd] + JLLWrappers v1.4.1
[b14d175d] + JuliaVariables v0.2.4
[63c18a36] + KernelAbstractions v0.9.1
[929cbde3] + LLVM v4.17.1
[2ab3a3ac] + LogExpFunctions v0.3.23
[d8e11817] + MLStyle v0.4.17
[f1d291b0] + MLUtils v0.4.1
[1914dd2f] + MacroTools v0.5.10
[128add7d] + MicroCollections v0.1.4
[e1d29d7a] + Missings v1.1.0
[872c559c] + NNlib v0.8.19
[a00861dc] + NNlibCUDA v0.2.7
[77ba4419] + NaNMath v1.0.2
[71a1bf82] + NameResolution v0.1.5
[0b1bfda6] + OneHotArrays v0.2.3
[3bd65402] + Optimisers v0.2.17
[bac558e1] + OrderedCollections v1.4.1
[21216c6a] + Preferences v1.3.0
[8162dcfd] + PrettyPrint v0.2.0
[33c8b6b6] + ProgressLogging v0.1.4
[74087812] + Random123 v1.6.0
[e6cf234a] + RandomNumbers v1.5.3
[c1ae055f] + RealDot v0.1.0
[189a3867] + Reexport v1.2.2
[ae029012] + Requires v1.3.0
[efcf1570] + Setfield v1.1.1
[605ecd9f] + ShowCases v0.1.0
[699a6c99] + SimpleTraits v0.9.4
[66db9d55] + SnoopPrecompile v1.0.3
[a2af1166] + SortingAlgorithms v1.1.0
[276daf66] + SpecialFunctions v2.2.0
[171d559e] + SplittablesBase v0.1.15
[90137ffa] + StaticArrays v1.5.19
[1e83bf80] + StaticArraysCore v1.4.0
[82ae8749] + StatsAPI v1.5.0
[2913bbd2] + StatsBase v0.33.21
[09ab397b] + StructArrays v0.6.15
[3783bdb8] + TableTraits v1.0.1
[bd369af6] + Tables v1.10.1
[a759f4b9] + TimerOutputs v0.5.22
[28d57a85] + Transducers v0.4.75
[013be700] + UnsafeAtomics v0.2.1
[d80eeb9a] + UnsafeAtomicsLLVM v0.1.0
[e88e6eb3] + Zygote v0.6.59
[700de1a5] + ZygoteRules v0.2.3
[02a925ec] + cuDNN v1.0.2
⌅ [4ee394cb] + CUDA_Driver_jll v0.4.0+2
[76a88914] + CUDA_Runtime_jll v0.4.0+2
[62b44479] + CUDNN_jll v8.8.1+0
[f5851436] + FFTW_jll v3.3.10+0
[1d5cc7b8] + IntelOpenMP_jll v2018.0.3+2
⌅ [dad2f222] + LLVMExtra_jll v0.0.18+0
[856f044c] + MKL_jll v2022.2.0+0
[efe28fd5] + OpenSpecFun_jll v0.5.5+0
[0dad84c5] + ArgTools v1.1.1
[56f22d72] + Artifacts
[2a0f44e3] + Base64
[ade2ca70] + Dates
[8bb1440f] + DelimitedFiles
[8ba89e20] + Distributed
[f43a241f] + Downloads v1.6.0
[7b1f6079] + FileWatching
[9fa8497b] + Future
[b77e0a4c] + InteractiveUtils
[4af54fe1] + LazyArtifacts
[b27032c2] + LibCURL v0.6.3
[76f85450] + LibGit2
[8f399da3] + Libdl
[37e2e46d] + LinearAlgebra
[56ddb016] + Logging
[d6f4376e] + Markdown
[a63ad114] + Mmap
[ca575930] + NetworkOptions v1.2.0
[44cfe95a] + Pkg v1.8.0
[de0858da] + Printf
[3fa0cd96] + REPL
[9a3f8284] + Random
[ea8e919c] + SHA v0.7.0
[9e88b42a] + Serialization
[6462fe0b] + Sockets
[2f01184e] + SparseArrays
[10745b16] + Statistics
[fa267f1f] + TOML v1.0.0
[a4e569a6] + Tar v1.10.1
[8dfed614] + Test
[cf7118a7] + UUIDs
[4ec0a83e] + Unicode
[e66e0078] + CompilerSupportLibraries_jll v1.0.1+0
[deac9b47] + LibCURL_jll v7.84.0+0
[29816b5a] + LibSSH2_jll v1.10.2+0
[c8ffd9c3] + MbedTLS_jll v2.28.0+0
[14a3606d] + MozillaCACerts_jll v2022.2.1
[4536629a] + OpenBLAS_jll v0.3.20+0
[05823500] + OpenLibm_jll v0.8.1+0
[83775a58] + Zlib_jll v1.2.12+3
[8e850b90] + libblastrampoline_jll v5.1.1+0
[8e850ede] + nghttp2_jll v1.48.0+0
[3f19e933] + p7zip_jll v17.4.0+0
Info Packages marked with ⌅ have new versions available but compatibility constraints restrict them from upgrading. To see why use `status --outdated -m`
Precompiling project...
5 dependencies successfully precompiled in 54 seconds. 100 already precompiled.
julia> using CUDA, FFTW, Flux
julia> x = CUDA.randn(3)
3-element CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}:
0.19325934
0.55793864
0.08928435
julia> gradient(()->sum(abs.(fft(x))), Flux.params(x)) # this works
Grads(...)
julia> gradient(()->sum(abs.(rfft(x))), Flux.params(x))
ERROR: GPU compilation of broadcast_kernel(CUDA.CuKernelContext, CuDeviceVector{ComplexF32, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(/), Tuple{Base.Broadcast.Extruded{CuDeviceVector{ComplexF32, 1}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Vector{Int64}, Tuple{Bool}, Tuple{Int64}}}}, Int64) in world 32592 failed
KernelError: passing and using non-bitstype argument
Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(/), Tuple{Base.Broadcast.Extruded{CuDeviceVector{ComplexF32, 1}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Vector{Int64}, Tuple{Bool}, Tuple{Int64}}}}, which is not isbits:
.args is of type Tuple{Base.Broadcast.Extruded{CuDeviceVector{ComplexF32, 1}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Vector{Int64}, Tuple{Bool}, Tuple{Int64}}} which is not isbits.
.2 is of type Base.Broadcast.Extruded{Vector{Int64}, Tuple{Bool}, Tuple{Int64}} which is not isbits.
.x is of type Vector{Int64} which is not isbits.
Stacktrace:
[1] check_invocation(job::GPUCompiler.CompilerJob)
@ GPUCompiler ~/.julia/packages/GPUCompiler/anMCs/src/validation.jl:101
[2] macro expansion
@ ~/.julia/packages/GPUCompiler/anMCs/src/driver.jl:154 [inlined]
[3] macro expansion
@ ~/.julia/packages/TimerOutputs/LHjFw/src/TimerOutput.jl:253 [inlined]
[4] macro expansion
@ ~/.julia/packages/GPUCompiler/anMCs/src/driver.jl:152 [inlined]
[5] emit_julia(job::GPUCompiler.CompilerJob; validate::Bool)
@ GPUCompiler ~/.julia/packages/GPUCompiler/anMCs/src/utils.jl:83
[6] emit_julia
@ ~/.julia/packages/GPUCompiler/anMCs/src/utils.jl:77 [inlined]
[7] compile(job::GPUCompiler.CompilerJob, ctx::LLVM.Context)
@ CUDA ~/.julia/packages/CUDA/N71Iw/src/compiler/compilation.jl:105
[8] #203
@ ~/.julia/packages/CUDA/N71Iw/src/compiler/compilation.jl:100 [inlined]
[9] JuliaContext(f::CUDA.var"#203#204"{GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}})
@ GPUCompiler ~/.julia/packages/GPUCompiler/anMCs/src/driver.jl:76
[10] compile
@ ~/.julia/packages/CUDA/N71Iw/src/compiler/compilation.jl:99 [inlined]
[11] actual_compilation(cache::Dict{UInt64, Any}, key::UInt64, cfg::GPUCompiler.CompilerConfig{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}, ft::Type, tt::Type, world::UInt64, compiler::typeof(CUDA.compile), linker::typeof(CUDA.link))
@ GPUCompiler ~/.julia/packages/GPUCompiler/anMCs/src/cache.jl:184
[12] cached_compilation(cache::Dict{UInt64, Any}, cfg::GPUCompiler.CompilerConfig{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}, ft::Type, tt::Type, compiler::Function, linker::Function)
@ GPUCompiler ~/.julia/packages/GPUCompiler/anMCs/src/cache.jl:163
[13] macro expansion
@ ~/.julia/packages/CUDA/N71Iw/src/compiler/execution.jl:310 [inlined]
[14] macro expansion
@ ./lock.jl:223 [inlined]
[15] cufunction(f::GPUArrays.var"#broadcast_kernel#28", tt::Type{Tuple{CUDA.CuKernelContext, CuDeviceVector{ComplexF32, 1}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Tuple{Base.OneTo{Int64}}, typeof(/), Tuple{Base.Broadcast.Extruded{CuDeviceVector{ComplexF32, 1}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{Vector{Int64}, Tuple{Bool}, Tuple{Int64}}}}, Int64}}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ CUDA ~/.julia/packages/CUDA/N71Iw/src/compiler/execution.jl:306
[16] cufunction
@ ~/.julia/packages/CUDA/N71Iw/src/compiler/execution.jl:303 [inlined]
[17] macro expansion
@ ~/.julia/packages/CUDA/N71Iw/src/compiler/execution.jl:104 [inlined]
[18] #launch_heuristic#244
@ ~/.julia/packages/CUDA/N71Iw/src/gpuarrays.jl:17 [inlined]
[19] _copyto!
@ ~/.julia/packages/GPUArrays/XR4WO/src/host/broadcast.jl:65 [inlined]
[20] copyto!
@ ~/.julia/packages/GPUArrays/XR4WO/src/host/broadcast.jl:46 [inlined]
[21] copy
@ ~/.julia/packages/GPUArrays/XR4WO/src/host/broadcast.jl:37 [inlined]
[22] materialize
@ ./broadcast.jl:860 [inlined]
[23] (::AbstractFFTs.AbstractFFTsChainRulesCoreExt.var"#rfft_pullback#6"{UnitRange{Int64}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Vector{Int64}, Int64})(ȳ::CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer})
@ AbstractFFTs.AbstractFFTsChainRulesCoreExt ~/.julia/packages/AbstractFFTs/0uOAT/ext/AbstractFFTsChainRulesCoreExt.jl:40
[24] ZBack
@ ~/.julia/packages/Zygote/TSj5C/src/compiler/chainrules.jl:211 [inlined]
[25] Pullback
@ ~/.julia/packages/AbstractFFTs/0uOAT/src/definitions.jl:62 [inlined]
[26] Pullback
@ ./REPL[8]:1 [inlined]
[27] (::Zygote.Pullback{Tuple{var"#5#6"}, Tuple{Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(rfft), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.Pullback{Tuple{typeof(ndims), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.ZBack{AbstractFFTs.AbstractFFTsChainRulesCoreExt.var"#rfft_pullback#6"{UnitRange{Int64}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Vector{Int64}, Int64}}, Zygote.ZBack{ChainRules.var"#:_pullback#275"{Tuple{Int64, Int64}}}}}, Zygote.var"#4160#back#1438"{Zygote.var"#1434#1437"{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(abs), CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1169"{Tuple{Nothing, Nothing}}}}, Zygote.var"#2841#back#683"{Zygote.var"#map_back#677"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4128#back#1421"{Zygote.var"#bc_fwd_back#1409"{1, CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 1, CUDA.Mem.DeviceBuffer}, Tuple{CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}}, Val{1}}}}}, Zygote.var"#1982#back#200"{typeof(identity)}}}, Zygote.var"#1955#back#190"{Zygote.var"#186#189"{Zygote.Context{true}, GlobalRef, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/TSj5C/src/compiler/interface2.jl:0
[28] (::Zygote.var"#118#119"{Zygote.Params{Zygote.Buffer{Any, Vector{Any}}}, Zygote.Pullback{Tuple{var"#5#6"}, Tuple{Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(rfft), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.Pullback{Tuple{typeof(ndims), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.ZBack{AbstractFFTs.AbstractFFTsChainRulesCoreExt.var"#rfft_pullback#6"{UnitRange{Int64}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float32, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, Vector{Int64}, Int64}}, Zygote.ZBack{ChainRules.var"#:_pullback#275"{Tuple{Int64, Int64}}}}}, Zygote.var"#4160#back#1438"{Zygote.var"#1434#1437"{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(abs), CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.var"#1982#back#200"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1169"{Tuple{Nothing, Nothing}}}}, Zygote.var"#2841#back#683"{Zygote.var"#map_back#677"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}, Zygote.var"#2138#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4128#back#1421"{Zygote.var"#bc_fwd_back#1409"{1, CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 1, CUDA.Mem.DeviceBuffer}, Tuple{CuArray{ComplexF32, 1, CUDA.Mem.DeviceBuffer}}, Val{1}}}}}, Zygote.var"#1982#back#200"{typeof(identity)}}}, Zygote.var"#1955#back#190"{Zygote.var"#186#189"{Zygote.Context{true}, GlobalRef, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, Zygote.Context{true}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/TSj5C/src/compiler/interface.jl:389
[29] gradient(f::Function, args::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
@ Zygote ~/.julia/packages/Zygote/TSj5C/src/compiler/interface.jl:97
[30] top-level scope
@ REPL[8]:1
[31] top-level scope
@ ~/.julia/packages/CUDA/N71Iw/src/initialization.jl:163