flax icon indicating copy to clipboard operation
flax copied to clipboard

Redesign lstm

Open NeilGirdhar opened this issue 3 years ago • 3 comments

Redesign recurrent modules classes so that initialization happens in one place rather than two. This makes dtype, shape, and batching initialization simpler and less error-prone. Otherwise, the user has to be careful that dtypes, shapes, and batch dimensions match between carry initialization and iteration. This way, all the initialization is done through object attributes, and the initialize_carry function is the only method passed to init. The call method is only passed to apply.

This probably needs more discussion, but the goal is to comlete the dtype assumption problem described in the FLIP.

Fixes #1776

Checklist

  • [x] This change is discussed in a Github issue/ discussion (please add a link).
  • [x] The documentation and docstrings adhere to the documentation guidelines.
  • [x] This change includes necessary high-coverage tests. (No quality testing = no merge!)

NeilGirdhar avatar Feb 16 '22 12:02 NeilGirdhar

@marcvanzee Thanks for the review!

In order to make your PR a bit easier to review, do you think it is possible to keep your change as self-contained a possible?

I'm not sure if you noticed, but the change was actually simply rebased onto https://github.com/google/flax/pull/1803 because this change depends on that one. (You could always git diff HEAD~1 to see these changes, but I understand that that's annoying).

To make your review easier, I've rebased it onto main. However, _canonicalize_dtype won't be defined since it depends on #1803 (defined here).

NeilGirdhar avatar Feb 20 '22 13:02 NeilGirdhar

Hey @NeilGirdhar! Got interest in this as I started working on #2126.

I read through it a bit and agree that making initialize_carry a regular method and passing the inputs can help to figure out the proper dtype according the the recent FLIP. However, something to consider is that this PR introduces breaking change to initialize_carry which may be a tough sell.

Alternatively we could create a new initialize_carry_with_inputs that is a regular method so it has access to self.dtype, and doesn't introduce a breaking change.

Things I am not too sure of in this PR:

  • Implication of initializing the whole module with this method?
  • The usage of self.make_rng vs having an rng argument?

Maybe @jheek / @levskaya could comment a bit about the design choices of the current API.

cgarciae avatar May 17 '22 18:05 cgarciae

@cgarciae The motivation for this pull request was to prevent repeated parameters to initialization, which are currently passed in two places:

  • as parameters to the static method initialize_carry, and
  • as member variables visible to the Flax method __call__.

Passing repeated parameters is poor design because it counts on the caller to pass the same values in both places. Checking that the same thing was passed both times would be a minor improvement, and isn't easy to do.

The idea of this pull request is to have all of the parameters in one place: as member variables, which are visible to the Flax methods initialize_carry (no longer static) and __call__. This is how nearly all other Flax modules accept parameters.

I suspect the reason it was done this way is to try to keep the pattern that Flax modules are initialized by calling init on the __call__ method. Once that was decided, then the carry initializer had to be a static method since it needs to happen first. This pull request uses a different pattern, which is to make initialize_carry the initializer (called with init), and __call__ the update method (called with apply).

The reason this was originally part of the dtype pull request is because dtypes are one of the repeated parameters that I wanted to merge.

Good luck!

Alternatively we could create a new initialize_carry_with_inputs that is a regular method so it has access to self.dtype, and doesn't introduce a breaking change.

That's a great idea!

NeilGirdhar avatar May 17 '22 19:05 NeilGirdhar