MarkovRNNs icon indicating copy to clipboard operation
MarkovRNNs copied to clipboard

Pytorch implementation of Markov RNNs

Markov Recurrent Neural Networks

This repository is the PyTorch implementation of Markov Recurrent Neural Networks with two temporal datasets as quick demonstration.

Architecture

  • Heuristic explanation: MRNN is built as a deep learning model for time series, such as NLP, stock price prediction, or gravitational wave detection. The main idea is to create several parallel RNNs (LSTMs) to learn the time dependence of the data simultaneously. If data has complex temporal structures (behaviour), single RNN may not be enough to carry out the pattern. k parallel RNNs (k=1,2,3,...) can read same input signal at the same time, each learns different character of data. Then another latent variable z (also trained by networks) will determine when and which LSTM should be listened for attaining learning task (see Fig q_z(t) & z(t) below). The choosing mechanism by z itself is a process stochastic modeling of transitions between k LSTMs based on Markov property, and hence the name MRNN.

  • Note: The transition variable z between k LSTMs can regarded as an attention mechanism over individual LSTM hidden states.

Datasets

  • MNIST viewed in series as sequential input:

  • Artificial alien signals: I am imagining we are able to recognize radio signals sent by aliens from the sky such as SETI, where I generated two kinds of wave forms for Markov RNN to distinguish:

Prerequisites

Usage

No installation required except the prerequisites. File kLSTM.py contains all the modules needed for running Markov RNN. Two examples are provided in Jupyter notebook formats:

MRNN_MNIST.ipynb
MRNN_detect_alien_signal.ipynb

Results & Interpretations

  1. Take k=4 LSTM for MNIST.
  • In figure of q_z(t), the horizontal axis is time, vertical axis shows which LSTM to look. The color palette indicates the probability of the LSTM being used.

  • In figure of z(t), the yellow color indicates which LSTM was actually used (by Gumbel softmax sampling).

  • Over all probability of which LSTM being chosen. [Left]: digit 7, [right]: digit 9

  1. Take k=4 LSTM for alien signal (binary) classification
  • Non-alien signal

  • Fig q_z(t) & z(t) slightly shows periodicity

  • Alien signal (maybe say Hi or invasion)

  • Fig q_z(t) & z(t) use 4 LSTMs to detect irregular wave form from aliens.

  • Over all probability of which LSTM being chosen. [Left]: non-alien signal, [right]: alien signal

Improvements

This code of Pytorch is extended such that the Markov RNN can have more than 1 hidden layers in every LSTM, where it was restricted to only 1 hidden layer in the original code of Tensorflow version.