PyTrial
PyTrial copied to clipboard
Bug about PromptEHR
When I tried to load PromptEHR from pretrained, a bug occurred:
AttributeError Traceback (most recent call last)
Input In [9], in <cell line: 5>()
3 vocs = data['voc']
4 model = PromptEHR()
----> 5 model.from_pretrained()
File ~/miniconda3/envs/trial/lib/python3.9/site-packages/pytrial/tasks/trial_simulation/sequence/promptehr.py:222, in PromptEHR.from_pretrained(self, input_dir)
211 def from_pretrained(self, input_dir='./simulation/pretrained_promptEHR'):
212 '''
213 Load pretrained PromptEHR model and make patient EHRs generation.
214 Pretrained model was learned from MIMIC-III patient sequence data.
(...)
220 to this folder.
221 '''
--> 222 self.model.from_pretrained(input_dir=input_dir)
223 self.config.update(self.model.config)
File ~/miniconda3/envs/trial/lib/python3.9/site-packages/promptehr/promptehr.py:359, in PromptEHR.from_pretrained(self, input_dir)
356 print(f'Download pretrained PromptEHR model, save to {input_dir}.')
358 print('Load pretrained PromptEHR model from', input_dir)
--> 359 self.load_model(input_dir)
File ~/miniconda3/envs/trial/lib/python3.9/site-packages/promptehr/promptehr.py:298, in PromptEHR.load_model(self, checkpoint)
295 self._load_tokenizer(data_tokenizer_file, model_tokenizer_file)
297 # load configuration
--> 298 self.configuration = EHRBartConfig(self.data_tokenizer, self.model_tokenizer, n_num_feature=self.config['n_num_feature'], cat_cardinalities=self.config['cat_cardinalities'])
299 self.configuration.from_pretrained(checkpoint)
301 # build model
File ~/miniconda3/envs/trial/lib/python3.9/site-packages/promptehr/modeling_config.py:24, in EHRBartConfig(data_tokenizer, model_tokenizer, **kwargs)
22 bart_config = BartConfig.from_pretrained('facebook/bart-base')
23 kwargs.update(model_tokenizer.get_num_tokens)
---> 24 kwargs['data_tokenizer_num_vocab'] = len(data_tokenizer)
25 if 'd_prompt_hidden' not in kwargs:
26 kwargs['d_prompt_hidden'] = 128
File ~/miniconda3/envs/trial/lib/python3.9/site-packages/transformers/tokenization_utils.py:431, in PreTrainedTokenizer.__len__(self)
426 def __len__(self):
427 """
428 Size of the full vocabulary with the added tokens. Counts the `keys` and not the `values` because otherwise if
429 there is a hole in the vocab, we will add tokenizers at a wrong index.
430 """
--> 431 return len(set(self.get_vocab().keys()))
File ~/miniconda3/envs/trial/lib/python3.9/site-packages/transformers/models/bart/tokenization_bart.py:243, in BartTokenizer.get_vocab(self)
242 def get_vocab(self):
--> 243 return dict(self.encoder, **self.added_tokens_encoder)
File ~/miniconda3/envs/trial/lib/python3.9/site-packages/transformers/tokenization_utils.py:391, in PreTrainedTokenizer.added_tokens_encoder(self)
385 @property
386 def added_tokens_encoder(self) -> Dict[str, int]:
387 """
388 Returns the sorted mapping from string to index. The added tokens encoder is cached for performance
389 optimisation in `self._added_tokens_encoder` for the slow tokenizers.
390 """
--> 391 return {k.content: v for v, k in sorted(self._added_tokens_decoder.items(), key=lambda item: item[0])}
AttributeError: 'DataTokenizer' object has no attribute '_added_tokens_decoder'
—————————————————————————— My codes are:
from pytrial.tasks.trial_simulation.data import SequencePatient
from pytrial.data.demo_data import load_synthetic_ehr_sequence
data = load_synthetic_ehr_sequence()
train_data = SequencePatient(
data={
'v': data['visit'],
'y': data['y'],
'x': data['feature'],
},
metadata={
'visit': {'mode': 'dense'},
'label': {'mode': 'tensor'},
'voc': data['voc'],
'max_visit': 20,
'n_num_feature': data['n_num_feature'],
'cat_cardinalities': data['cat_cardinalities'],
}
)
from pytrial.tasks.trial_simulation.sequence import PromptEHR
vocs = data['voc']
model = PromptEHR()
model.from_pretrained()
I can directly load BartTokenizer successfully:
from transformers import BartTokenizer
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
tokenizer
BartTokenizer(name_or_path='facebook/bart-base', vocab_size=50265, model_max_length=1024, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': '<mask>'}, clean_up_tokenization_spaces=True), added_tokens_decoder={
0: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
1: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
3: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
50264: AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=True, special=True),
}
tokenizer.added_tokens_decoder
{0: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
1: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
3: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
50264: AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=True, special=True)}
Could you please help me to fix this bug?