plotmachines icon indicating copy to clipboard operation
plotmachines copied to clipboard

Here is the code for generating .pkl files

Open audreycs opened this issue 2 years ago • 1 comments

I wrote the following code for generating the .pkl files based on my understanding. The generated files can be read successfully.

import torch
from transformers.models.gpt2.modeling_gpt2 import *
from transformers import GPT2Tokenizer
import numpy as np
import pickle
import tqdm

device = 0

def tfmclassifier(textlines, model, tokenizer, gen_len):
    '''Create encoding of the previous paragraph (textlines) using the model and tokenizer'''
    clf = []
    nb = len(textlines)
    # if nb < 8:
    wds = torch.zeros(nb, gen_len, dtype=torch.long).to("cuda:"+str(device))
    mask = torch.zeros(nb, gen_len, dtype=torch.long).to("cuda:"+str(device))
    for j in range(nb):
        temp = torch.tensor(tokenizer.encode(textlines[j], add_special_tokens=False)[:gen_len])
        wds[j, :len(temp)] = temp.to("cuda:"+str(device))
        mask[j, :len(temp)] = torch.ones(len(temp), dtype=torch.long).to("cuda:"+str(device))
    model.eval()
    outputs = model(wds)
    total = (mask.unsqueeze(2).type_as(outputs[0]) * outputs[0]).sum(dim=1) / mask.type_as(outputs[0]).sum(
        dim=1).unsqueeze(1)
    return torch.mean(total, dim=0, keepdim=True)

gptmodel = GPT2Model.from_pretrained('gpt2-medium').to("cuda:"+str(device))  # embeding dim: 1024
gpttok = GPT2Tokenizer.from_pretrained('gpt2-medium')

file = '../../data_dir/train_encoded.csv'
f = open(file, 'r', encoding='utf-8')
f_out = open('../../data_dir/train_encoded_gpt2.pkl', 'wb')
lines = f.readlines()
for idx, line in enumerate(tqdm.tqdm(lines)):
    if idx == 0:  # header row.
        row = (0, "empty", [])
        pickle.dump(row, f_out)

    else:
        if line.strip().split('\t')[-1] == 'NA':
            context = 'NA'
            textlines = context
        else:
            context = line.strip().split('\t')[-1]
            textlines = context.split('.')
        encoding = tfmclassifier(textlines, gptmodel, gpttok, gen_len=100)

        row = (idx, context, encoding.cpu().detach().numpy())

        pickle.dump(row, f_out)

f.close()
f_out.close()

The tfmclassifier function is the same as the author's implementation.

audreycs avatar Oct 26 '22 07:10 audreycs

I'm so glad to see this code!!

Kyeongman-header avatar Mar 31 '23 13:03 Kyeongman-header