NeuralPDE.jl icon indicating copy to clipboard operation
NeuralPDE.jl copied to clipboard

GPU Compatibility Issue: Compilation Error with Complex-Valued Data in LuxCUDA Broadcasting Kernel

Open RomanSahakyan03 opened this issue 3 months ago • 8 comments

Bug Description


When attempting to solve a neural network optimization problem on a GPU using Lux and LuxCUDA packages in Julia, a GPU compilation error occurs.

Steps to Reproduce

  • Define a neural network architecture using Lux and LuxCUDA packages.
  • Set up the optimization problem with specified optimizer and solver.
  • Attempt to solve the optimization problem on a GPU.

Expected Behavior

The optimization problem should be solved without errors, utilizing GPU acceleration provided by the LuxCUDA package. Observed Behavior

The GPU compilation of MethodInstance for broadcasting fails with a KernelError, specifically mentioning a non-bitstype argument issue. Code Snippet

using Lux, LuxCUDA, ComponentArrays, Random

# Define neural network architecture
const gpud = gpu_device()
rng = Random.default_rng()
Random.seed!(rng, 0)

inner = 16
chain = Chain(Dense(1, inner, tanh; init_weight = (rng, a...) -> kaiming_normal(rng, ComplexF64, a...)),
              Dense(inner, inner, tanh; init_weight = (rng, a...) -> kaiming_normal(rng, ComplexF64, a...)), 
              Dense(inner, inner, tanh; init_weight = (rng, a...) -> kaiming_normal(rng, ComplexF64, a...)), 
              Dense(inner, 9; init_weight = (rng, a...) -> kaiming_normal(rng, ComplexF64, a...)))
ps = Lux.setup(rng, chain)[1]
ps = ps |> ComponentArray |> gpud .|> ComplexF64
ComponentVector{ComplexF64, CuArray{ComplexF64, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(layer_1 = ViewAxis(1:32, Axis(weight = ViewAxis(1:16, ShapedAxis((16, 1))), bias = ViewAxis(17:32, ShapedAxis((16, 1))))), layer_2 = ViewAxis(33:304, Axis(weight = ViewAxis(1:256, ShapedAxis((16, 16))), bias = ViewAxis(257:272, ShapedAxis((16, 1))))), layer_3 = ViewAxis(305:576, Axis(weight = ViewAxis(1:256, ShapedAxis((16, 16))), bias = ViewAxis(257:272, ShapedAxis((16, 1))))), layer_4 = ViewAxis(577:729, Axis(weight = ViewAxis(1:144, ShapedAxis((9, 16))), bias = ViewAxis(145:153, ShapedAxis((9, 1))))))}}}(layer_1 = (weight = ComplexF64[0.9429705142974854 + 0.1339227557182312im; 1.5250688791275024 + 0.12390123307704926im; … ; 0.5579001307487488 - 0.35648801922798157im; 0.9500746726989746 - 0.20232219994068146im;;], bias = ComplexF64[0.0 + 0.0im; 0.0 + 0.0im; … ; 0.0 + 0.0im; 0.0 + 0.0im;;]), layer_2 = (weight = ComplexF64[0.059399593621492386 + 0.025146976113319397im 0.1949768215417862 + 0.24093444645404816im … 0.02936505898833275 - 0.1352502554655075im 0.5359262824058533 - 0.491843044757843im; -0.07353769242763519 + 0.050222259014844894im -0.23228807747364044 + 0.01972302421927452im … -0.1863224059343338 + 0.030169149860739708im -0.2124786078929901 - 0.04057123884558678im; … ; 0.04917571693658829 + 0.06531829386949539im -0.26813575625419617 - 0.24699832499027252im … -0.005230876617133617 + 0.021611899137496948im -0.1623590737581253 + 0.14148622751235962im; 0.3998381197452545 - 0.09549206495285034im 0.01471997331827879 - 0.27302247285842896im … -0.09034821391105652 + 0.11481619626283646im -0.5329245924949646 + 0.3032892346382141im], bias = ComplexF64[0.0 + 0.0im; 0.0 + 0.0im; … ; 0.0 + 0.0im; 0.0 + 0.0im;;]), layer_3 = (weight = ComplexF64[0.18369489908218384 - 0.17931848764419556im -0.4184981882572174 + 0.15965186059474945im … 0.22417707741260529 - 0.22444866597652435im 0.3134605288505554 - 0.005288226064294577im; 0.5319058299064636 - 0.12305065989494324im 0.02565431408584118 - 0.02762402780354023im … -0.11335651576519012 + 0.2669583559036255im -0.0010091445874422789 - 0.053010717034339905im; … ; -0.3982292413711548 - 0.006003747694194317im -0.29939648509025574 + 0.17847703397274017im … -0.012875470332801342 - 0.3082279860973358im -0.5564959049224854 + 0.09695551544427872im; 0.007936030626296997 - 0.2567330002784729im 0.11311032623052597 + 0.1972206085920334im … 0.02036339044570923 - 0.14611773192882538im -0.024891655892133713 + 0.17227661609649658im], bias = ComplexF64[0.0 + 0.0im; 0.0 + 0.0im; … ; 0.0 + 0.0im; 0.0 + 0.0im;;]), layer_4 = (weight = ComplexF64[-0.046117015182971954 + 0.09711457043886185im 0.5025700330734253 + 0.05446240305900574im … 0.2066519558429718 - 0.01681804470717907im 0.15362724661827087 + 0.24123860895633698im; -0.11880122870206833 - 0.2789801061153412im -0.08881326764822006 + 0.14416104555130005im … 0.34971800446510315 + 0.02146727591753006im 0.10826357454061508 - 0.021323617547750473im; … ; -0.15876266360282898 - 0.6521790027618408im 0.04549488052725792 + 0.018977994099259377im … -0.04921087995171547 + 0.2560370862483978im -0.23153409361839294 - 0.29215309023857117im; -0.13698288798332214 - 0.28654682636260986im 0.03768850117921829 + 0.06687548756599426im … -0.4321778416633606 + 0.4295826852321625im -0.0034131575375795364 - 0.45368692278862im], bias = ComplexF64[0.0 + 0.0im; 0.0 + 0.0im; … ; 0.0 + 0.0im; 0.0 + 0.0im;;]))
opt = Adam(0.01)
alg = NNODE(chain, opt, ps; strategy = StochasticTraining(300,30000))
SciMLBase.allowscomplex(::NNODE) = true

