NNlib.jl icon indicating copy to clipboard operation
NNlib.jl copied to clipboard

Add norm functions

Open ToucheSir opened this issue 2 years ago • 17 comments

These roughly correspond to Flux's *Norm layers.

Constitutes the bulk of the work in https://github.com/FluxML/NNlib.jl/issues/19. Dropout is a different beast and deserving of a separate discussion/its own PR. In the meantime, this should give NNlib{CUDA,ROC} a common interface to implement and allow Flux to start divesting its own norm function helpers.

Design Notes

The affine parameters are included as part of the function signature here because backends like cuDNN can fuse them into the main layer computation. This differs from Flux's use of Scale in LayerNorm, so how do we best reconcile the two? Activation functions are not included of this API for a similar reason (backends don't handle them).

Another thing not addressed in this PR is alternative memory layouts/dimension orderings such as channels-first. Hopefully once we figure out a way to represent inputs with different layouts, we can add internal extension points that dispatch on those. For now, everything is assumed to be spatial dims... x channels [x batch].

Relatedly, juggling dims is an absolute pain and both the most painful and time-consuming part of developing this PR. I wish there was a way to use named dims while not requiring users or backend implementors to do the same. Something for future work.

PR Checklist

  • [x] Tests are added
  • [x] Documentation, if applicable

ToucheSir avatar Jan 03 '23 00:01 ToucheSir

Once the documentation build has completed, you can preview any updated documentation at this URL: https://fluxml.ai/NNlib.jl/previews/PR452/

github-actions[bot] avatar Jan 03 '23 01:01 github-actions[bot]

