ABSA-PyTorch
ABSA-PyTorch copied to clipboard
可能导致的长度不一致的一点小问题
在mgan.py
, class LocationEncoding
, weight_matrix()
函数中:
for i in range(batch_size):
for j in range(min(pos_inx[i][0], seq_len)):
relative_pos = pos_inx[i][0] - j
aspect_len = pos_inx[i][1] - pos_inx[i][0] + 1
sentence_len = seq_len - aspect_len
weight[i].append(1 - relative_pos / sentence_len)
for j in range(pos_inx[i][0], min(pos_inx[i][1] + 1, seq_len)):
weight[i].append(0)
for j in range(pos_inx[i][1] + 1, seq_len):
relative_pos = j - pos_inx[i][1]
aspect_len = pos_inx[i][1] - pos_inx[i][0] + 1
sentence_len = seq_len - aspect_len
weight[i].append(1 - relative_pos / sentence_len)
以上似乎更优, 因为可能pos_inx[i][1] + 1 > seq_len
(aspect在末尾), 需要min
限制一下. 否则weight行长度不一致, 接下来将会报错.
谢谢您的贡献, 非常感谢! 学到了很多, 或许以上是一个小BUG.
@GeneZC 麻烦看一下
@ZhangYikaii 按理说只要seq_len是所有数据的最大长度应该就没问题?
我发现这个也提了类似的错误 #114
是的, 就是这个报错, 您按照我上面的代码修改就没问题了
您看: pos_inx[i][1] == seq_len
的时候(aspect在末尾), 中间的for
就会走到比seq_len
还多的地方, weight.append
之后就多了, 长度就不一致了
@GeneZC
我个人分析应该是这样的: 所有文本的最长长度为83(举个例子) 而本仓库设置的文本最长长度为80,所以会有原文本被阶段的情况(当aspect在末尾是,aspect也会被部分截断),损失了文本信息。 在处理时,我计算pos_inx的方式是利用aspect左文本长度和aspect本身长度,此时pos_idx[i][1]可能会大于seq_len。 综上,我个人觉得更为合理的解决方式是,将max_len设置为所有文本的最长长度(亦即本仓库中的seq_len),这样不但能解决上述超过界限的问题,还能够不损失文本信息。
@GeneZC 或许是这样?
调节max_len
诚然可以保留更多的文本信息, 但是以上报错的根本成因是 https://github.com/songyouwei/ABSA-PyTorch/blob/master/models/mgan.py#L31
pos_inx[i][1] + 1 > seq_len
(当aspect在末尾, 即便max_len = 所有文本最长长度
), weight[i].append()被执行超过seq_len
次, 导致weight矩阵中某些行长度超出seq_len.
不过您说的有道理, 我会关注一下aspect被截掉的情况, 谢谢
我个人的理解是,其实只要max_len=所有文本最长长度,pos_inx[i][1]+1就不会大于seq_len
或许是, 是我没有理解整个过程, 抱歉
@ZhangYikaii 没必要抱歉啊,互相交流而已 :)