galai icon indicating copy to clipboard operation
galai copied to clipboard

hidden state of the eos token?

Open katzurik opened this issue 1 year ago • 3 comments

Do galactica model output the hidden state of the EOS ? Would it be possible to get it somehow using Huggingface's codebase or the original implementation? In a similar manner to OPT when doing sequence classification

katzurik avatar Nov 21 '22 20:11 katzurik

I'm using this approach You can also calculate mean of the last hidden state, but don't forget to apply L2 norm after that. It might work better than EOS embedding for some use cases.

from transformers import AutoTokenizer, OPTForCausalLM

tokenizer = AutoTokenizer.from_pretrained("facebook/galactica-1.3b")
model = OPTForCausalLM.from_pretrained("facebook/galactica-1.3b", device_map="auto")

def get_embedding(s):
  input_text = s + "</s>"
  input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
  outputs = model.generate(input_ids, max_new_tokens=0, return_dict_in_generate=True, output_hidden_states=True)

  sentence_representation = outputs['hidden_states'][0][-1][0,-1].to('cpu').numpy()
  return sentence_representation

Puzer avatar Nov 24 '22 10:11 Puzer

@Puzer Thanks, What's your take on the quality of sentence representations using this method? i'm not sure the model manages to do that very good

katzurik avatar Nov 27 '22 15:11 katzurik

@Puzer Thanks, What's your take on the quality of sentence representations using this method? i'm not sure the model manages to do that very good

I tried to embed arxiv papers by their titles and then train a linear model to classify tags:

Model F1 macro AUC-ROC mean
galactica-6.7b 0.749 0.806
all-mpnet-base-v1 0.744 0.799
all-roberta-large-v1 0.738 0.8
galactica-1.3b 0.722 0.796
tf-idf 0.697 0.763

Also tried to use that for semantic search of similar papers using cosine similarity of embeddings, I don't have any metrics for that, but results of all-mpnet-base-v1 subjectively looks more reasonable for me.

Puzer avatar Nov 28 '22 11:11 Puzer