keras-xlnet
keras-xlnet copied to clipboard
Implementation of XLNet that can load pretrained checkpoints
Keras XLNet
Unofficial implementation of XLNet. Embedding extraction and embedding extract with memory show how to get the outputs of the last transformer layer using pre-trained checkpoints.
Install
pip install keras-xlnet
Usage
Fine-tuning on GLUE
Click the task name to see the demos with base model:
| Task Name | Metrics | Approximate Results on Dev Set |
|---|---|---|
| CoLA | Matthew Corr. | 52 |
| SST-2 | Accuracy | 93 |
| MRPC | Accuracy/F1 | 86/89 |
| STS-B | Pearson Corr. / Spearman Corr. | 86/87 |
| QQP | Accuracy/F1 | 90/86 |
| MNLI | Accuracy | 84/84 |
| QNLI | Accuracy | 86 |
| RTE | Accuracy | 64 |
| WNLI | Accuracy | 56 |
(Only 0s are predicted in WNLI dataset)
Load Pretrained Checkpoints
import os
from keras_xlnet import Tokenizer, load_trained_model_from_checkpoint, ATTENTION_TYPE_BI
checkpoint_path = '.../xlnet_cased_L-24_H-1024_A-16'
tokenizer = Tokenizer(os.path.join(checkpoint_path, 'spiece.model'))
model = load_trained_model_from_checkpoint(
config_path=os.path.join(checkpoint_path, 'xlnet_config.json'),
checkpoint_path=os.path.join(checkpoint_path, 'xlnet_model.ckpt'),
batch_size=16,
memory_len=512,
target_len=128,
in_train_phase=False,
attention_type=ATTENTION_TYPE_BI,
)
model.summary()
Arguments batch_size, memory_len and target_len are maximum sizes used for initialization of memories. The model used for training a language model is returned if in_train_phase is True, otherwise a model used for fine-tuning will be returned.
About I/O
Note that shuffle should be False in either fit or fit_generator if memories are used.
in_train_phase is False
3 inputs:
- IDs of tokens, with shape
(batch_size, target_len). - IDs of segments, with shape
(batch_size, target_len). - Length of memories, with shape
(batch_size, 1).
1 output:
- The feature for each token, with shape
(batch_size, target_len, units).
in_train_phase is True
4 inputs:
- IDs of tokens, with shape
(batch_size, target_len). - IDs of segments, with shape
(batch_size, target_len). - Length of memories, with shape
(batch_size, 1). - Masks of tokens, with shape
(batch_size, target_len).
1 output:
- The probability of each token in each position, with shape
(batch_size, target_len, num_token).