SliceMap.jl icon indicating copy to clipboard operation
SliceMap.jl copied to clipboard

Same-same but different

SliceMap.jl

Build Status

This package provides some mapslices-like functions, with gradients defined for Tracker and Zygote:

mapcols(f, M) ≈ mapreduce(f, hcat, eachcol(M))
MapCols{d}(f, M)         # where d=size(M,1), for SVector slices
ThreadMapCols{d}(f, M)   # using Threads.@threads

maprows(f, M) ≈ mapslices(f, M, dims=2)

slicemap(f, A; dims) ≈ mapslices(f, A, dims=dims) # only Zygote

The capitalised functions differ both in using StaticArrays slices, and using ForwardDiff for the gradient of each slice, instead of the same reverse-mode Tracker/Zygote. For small slices, this will often be much faster, with or without gradients.

The package also defines Zygote gradients for the Slice/Align functions in JuliennedArrays, which is a good way to roll-your-own mapslices-like thing (and is exactly how slicemap(f, A; dims) works). Similar gradients are also available in TensorCast, and in LazyStack.

There are more details & examples at docs/intro.md.