equinox icon indicating copy to clipboard operation
equinox copied to clipboard

[Question] Using eqx.Linear with inferred input size?

Open phinate opened this issue 1 year ago • 1 comments

Hi Patrick! I'm currently converting a Haiku codebase to use equinox, and there's a lot of boilerplate that tries to construct MLPs before knowing the full size of the input data.

Without going into too many details, there's a lot of scaffolding involved in embedding these MLPs in a graph structure -- the codebase decides to do all this scaffolding at init time, before the shapes of the input features are known. For Haiku linear layers, that's not a problem: one can infer the input size at runtime by default. For other frameworks (such as eqx), input shapes need to be known before object instantiation.

I have the gut feeling that we're starting to get into the territory of dynamic shapes under jit, but I was wondering if you'd have any advice on the following actions:

  • Creating a new MLP every time the main network is called, such that the shapes are correctly inferred? (feels sub-optimal, especially since the input shape will be the same for every runtime instance of the MLP)
  • Patch eqx.Linear somehow to infer the input shape at runtime, when the MLP object is already instantiated?
  • Make a PR to equinox to allow input sizes to be inferred, provided this doesn't invoke any funny dynamic shape business?

Looking forward to your thoughts!

phinate avatar Nov 28 '23 14:11 phinate

Heyhey! It's great to hear from you. Okay, so I don't think this has anything to do with dynamic under JIT or anything like that.

So this kind of thing just isn't really supported in JAX, and it's one of the things that always made frameworks like Haiku a bit of an odd fit. You need to initialise your parameters before runtime! You can't really have JIT'd regions go around creating parameters for you on the fly -- a JIT'd region is basically just a big block of math compiled in XLA, nothing else.

More generally, it's not really clear how to make this kind of "deferred building of the computation graph" work with transforms that manipulate the computation graph, like jax.vmap or jax.grad, or higher-order operations like jax.lax.scan or diffrax.diffeqsolve.

(TL;DR: this kind of approach tends to work in simple-ish use cases but falls down in more general ones.)

So my first recommendation is really just to try and write things in idiomatic Equinox, re-organising the scaffold as necessary.

But if you want a quick fix, would it be possible to just wrap the whole init-time procedure in a function that takes some example args (and extracts the necessary input size from their shape)? Then perform your initialisation just before your first runtime, when you have some example arguments available.

patrick-kidger avatar Nov 28 '23 15:11 patrick-kidger