KRED
KRED copied to clipboard
multi-task-training fails due to behavior of build_pop_data in utils.util.py
I tried multi-task training on MIND large. The training fails with the following error:
KeyError Traceback (most recent call last)
in 2 single_task_training(config, data) 3 else: ----> 4 multi_task_training(config, data) ~\path\to\KRED\train_test.py in multi_task_training(config, data) 34 35 train_data_pop = NewsDataset(pop_train) ---> 36 train_sampler_pop = RandomSampler(train_data_pop) 37 train_dataloader_pop = DataLoader(train_data_pop, sampler=train_sampler_pop, batch_size=config['data_loader']['batch_size'], 38 pin_memory=False)
~.conda\envs\kred36\lib\site-packages\torch\utils\data\sampler.py in init(self, data_source, replacement, num_samples, generator) 99 "since a random permute will be performed.") 100 --> 101 if not isinstance(self.num_samples, int) or self.num_samples <= 0: 102 raise ValueError("num_samples should be a positive integer " 103 "value, but got num_samples={}".format(self.num_samples))
~.conda\envs\kred36\lib\site-packages\torch\utils\data\sampler.py in num_samples(self) 107 # dataset size might change at runtime 108 if self._num_samples is None: --> 109 return len(self.data_source) 110 return self._num_samples 111
~\path\to\KRED\train_test.py in len(self) 14 self.transform = transform 15 def len(self): ---> 16 return len(self.dic_data['label']) 17 def getitem(self, idx): 18 if torch.is_tensor(idx):
KeyError: 'label'
The build_pop_data function in utils.util.py returns empty dicts, because they are initialized but never populated in the function's code:
def build_pop_data(config):
print('building pop data ...')
fp_train = open(config['data']['train_behavior'], 'r', encoding='utf-8')
news_imp_dict = {}
pop_train = {}
pop_test = {}
for line in fp_train:
index, userid, imp_time, history, behavior = line.strip().split('\t')
behavior = behavior.split(' ')
for news in behavior:
newsid, news_label = news.split('-')
if news_label == "1":
if newsid not in news_imp_dict:
news_imp_dict[newsid] = [1,1]
else:
news_imp_dict[newsid][0] = news_imp_dict[newsid][0] + 1
news_imp_dict[newsid][1] = news_imp_dict[newsid][1] + 1
else:
if newsid not in news_imp_dict:
news_imp_dict[newsid] = [0,1]
else:
news_imp_dict[newsid][1] = news_imp_dict[newsid][1] + 1
return pop_train, pop_test
I think this might not be the intended behavior. Please let me know. Thanks in advance!
I find the same problem. If there is any update, please let me know.