transform-and-tell icon indicating copy to clipboard operation
transform-and-tell copied to clipboard

BERT vs RoBERTa

Open xuewyang opened this issue 4 years ago • 24 comments

Hi Alasdair,

Have you tested with BERT?

xuewyang avatar Jul 06 '20 22:07 xuewyang

Yeah I did run some early tests with BERT. I didn't keep the results but the performance was slightly worse than RoBERTa, so I switched to RoBERTa in later experiments.

alasdairtran avatar Jul 06 '20 23:07 alasdairtran

https://github.com/alasdairtran/transform-and-tell/blob/master/tell/data/token_indexers/roberta_indexer.py Can you explain in roberta_indexer.py file, 1. you gave us an example of # e.g.[' Tomas', ' Maier', ',', ' autumn', '/', 'winter', ' 2014', ',', '\n', ' in', 'Milan', '.'] (line 122), why there is space in ' autumn' and ' 2014'? Why not just 'autumn'? 2. in this function _byte_pair_encode, what is the variable doc used for? Thanks.

xuewyang avatar Jul 08 '20 04:07 xuewyang

Check out section 2.2 in Radford et al (2018), where they explain how the BPE vocab is generated (RoBERTa model uses the same vocab). Basically in BPE, you have to explicitly encode the spaces in a sentence. Most of the time the space gets merged with the word immediately after it. You can try this out yourself:

import torch
roberta = torch.hub.load('pytorch/fairseq', 'roberta.large')
roberta = roberta.eval()

print(roberta.encode(' autumn'))
# You should get [0, 13270, 2], where token ID 13270 corresponds to ' autumn',
# token ID 0 is <s>, and token ID 2 is </s>

print(roberta.encode('autumn'))
# You should get [0, 4255, 24422, 2], where 4255 is "aut" and 24422 is "#umn"

Basically in a text corpus, every time you see the word "autumn", it is almost always preceded by a space. Hence " autumn" gets its own ID, whereas "autumn" without the preceding space is much rarer, so it needs to be split into more common subparts.

You can ignore doc. It's the Document object of the text returned by spacy. I used it to mark which tokens in the text are part of a named entity. I was trying to see if giving the model the entity information would improve the performance, but it didn't help, so doc is not actually used in any of the experiments presented in the paper.

alasdairtran avatar Jul 08 '20 05:07 alasdairtran

Thanks again. When you tokenize the article, it seems that you didn't remove the paragraph space. For example, in word_splitter.py, 'Documentary\n\nCURB' might be a possible token with \n\n. Do you think this will cause a problem?

xuewyang avatar Jul 08 '20 21:07 xuewyang

That was deliberate. I was thinking that the model might benefit from knowing where paragraph boundaries are (just like us, when we read an article, it's easier to read if it's broken down into paragraphs). But I haven't run any formal experiments to show if it actually makes a difference. You can try to test that hypothesis.

alasdairtran avatar Jul 08 '20 23:07 alasdairtran

Oh and the just_spaces_keep_newlines Tokenizer that you mentioned won't split the new lines, but the roberta TokenIndexer does. See here:

raw_tokens = self.bpe.re.findall(self.bpe.pat, text)

If you print out self.bpe.pat, it's a regex that split by certain punctuations and newlines. So you will get

raw_tokens = self.bpe.re.findall(self.bpe.pat, 'Documentary\n\nCURB')
assert raw_tokens == ['Documentary', '\n', '\n', 'CURB']

alasdairtran avatar Jul 09 '20 00:07 alasdairtran

Gotcha. I still have some questions. Hope you can help me.

  1. in vocabulary.py _load_bert_vocab function seems to load a bert dictionary. Why do we need this? What is the difference between this function and _add_encoding_to_vocabulary in roberta.py? Why not directly use roberta dictionary? And where is this function used in other files? setstate function is ever used?
  2. In roberta_indexer.py. as_padded_tensor function is overrides. Can you tell me where in the project is this used? Maybe because you use allennlp, then you have to override this function?
  3. The max_length of caption is 40, right? Where is this number set-up? Thank you very much.

xuewyang avatar Jul 09 '20 04:07 xuewyang

  1. Is decoder_flattened.py ever used? What is the difference between this one and transformer_flattened.py?

