Flux.jl
Flux.jl copied to clipboard
Allow `Parallel(+, f)(x, y, z)` to work like broadcasting, and enable `Chain(identity, Parallel(+, f))(x, y, z)`
At present Parallel allows multiple layers and one input, but not the reverse. This PR extends it to allow both ways... much like broadcasting in connection((inputs .|> layers)...).
julia> Parallel(+, inv)(1, 2, 3) # was an error
1.8333333333333333
julia> (1,2,3) .|> (inv,)
(1.0, 0.5, 0.3333333333333333)
Does this have any unintended side-effects?
PR Checklist
- [x] Tests are added
- [ ] Entry in NEWS.md
- [x] Documentation, if applicable
Codecov Report
All modified and coverable lines are covered by tests :white_check_mark:
Project coverage is 74.03%. Comparing base (
eb6492c) to head (0544711).
:exclamation: Current head 0544711 differs from pull request most recent head 9ee2c69. Consider uploading reports for the commit 9ee2c69 to get more accurate results
Additional details and impacted files
@@ Coverage Diff @@
## master #2393 +/- ##
===========================================
+ Coverage 43.04% 74.03% +30.98%
===========================================
Files 32 32
Lines 1856 1918 +62
===========================================
+ Hits 799 1420 +621
+ Misses 1057 498 -559
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
Here's the complete run-down on where Flux does & doesn't splat at present:
julia> using Flux
julia> pr(x) = begin println("arg: ", x); x end;
julia> pr(x...) = begin println(length(x), " args: ", join(x, " & "), " -> tuple"); x end;
julia> c1 = Chain(pr, pr); ########## simple chain
julia> c1(1)
arg: 1
arg: 1
1
julia> c1((1, 2))
arg: (1, 2)
arg: (1, 2)
(1, 2)
julia> c1(1, 2)
ERROR: MethodError:
Closest candidates are:
(::Chain)(::Any)
julia> p1 = Parallel(pr, a=pr); ########## combiner + one layer
julia> p1(1)
arg: 1
arg: 1
1
julia> p1((1, 2)) # one 2-Tuple is NOT accepted, always splatted --> changed by PR
ERROR: ArgumentError: Parallel with 1 sub-layers can take one input or 1 inputs, but got 2 inputs
julia> p1(1, 2) # more obvious error --> changed by PR
ERROR: ArgumentError: Parallel with 1 sub-layers can take one input or 1 inputs, but got 2 inputs
julia> p1((a=1, b=2)) # one NamedTuple is ok
arg: (a = 1, b = 2)
arg: (a = 1, b = 2)
(a = 1, b = 2)
julia> p1((((1,),),)) # splatted many times
arg: 1
arg: 1
1
julia> p2 = Parallel(pr, a=pr, b=pr); ########## combiner + two layers
julia> p2(1) # one non-tuple arg is broadcasted
arg: 1
arg: 1
2 args: 1 & 1 -> tuple
(1, 1)
julia> p2(1, 2) # 2 args sent to 2 layers
arg: 1
arg: 2
2 args: 1 & 2 -> tuple
(1, 2)
julia> p2((1, 2)) # one tuple splatted
arg: 1
arg: 2
2 args: 1 & 2 -> tuple
(1, 2)
julia> p2((a=1, b=2)) # one NamedTuple sent to both
arg: (a = 1, b = 2)
arg: (a = 1, b = 2)
2 args: (a = 1, b = 2) & (a = 1, b = 2) -> tuple
((a = 1, b = 2), (a = 1, b = 2))
julia> p2(((1,2), ((3,4),))) # only splatted once
arg: (1, 2)
arg: ((3, 4),)
2 args: (1, 2) & ((3, 4),) -> tuple
((1, 2), ((3, 4),))
julia> Chain(pr, p2, pr)((1, 2)) # here earlier layers cannot pass p2 two arguments
arg: (1, 2)
arg: 1
arg: 2
2 args: 1 & 2 -> tuple
arg: (1, 2)
(1, 2)
This PR changes the two error cases above:
julia> p1((1, 2)) # changed by PR
arg: 1
arg: 2
2 args: 1 & 2 -> tuple
(1, 2)
julia> p1(1, 2) # changed by PR
arg: 1
arg: 2
2 args: 1 & 2 -> tuple
(1, 2)
You could argue that p1((1, 2)) already has a plausible meaning, apply one layer to one input Tuple. But this use of Parallel is really just Chain (or in this order, ∘). And it's an error at present.
I think p1(1, 2) has no other plausible meaning.
The rule after this PR is:
(p::Paralel)(input::Tuple)always splats top(input...)- return
combine((inputs .|> layers)...)
Step 1 is unchanged, but step 2 previously allowed only broadcasting of the input. And today, I have a use where I want to broadcast the layer instead (easier than sharing it). That's in fact the 3rd case mentioned here: https://github.com/FluxML/Flux.jl/issues/1685#issuecomment-890562799 but I think it never worked.
Reading old threads... around here https://github.com/FluxML/Flux.jl/pull/2101#issuecomment-1306061980 it was agreed that adding (c::Chain)(xs...) = c(xs) would make sense, but there was never a PR.
That's the first MethodError in my list above. I would like this too, and perhaps should just add it here.
Anyone remember why we allow Parallel(hcat)? You can write Returns(hcat()) if you really want that...
julia> Parallel(hcat)()
Any[]
julia> Parallel(hcat)(NaN) # ignores input, but this case is tested
Any[]
julia> Parallel(hcat)(1,2,3)
ERROR: ArgumentError: Parallel with 0 sub-layers can take one input or 0 inputs, but got 3 inputs
Can we just make this an error on construction? I think that's basically what was agreed in https://github.com/FluxML/Flux.jl/issues/1685
I put this on 0.15 milestone... I still think it's the right thing to do, but perhaps a breaking change is the right time to merge it.