KernelAbstractions.jl icon indicating copy to clipboard operation
KernelAbstractions.jl copied to clipboard

User-facing API like vmap

Open jessebett opened this issue 4 years ago • 11 comments

@ChrisRackauckas suggests that this package provides much of the utilities that would make broadcasting over specified axes efficient. This can be seen in DiffEqGPU.jl.

Can we discuss a user facing API so we can directly compare against JAX vmap.

For instance if I have a function

f(x::Scalar, y::Vector, A::Array) = linalg...

How can I efficiently broadcast over collections of inputs stored in collections with axes like multidimensional arrays ("tensors").

# Broadcast over rows of second argument
vmap(f, in_axes=(nothing, 1, nothing))(scalar, array, array)

# Broadcast over axes for all arguments
vmap(f, in_axes=(1, 1, 3))(vector, array, tensor)

Further, is it possible to provide these as defaults for something like eachslice so that broadcasting Just Works?

f.(scalar, eachrow(array), array)

jessebett avatar Jul 29 '20 21:07 jessebett

function vmap!(f, out, args...)
         @kernel function kernelF(f::F, out, args...) where F
            I = @index(Global, Linear)
            elems = ntuple(Val(length(args))) do i
              Base.@_inline_meta
              @inbounds args[i][I]
            end
            @inbounds out[I] = f(elems...)
         end
         
         if out isa Array
           device = CPU()
         # elseif out isa CuArray
         #  device = CUDADevice()
         end
         
         kernel = kernelF(device, (256,))
         event = kernel(f, out, args..., ndrange=size(out))
         wait(event)
         return out
       end

vchuravy avatar Jul 29 '20 21:07 vchuravy

What would be a good API for fusing matmuls and dots with the vmap!, like in this post?

chriselrod avatar Jul 29 '20 22:07 chriselrod

This allows vmap!(sqrt, cu(rand(3)), cu(rand(1:10,3).^2)) but things like vmap!(sum, rand(3), [rand(Int8,4) for _ in 1:3]) will only work on the CPU right?

I never thought to try, but kernels which take slices of larger arrays seem to work fine on the CPU. Somehow the linked DiffEqGPU.jl source is able to do this for GPU too, what's the trick? Edit: https://github.com/mcabbott/Tullio.jl/pull/20

mcabbott avatar Jul 30 '20 11:07 mcabbott

I posted my long form comment on discourse, but here are the parts related to this discussion:

  • I think to get the performance of fused BLAS operations, the user has to specify f appropriately to vmap?
  • Unlike Jax.vmap, this doesn't have any "depth," right? It can't do anything about what's inside f. Ultimately, I think this would require a custom compiler pass or a significant change to function calling behavior to pass the axis info all the way through the call stack (bad idea imo).

darsnack avatar Jul 30 '20 16:07 darsnack

KA kernels are all about changing what's inside of f to build a SPMD function? I'm not sure what you mean by "depth" in this context.

ChrisRackauckas avatar Jul 30 '20 16:07 ChrisRackauckas

Something like this where "depth" refers to the depth in the call stack. Does KA do this? That's awesome!

darsnack avatar Jul 30 '20 16:07 darsnack

I believe depth here refers to LA op/kernel fusion and (where applicable) reordering/lifting of operations.

ToucheSir avatar Jul 30 '20 16:07 ToucheSir

I am off on vacation, so I won't partake in this discussion for the next two weeks.

KA is build on top of Cassette so it can indeed change the entire callgraph, bit I would argue that this is the wrong level of abstraction. For fusing one might build a DSL that in the end lowers to KernelAbstractions as an execution engine.

On Thu, Jul 30, 2020, 18:15 Brian Chen [email protected] wrote:

I believe depth here refers to LA op/kernel fusion and (where applicable) reordering/lifting of operations.

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/JuliaGPU/KernelAbstractions.jl/issues/117#issuecomment-666498978, or unsubscribe https://github.com/notifications/unsubscribe-auth/AABDO2UH6UR3C5FWA7VW7JDR6GMBVANCNFSM4PMBHS7Q .

vchuravy avatar Jul 31 '20 11:07 vchuravy

Apologies for reviving this thread with yet more questions, but what would be the most appropriate place to define such a DSL (if indeed one exists at all)? In Python land one would likely pick up XLA or TVM, but such monolithic frameworks seem like a poor fit given that all of CUDA, GPUCompiler, KernelAbstractions and Dagger(GPU) exist.

ToucheSir avatar Aug 25 '20 01:08 ToucheSir

I think the design space is still quite open. Tullio.jl is something like that for fans of Einstein notation. I have my own playground where I explore infrastructure ideas. I might also be convinced that it is a value add for KA, but in general we have orthogonal packages in Julia.

vchuravy avatar Aug 25 '20 16:08 vchuravy

I wanted to write a short message that there is definitely user demand for some flavor of vmap.

There are at least two reasons vmap is interesting.

  1. Improved performance. For instance, fusing matrix-vector multiplications into a matrix-matrix multiplication. For more general cases where no such fusion is possible, one can run a for loop in parallel by, say, launching multiple CUDA kernels simultaneously rather than one-at-a-time.
  2. Improved syntax, code readability, and user experience. I have found that vmap significantly reduces the amount of unreadable batching code and boilerplate that one needs to write, and in fact that this one of the main reasons to have vmap in the first place. In this sense, vmap could be considered to perform a similar functions to TensorCast.jl, which more-or-less provides different ways of expressing existing functionality rather than adding new functionality.

I have not myself seen point 2 discussed much and would like to add that I believe there is great value here from the users' perspective, particularly for those who are either newer to the language, or aren't interested in getting into too many details. JAX' main difficulty from my perspective is a significant quantity of boilerplate mess the surrounding ecosystem generates (think repeating Haiku's hk.get_parameter hundreds of times, and similar), and Zygote/Flux do much better in general, but not in all cases. vmap is one of the things JAX gets very right and I think it would benefit the Julia ecosystem to have a well-thought-out analogue.

aterenin avatar Feb 28 '21 14:02 aterenin