xuewyang avatar Jul 09 '20 04:07 xuewyang

A lot of the design choices are a bit weird because I had to work within the AllenNLP framework, which is quite opinionated about certain things:

  1. _load_bert_vocab is a dead function that I forgot to delete. It's not called anywhere. Initially I wanted the Vocabulary to handle the vocab loading process, but that was quite cumbersome because you need a vocab file. So now the loading process is done by the indexer the first time we convert a text to token IDs. See here. The method __setstate__ is inherited from the base Vocabulary class (see here) to make it work with our custom _RobertaTokenToIndexDefaultDict. I believe __setstate__ is used when we unpickle a dictionary from disk. But anyway, the whole reason why we have to subclass AllenNLP's default Vocabulary is because in AllenNLP, padding token is 0 and unknown token is 1, whereas in Roberta, padding token in 1 and unknown token is 3.

  2. The method as_padded_tensor is used by the method as_tensor in AllenNLP here. as_tensor, in turn, is used here when we ask an instance to return a Tensor representation, which is used by the AllenNLP training loop.

  3. The max_caption_len is set to 50 here and used here but only for GloVe model to avoid OOM errors. No max length is specified for Roberta models.

  4. The code in transformer_flattened.py is a subclass of Model, which defines the overall logic of the model. The code in decoder_flattened.py is a subclass of Decoder, which is used by a Model. For e.g., this config uses the model transformer_flattened, which in turn uses the decoder dynamic_conv_decoder_flattened. You can swap the decoder with an LSTM, like this config. The name transformer_flattened is not the best, but it refers to the variants that use the RoBERTa encoder.

alasdairtran avatar Jul 09 '20 06:07 alasdairtran

Thanks. So 1. the _RobertaTokenToIndexDefaultDict and _RobertaIndexToTokenDefaultDict actually used to solve the mis-alignment problem of the two vocabularies? And after using vocabulary.py, this problem is fixed? 3. I think in the roberta_indexer.py, the return of this function as_padded_tensor is actually tensors of dimension 512 (for article) and 40 (for caption). That is the reason why I think they are size 512 and 40. Can you explain?

xuewyang avatar Jul 09 '20 13:07 xuewyang

  1. What is the difference between copy_text_field and list_test_filed? When are they used?

xuewyang avatar Jul 09 '20 13:07 xuewyang

  1. Does this code support multi-gpu?

xuewyang avatar Jul 09 '20 13:07 xuewyang

  1. Yep the alignment is fixed by using the custom Vocabulary subclass.

  2. RoBERTa and its cousins can only encode text with a max of 512 tokens - it's a hard limit imposed by the positional encoding. So we can only select 512 tokens in the article (in the paper, you might remember that I tried two methods - one is to select the first 512 tokens, the other is to select 512 tokens surrounding the image). With the caption length, the 40 you see is probably just the length of the longest caption in that particular batch. If you look at a different batch, you should see a different dimension for the caption. The only real constraint is 512.

  3. CopyTextField contains the extra copying info (i.e. proper nouns labels) which was used in the model with the copy mechanism. Copying ended up not improving the performance so I didn't include it in the paper, and so when you see CopyTextField, just think of it as the standard AllenNLP's TextField. ListTextField is used to store the list of all named entities in the text. It's used in the experiment where I give the model the ground-truth named entities as additional inputs. I also ended up not using it.

  4. This code uses AllenNLP v0.9.0. I think at that time multi-GPU wasn't well supported. If you upgrade the code to v1.0.0 (might take some effort since the API has changed quite a bit), you should be able to use multi-GPU.

alasdairtran avatar Jul 09 '20 23:07 alasdairtran

Gotcha. I am currently using BERT. But the results are not good. The generated captions are long, around 100 tokens. I think the reasons might be in indexer, dictionary, etc. I am using huggingface bert. Do you think that for each article, I should use several SEP tokens or just one at the last position? Thank you.

xuewyang avatar Jul 12 '20 17:07 xuewyang

I don't think the position of the [SEP] token matters much. If you're using BERT to encode the article, you only need to change two lines of code. One where you load of the encoder here and one where you extract the BERT features here. Everything else can stay the same. (you might still want to keep the self.roberta object because the decoder uses its encoding scheme)

