Bijectors.jl icon indicating copy to clipboard operation
Bijectors.jl copied to clipboard

RationalQuadraticSpline parameter ordering issue

Open mmikhasenko opened this issue 5 months ago • 2 comments

Thanks for the package, good that it exists. I'm bridging it with Flux.jl and noticed convention differences.

Current RationalQuadraticSpline uses (D, K) parameter ordering, which conflicts with Flux.jl convention where the first dimension is the "feature" dimension.

Current Behavior

# D=2 dimensions, K=5 knots per spline
block_25 = randn(2, 5)    # (D, K)
b = RationalQuadraticSpline(block_25, block_25, block_25)

Expected (Flux.jl-like)

# K=5 knots, D=2 dimensions  
block_52 = randn(5, 2)    # (K, D) - complexity first, batch second
b = RationalQuadraticSpline(block_52, block_52, block_52)

UI

it impacts how the argument is shaped,

D = 15
K = 4
#
b = Bijectors.RationalQuadraticSpline(rand(D, K), rand(D, K), rand(D, K))

b(rand(D, 1)) # does not work
b(rand(1,D)) # does not work - that is what I want for for (K,D)

Currently, for the vector input it does something unintuitive and risky, -- partially processes dimensions.

b(rand(3)) # works, 3<15
b(rand(15)) #works, 
b(rand(16)) # breaks, since 16>15

Question

Which ordering should Bijectors.jl adopt for future extensions? - I'm thinking of a PR

  1. Current style: (Dn, ..., D1, K) - batch dimensions first
  2. Flux.jl style: (K, D1, ..., Dn) - complexity first

Btw

distrax using the last dimension for the features

mmikhasenko avatar Jul 19 '25 18:07 mmikhasenko

Not just Flux, Bijectors.PartitionMask uses the same conventions, (features, batch...)

mmikhasenko avatar Jul 19 '25 19:07 mmikhasenko

Here is a distrax logic,

  This bijector is applied elementwise. Given some input `x`, the parameters
  `params` and the input `x` are broadcast against each other. For example,
  suppose `x` is of shape `[N, D]`. Then:
  - If `params` is of shape `[3 * num_bins + 1]`, the same spline is identically
    applied to each element of `x`.
  - If `params` is of shape `[D, 3 * num_bins + 1]`, the same spline is applied
    along the first axis of `x` but a different spline is applied along the
    second axis of `x`.
  - If `params` is of shape `[N, D, 3 * num_bins + 1]`, a different spline is
    applied to each element of `x`.
  - If `params` is of shape `[M, N, D, 3 * num_bins + 1]`, `M` different splines
    are applied to each element of `x`, and the output is of shape `[M, N, D]`.

What would be corresponding logic in Bijectors?

Context

using Flux

NN = Dense(randn(12, 3), randn(12)) # 3 => 12 transformation

x1 = rand(3)
x2 = rand(3, 10)
x3 = rand(3, 10, 4)
x4 = rand(3, 10, 4, 8)

# gives
(12,)
(12, 10)
(12, 10, 4)
(12, 10, 4, 8)

m = Bijectors.PartitionMask(3, [2], [1])

Bijectors.partition(m, x1) |> first |> size # correct (1, )
Bijectors.partition(m, x2) |> first |> size # correct (1, 10) 
Bijectors.partition(m, x3) # breaks 
Bijectors.partition(m, x4) # breaks

mmikhasenko avatar Jul 19 '25 20:07 mmikhasenko