attention-is-all-you-need-pytorch icon indicating copy to clipboard operation
attention-is-all-you-need-pytorch copied to clipboard

Replace NumPy with PyTorch in PositionalEncoding

Open ZYM66 opened this issue 1 year ago • 0 comments

Description:

This PR replaces the usage of NumPy with PyTorch in the PositionalEncoding class, specifically in the _get_sinusoid_encoding_table method. The goal is to make the function compatible with PyTorch and take advantage of GPU acceleration when available.

Changes:

Replace np.array with torch.tensor for creating the sinusoid table. Replace np.power with torch.pow for calculating the position angle vector. Change the data type from NumPy float array to PyTorch float tensor.

Here's the modified _get_sinusoid_encoding_table method:

def _get_sinusoid_encoding_table(self, n_position, d_hid):
    ''' Sinusoid position encoding table '''

    def get_position_angle_vec(position):
        return [position / torch.pow(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]

    sinusoid_table = torch.tensor([get_position_angle_vec(pos_i) for pos_i in range(n_position)], dtype=torch.float32)
    sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2])  # dim 2i
    sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2])  # dim 2i+1

    return sinusoid_table.unsqueeze(0)

ZYM66 avatar Apr 28 '23 00:04 ZYM66