tianchi_nl2sql icon indicating copy to clipboard operation
tianchi_nl2sql copied to clipboard

model1里的Multiply()和Masking()

Open LHT-Curry opened this issue 4 years ago • 3 comments

请问一下,model1里的Multiply()和Masking()的作用是什么?

LHT-Curry avatar May 13 '20 03:05 LHT-Curry

因为一个 batch 里面,每条数据对应的 table 的 header 数量是不一样的。但同一个 batch,模型的输出的大小又是相同的。

举个例子,样本 1 的 header 有 [ 商品名称,价格 ] , 样本 2 的 header 有 [ 电影名称,放映时间,评分]。当样本 1 和样本 2 在同一个 batch 中时,会把 header shape 成相同的大小。

此时模型输入样本1之后,可能给出 [0, 1, 0],输入样本 2之后,可能会给出 [1, 1, 1]。1代表这一列要被选中,0代表不选。

那么可以发现,样本1中的第3个输出,不应该参与 loss 的计算,因为样本1 实际上只有两列。加入 Masking() 层之后,如果某个 timesteps 的输出,全是0,那么,这一个 timesteps 将在后续的 loss 计算中被忽略。

beader avatar May 13 '20 03:05 beader

多谢解答

LHT-Curry avatar May 13 '20 03:05 LHT-Curry

非常感谢

LHT-Curry avatar May 13 '20 03:05 LHT-Curry