Distributions.jl
Distributions.jl copied to clipboard
Make Distributions.jl GPU friendly
An increasing number of ML applications require sampling and reusing Distributions.jl instead of rolling your own sampler would be a major step forward. I have already forked this package to fix issues where parametric types are hardcoded to Float64
instead of generic (to support a ML project I am working on). At this point, being able to sample on the GPU instead of moving data back and forth is pretty important to me. I am willing to put some effort into making this happen, but Distributions.jl covers a lot of stuff that I am not familiar with.
To that end, I'd like some help coming up with the list of changes that need to bring us towards this GPU friendly goal. Here's what I have so far:
- [ ] Replace hardcoded types (e.g.
Float64
) with generics (e.g.AbstractFloat
/Real
) - [ ] Replicate tests currently intended for CPU to GPU
- [ ] A system to pack sampling over an array of distributions into a single operation
- [ ] Remove scalar indexing from sampling code
I have already forked this package to fix issues where parametric types are hardcoded to Float64 instead of generic (to support a ML project I am working on).
Many distributions are already generic. It would be great if you could prepare PRs for the ones that aren't.
A system to pack sampling over an array of distributions into a single operation
Can't you use broadcasting for this?
Many distributions are already generic. It would be great if you could prepare PRs for the ones that aren't.
Yeah I'll start looking into this.
Can't you use broadcasting for this?
Yes, we'll need to define broadcasting styles to do this packing.
Also another thing I thought of yesterday — the sampling code will need to be modified to not use scalar indexing. This would be very slow on the GPU, and most people has scalar indexing disabled when loading CuArrays.jl.
Curious if this is likely to be prioritized any time soon!
I got pulled into other projects that prevents me from focusing on this issue right now. But I know @srinjoyganguly is working on a DistributionsGPU.jl package.
Thanks for the prompt response!
Including a link to srinjoyganguly's repo for anyone who ends up here: DistributionsGPU.jl
Are GPU features from DistributionsGPU.jl going to eventually be ported into Distributions.jl? I really hope so
Hello @darsnack @johnurbanik I am sorry for the late response. I got busy with some assignments. I will start working on the issue soon. @azev77 I truly hope these features are added to Distributions.jl package. It will help a lot in computation speeds during sampling. Thanks so much.
https://github.com/JuliaStats/Distributions.jl/pull/1126
Just want to point out we don't have to solve all these GPU things at once, and the abstract types would help integrate with other libraries regardless of gpu implications
I recently ran into trouble too, trying to integrate Distributions into GPU code. Maybe we could make Distributions.jl more GPU friendly step-by-step? Currently, even simple things like
using Distributions, CUDA
Mu = cu(rand(10))
Sigma = cu(rand(10))
D = Normal.(Mu, Sigma)
fail with a
ERROR: InvalidIRError: compiling kernel broadcast_kernel [...] Reason: unsupported dynamic function invocation (call to string(a::Union{Char, SubString{String}, String}...) [...]
This works, however:
D = Normal{eltype(Mu)}.(Mu, Sigma)
But then we run into the same error as above again with
logpdf.(D, X)
ERROR: InvalidIRError: compiling kernel broadcast_kernel [...] Reason: unsupported dynamic function invocation (call to string(a::Union{Char, SubString{String}, String}...) [...]
This is not surprising, since logpdf(::Normal, ::Real)
contains a z = zval(Normal(...), x)
.
If we override the Normal
ctor (the replacement below is just an example implementation and lacks check_args
), we can get things to work:
Distributions.Normal(mu::T, sigma::T) where {T<:Real} = Normal{eltype(mu)}(mu, sigma)
D = Normal.(Mu, Sigma)
logpdf.(D, X) isa CuArray
That's just one specific case, of course, but I have a feeling that with a bit of tweaking here and there Distributions could support GPU operations quite well.
Is this a Distributions or a CUDA issue? What exactly is the problem?
The Normal
constructor is not special at all, basically just the standard constructor with a check of the standard deviation parameter.
I think it's because most distribution default ctors (like Normal()
) check arguments and can throw an exception (involving string generation), which CUDA doesn't like. One can work around that by using alternative ctors like Normal{typeof(Mu)}.(Mu, Sigma)
.
But logpdf
and logcdf
call Normal()
internally as well, and that the user can't work around. Should be easy to change though, there's now reason why logpdf
should trigger check_args
of the ctors.
Yeah, usually (e.g. in constructors without sigma or convert
) we use Normal(mu, sigma; check_args=false)
. However, on this specific case I think we should just call https://github.com/JuliaStats/StatsFuns.jl/blob/master/src/distrs/norm.jl, I've never understood why the code is duplicated.
I think we'll find a few CUDA/GPU-incompatible things like that scattered throughout Distributions, but I hope many will be easy to fix in that regard (except calls to Rmath, obviously, but luckily there's a lot less of those than in the early days by now).
I have some GPU-related statistics tasks coming up, I can do PR's along the way when I hit issues like the above.
I opened https://github.com/JuliaStats/Distributions.jl/pull/1487.
Thanks!
Is there currently a way to transfer an arbitrary Distribution
to the GPU (like cu(dist)
)? For example, the following works for MvNormal
:
using CUDA, Distributions, LinearAlgebra
"Transfer the distribution to the GPU."
function CUDA.cu(ρ::D) where {D <: MvNormal}
MvNormal(cu(dist.μ), cu(dist.Σ))
end
dist = MvNormal(1f0*I(2))
cu_dist=cu(dist)
gradlogpdf(cu_dist, cu(ones(2)))
Maybe something like this:
function CUDA.cu(ρ::D) where {D}
D(cu.(getfield(ρ, field) for field in fieldnames(D))...)
end
Given that there are no constraints or specifications on fields, parameters, and constructors I don't think any such implementation will generally work. The cleanest approach would be to implement https://github.com/JuliaGPU/Adapt.jl but that has to be done for each distribution separately to ensure it is correct (possibly one could use ConstructionBase but I guess it is not even needed for this task).
Hence I think it's much easier to just construct distributions with GPU-compatible parameters instead of trying to move an existing distribution to the GPU.
perhaps webgpu, being an abstraction over GPU-specific stuff, could enable a more hardware-agnostic and forward-compatible approach? https://github.com/cshenton/WebGPU.jl and https://github.com/JuliaWGPU/WGPUNative.jl could be examples or usable for E2E testing to automatically check if Distributions.jl works on GPU. Then you could handle AMD/NVIDIA/Intel/Whatever and potentially the code would last longer than if we focus on some particular version of CUDA.
another idea would be to enforce a linter rule whereby all the functions in this library must act on abstract number types instead of concrete number kinds like float64. otherwise, it's difficult to go the last mile to the GPU because you might need a different (usually smaller or somehow different) entry data type in your arrays to actually fit your problems on your GPU so it works in practice.
does julia have a concept of traits? you could probably modularize this better:
/// a field is a set on which addition, subtraction, multiplication, and division are defined and behave as the corresponding operations on rational and real numbers do
pub trait Field:
Clone
+ Default
+ std::fmt::Debug
+ std::fmt::Display
+ Sized
+ Eq
+ PartialEq
+ Add<Output = Option<Self>>
+ AddAssign<Self>
+ Sub<Output = Option<Self>>
+ SubAssign<Self>
+ Mul<Output = Option<Self>>
+ MulAssign<Self>
+ Div<Output = Option<Self>>
+ DivAssign<Self>
+ Neg<Output = Option<Self>>
{
const FIELD_NAME: &'static str;
const DATA_TYPE_STR: &'static str;
type Data: Add + Sub + Mul + Div + Default + Debug + Display + Clone;
type Shape;
// Basic Info
fn field_name(&self) -> &'static str {
Self::FIELD_NAME
}
fn data_type_str(&self) -> &'static str {
Self::DATA_TYPE_STR
}
fn shape(&self) -> &Self::Shape;
// Construction and value access
fn of<X>(value: X) -> Self
where
Self::Data: From<X>;
fn from_dtype(value: Self::Data) -> Self {
Self::of(value)
}
fn try_from_n_or_0<T>(value: T) -> Self
where
Self::Data: TryFrom<T>,
{
match value.try_into() {
Ok(converted) => Self::of(converted),
Err(_e) => {
println!("[from_n_or_0] failed to convert a value!");
Self::zero()
}
}
}
fn try_from_n_or_1<T>(value: T) -> Self
where
Self::Data: TryFrom<T>,
{
match value.try_into() {
Ok(converted) => Self::of(converted),
Err(_) => {
println!("[from_n_or_1] failed to convert a value!");
Self::one()
}
}
}
fn get_value(&self) -> Option<&Self::Data>;
fn set_value(&mut self, value: Self::Data);
// Additive & Multiplicative Identity
fn is_zero(&self) -> bool;
fn is_one(&self) -> bool;
fn zero() -> Self;
fn one() -> Self;
// Basic operations with references
fn add_ref(&self, other: &Self) -> Option<Self>;
fn sub_ref(&self, other: &Self) -> Option<Self>;
fn mul_ref(&self, other: &Self) -> Option<Self>;
fn safe_div_ref(&self, other: &Self) -> Option<Self>;
fn rem_ref(&self, other: &Self) -> Option<Self>;
// Field has a multiplicative inverse -- requires rational number field for unsigned integers
// fn mul_inv(&self) -> Option<Ratio<Self, Self>>;
} // thus, this trait should be modularized into smaller traits in a progression of capability
anyway, i'll spare you the implementation details and that can probably be way better if informed mathematically and modularized better. If you think about it, when Socrates was talking about Forms, he was talking about Traits; when we think about Float or Number, that's a trait, not a concrete type, so if you really want to make it hella easy to work with GPUs in Distributions.jl, maybe the solution is to make a test with webgpu and some abstract number type which would work independently of the hardware and precision or whatever. Then our current implicit data type Float { processor: CPU, dtype: Float64 } can become Float { processor: WebGPU, dtype: Bfloat16 } as long as a "Bfloat16" implements the same required functionality as a Float64 then you could swap them out like Indiana Jones!