einops
einops copied to clipboard
[Feature] Jax/Flax Layers, especially Einmix?
EinMix layer looks great but cannot be used with Jax since it's not a function. Would be great to have that EinMix layer in a Jax-based framework like Flax.
Hi @lkhphuc Situation is:
- I'd like to have jax layers, but reasonably I can support one or maybe two layer systems, no more.
- there are a dozen (maybe more) of layer systems for jax with different level of stability and features, which questions the value of supporting any specific system.
If you think that you're settled with Flax - I'm ready to accept PR for it. It shouldn't take much: https://github.com/arogozhnikov/einops/blob/master/einops/layers/torch.py#L30-L62
I made a quick attempt but found that Flax currently has a bug preventing multiple inheritance with Mixin https://github.com/google/flax/discussions/1390 . I will try a different way later.
It's simple enough to use einops for Flax. The solution that I worked out after raising https://github.com/google/flax/discussions/1390, was to just create another Flax nn.Module
especially for the layer viz.
class RearrangeMixin:
"""
From einops/einops/layers/__init__.py
"""
class Rearrange(nn.Module):
"""
Flax Module to act as a Rearrange layer (from einops)
"""
pattern: str
def setup(self):
self.rearranger = RearrangeMixin(self.pattern)
@nn.compact
def __call__(self, input):
return self.rearranger._apply_recipe(input)
@arogozhnikov can I open a PR with this structure ?
Gentle ping @arogozhnikov
@SauravMaheshkar yes, sure
added an experimental version of flax layers, feedback is welcome:
pip install einops==0.5.0