One thing to note, the decoder architecture is independent of the encoders. The current decoder uses BPE vocab from RoBERTa. Most of the questions you asked (about the indexer, dictionary, etc) are actually part fo the decoder's architecture, which doesn't have anything to do with the pretrained BERT/RoBERTa embeddings of the article.

Another thing is that the generator code does cap the generated caption at 100 tokens (it's hard-coded here). The EOS token is also hard-coded here and here. Make sure the EOS token matches whatever encoding scheme you're using.

alasdairtran avatar Jul 13 '20 01:07 alasdairtran

Gotcha. But since encoder and decoder are independent, why we still need the EOS token of the encoder and decoder be the same?

xuewyang avatar Jul 13 '20 01:07 xuewyang

They don't have to. If you have not changed the decoder's vocab and encoding scheme, then you don't need to worry about changing the EOS token. I was just unsure what kind of change you've made and pointing to places where I hard-code certain variables.

alasdairtran avatar Jul 13 '20 01:07 alasdairtran

I just plan to change to bert. Thus, I have to use BERT encoding which is different with roberta. But I keep the same decoder for now.

xuewyang avatar Jul 13 '20 03:07 xuewyang

Hi, Even though we can keep two encodings, since we are changing the encoder to BERT, I think we still need to change a lot in the /data folder because BERT uses a different encoding method, right? I mean in order to extract_features, we need the right tokens, thus we also need to change the indexer. In the following codes, you firstly encode the sentence using bpe encoding of roberta; then you change the bpe tokens to string, and then you encode the string to tokens using source_dictionary. I really can't understand why you do encoding twice. Why not just using the bpe encoding in the first step?

def encode(self, sentence, doc):
    if self.legacy:
        return self.encode_legacy(sentence)
    bpe_tokens, copy_masks = self._byte_pair_encode(sentence, doc)
    sentence = ' '.join(map(str, bpe_tokens))
    words = tokenize_line(sentence)
    assert len(words) == len(copy_masks)
    # Enforce maximum length constraint
    words = words[:self._max_len - 2]
    copy_masks = copy_masks[:self._max_len - 2]
    words = ['<s>'] + words + ['</s>']
    copy_masks = [0] + copy_masks + [0]
    token_ids = []
    for word in words:
        idx = self.source_dictionary.indices[word]
        token_ids.append(idx)
    return token_ids, copy_masks

xuewyang avatar Jul 13 '20 15:07 xuewyang

Ah you're right actually. I forgot about the indexer for the article. I think you can just use allennlp's pretrained bert indexer, so you don't need to write your own.

The indexer code I wrote is quite complex because I needed to align each token with the correct entity label (spacy labels whole words, whereas bert/roberta break words into smaller pieces). This alignment takes a bit of effort to work. If you don't need entity labels (those copy_masks), then you can make the code much simpler.

I actually wondered about your last question too. My code is based on fairseq code. They also encoded twice, so I need to do the same to have the same mapping. It might just be a quirk in their system.

alasdairtran avatar Jul 13 '20 15:07 alasdairtran

Why do we need entity labels? Like the one we get from spacy? And copy_masks are used for name entities or for masking article/caption tokens for extracting features?

xuewyang avatar Jul 14 '20 03:07 xuewyang

Both the entity labels and copy_masks (used to label which tokens are part of an entity) were used in the experiments with the copy mechanism. You can safely ignore these. I ran a lot of different experiments, only some of which were reported in the paper.

alasdairtran avatar Jul 14 '20 10:07 alasdairtran

I tested with BERT. After training 5 epochs, BERT has BLEU score 0.7 while ROBERTA has 1.6. I think maybe BPE encoding is more accurate than Wordpiece for this task. I also tested with Roberta.eval() and .no_grad(). eval() mode reduced the performance by a small margin. It is hard to explain why but just want you to be updated.

xuewyang avatar Aug 14 '20 20:08 xuewyang

I see. It is suprising that setting roberta to eval mode reduces the performance. Maybe the dropout in roberta acts as an extra regularization.

alasdairtran avatar Aug 15 '20 08:08 alasdairtran