einops icon indicating copy to clipboard operation
einops copied to clipboard

[Feature] Jax/Flax Layers, especially Einmix?

Open lkhphuc opened this issue 3 years ago • 6 comments

EinMix layer looks great but cannot be used with Jax since it's not a function. Would be great to have that EinMix layer in a Jax-based framework like Flax.

lkhphuc avatar Dec 05 '21 19:12 lkhphuc

Hi @lkhphuc Situation is:

  • I'd like to have jax layers, but reasonably I can support one or maybe two layer systems, no more.
  • there are a dozen (maybe more) of layer systems for jax with different level of stability and features, which questions the value of supporting any specific system.

If you think that you're settled with Flax - I'm ready to accept PR for it. It shouldn't take much: https://github.com/arogozhnikov/einops/blob/master/einops/layers/torch.py#L30-L62

arogozhnikov avatar Dec 06 '21 03:12 arogozhnikov

I made a quick attempt but found that Flax currently has a bug preventing multiple inheritance with Mixin https://github.com/google/flax/discussions/1390 . I will try a different way later.

lkhphuc avatar Dec 06 '21 15:12 lkhphuc

It's simple enough to use einops for Flax. The solution that I worked out after raising https://github.com/google/flax/discussions/1390, was to just create another Flax nn.Module especially for the layer viz.

class RearrangeMixin:
    """
    From einops/einops/layers/__init__.py 
    """


class Rearrange(nn.Module):
    """
    Flax Module to act as a Rearrange layer (from einops)
    """

    pattern: str

    def setup(self):
        self.rearranger = RearrangeMixin(self.pattern)

    @nn.compact
    def __call__(self, input):
        return self.rearranger._apply_recipe(input)

SauravMaheshkar avatar Mar 06 '22 08:03 SauravMaheshkar

@arogozhnikov can I open a PR with this structure ?

SauravMaheshkar avatar Mar 09 '22 18:03 SauravMaheshkar

Gentle ping @arogozhnikov

SauravMaheshkar avatar Jun 23 '22 17:06 SauravMaheshkar

@SauravMaheshkar yes, sure

arogozhnikov avatar Jun 27 '22 10:06 arogozhnikov

added an experimental version of flax layers, feedback is welcome: pip install einops==0.5.0

arogozhnikov avatar Oct 03 '22 06:10 arogozhnikov