Add fold.jl
from a quick glance looks mostly ok, thanks. Convenience methods for the 1d and 2d cases are missing.
Then some tests and rrules
Thanks so much to look into this! I've left a couple thoughts but overall I'm wondering if you've checked this on the GPU as well
I haven't tested it on GPU yet, I'll check it once.
We might not have a layer in or maybe we need to just hold the config to construct the cdims object.
from a quick glance looks mostly ok, thanks. Convenience methods for the 1d and 2d cases are missing. Then some tests and
rrulesShould I implement them here or will they be taken care of while implementing a layer in Flux.jl.
Thanks for the PR! Tests and rrules need to be here. I'm not sure what exactly is meant by convenience methods, but if it makes sense to call directly (as opposed to only as a helper for a layer struct) then it should be in NNlib as well.
We might not have a layer in or maybe we need to just hold the config to construct the cdims object.
Okay, so I'll try to implement those convenience wrappers.
from a quick glance looks mostly ok, thanks. Convenience methods for the 1d and 2d cases are missing. Then some tests and
rrulesShould I implement them here or will they be taken care of while implementing a layer in Flux.jl.
Thanks for the PR! Tests and rrules need to be here. I'm not sure what exactly is meant by convenience methods, but if it makes sense to call directly (as opposed to only as a helper for a layer struct) then it should be in NNlib as well.
As per my understanding, Those are the wrappers around this implementation, so that it can handle 1d and 2d inputs too. I'll work on tests but I'm not sure what 'rrule`s are.
rrules are what allow for the AD system and gradients to work, more details at https://juliadiff.org/ChainRulesCore.jl/stable/#frule-and-rrule. For functions like this that use mutation and inner loops, you'll probably need another function that manually performs the backwards pass. They're a little abstracted, but it may be worth looking at how the conv rrules work: https://github.com/FluxML/NNlib.jl/blob/7b8ae45c7fdcb331d63229a0a848e588b7d164a1/src/conv.jl#L223.
Replaced by #444