LM-BFF
LM-BFF copied to clipboard
Average the logits of 16 demonstrations
Hi! Really thanks for your work. And I have a question about prediction at inference with demonstrations. As the paper mentioned, the final prediction logits is by averaging the the results of 16 demonstrations. But I cannot find such an average operation, and I only find you augment the dataset with 16 times larger, that means every input will accompany with 16 demonstrations and then there should be an operation to gather these 16 concated texts (input+demo) to get the final predicton logits. The pseudo code is like below
Require: the prediction logits of an input text_a
Input: 16 concated inputs (text_a, demo1) .... (text_a, demo16)
predictions = [16, num_classes]
Output: torch.mean(predictions, dim=0) # [1, num_classes]
Can you help me find this operation in the code? Or maybe I have some mis understandings... Thanks!
Hi,
The average operation is here: https://github.com/princeton-nlp/LM-BFF/blob/1bbdc42455502a152541c48974cdc71027569a2e/run.py#L495