Fix scalar indexing of ProjectTo for wrappers of GPU arrays
This PR attempts to remove scalar indexing for ProjectTo involving GPUArrays by restricting the depth of wrappers to 1 at all times.
Fixes #624. Additionally, as a first step, I copied many of the standard Array tests and found several other cases where scalar indexing occurred or projected types were incorrect (i.e. nested wrappers rather than plain arrays).
In summary, the changes consist of:
- Adding JLArrays as a test dependency and copying / adapting many Array tests to try and find cases that are incorrect / produce scalar indexing (it is certainly possible some cases have been missed, but I've tried to be thorough).
- Adding GPUArraysCore as a source dependency so that overloads could be added to ProjectTo. This allows limiting wrapper depth for GPUArrays without hampering CPU performance (I hope).
- Adding overloads for projections related to adjoints and transposes of GPUArrays.
While this isn't exactly ideal or elegant, it does enable AD of depth 1 wrappers of GPUArrays to not trigger scalar indexing, bringing it to parity with the forward pass. Hopefully, the additional dependencies are acceptable and I've added them to the project.toml correctly.
I seem to have formatted a few lines in the projection.jl tests file by accident. I can revert those changes if needed.
So the ChainRules test failure is actually a fixed broken test. I have no idea what is causing Diffractor to fail. From other PRs the failures from ChainRulesOverloadGeneration, StatsFuns and LogExpFunctions seem expected.
I would rather have this handled as an extension package that lives in GPUArraysCore. Adding this directly to CRC adds GPUArraysCore as a dependency to well over 1000 downstream packages.
Can we instead workout what interfaces we would need to expose for GPUArraysCore to hook into?
It looks like the only thing that wasn't just done via adding an overload for was the change to (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M}
and that could be done with an overload, it would just involve a little more copy paste than ideal
Indeed, it could all be done with overloads, I was just trying to avoid the copy-pasting as you say.
Aside from ProjectTo, which is already exported, I think the only things from CRC needed are project_type and _projection_mismatch, a convenience error function. I suppose project_type could be exported, or alternatively, perhaps the getproperty overload could be modified to avoid the need to export, i.e. if the symbol is :project_type then return the first type parameter. Not sure what the most appropriate way to handle the error function would be.
I actually forgot about extensions, and I'm not too familiar with how they work. This does seem like an ideal use case for them. What is the benefit of extending GPUArraysCore rather than ChainRulesCore?
What is the benefit of extending GPUArraysCore rather than ChainRulesCore?
The maintainers of GPUArraysCore are much more familar with what a GPU array represents and what is allowable on it, than the maintainers of ChainRulesCore.