plotmachines
plotmachines copied to clipboard
Here is the code for generating .pkl files
I wrote the following code for generating the .pkl
files based on my understanding. The generated files can be read successfully.
import torch
from transformers.models.gpt2.modeling_gpt2 import *
from transformers import GPT2Tokenizer
import numpy as np
import pickle
import tqdm
device = 0
def tfmclassifier(textlines, model, tokenizer, gen_len):
'''Create encoding of the previous paragraph (textlines) using the model and tokenizer'''
clf = []
nb = len(textlines)
# if nb < 8:
wds = torch.zeros(nb, gen_len, dtype=torch.long).to("cuda:"+str(device))
mask = torch.zeros(nb, gen_len, dtype=torch.long).to("cuda:"+str(device))
for j in range(nb):
temp = torch.tensor(tokenizer.encode(textlines[j], add_special_tokens=False)[:gen_len])
wds[j, :len(temp)] = temp.to("cuda:"+str(device))
mask[j, :len(temp)] = torch.ones(len(temp), dtype=torch.long).to("cuda:"+str(device))
model.eval()
outputs = model(wds)
total = (mask.unsqueeze(2).type_as(outputs[0]) * outputs[0]).sum(dim=1) / mask.type_as(outputs[0]).sum(
dim=1).unsqueeze(1)
return torch.mean(total, dim=0, keepdim=True)
gptmodel = GPT2Model.from_pretrained('gpt2-medium').to("cuda:"+str(device)) # embeding dim: 1024
gpttok = GPT2Tokenizer.from_pretrained('gpt2-medium')
file = '../../data_dir/train_encoded.csv'
f = open(file, 'r', encoding='utf-8')
f_out = open('../../data_dir/train_encoded_gpt2.pkl', 'wb')
lines = f.readlines()
for idx, line in enumerate(tqdm.tqdm(lines)):
if idx == 0: # header row.
row = (0, "empty", [])
pickle.dump(row, f_out)
else:
if line.strip().split('\t')[-1] == 'NA':
context = 'NA'
textlines = context
else:
context = line.strip().split('\t')[-1]
textlines = context.split('.')
encoding = tfmclassifier(textlines, gptmodel, gpttok, gen_len=100)
row = (idx, context, encoding.cpu().detach().numpy())
pickle.dump(row, f_out)
f.close()
f_out.close()
The tfmclassifier
function is the same as the author's implementation.
I'm so glad to see this code!!