NNlib.jl icon indicating copy to clipboard operation
NNlib.jl copied to clipboard

Add fold.jl

Open spazewalker opened this issue 4 years ago • 6 comments

Adding Fold/Unfold functions in reference to pytorch feature parity here.

spazewalker avatar Mar 19 '21 14:03 spazewalker

from a quick glance looks mostly ok, thanks. Convenience methods for the 1d and 2d cases are missing. Then some tests and rrules

CarloLucibello avatar Mar 19 '21 15:03 CarloLucibello

Thanks so much to look into this! I've left a couple thoughts but overall I'm wondering if you've checked this on the GPU as well

I haven't tested it on GPU yet, I'll check it once.

spazewalker avatar Mar 19 '21 17:03 spazewalker

We might not have a layer in or maybe we need to just hold the config to construct the cdims object.

DhairyaLGandhi avatar Mar 19 '21 17:03 DhairyaLGandhi

from a quick glance looks mostly ok, thanks. Convenience methods for the 1d and 2d cases are missing. Then some tests and rrules

Should I implement them here or will they be taken care of while implementing a layer in Flux.jl.

Thanks for the PR! Tests and rrules need to be here. I'm not sure what exactly is meant by convenience methods, but if it makes sense to call directly (as opposed to only as a helper for a layer struct) then it should be in NNlib as well.

ToucheSir avatar Mar 19 '21 17:03 ToucheSir

We might not have a layer in or maybe we need to just hold the config to construct the cdims object.

Okay, so I'll try to implement those convenience wrappers.

from a quick glance looks mostly ok, thanks. Convenience methods for the 1d and 2d cases are missing. Then some tests and rrules

Should I implement them here or will they be taken care of while implementing a layer in Flux.jl.

Thanks for the PR! Tests and rrules need to be here. I'm not sure what exactly is meant by convenience methods, but if it makes sense to call directly (as opposed to only as a helper for a layer struct) then it should be in NNlib as well.

As per my understanding, Those are the wrappers around this implementation, so that it can handle 1d and 2d inputs too. I'll work on tests but I'm not sure what 'rrule`s are.

spazewalker avatar Mar 19 '21 17:03 spazewalker

rrules are what allow for the AD system and gradients to work, more details at https://juliadiff.org/ChainRulesCore.jl/stable/#frule-and-rrule. For functions like this that use mutation and inner loops, you'll probably need another function that manually performs the backwards pass. They're a little abstracted, but it may be worth looking at how the conv rrules work: https://github.com/FluxML/NNlib.jl/blob/7b8ae45c7fdcb331d63229a0a848e588b7d164a1/src/conv.jl#L223.

ToucheSir avatar Mar 19 '21 17:03 ToucheSir

Replaced by #444

mcabbott avatar Nov 28 '22 17:11 mcabbott