hyena-dna
hyena-dna copied to clipboard
Performance question regarding next token prediction task
I tried to perform next token prediction task using the pretrained model hyenadna-small-32k-seqlen-hf
, and I found the result not so solid. Here' the code I tried:
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForCausalLM, AutoModel, AutoConfig
from transformers import TrainingArguments, Trainer, logging
from configuration_hyena import HyenaConfig
import torch
# instantiate pretrained model
checkpoint = 'hyenadna-small-32k-seqlen-hf'
max_length = 500
config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(checkpoint, trust_remote_code=True, config=config)
seq = 'AGCTACATTGGCC'
tok_seq = tokenizer(seq)['input_ids']
print(tok_seq)
tok_seq = torch.LongTensor(tok_seq).unsqueeze(0)
print(tokenizer.batch_decode(tok_seq))
out = model(tok_seq)
tokenizer.batch_decode(out['logits'][:, :, :].argmax(-1))
and I get:
[7, 9, 8, 10, 7, 8, 7, 10, 10, 9, 9, 8, 8, 1]
['AGCTACATTGGCC[SEP]']
['AAATAAATTGTAAC']
In my understanding, I've set this model to perform next token prediction, therefore if I input a sequence 'AGCTACATTGGCC'
, the model should return something like 'AGCTACATTGGCC+new_predict_token'
(i.e. keep the most of previous bases the same), but the sequence I get differs from what I input a lot. I wonder if there's anything wrong in my understanding or coding.