Liger-Kernel
Liger-Kernel copied to clipboard
Add on-paper form of RoPE kernel
Summary
Implement the on-paper form of the RoPE kernel from RoFormer. This implementation does not support optional value input, unlike the HuggingFace RoFormer RoPE implementation.
Details
The code is adapted from Liger Kernel's RoPE implementation. In the current Liger Kernel's RoPE implementation, the head is divided into left and right parts for computation:
$y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]$ $dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin]$
Corresponds to the vector-vector multiplication-addition form:
\begin{pmatrix}
q_0\\
q_1\\
q_2\\
\vdots\\
q_{d/2-1}\\
q_{d/2}\\
q_{d/2+1}\\
q_{d/2+2}\\
\vdots\\
q_{d-1}
\end{pmatrix}
\otimes
\begin{pmatrix}
cos\ m\theta_0 \\
cos\ m\theta_1 \\
cos\ m\theta_2 \\
\vdots\\
cos\ m\theta_{d/2-1} \\
cos\ m\theta_0 \\
cos\ m\theta_1 \\
cos\ m\theta_2 \\
\vdots\\
cos\ m\theta_{d/2-1}
\end{pmatrix}
+
\begin{pmatrix}
-q_{d/2}\\
-q_{d/2+1}\\
-q_{d/2+2}\\
\vdots\\
-q_{d-1}\\
q_0\\
q_1\\
q_2\\
\vdots\\
q_{d/2-1}
\end{pmatrix}
\otimes
\begin{pmatrix}
sin\ m\theta_0 \\
sin\ m\theta_1 \\
sin\ m\theta_2 \\
\vdots\\
sin\ m\theta_{d/2-1} \\
sin\ m\theta_0 \\
sin\ m\theta_1 \\
sin\ m\theta_2 \\
\vdots\\
sin\ m\theta_{d/2-1}
\end{pmatrix}
To obtain the on-paper form of RoPE, the head is instead divided into even-indexed and odd-indexed parts for computation: $y_{even} = x_{even} * cos - x_{odd} * sin$ $y_{odd} = x_{odd} * cos + x_{even} * sin$
$dy_{even} = dx_{even} * cos + dx_{odd} * sin$ $dy_{odd} = dx_{odd} * cos - dx_{even} * sin$
Corresponds to the vector-vector multiplication-addition form:
\begin{pmatrix}
q_0\\
q_1\\
q_2\\
q_3\\
\vdots\\
q_{d-2}\\
q_{d-1}
\end{pmatrix}
\otimes
\begin{pmatrix}
cos\ m\theta_0 \\
cos\ m\theta_0 \\
cos\ m\theta_1 \\
cos\ m\theta_1 \\
\vdots\\
cos\ m\theta_{d/2-1} \\
cos\ m\theta_{d/2-1}
\end{pmatrix}
+
\begin{pmatrix}
-q_1\\
q_0\\
-q_3\\
q_2\\
\vdots\\
-q_{d-1}\\
q_{d-2}
\end{pmatrix}
\otimes
\begin{pmatrix}
sin\ m\theta_0 \\
sin\ m\theta_0 \\
sin\ m\theta_1 \\
sin\ m\theta_1 \\
\vdots\\
sin\ m\theta_{d/2-1} \\
sin\ m\theta_{d/2-1}
\end{pmatrix}
Testing Done
- Hardware Type: A100-80G-PCIe
- [x] run
make testto ensure correctness - [x] run
make checkstyleto ensure code style - [ ] run
make test-convergenceto ensure convergence
I've added a paper-form option to the current Liger Kernel RoPE implementation.