flowjax
flowjax copied to clipboard
Consider adding orthogonal transformations
Instead of permutations, one can consider learnable linear transformations (See section 3.2 of the Papamakarios review.). As far as I understand it, in theory this could allow us to learn which conditionals are easier to fit within a model during training. Obviously this comes with a cost, so may not be practical in all cases.
I have had some decent results with simple versions of this, albeit in low dimensions (< 5), so it could be something worth adding.
e.g. see https://github.com/bayesiains/nflows/blob/master/nflows/transforms/orthogonal.py
That would definitely be interesting to have