PaddleNLP icon indicating copy to clipboard operation
PaddleNLP copied to clipboard

Make BERT support past_key_values.

Open guoshengCS opened this issue 3 years ago β€’ 0 comments

PR types

New features

PR changes

APIs

Description

Make BERT support past_key_value_caches.

import numpy as np
import paddle
import paddlenlp
from paddlenlp.transformers import AutoModel, AutoTokenizer, AutoModelForTokenClassification, AutoModelForPretraining, BertForTokenClassification, BertModel

paddle.seed(123)
np.random.seed(123)
model_name = "ernie-3.0-base-zh"
tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer(["ζ¬’θΏŽδ½Ώη”¨η™ΎεΊ¦ι£žζ‘¨!"])
print(inputs)
to_tensor = paddle.to_tensor
inputs = {k:to_tensor(v) for (k, v) in inputs.items()}
print(inputs)
model = AutoModel.from_pretrained(model_name)#, config=conf)#, force_download=True)
model.eval()
n_layer = 12
n_head = 12
size_per_head = 64
cache_len = 2
cache = np.random.rand(1, n_head, cache_len, size_per_head).astype('float32')
caches = [[to_tensor(cache)]*2] * n_layer
attention_mask = to_tensor(np.ones((1,  inputs['input_ids'].shape[-1] + cache_len)).astype('float32'))
print(attention_mask.shape)
print(tokenizer.pad_token_id)
inputs['attention_mask'] = attention_mask
outputs = model(**inputs, past_key_values=caches, return_dict=True, output_hidden_states=False, output_attentions=False, use_cache=False)
print(outputs[0])

guoshengCS avatar Jul 14 '22 07:07 guoshengCS