My thought was that this PR (and any future functions which advertise ::AbstractArray) should be general enough to work on all array types. We then add overloads for native Arrays (in-repo and in extensions), GPU arrays and others (e.g. StaticArrays?) as applicable. So it's correct to say there's plenty of perf left on the table (I don't do a tenth of the tricks in https://github.com/chengchingwen/NeuralAttentionlib.jl/blob/b418c0d2a9e99c960e88879a5fd879d47d8e4c22/src/functional/layernorm.jl, for example), but generality comes first.

ToucheSir avatar Jan 03 '23 04:01 ToucheSir

Ah I had not seen that. It looks optimised but I don't immediately see what it's actually doing! On the CPU it seems to give correct answers but is not especially efficient. However, I see some multi-arg mapreduce which I know hits a fallback defn, so I assume it's aimed at GPU arrays.

The counter to trying to bolt on speed afterwards is that it may want different organisation. I haven't absorbed how this PR structures things, but for me, absorbing the mean & var pullbacks completely seemed to be important.

mcabbott avatar Jan 03 '23 05:01 mcabbott

It's likely optimized routines will want to replace everything and just fulfill the high-level *norm interface. The helper functions that do exist are meant to be internal, but implementors do have the choice of overriding specific parts. I suspect most will opt for the full replacement route, which is why functions like norm_helper don't have rrules defined and no equivalents to ∇conv_* are present—there aren't good one-size-fits-all solutions for either. To illustrate the extremes of the spectrum, cuDNN/oneDNN may want everything fused while XLA/LibTorch want everything decomposed into primitive ops internally.

ToucheSir avatar Jan 03 '23 05:01 ToucheSir

Ah I had not seen that. It looks optimised but I don't immediately see what it's actually doing! On the CPU it seems to give correct answers but is not especially efficient. However, I see some multi-arg mapreduce which I know hits a fallback defn, so I assume it's aimed at GPU arrays.

It's focusing on GPU arrays. The idea is to use the GPUArrays' broadcast/mapreduce kernel for simple kernel fusion and reducing gpu allocation. So like in the layer norm forward funciton:

function layer_norm(epsilon, alpha, beta, x)
    # [...]
    sum_sum2 = mapreduce(_x_x2, .+, x; dims=1, init = (zero(T), zero(T)))
    return fma.(α, _normalize.(convert(T, 1//N), ϵ, x, sum_sum2), β)
end

This only take two gpu kernel to compute the layer norm, and only 1 intermediate array is created (sum_sum2). Then we do the same in the pullback function. It's not the most optimized implementation since it's accessing the reduced array from global memory, but should be a reasonable trade-off.

chengchingwen avatar Jan 03 '23 15:01 chengchingwen

Cool. I've updated my gist above to time things... this is indeed fast.

On the GPU, I discover that var(x; mean, dims) seems to cost an entire copy. Despite that, the gradients for layer_norm and normal_new seem to be tied in memory alloc, suggesting there's one copy which could be avoided.

On the CPU it does some work N^2 times, but the cost of that can be largely solved by @fastmath sqrt, max.

What all this means for this PR, I don't know. I suppose that I think the basic implementation should (1) try to be efficient without doing anything exotic, Array and CuArray, (2) must not error on SArray, FillArray, etc, but little point optimising that, and (3) try to line up with cuDNN etc. routines to make such overloads easy, and perhaps try to make easy to hook on other overloads like using LoopVectorization.

Whether this mapreduce trick to avoid calling mean, var is too exotic, I don't know. Zygote would hate it but shouldn't see. It doesn't mean that there shouldn't be a function norm_stats, as BatchNorm wants to store that, but the fast path isn't then a fast version of that function at all.

mcabbott avatar Jan 03 '23 21:01 mcabbott

Thanks for timing it. Were the reported numbers for layer_norm on gpu obtained with or without @fastmath?

I don't know much about fastmath, but I vaguely remember it is considered evil?

chengchingwen avatar Jan 03 '23 23:01 chengchingwen

Not 100% sure but I think it made no difference on GPU. Weird beasts but computation often seems free compared to memory anyway.

Not sure I have an opinion on how evil it is. Here it seem pretty safe, it may fail to propagate NaN via the variance, but you'll get it again from x itself. But I didn't check very carefully.

mcabbott avatar Jan 03 '23 23:01 mcabbott

If one really wants to go wild with optimizations, it's possible to fuse the computation of sum_sum2 into the forward pass as well and write out to mean + var/inverse std as you go along. https://triton-lang.org/master/getting-started/tutorials/05-layer-norm.html does this and is the fastest LayerNorm implementation I know of.

It doesn't mean that there shouldn't be a function norm_stats, as BatchNorm wants to store that, but the fast path isn't then a fast version of that function at all.

norm_stats is a kludge and wouldn't exist in an ideal world. The code reuse is nice, but its main purpose is to support maybe_norm_stats, which only exists to make the pullback for batch/instancenorm somewhat type stable. And we only care about type stability because Zygote gets cranky if a pullback in a deeply nested model isn't type stable. Indeed, a lot of code in this PR was written mainly to keep AD happy.

ToucheSir avatar Jan 04 '23 00:01 ToucheSir

If one really wants to go wild with optimizations, it's possible to fuse the computation of sum_sum2 into the forward pass as well and write out to mean + var/inverse std as you go along. https://triton-lang.org/master/getting-started/tutorials/05-layer-norm.html does this and is the fastest LayerNorm implementation I know of.

The issue of that approach is that we would need to go down to kernel programming with either CUDA.jl or KernelAbstraction.jl. And the input size would also affect the performance, so in Pytorch they implement multiple kernel and dispatch with the input type and size. I guess part of the success of triton also comes from their compiler?

chengchingwen avatar Jan 04 '23 00:01 chengchingwen

It would be interesting to know, because the kernels are so much shorter too. Wonder what performance we'd get with Julia versions. A Triton-like library would be a nice force multiplier in this ecosystem.

ToucheSir avatar Jan 04 '23 01:01 ToucheSir

I think this sum_sum2 trick has catastrophic cancellation problems. Not sure if this matters in practice, but:

julia> data32 = rand(Float32, 100) .+ 10^4 |> cu;

julia> normal_now(data32) |> mean  # using mean, var, like Flux
-0.0031580066f0

julia> layer_norm(nothing, nothing, data32) |> mean
-87.890625f0

julia> hcat(normal_now(data32), layer_norm(nothing, nothing, data32))
100×2 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
 -1.28074    -35644.5
  1.02109     28418.0
 -1.64216    -45703.1
 -1.19302    -33203.1
  0.635108    17675.8
...

On CPU without @fastmath this gives an error from sqrt(negative).

mcabbott avatar Jan 04 '23 02:01 mcabbott

That doesn't look good. Maybe it's worth switching to Welford algo for that.

chengchingwen avatar Jan 04 '23 05:01 chengchingwen

It sounds like that's the done thing. Can it use mapreduce or do you need lower-level things?

mcabbott avatar Jan 04 '23 06:01 mcabbott

It should be doable with mapreduce, just replace .+ with the correct update rule. I didn't do that because at the beginning I thought if the input is a large array with small values, the division-then-addition in the update rule would introduce more error. But looks like the catastrophic cancellation is more troublesome.

chengchingwen avatar Jan 04 '23 06:01 chengchingwen

What's the status here? Can we delay optimization/generalization to future PRs and focus on approximately porting existing functionality if that's what needed to move on?

CarloLucibello avatar Jan 24 '23 04:01 CarloLucibello

Kind ping on the status of this PR as it will unblock support for AMDGPU backend (and we can specialize batchnorm using MIOpen for which I plan to open a PR).

pxl-th avatar Feb 13 '23 21:02 pxl-th