RationalQuadraticSpline parameter ordering issue
Thanks for the package, good that it exists. I'm bridging it with Flux.jl and noticed convention differences.
Current RationalQuadraticSpline uses (D, K) parameter ordering, which conflicts with Flux.jl convention where the first dimension is the "feature" dimension.
Current Behavior
# D=2 dimensions, K=5 knots per spline
block_25 = randn(2, 5) # (D, K)
b = RationalQuadraticSpline(block_25, block_25, block_25)
Expected (Flux.jl-like)
# K=5 knots, D=2 dimensions
block_52 = randn(5, 2) # (K, D) - complexity first, batch second
b = RationalQuadraticSpline(block_52, block_52, block_52)
UI
it impacts how the argument is shaped,
D = 15
K = 4
#
b = Bijectors.RationalQuadraticSpline(rand(D, K), rand(D, K), rand(D, K))
b(rand(D, 1)) # does not work
b(rand(1,D)) # does not work - that is what I want for for (K,D)
Currently, for the vector input it does something unintuitive and risky, -- partially processes dimensions.
b(rand(3)) # works, 3<15
b(rand(15)) #works,
b(rand(16)) # breaks, since 16>15
Question
Which ordering should Bijectors.jl adopt for future extensions? - I'm thinking of a PR
- Current style:
(Dn, ..., D1, K)- batch dimensions first - Flux.jl style:
(K, D1, ..., Dn)- complexity first
Btw
distrax using the last dimension for the features
Not just Flux, Bijectors.PartitionMask uses the same conventions, (features, batch...)
Here is a distrax logic,
This bijector is applied elementwise. Given some input `x`, the parameters
`params` and the input `x` are broadcast against each other. For example,
suppose `x` is of shape `[N, D]`. Then:
- If `params` is of shape `[3 * num_bins + 1]`, the same spline is identically
applied to each element of `x`.
- If `params` is of shape `[D, 3 * num_bins + 1]`, the same spline is applied
along the first axis of `x` but a different spline is applied along the
second axis of `x`.
- If `params` is of shape `[N, D, 3 * num_bins + 1]`, a different spline is
applied to each element of `x`.
- If `params` is of shape `[M, N, D, 3 * num_bins + 1]`, `M` different splines
are applied to each element of `x`, and the output is of shape `[M, N, D]`.
What would be corresponding logic in Bijectors?
Context
using Flux
NN = Dense(randn(12, 3), randn(12)) # 3 => 12 transformation
x1 = rand(3)
x2 = rand(3, 10)
x3 = rand(3, 10, 4)
x4 = rand(3, 10, 4, 8)
# gives
(12,)
(12, 10)
(12, 10, 4)
(12, 10, 4, 8)
m = Bijectors.PartitionMask(3, [2], [1])
Bijectors.partition(m, x1) |> first |> size # correct (1, )
Bijectors.partition(m, x2) |> first |> size # correct (1, 10)
Bijectors.partition(m, x3) # breaks
Bijectors.partition(m, x4) # breaks