HSCRF-pytorch
HSCRF-pytorch copied to clipboard
Question about scrf_to_crf in utils.py
Hi. I read your code and I have a question about the function scrf_to_crf in utils.py, ie..
for i_l in decoded_scrf:
sent_labels = [l_map['<start>']]
for label in i_l:
if label != l_map['<pad>']:
sent_labels.append(label)
else:
break
crf_labels.append(sent_labels)
crfdata = []
masks = []
maxl_1 = max([len(i) for i in crf_labels])
for i_l in crf_labels:
cur_len_1 = len(i_l)
cur_len = cur_len_1 - 1
i_l_pad = [i_l[ind] * label_size + i_l[ind + 1] for ind in range(0, cur_len)] + [i_l[cur_len] * label_size + pad_label] + [
pad_label * label_size + pad_label] * (maxl_1 - cur_len_1)
mask = [1] * cur_len_1 + [0] * (maxl_1 - cur_len_1)
crfdata.append(i_l_pad)
masks.append(mask)
Why would it break if lable == l_map['
Actually, when I use pytorch0.4, label 'S-Stop' (the corresponding id in SCRF_l_map is 5) may be in the middle of the decoded_scrf, instead of in the end of it. See below.
(I print the spans in hscrf_layer.py)
This may be because the model is not well trained. I do not know how to fix it. Could you help me out?