AbstractGPs.jl icon indicating copy to clipboard operation
AbstractGPs.jl copied to clipboard

GPU Minimal Working Example

Open rossviljoen opened this issue 3 years ago • 2 comments

NOT INTENDED FOR MERGE

This PR is meant to give an idea of what work is needed to allow AbstractGPs to run on GPU (although with no regard for performance).

The main changes required are:

In AbstractGPs

  • mean functions need to return lazy Fill arrays for ZeroMean and ConstMean (should probably do this anyway?)
  • Broadcasting over Diagonal{<:Real,<:FillArrays.Fill}} will produce an Array, not a CuArray (i.e. when computing cov(fx) = cov(fx.f, fx.x) + fx.Σy). I don't think there's an easy way to fix this in general, so either we don't use a lazy Fill array here - which is what I've implemented - or we define a custom _add_broadcasted function which overrides the default broadcasting logic.

In KernelFunctions

  • Distances.jl is not at all GPU compatible, so custom implementations of pairwise / colwise are needed (obviously, this should be done in KernelFunctions, but I've included it here to keep it in one place).
  • https://github.com/JuliaGaussianProcesses/KernelFunctions.jl/issues/299

See also https://github.com/rossviljoen/SparseGPs.jl/issues/15

rossviljoen avatar Aug 13 '21 11:08 rossviljoen

This honestly isn't as bad as I was expecting -- very nice work.

Looking at GPUArrays, it seems to be an incredibly light-weight dependency (just depends on stdlib and Adapt.jl, which itself just depends on the stdlib), so I wouldn't be opposed to making it a dependency throughout JuliaGPs.

I'll try to take a look at this in more depth at some point, quite busy at the minute so unless it's blocking other work, I'll hold off for now.

willtebbutt avatar Aug 13 '21 11:08 willtebbutt

Codecov Report

Merging #197 (945ae43) into master (dd513f1) will decrease coverage by 4.27%. The diff coverage is 23.80%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #197      +/-   ##
==========================================
- Coverage   97.98%   93.71%   -4.28%     
==========================================
  Files          10       11       +1     
  Lines         348      366      +18     
==========================================
+ Hits          341      343       +2     
- Misses          7       23      +16     
Impacted Files Coverage Δ
src/util/gpu.jl 11.11% <11.11%> (ø)
src/finite_gp_projection.jl 100.00% <100.00%> (ø)
src/mean_function.jl 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update dd513f1...945ae43. Read the comment docs.

codecov-commenter avatar Aug 13 '21 14:08 codecov-commenter