pytorch-relation-extraction icon indicating copy to clipboard operation
pytorch-relation-extraction copied to clipboard

问题求助

Open justnoone opened this issue 6 years ago • 2 comments

loading train data loading finish loading test data loading finish 2018-11-15 17:55:06 train data: 65726; test data: 93574 2018-11-15 17:55:06 Epoch 1/100: train loss: 0; 2018-11-15 17:55:06 Epoch 2/100: train loss: 0; 2018-11-15 17:55:06 Epoch 3/100: train loss: 0; Traceback (most recent call last): File "main_att.py", line 178, in fire.Fire() File "D:\Anaconda3\lib\site-packages\fire\core.py", line 127, in Fire component_trace = _Fire(component, args, context, name) File "D:\Anaconda3\lib\site-packages\fire\core.py", line 366, in _Fire component, remaining_args) File "D:\Anaconda3\lib\site-packages\fire\core.py", line 542, in _CallCallable

result = fn(*varargs, **kwargs)

File "main_att.py", line 78, in train all_pre, all_rec = eval_metric_var(pred_res, p_num) File "main_att.py", line 131, in eval_metric_var true_y = pred_res_sort[i][0] IndexError: list index out of range

justnoone avatar Nov 15 '18 10:11 justnoone

我也遇到了这个问题,不知道您现在是否解决了,还请不吝赐教

YinggangZhang avatar Oct 21 '19 12:10 YinggangZhang

这个问题我也遇到了。我是这么解决的: 首先我这边出现这个问题的原因是运行代码后报错,报错信息为zip() 没有len属性。定位错误发生在filternyt.py文件的25、28行。这是python2->python3的原因。将其修改为:

def __init__(self, root_path, train=True):
    if train:
        path = os.path.join(root_path, 'train/')
        print('loading train data')
    else:
        path = os.path.join(root_path, 'test/')
        print('loading test data')

    self.labels = np.load(path + 'labels.npy')
    self.x = np.load(path + 'bags_feature.npy')
    self.x = zip(self.x, self.labels)

    print('loading finish')

def __getitem__(self, idx):
    assert idx < len(list(self.x))
    return self.x[idx]

def __len__(self):
    return len(list(self.x))

然后就出现了如下错误:

IndexError: list index out of range

查找原因为,zip类似为迭代指针,使用一次后会自动释放。因此应该如此修改代码:

def __init__(self, root_path, train=True):
    if train:
        path = os.path.join(root_path, 'train/')
        print('loading train data')
    else:
        path = os.path.join(root_path, 'test/')
        print('loading test data')

    self.labels = np.load(path + 'labels.npy')
    self.x = np.load(path + 'bags_feature.npy')
    self.x = list(zip(self.x, self.labels))

    print('loading finish')

def __getitem__(self, idx):
    assert idx < len(self.x)
    return self.x[idx]

def __len__(self):
    return len(self.x)

YinggangZhang avatar Oct 22 '19 01:10 YinggangZhang