netket icon indicating copy to clipboard operation
netket copied to clipboard

Add RNN

Open wdphy16 opened this issue 3 years ago • 4 comments

RNNs are a subclass of ARNNs, and they reuse ARNN's sampler. They support complex parameters using the same normalization as ARNNs, and I'll leave the mod-phase version to a more general implementation of ModPhase.

When using an RNN on a 2D (or higher-D) lattice, we usually need a non-trivial autoregressive order (like the snake order) to utilize the locality of the lattice. For 1D RNN it's straightforward to implement because the RNN cell only needs to access the previous site in the autoregressive order, without knowing the lattice geometry. Currently 1D RNN already supports arbitrary ordering.

However, for 2D RNN it's complicated to implement the ordering, because the RNN cell needs to access previous spatial neighbors, which are defined by both the ordering ('previous') and the lattice geometry. Currently 2D RNN only supports snake ordering for square lattice (and there is actually no check for square lattice), which is hard-coded in nk.nn.rnn_2d._get_h_xy and nk.models.rnn.LSTMNet2D.setup. If we really want to support arbitrary ordering, we need a fast and easy-to-use way to access previous spatial neighbors given the ordering and the lattice, and make sure that the number of neighbors is the same for every step (if not... that's the job of graph RNN).

The methods reorder and inverse_reorder have been added to AbstractARNN because ARDirectSampler needs them, but currently there is no way to specify ordering for previous ARNNs.

wdphy16 avatar Aug 26 '22 12:08 wdphy16

Codecov Report

Attention: 35 lines in your changes are missing coverage. Please review.

Comparison is base (202cab6) 82.82% compared to head (6c55ed7) 83.10%.

:exclamation: Current head 6c55ed7 differs from pull request most recent head 671ec10. Consider uploading reports for the commit 671ec10 to get more accurate results

Files Patch % Lines
netket/experimental/nn/rnn/ordering.py 80.99% 11 Missing and 12 partials :warning:
netket/models/autoreg.py 68.42% 3 Missing and 3 partials :warning:
netket/utils/array.py 70.00% 3 Missing :warning:
netket/experimental/nn/rnn/layers.py 96.96% 1 Missing and 1 partial :warning:
netket/sampler/autoreg.py 90.90% 1 Missing :warning:
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1305      +/-   ##
==========================================
+ Coverage   82.82%   83.10%   +0.27%     
==========================================
  Files         279      287       +8     
  Lines       16913    17311     +398     
  Branches     3279     3334      +55     
==========================================
+ Hits        14009    14387     +378     
- Misses       2280     2291      +11     
- Partials      624      633       +9     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov[bot] avatar Aug 26 '22 13:08 codecov[bot]

I guess the type object RNN.Layer is not hashable in Python 3.7. Let me try to use a method instead...

wdphy16 avatar Aug 26 '22 13:08 wdphy16

I think there should be a way to make it work. Check how flax implements the Sequential model

PhilipVinc avatar Aug 26 '22 13:08 PhilipVinc

That's different. The arguments of Sequential are module instances, not module types, but here I wanted to reuse some code by only passing the type

wdphy16 avatar Aug 26 '22 13:08 wdphy16

I've added a way to specify prev_neighbors in RNN, and generate prev_neighbors from any graph in models/rnn.py. Therefore, LSTMNet supports any geometry, and there is no need to separately implement LSTMNet1D and LSTMNet2D. I didn't see any performance regression in some benchmarks. What do you think?

wdphy16 avatar Nov 07 '22 16:11 wdphy16

@wdphy16 did you see how they do RNN in flax with the listen transforms? https://github.com/google/flax/pull/2604/files

I think this is much cleaner, don't you think?

PhilipVinc avatar Dec 01 '22 18:12 PhilipVinc

Hi @PhilipVinc I've read Flax's RNN implementation but it turned out we cannot reuse their code.

  1. They have reverse which is mainly used in 1D bidirectional RNN, while we need to do reordering in 2D (mainly the snake ordering but arbitrary ordering is also supported). The bidirectional RNN is not autoregressive so we cannot implement it in this RNN class, although it can be implemented in a higher-level class. There're even more ways to aggregate different orderings in 2D, such as leftward-rightward and quaddirectional. They may be either autoregressive or not, and I think they're out of the scope of this RNN class.

  2. They have initial_carry and init_key, but in NetKet we don't have a way to pass a PRNG key to model.__call__. (Maybe that has some use in disordered systems but I haven't thought about it yet.) For the purpose of this RNN class we just initialize the carry to zeros. It's possible to initialize the carry to a constant vector or make it trainable, but I'd like to leave that until someone really uses that.

  3. They have seq_lengths and some logic of end-sequence padding, but I think that's not used in NQS, and we always assume that there are N spatial sites in the inputs.

  4. The multiple previous neighbors at each site also make us unable to reuse their RNN cells. In our default LSTM cell there is some logic to concatenate the neighbors, and the user may define other ways to handle different neighbors.

  5. In the RNN layer I tried to use Flax's lifted scan but that's definitely not cleaner and makes huge mental burden to me. Currently in scan_func inside RNNLayer.__call__ we need to first handle the ordering then call the RNN cell, and I'd like to put the logic of the ordering in the RNN layer rather than every RNN cell, but it seems the lifted scan does not work well with this kind of closure. Also, in ARDirectSampler there is another scan...

  6. We need FastRNNLayer.update_site for fast AR sampling.

  7. Names for the classes: They use RNN for an RNN layer, while in NetKet we use 'network' like LSTMNet or ARNNDense for a model, and we explicitly say RNNLayer for a layer.

The current RNN class is enough for implementing Mohamed's paper without the phase part. To implement my tensor-RNN paper we still need some change when initializing the cache. Namely, some variables are independent of the model parameters (like carry), while some depend on the parameters (like gamma in tensor-RNN) and must be initialized after Module.setup(). We can talk about that design after finishing this PR.

wdphy16 avatar Oct 10 '23 10:10 wdphy16

Ok I've made the changes, mostly to the docs

wdphy16 avatar Oct 23 '23 13:10 wdphy16

Now we've found how to use the lifted scan and I think the code is clearer than before

Surprisingly, the jax scan in ARDirectSampler just works without lifting

I'll check again if there is performance regression comparing to my rnn_old branch

wdphy16 avatar Nov 13 '23 23:11 wdphy16

I did some profiling for 1D and 2D LSTMs and I didn't see any performance regression on my laptop (CPU) and my workstation (GPU)

wdphy16 avatar Nov 14 '23 07:11 wdphy16

great, good to know!

PhilipVinc avatar Nov 15 '23 01:11 PhilipVinc