AbstractGPs.jl
AbstractGPs.jl copied to clipboard
GPU Minimal Working Example
NOT INTENDED FOR MERGE
This PR is meant to give an idea of what work is needed to allow AbstractGPs to run on GPU (although with no regard for performance).
The main changes required are:
In AbstractGPs
-
mean
functions need to return lazy Fill arrays forZeroMean
andConstMean
(should probably do this anyway?) - Broadcasting over
Diagonal{<:Real,<:FillArrays.Fill}}
will produce anArray
, not aCuArray
(i.e. when computingcov(fx) = cov(fx.f, fx.x) + fx.Σy
). I don't think there's an easy way to fix this in general, so either we don't use a lazy Fill array here - which is what I've implemented - or we define a custom_add_broadcasted
function which overrides the default broadcasting logic.
In KernelFunctions
-
Distances.jl
is not at all GPU compatible, so custom implementations ofpairwise
/colwise
are needed (obviously, this should be done inKernelFunctions
, but I've included it here to keep it in one place). - https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/issues/299
See also https://github.com/rossviljoen/SparseGPs.jl/issues/15
This honestly isn't as bad as I was expecting -- very nice work.
Looking at GPUArrays
, it seems to be an incredibly light-weight dependency (just depends on stdlib and Adapt.jl
, which itself just depends on the stdlib), so I wouldn't be opposed to making it a dependency throughout JuliaGPs.
I'll try to take a look at this in more depth at some point, quite busy at the minute so unless it's blocking other work, I'll hold off for now.
Codecov Report
Merging #197 (945ae43) into master (dd513f1) will decrease coverage by
4.27%
. The diff coverage is23.80%
.
@@ Coverage Diff @@
## master #197 +/- ##
==========================================
- Coverage 97.98% 93.71% -4.28%
==========================================
Files 10 11 +1
Lines 348 366 +18
==========================================
+ Hits 341 343 +2
- Misses 7 23 +16
Impacted Files | Coverage Δ | |
---|---|---|
src/util/gpu.jl | 11.11% <11.11%> (ø) |
|
src/finite_gp_projection.jl | 100.00% <100.00%> (ø) |
|
src/mean_function.jl | 100.00% <100.00%> (ø) |
Continue to review full report at Codecov.
Legend - Click here to learn more
Δ = absolute <relative> (impact)
,ø = not affected
,? = missing data
Powered by Codecov. Last update dd513f1...945ae43. Read the comment docs.