Projections do not play well with GPUCompiler
Here's an example that does not play well with GPUCompiler: https://gist.github.com/pabloferz/1390d85383e3243015be7ad5b162bcc4
A possible, but probably incomplete fix discussed with @mcabbott, is having the following specializations:
function ProjectTo(x::AbstractArray{T}) where {T <: AbstractFloat}
return ProjectTo{AbstractArray}(; element=ProjectTo(zero(T)), axes=axes(x))
end
ProjectTo(x::AbstractArray{T}) where {T <: Bool} = ProjectTo{NoTangent}()
function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S}) where {S <: Number}
T = ChainRulesCore.project_type(project.element)
return S <: T ? dx : map(project.element, dx)
end
I think we should do something like the first fix. The whole constructor function can be dispatch not branching, without loss, it just happened to get written that way: https://github.com/JuliaDiff/ChainRulesCore.jl/blob/master/src/projection.jl#L181-L188
Edit -- like this commit, maybe: https://github.com/JuliaDiff/ChainRulesCore.jl/pull/427/commits/7e5ae8edf5661f360f793982813fbbd2d3536dfd
The second seems trickier, it's avoiding this if hasproperty(project, :element)
https://github.com/JuliaDiff/ChainRulesCore.jl/blob/master/src/projection.jl#L216
But also:
- It avoids the
reshapecompletely. Can we bypass this in more cases? - It assumes that the projector being applied to
dx::AbstractArray{<:Number}has.element, but that's need not be true, e.g. forx = Any[1,2,3].
It's possible that we should insist that every array projector has .element, if necessary a trivial one. That might help in-place stuff too.
It's also possible that we should mark the two cases in some way easier to dispatch off of?
One iteration of this thing had ProjectTo{AbstractArray{Float32}}(...) always encoding the eltype. For every number type, the projector is in fact fully described by the element type. This design of literally storing the inner projector seemed simpler. For instance: do you make Projector{Diagonal{Float32}}(...) too, or do wrapper types always delegate this?
(It's also possible that we should @inline many things. No idea if this would help here; was reluctant to add clutter until finding at least one example where it does.)
ChainRulesCore v1.3.1, latest, has the above branch-free construction of projectors, but still has if hasproperty(project, :element) in applying them.
If I try the linked gist, on latest everything (CUDA v3.4.2) I get a warning on the first run, but subsequent runs are find. Can you confirm what you see, and whether you think there are sill problems here?
julia> val_and_grad(dihedral_angle, CUDA.rand(Float64, 3, 4))
(┌ Warning: Performing scalar indexing on task Task (runnable) @0x00007f5e3878d710.
│ Invocation of getindex resulted in scalar indexing of a GPU array.
...[etc]...
└ @ GPUArrays ~/.julia/packages/GPUArrays/UBzTm/src/host/indexing.jl:56
fill(-0.33681546256995853), fill(SVector{3, Float64}[[0.4074603842207669, -1.441045304947525, 0.5486342400104408], [0.36579849547697396, -0.2231744971894618, 0.43044376309880106], [5.034006482016548, -7.684693458589903, 6.191228243615761], [-5.807265361714289, 9.34891326072689, -7.170306246725003]]))
julia> CUDA.allowscalar(false)
julia> val_and_grad(dihedral_angle, CUDA.rand(Float64, 3, 4))
(fill(-0.780010516490922), fill(SVector{3, Float64}[[0.845104396943803, 1.9836409707095195, 2.047662646035268], [-1.7434321992965327, -0.7401691262870027, -0.25745246549838496], [-0.24669189468086605, -0.0020683788027422434, 0.08506441265560705], [1.1450196970335957, -1.2414034656197745, -1.87527459319249]]))
PR #430 removes the if hasproperty(project, :element) branch, but I'm not certain it won't land us in dispatch hell. Would need careful thought, at least.