mauve icon indicating copy to clipboard operation
mauve copied to clipboard

Questions regarding hidden states.

Open samhedin opened this issue 3 years ago • 10 comments

Hi, and thank you for the great paper. I have questions regarding https://github.com/krishnap25/mauve/blob/20613eecd7b084281ed4df9bfeee67d66cbfe5ee/src/mauve/utils.py#L120-L121 The activations in the final hidden layer is taken: outs.hidden_states[-1], right? Is my understanding of why sent_length needed correct?

  • Looking at hidden_state[0] is looking at the embedding of the first word in the sentence
  • When sent_length < len(hidden_state), hidden_state[-1] is padding
  • Therefore, hidden_state[sent_length - 1] is the embedding of the entire sentence.

Second, is there a particular reason you chose to only look at the embeddings in the final hidden layer? Did you consider taking an average of the embeddings in all hidden layers?

samhedin avatar Feb 08 '22 12:02 samhedin

Hi @samhedin, you are right.

The activations in the final hidden layer is taken: outs.hidden_states[-1], right?

Correct

Looking at hidden_state[0] is looking at the embedding of the first word in the sentence.

Correct again

When sent_length < len(hidden_state), hidden_state[-1] is padding

Also correct: hidden_state[sent_len:] are also embeddings of padding, and therefore meaningless.

Therefore, hidden_state[sent_length - 1] is the embedding of the entire sentence.

This is the embedding of the last non-padding token, and we take it to be the embedding of the entire sentence.

Second, is there a particular reason you chose to only look at the embeddings in the final hidden layer? Did you consider taking an average of the embeddings in all hidden layers?

We found in some very preliminary experiments that the first few layers were no good but higher layers were fine. We used the last layer as a default. If you try out averaging embeddings across layers, I'll be curious to hear what you find!

krishnap25 avatar Feb 09 '22 03:02 krishnap25

Thank you!

samhedin avatar Feb 09 '22 08:02 samhedin

Running the code produces the following warning:

WARNING clustering a points to b centroids: please provide at least c training points

Explanation from https://github.com/facebookresearch/faiss/wiki/FAQ#can-i-ignore-warning-clustering-xxx-points-to-yyy-centroids :

n < min_points_per_centroid * k: this produces the warning above. 
It means that usually there are too few points to reliably estimate the centroids. 
This may still be ok if the dataset to index is as small as the training

Was this seen during development? Do you think it's a potential issue?

samhedin avatar Feb 14 '22 13:02 samhedin

I would ignore this warning (note that we do not care about estimating the centroids).

krishnap25 avatar Feb 16 '22 20:02 krishnap25

Thank you for your great answers so far! How do you calculate mauve on encoder-only models such as roberta/bert? These provide a semantic embedding of each word in a sentence. This seems like more dimensions than what mauve wants (and naively replacing gpt2-large with some roberta variation fails because of it). There are bert variations that provide sentence embeddings like sbert, but the article specifically mentions roberta. Do you take some average, or the [CLS] token, or...?

samhedin avatar Feb 21 '22 09:02 samhedin

We simply take the encoding of the last token (this is true for both RoBERTa and GPT-2).

krishnap25 avatar Feb 22 '22 03:02 krishnap25

What is the motivation for this? From my understanding, this is taking the embedding of the last token in the sequence. I understand it for GPT-2, but for BERT, I don't understand why this is more informative than taking any other token. I have modified your code by adding

    if bert:
        qs = torch.tensor(q)
        q = qs[:, -1, :]
        ps = torch.tensor(p)
        p = ps[:, -1, :]

to compute_mauve.cluster_feats. Is this similar to your implementation?

I hope that I'm not taking up too much of your time. I found this work very interesting, and I'm looking into it as part of my MSc thesis, and I want to make sure I don't end up saying things that are factually incorrect, or results that are clearly due to incorrect implementation.

samhedin avatar Feb 23 '22 12:02 samhedin

Hi @samhedin,

We did most of our experimentation with GPT-2, so we chose the last token. We stuck to the same for RoBERTa to avoid unnecessary confounding factors. Through all our experiments, we found that the ranking induced by MAUVE is quite robust to variations in the embeddings.

If you find that an average of the embeddings or the embedding of the CLS token is more useful for BERT/RoBERTa, I would be all ears. :)

To use BERT/RoBERTa with the mauve package, it should (in principle) be possible to pass it with the featurize_model_name argument to compute_mauve. There is an unnecessary if condition here that can be removed to allow this behavior. If you find that feature useful, we'd appreciate it if you could contribute a pull request with this update.

Best, Krishna

krishnap25 avatar Mar 02 '22 19:03 krishnap25

Thank you again! Have you considered using MAUVE to evaluate other types of data? It was fairly straightforward to get it working with images by taking some code from https://github.com/mseitzer/pytorch-fid. I can not see anything specific to MAUVE that makes it require sequential data - you just need an appropriate embedding model. It seems to me that it could be an attractive alternative to FID and IS.

samhedin avatar Mar 21 '22 13:03 samhedin

Good point. In principle, MAUVE simply measures the gap between two distributions via embeddings of their samples. You could use MAUVE for other modalities and more generally, any two distributions where you can obtain high quality embeddings.

In fact, MAUVE was inspired by the literature from computer vision (FID, precision-recall, etc.). We focused on open-ended text generation as some unique concerns come up in this particular task. For instance, FID works great for images but not so much for text (see Figure 4 in the paper and the related text).

krishnap25 avatar Mar 29 '22 16:03 krishnap25