Flux.jl icon indicating copy to clipboard operation
Flux.jl copied to clipboard

Allow `Parallel(+, f)(x, y, z)` to work like broadcasting, and enable `Chain(identity, Parallel(+, f))(x, y, z)`

Open mcabbott opened this issue 1 year ago • 5 comments

At present Parallel allows multiple layers and one input, but not the reverse. This PR extends it to allow both ways... much like broadcasting in connection((inputs .|> layers)...).

julia> Parallel(+, inv)(1, 2, 3)  # was an error
1.8333333333333333

julia> (1,2,3) .|> (inv,)
(1.0, 0.5, 0.3333333333333333)

Does this have any unintended side-effects?

PR Checklist

  • [x] Tests are added
  • [ ] Entry in NEWS.md
  • [x] Documentation, if applicable

mcabbott avatar Mar 10 '24 20:03 mcabbott

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 74.03%. Comparing base (eb6492c) to head (0544711).

:exclamation: Current head 0544711 differs from pull request most recent head 9ee2c69. Consider uploading reports for the commit 9ee2c69 to get more accurate results

Additional details and impacted files
@@             Coverage Diff             @@
##           master    #2393       +/-   ##
===========================================
+ Coverage   43.04%   74.03%   +30.98%     
===========================================
  Files          32       32               
  Lines        1856     1918       +62     
===========================================
+ Hits          799     1420      +621     
+ Misses       1057      498      -559     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov[bot] avatar Mar 10 '24 20:03 codecov[bot]

Here's the complete run-down on where Flux does & doesn't splat at present:

julia> using Flux

julia> pr(x) = begin println("arg: ", x); x end;

julia> pr(x...) = begin println(length(x), " args: ", join(x, " & "), " -> tuple"); x end;

julia> c1 = Chain(pr, pr); ########## simple chain

julia> c1(1)
arg: 1
arg: 1
1

julia> c1((1, 2))
arg: (1, 2)
arg: (1, 2)
(1, 2)

julia> c1(1, 2)
ERROR: MethodError:
Closest candidates are:
  (::Chain)(::Any)

julia> p1 = Parallel(pr, a=pr);  ########## combiner + one layer

julia> p1(1)
arg: 1
arg: 1
1

julia> p1((1, 2))  # one 2-Tuple is NOT accepted, always splatted  --> changed by PR
ERROR: ArgumentError: Parallel with 1 sub-layers can take one input or 1 inputs, but got 2 inputs

julia> p1(1, 2)  # more obvious error  --> changed by PR
ERROR: ArgumentError: Parallel with 1 sub-layers can take one input or 1 inputs, but got 2 inputs

julia> p1((a=1, b=2))  # one NamedTuple is ok
arg: (a = 1, b = 2)
arg: (a = 1, b = 2)
(a = 1, b = 2)

julia> p1((((1,),),))  # splatted many times
arg: 1
arg: 1
1

julia> p2 = Parallel(pr, a=pr, b=pr);  ########## combiner + two layers

julia> p2(1)  # one non-tuple arg is broadcasted
arg: 1
arg: 1
2 args: 1 & 1 -> tuple
(1, 1)

julia> p2(1, 2)  # 2 args sent to 2 layers
arg: 1
arg: 2
2 args: 1 & 2 -> tuple
(1, 2)

julia> p2((1, 2))  # one tuple splatted
arg: 1
arg: 2
2 args: 1 & 2 -> tuple
(1, 2)

julia> p2((a=1, b=2))  # one NamedTuple sent to both
arg: (a = 1, b = 2)
arg: (a = 1, b = 2)
2 args: (a = 1, b = 2) & (a = 1, b = 2) -> tuple
((a = 1, b = 2), (a = 1, b = 2))

julia> p2(((1,2), ((3,4),)))  # only splatted once
arg: (1, 2)
arg: ((3, 4),)
2 args: (1, 2) & ((3, 4),) -> tuple
((1, 2), ((3, 4),))

julia> Chain(pr, p2, pr)((1, 2))  # here earlier layers cannot pass p2 two arguments
arg: (1, 2)
arg: 1
arg: 2
2 args: 1 & 2 -> tuple
arg: (1, 2)
(1, 2)

This PR changes the two error cases above:

julia> p1((1, 2))  # changed by PR
arg: 1
arg: 2
2 args: 1 & 2 -> tuple
(1, 2)

julia> p1(1, 2)  # changed by PR
arg: 1
arg: 2
2 args: 1 & 2 -> tuple
(1, 2)

You could argue that p1((1, 2)) already has a plausible meaning, apply one layer to one input Tuple. But this use of Parallel is really just Chain (or in this order, ). And it's an error at present.

I think p1(1, 2) has no other plausible meaning.

The rule after this PR is:

  1. (p::Paralel)(input::Tuple) always splats to p(input...)
  2. return combine((inputs .|> layers)...)

Step 1 is unchanged, but step 2 previously allowed only broadcasting of the input. And today, I have a use where I want to broadcast the layer instead (easier than sharing it). That's in fact the 3rd case mentioned here: https://github.com/FluxML/Flux.jl/issues/1685#issuecomment-890562799 but I think it never worked.

mcabbott avatar Mar 13 '24 02:03 mcabbott

Reading old threads... around here https://github.com/FluxML/Flux.jl/pull/2101#issuecomment-1306061980 it was agreed that adding (c::Chain)(xs...) = c(xs) would make sense, but there was never a PR.

That's the first MethodError in my list above. I would like this too, and perhaps should just add it here.

mcabbott avatar Mar 13 '24 03:03 mcabbott

Anyone remember why we allow Parallel(hcat)? You can write Returns(hcat()) if you really want that...

julia> Parallel(hcat)()
Any[]

julia> Parallel(hcat)(NaN)  # ignores input, but this case is tested
Any[]

julia> Parallel(hcat)(1,2,3)
ERROR: ArgumentError: Parallel with 0 sub-layers can take one input or 0 inputs, but got 3 inputs

Can we just make this an error on construction? I think that's basically what was agreed in https://github.com/FluxML/Flux.jl/issues/1685

mcabbott avatar Mar 13 '24 14:03 mcabbott

I put this on 0.15 milestone... I still think it's the right thing to do, but perhaps a breaking change is the right time to merge it.

mcabbott avatar Oct 16 '24 14:10 mcabbott