deq-jax icon indicating copy to clipboard operation
deq-jax copied to clipboard

Support of broyden to arbitrarily shaped inputs

Open zaccharieramzi opened this issue 5 years ago • 2 comments

Hi there,

Glad to be writing the first issue of this repo. In the documentation of broyden, you mention that it doesn't yet accept arbitrarily shaped inputs.

I figured from the "yet" that there was a plan to support it at some point, do you have in mind a timeline for this? I'd be happy to contribute as well if needed.

zaccharieramzi avatar Dec 10 '20 17:12 zaccharieramzi

Actually, what @jerrybai1995 used in mdeq, at least for images, is to define a function that can translate (a.k.a. reshape) a list of image-like tensors to a tensor with the correct dimensions.

I guess it's not very modular, but at least it's a beginning. If I find some time, I can try to contribute on that if it's something you consider.

zaccharieramzi avatar Dec 11 '20 15:12 zaccharieramzi

Hey @zaccharieramzi, this is a great feature request!

For JAX to compile to XLA its necessary for shapes to be defined ahead of time. So instead of supporting any shape, we could define a fixed API.

The broyden method optimises a function g(x, *args) with respect to parameter x.

In our implementation this was for optimising the hidden state of a Transformer or Trellisnet, the input shape is (batch_size, hidden_size, sequence_length) and g is the network layer. So the shape of x is really defined by g.

Instead of supporting any shape or our current arbitrary choice, we could define a fixed API, e.g where the input shape is (batch_size, feature_size) and let users deal with their own reshaping.

This should be an easy change and I'd start by changing the linear algebra used within matvec and rmatvec, which assume the shape explicitly. The rest I believe is plumbing.

I'd also be interested in seeing if this changes performance. XLA can make reshapes very cheap but i'd assume this is adding computation at the cost of flexibility!

akbir avatar Dec 12 '20 14:12 akbir