mlm-pytorch
mlm-pytorch copied to clipboard
An implementation of masked language modeling for Pytorch, made as concise and simple as possible
trafficstars
MLM (Masked Language Modeling) Pytorch
This repository allows you to quickly setup unsupervised training for your transformer off a corpus of sequence data.
Install
$ pip install mlm-pytorch
Usage
First pip install x-transformer, then run the following example to see what one iteration of the unsupervised training is like
import torch
from torch import nn
from torch.optim import Adam
from mlm_pytorch import MLM
# instantiate the language model
from x_transformers import TransformerWrapper, Encoder
transformer = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Encoder(
dim = 512,
depth = 6,
heads = 8
)
)
# plugin the language model into the MLM trainer
trainer = MLM(
transformer,
mask_token_id = 2, # the token id reserved for masking
pad_token_id = 0, # the token id for padding
mask_prob = 0.15, # masking probability for masked language modeling
replace_prob = 0.90, # ~10% probability that token will not be masked, but included in loss, as detailed in the epaper
mask_ignore_token_ids = [] # other tokens to exclude from masking, include the [cls] and [sep] here
).cuda()
# optimizer
opt = Adam(trainer.parameters(), lr=3e-4)
# one training step (do this for many steps in a for loop, getting new `data` each time)
data = torch.randint(0, 20000, (8, 1024)).cuda()
loss = trainer(data)
loss.backward()
opt.step()
opt.zero_grad()
# after much training, the model should have improved for downstream tasks
torch.save(transformer, f'./pretrained-model.pt')
Do the above for many steps, and your model should improve.
Citation
@misc{devlin2018bert,
title = {BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding},
author = {Jacob Devlin and Ming-Wei Chang and Kenton Lee and Kristina Toutanova},
year = {2018},
eprint = {1810.04805},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}