tianchi_nl2sql
tianchi_nl2sql copied to clipboard
model1里的Multiply()和Masking()
请问一下,model1里的Multiply()和Masking()的作用是什么?
因为一个 batch 里面,每条数据对应的 table 的 header 数量是不一样的。但同一个 batch,模型的输出的大小又是相同的。
举个例子,样本 1 的 header 有 [ 商品名称,价格 ] , 样本 2 的 header 有 [ 电影名称,放映时间,评分]。当样本 1 和样本 2 在同一个 batch 中时,会把 header shape 成相同的大小。
此时模型输入样本1之后,可能给出 [0, 1, 0],输入样本 2之后,可能会给出 [1, 1, 1]。1代表这一列要被选中,0代表不选。
那么可以发现,样本1中的第3个输出,不应该参与 loss 的计算,因为样本1 实际上只有两列。加入 Masking() 层之后,如果某个 timesteps 的输出,全是0,那么,这一个 timesteps 将在后续的 loss 计算中被忽略。
多谢解答
非常感谢