# Attempt to solve the problem
sol = solve(problem, alg, verbose = true, maxiters = 1000, saveat = 0.001)
ERROR: GPU compilation of MethodInstance for (::GPUArrays.var"#broadcast_kernel#38")(::CUDA.CuKernelContext, ::CuDeviceVector{…}, ::Base.Broadcast.Broadcasted{…}, ::Int64) failed
KernelError: passing and using non-bitstype argument

Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1, CUDA.Mem.DeviceBuffer}, Tuple{Base.OneTo{Int64}}, typeof(+), Tuple{Base.Broadcast.Extruded{Vector{ComplexF64}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1, CUDA.Mem.DeviceBuffer}, Nothing, typeof(*), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Float64, Float64}}, Base.Broadcast.Extruded{CuDeviceVector{ComplexF64, 1}, Tuple{Bool}, Tuple{Int64}}}}}}, which is not isbits:
  .args is of type Tuple{Base.Broadcast.Extruded{Vector{ComplexF64}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1, CUDA.Mem.DeviceBuffer}, Nothing, typeof(*), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Float64, Float64}}, Base.Broadcast.Extruded{CuDeviceVector{ComplexF64, 1}, Tuple{Bool}, Tuple{Int64}}}}} which is not isbits.
    .1 is of type Base.Broadcast.Extruded{Vector{ComplexF64}, Tuple{Bool}, Tuple{Int64}} which is not isbits.
      .x is of type Vector{ComplexF64} which is not isbits.

Additional Information

  • Environment: Julia 1.10.2, Lux v0.5.19, LuxCUDA v0.3.2, ComponentArrays v0.15.10
  • The error message specifically points to a non-bitstype argument passed to the broadcasting kernel.
  • This issue prevents the successful execution of the neural network optimization problem on a GPU, limiting performance and efficiency.

RomanSahakyan03 avatar Apr 03 '24 18:04 RomanSahakyan03