attention-is-all-you-need-pytorch
attention-is-all-you-need-pytorch copied to clipboard
Replace NumPy with PyTorch in PositionalEncoding
Description:
This PR replaces the usage of NumPy with PyTorch in the PositionalEncoding class, specifically in the _get_sinusoid_encoding_table method. The goal is to make the function compatible with PyTorch and take advantage of GPU acceleration when available.
Changes:
Replace np.array with torch.tensor for creating the sinusoid table. Replace np.power with torch.pow for calculating the position angle vector. Change the data type from NumPy float array to PyTorch float tensor.
Here's the modified _get_sinusoid_encoding_table method:
def _get_sinusoid_encoding_table(self, n_position, d_hid):
''' Sinusoid position encoding table '''
def get_position_angle_vec(position):
return [position / torch.pow(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = torch.tensor([get_position_angle_vec(pos_i) for pos_i in range(n_position)], dtype=torch.float32)
sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return sinusoid_table.unsqueeze(0)