graphtrans
graphtrans copied to clipboard
Another Bug report on MaskedOnlyTransformerEncoder
masked_transformer_encoder.py, line47, should be changed to
att = att.masked_fill(valid_input_mask.unsqueeze(1).unsqueeze(2) != 0, mask_value)
the origin code means that change all the place where valid_input_mask==False
into mask_value