tf-seq2seq
tf-seq2seq copied to clipboard
Sequence to sequence learning using TensorFlow.
TF-seq2seq
Sequence to sequence (seq2seq) learning Using TensorFlow.
The core building blocks are RNN Encoder-Decoder architectures and Attention mechanism.
The package was largely implemented using the latest (1.2) tf.contrib.seq2seq modules
- AttentionWrapper
- Decoder
- BasicDecoder
- BeamSearchDecoder
The package supports
- Multi-layer GRU/LSTM
- Residual connection
- Dropout
- Attention and input_feeding
- Beamsearch decoding
- Write n-best list
Dependencies
- NumPy >= 1.11.1
- Tensorflow >= 1.2
History
- June 5, 2017: Major update
- June 6, 2017: Supports batch beamsearch decoding
- June 11, 2017: Separted training / decoding
- June 22, 2017: Supports tf.1.2 (contrib.rnn -> python.ops.rnn_cell)
Usage Instructions
Data Preparation
To preprocess raw parallel data of sample_data.src and sample_data.trg, simply run
cd data/
./preprocess.sh src trg sample_data ${max_seq_len}
Running the above code performs widely used preprocessing steps for Machine Translation (MT).
- Normalizing punctuation
- Tokenizing
- Bytepair encoding (# merge = 30000) (Sennrich et al., 2016)
- Cleaning sequences of length over ${max_seq_len}
- Shuffling
- Building dictionaries
Training
To train a seq2seq model,
$ python train.py --cell_type 'lstm' \
--attention_type 'luong' \
--hidden_units 1024 \
--depth 2 \
--embedding_size 500 \
--num_encoder_symbols 30000 \
--num_decoder_symbols 30000 ...
Decoding
To run the trained model for decoding,
$ python decode.py --beam_width 5 \
--decode_batch_size 30 \
--model_path $PATH_TO_A_MODEL_CHECKPOINT (e.g. model/translate.ckpt-100) \
--max_decode_step 300 \
--write_n_best False
--decode_input $PATH_TO_DECODE_INPUT
--decode_output $PATH_TO_DECODE_OUTPUT
If --beam_width=1, greedy decoding is performed at each time-step.
Arguments
Data params
--source_vocabulary: Path to source vocabulary--target_vocabulary: Path to target vocabulary--source_train_data: Path to source training data--target_train_data: Path to target training data--source_valid_data: Path to source validation data--target_valid_data: Path to target validation data
Network params
--cell_type: RNN cell to use for encoder and decoder (default: lstm)--attention_type: Attention mechanism (bahdanau, luong), (default: bahdanau)--depth: Number of hidden units for each layer in the model (default: 2)--embedding_size: Embedding dimensions of encoder and decoder inputs (default: 500)--num_encoder_symbols: Source vocabulary size to use (default: 30000)--num_decoder_symbols: Target vocabulary size to use (default: 30000)--use_residual: Use residual connection between layers (default: True)--attn_input_feeding: Use input feeding method in attentional decoder (Luong et al., 2015) (default: True)--use_dropout: Use dropout in rnn cell output (default: True)--dropout_rate: Dropout probability for cell outputs (0.0: no dropout) (default: 0.3)
Training params
--learning_rate: Number of hidden units for each layer in the model (default: 0.0002)--max_gradient_norm: Clip gradients to this norm (default 1.0)--batch_size: Batch size--max_epochs: Maximum training epochs--max_load_batches: Maximum number of batches to prefetch at one time.--max_seq_length: Maximum sequence length--display_freq: Display training status every this iteration--save_freq: Save model checkpoint every this iteration--valid_freq: Evaluate the model every this iteration: valid_data needed--optimizer: Optimizer for training: (adadelta, adam, rmsprop) (default: adam)--model_dir: Path to save model checkpoints--model_name: File name used for model checkpoints--shuffle_each_epoch: Shuffle training dataset for each epoch (default: True)--sort_by_length: Sort pre-fetched minibatches by their target sequence lengths (default: True)
Decoding params
--beam_width: Beam width used in beamsearch (default: 1)--decode_batch_size: Batch size used in decoding--max_decode_step: Maximum time step limit in decoding (default: 500)--write_n_best: Write beamsearch n-best list (n=beam_width) (default: False)--decode_input: Input file path to decode--decode_output: Output file path of decoding output
Runtime params
--allow_soft_placement: Allow device soft placement--log_device_placement: Log placement of ops on devices
Acknowledgements
The implementation is based on following projects:
- nematus: Theano implementation of Neural Machine Translation. Major reference of this project
- subword-nmt: Included subword-unit scripts to preprocess input data
- moses: Included preprocessing scripts to preprocess input data
- tf.seq2seq_legacy Legacy Tensorflow seq2seq tutorial
- tf_tutorial_plus: Nice tutorials for tf.contrib.seq2seq API
For any comments and feedbacks, please email me at [email protected] or open an issue here.