KernelAbstractions.jl
KernelAbstractions.jl copied to clipboard
User-facing API like vmap
@ChrisRackauckas suggests that this package provides much of the utilities that would make broadcasting over specified axes efficient. This can be seen in DiffEqGPU.jl.
Can we discuss a user facing API so we can directly compare against JAX vmap.
For instance if I have a function
f(x::Scalar, y::Vector, A::Array) = linalg...
How can I efficiently broadcast over collections of inputs stored in collections with axes like multidimensional arrays ("tensors").
# Broadcast over rows of second argument
vmap(f, in_axes=(nothing, 1, nothing))(scalar, array, array)
# Broadcast over axes for all arguments
vmap(f, in_axes=(1, 1, 3))(vector, array, tensor)
Further, is it possible to provide these as defaults for something like eachslice
so that broadcasting Just Works?
f.(scalar, eachrow(array), array)
function vmap!(f, out, args...)
@kernel function kernelF(f::F, out, args...) where F
I = @index(Global, Linear)
elems = ntuple(Val(length(args))) do i
Base.@_inline_meta
@inbounds args[i][I]
end
@inbounds out[I] = f(elems...)
end
if out isa Array
device = CPU()
# elseif out isa CuArray
# device = CUDADevice()
end
kernel = kernelF(device, (256,))
event = kernel(f, out, args..., ndrange=size(out))
wait(event)
return out
end
What would be a good API for fusing matmuls and dots with the vmap!
, like in this post?
This allows vmap!(sqrt, cu(rand(3)), cu(rand(1:10,3).^2))
but things like vmap!(sum, rand(3), [rand(Int8,4) for _ in 1:3])
will only work on the CPU right?
I never thought to try, but kernels which take slices of larger arrays seem to work fine on the CPU. Somehow the linked DiffEqGPU.jl source is able to do this for GPU too, what's the trick? Edit: https://github.com/mcabbott/Tullio.jl/pull/20
I posted my long form comment on discourse, but here are the parts related to this discussion:
- I think to get the performance of fused BLAS operations, the user has to specify
f
appropriately tovmap
? - Unlike
Jax.vmap
, this doesn't have any "depth," right? It can't do anything about what's insidef
. Ultimately, I think this would require a custom compiler pass or a significant change to function calling behavior to pass the axis info all the way through the call stack (bad idea imo).
KA kernels are all about changing what's inside of f
to build a SPMD function? I'm not sure what you mean by "depth" in this context.
Something like this where "depth" refers to the depth in the call stack. Does KA do this? That's awesome!
I believe depth here refers to LA op/kernel fusion and (where applicable) reordering/lifting of operations.
I am off on vacation, so I won't partake in this discussion for the next two weeks.
KA is build on top of Cassette so it can indeed change the entire callgraph, bit I would argue that this is the wrong level of abstraction. For fusing one might build a DSL that in the end lowers to KernelAbstractions as an execution engine.
On Thu, Jul 30, 2020, 18:15 Brian Chen [email protected] wrote:
I believe depth here refers to LA op/kernel fusion and (where applicable) reordering/lifting of operations.
— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/JuliaGPU/KernelAbstractions.jl/issues/117#issuecomment-666498978, or unsubscribe https://github.com/notifications/unsubscribe-auth/AABDO2UH6UR3C5FWA7VW7JDR6GMBVANCNFSM4PMBHS7Q .
Apologies for reviving this thread with yet more questions, but what would be the most appropriate place to define such a DSL (if indeed one exists at all)? In Python land one would likely pick up XLA or TVM, but such monolithic frameworks seem like a poor fit given that all of CUDA, GPUCompiler, KernelAbstractions and Dagger(GPU) exist.
I think the design space is still quite open. Tullio.jl is something like that for fans of Einstein notation. I have my own playground where I explore infrastructure ideas. I might also be convinced that it is a value add for KA, but in general we have orthogonal packages in Julia.
I wanted to write a short message that there is definitely user demand for some flavor of vmap
.
There are at least two reasons vmap
is interesting.
- Improved performance. For instance, fusing matrix-vector multiplications into a matrix-matrix multiplication. For more general cases where no such fusion is possible, one can run a for loop in parallel by, say, launching multiple CUDA kernels simultaneously rather than one-at-a-time.
- Improved syntax, code readability, and user experience. I have found that
vmap
significantly reduces the amount of unreadable batching code and boilerplate that one needs to write, and in fact that this one of the main reasons to havevmap
in the first place. In this sense,vmap
could be considered to perform a similar functions to TensorCast.jl, which more-or-less provides different ways of expressing existing functionality rather than adding new functionality.
I have not myself seen point 2 discussed much and would like to add that I believe there is great value here from the users' perspective, particularly for those who are either newer to the language, or aren't interested in getting into too many details. JAX' main difficulty from my perspective is a significant quantity of boilerplate mess the surrounding ecosystem generates (think repeating Haiku's hk.get_parameter
hundreds of times, and similar), and Zygote/Flux do much better in general, but not in all cases. vmap
is one of the things JAX gets very right and I think it would benefit the Julia ecosystem to have a well-thought-out analogue.