transformer
transformer copied to clipboard
Correct key_masks shape
The mask function defined in the modules.py expects key_masks to be a 2d-tensor of shape (N, T_k), but the comments describe it as a 3d tensor. (N, 1, T_k).