equinox
equinox copied to clipboard
[Question] Using eqx.Linear with inferred input size?
Hi Patrick! I'm currently converting a Haiku codebase to use equinox, and there's a lot of boilerplate that tries to construct MLPs before knowing the full size of the input data.
Without going into too many details, there's a lot of scaffolding involved in embedding these MLPs in a graph structure -- the codebase decides to do all this scaffolding at init time, before the shapes of the input features are known. For Haiku linear layers, that's not a problem: one can infer the input size at runtime by default. For other frameworks (such as eqx), input shapes need to be known before object instantiation.
I have the gut feeling that we're starting to get into the territory of dynamic shapes under jit, but I was wondering if you'd have any advice on the following actions:
- Creating a new MLP every time the main network is called, such that the shapes are correctly inferred? (feels sub-optimal, especially since the input shape will be the same for every runtime instance of the MLP)
- Patch eqx.Linear somehow to infer the input shape at runtime, when the MLP object is already instantiated?
- Make a PR to equinox to allow input sizes to be inferred, provided this doesn't invoke any funny dynamic shape business?
Looking forward to your thoughts!