equinox icon indicating copy to clipboard operation
equinox copied to clipboard

[Question] Using eqx.Linear with inferred input size?

Open phinate opened this issue 7 months ago • 1 comments

Hi Patrick! I'm currently converting a Haiku codebase to use equinox, and there's a lot of boilerplate that tries to construct MLPs before knowing the full size of the input data.

Without going into too many details, there's a lot of scaffolding involved in embedding these MLPs in a graph structure -- the codebase decides to do all this scaffolding at init time, before the shapes of the input features are known. For Haiku linear layers, that's not a problem: one can infer the input size at runtime by default. For other frameworks (such as eqx), input shapes need to be known before object instantiation.

I have the gut feeling that we're starting to get into the territory of dynamic shapes under jit, but I was wondering if you'd have any advice on the following actions:

  • Creating a new MLP every time the main network is called, such that the shapes are correctly inferred? (feels sub-optimal, especially since the input shape will be the same for every runtime instance of the MLP)
  • Patch eqx.Linear somehow to infer the input shape at runtime, when the MLP object is already instantiated?
  • Make a PR to equinox to allow input sizes to be inferred, provided this doesn't invoke any funny dynamic shape business?

Looking forward to your thoughts!

phinate avatar Nov 28 '23 14:11 phinate