ChainRulesCore.jl icon indicating copy to clipboard operation
ChainRulesCore.jl copied to clipboard

Projections do not play well with GPUCompiler

Open p-zubieta opened this issue 4 years ago • 2 comments

Here's an example that does not play well with GPUCompiler: https://gist.github.com/pabloferz/1390d85383e3243015be7ad5b162bcc4

A possible, but probably incomplete fix discussed with @mcabbott, is having the following specializations:

function ProjectTo(x::AbstractArray{T}) where {T <: AbstractFloat}
    return ProjectTo{AbstractArray}(; element=ProjectTo(zero(T)), axes=axes(x))
end

ProjectTo(x::AbstractArray{T}) where {T <: Bool} = ProjectTo{NoTangent}()

function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S}) where {S <: Number}
    T = ChainRulesCore.project_type(project.element)
    return S <: T ? dx : map(project.element, dx)
end

p-zubieta avatar Aug 09 '21 23:08 p-zubieta

I think we should do something like the first fix. The whole constructor function can be dispatch not branching, without loss, it just happened to get written that way: https://github.com/JuliaDiff/ChainRulesCore.jl/blob/master/src/projection.jl#L181-L188

Edit -- like this commit, maybe: https://github.com/JuliaDiff/ChainRulesCore.jl/pull/427/commits/7e5ae8edf5661f360f793982813fbbd2d3536dfd

The second seems trickier, it's avoiding this if hasproperty(project, :element) https://github.com/JuliaDiff/ChainRulesCore.jl/blob/master/src/projection.jl#L216 But also:

  • It avoids the reshape completely. Can we bypass this in more cases?
  • It assumes that the projector being applied to dx::AbstractArray{<:Number} has .element, but that's need not be true, e.g. for x = Any[1,2,3].

It's possible that we should insist that every array projector has .element, if necessary a trivial one. That might help in-place stuff too.

It's also possible that we should mark the two cases in some way easier to dispatch off of?

One iteration of this thing had ProjectTo{AbstractArray{Float32}}(...) always encoding the eltype. For every number type, the projector is in fact fully described by the element type. This design of literally storing the inner projector seemed simpler. For instance: do you make Projector{Diagonal{Float32}}(...) too, or do wrapper types always delegate this?

(It's also possible that we should @inline many things. No idea if this would help here; was reluctant to add clutter until finding at least one example where it does.)

mcabbott avatar Aug 09 '21 23:08 mcabbott

ChainRulesCore v1.3.1, latest, has the above branch-free construction of projectors, but still has if hasproperty(project, :element) in applying them.

If I try the linked gist, on latest everything (CUDA v3.4.2) I get a warning on the first run, but subsequent runs are find. Can you confirm what you see, and whether you think there are sill problems here?

julia> val_and_grad(dihedral_angle, CUDA.rand(Float64, 3, 4))
(┌ Warning: Performing scalar indexing on task Task (runnable) @0x00007f5e3878d710.
│ Invocation of getindex resulted in scalar indexing of a GPU array.
...[etc]...
└ @ GPUArrays ~/.julia/packages/GPUArrays/UBzTm/src/host/indexing.jl:56
fill(-0.33681546256995853), fill(SVector{3, Float64}[[0.4074603842207669, -1.441045304947525, 0.5486342400104408], [0.36579849547697396, -0.2231744971894618, 0.43044376309880106], [5.034006482016548, -7.684693458589903, 6.191228243615761], [-5.807265361714289, 9.34891326072689, -7.170306246725003]]))

julia> CUDA.allowscalar(false)

julia> val_and_grad(dihedral_angle, CUDA.rand(Float64, 3, 4))
(fill(-0.780010516490922), fill(SVector{3, Float64}[[0.845104396943803, 1.9836409707095195, 2.047662646035268], [-1.7434321992965327, -0.7401691262870027, -0.25745246549838496], [-0.24669189468086605, -0.0020683788027422434, 0.08506441265560705], [1.1450196970335957, -1.2414034656197745, -1.87527459319249]]))

PR #430 removes the if hasproperty(project, :element) branch, but I'm not certain it won't land us in dispatch hell. Would need careful thought, at least.

mcabbott avatar Sep 05 '21 21:09 mcabbott