Liger-Kernel
Liger-Kernel copied to clipboard
Support Z Loss in CE
Summary
This PR aims to resolve #197
Implemented z loss in LigerCrossEntropy.
note: lse_square_scale not exposed at flce yet, having issues passing the tests.
Details
For loss:
\begin{align}
L_{total} &= L_{ce} + z\_loss\
z\_loss &= lse\_square\_scale \cdot lse^2\
lse &= log \sum e^{X_i}
\end{align}
We can use $m = max(X_i)$ and $d = \sum e^{X_i - m}$, obtained from online softmax algorithm, to calculate $lse$ directly.
\begin{align}
lse &= log \sum e^{X_i}\
&= log \sum e^{X_i - m + m} = log \sum e^{X_i -m} \cdot e^m\
&= log\ e^m\sum e^{X_i - m} = m + d
\end{align}
For gradients:
First, we calculate the derivative of lse
\begin{align}
\frac{\partial}{\partial x_i}(lse) &= \frac{\partial}{\partial x_i}(log \sum e^{x_i}) \
&= \frac{1}{\sum e^{x_i}} \cdot \frac{\partial}{\partial x_i} \sum e^{x_i}\
&= \frac{e^{x_i}}{\sum e^{x_i}} = softmax(x_i).
\end{align}
Then we can obtain the derivative of z_loss by chain rule.
\frac{\partial z\_loss}{\partial x_i} = \frac{\partial}{\partial x_i}\left( lse\_square\_scale \cdot lse^2\right) = 2\cdot lse\_square\_scale \cdot lse \cdot softmax(x_i),
and we have the derivative of cross entropy loss with label smoothing
\frac{\partial L_{ce}}{\partial x_i} = softmax(x_i) - (1 - \epsilon)\delta_{k,y} + \frac{\epsilon}{K}= \begin{cases} softmax(x_i) - \frac{\epsilon}{K}, & i \neq y \\
softmax(x_i) - \frac{\epsilon}{K} - (1 - \epsilon) & i = y \end{cases}
where $\epsilon$ is label_smoothing and $K$ is the number of total classes. Thus, the derivative of total loss is
\begin{align}
\frac{\partial}{\partial x_i}L_{total} &= \frac{\partial}{\partial x_i}L_{ce} + \frac{\partial}{\partial x_i}z\_loss\
&= softmax(x_i) - \frac{\epsilon}{K} - (1 - \epsilon)\delta_{k,y} + 2\cdot lse\_square\_scale \cdot lse \cdot softmax(x_i)\
&=\begin{cases} (1 + 2\cdot lse\_square\_scale \cdot lse)\ softmax(x_i) - \frac{\epsilon}{K}, & i \neq y\\
(1 + 2\cdot lse\_square\_scale \cdot lse)\ softmax(x_i) - \frac{\epsilon}{K} - (1 - \epsilon), & i = y \end{cases}
\end{align}
Reference
PaLM: Scaling Language Modeling with Pathways Chameleon: Mixed-Modal Early-Fusion Foundation Models
Testing Done
benchmark gist neglectable error in speed benchmark.
This benchmark was done on my machine, which is probably not accurate.
liger ce: 66.123ms
Peak mem: 8.66200832
liger ce with zloss: 65.991ms
Peak mem: 8.66200832
liger ce with zloss with return zloss: 65.951ms
Peak mem: 8.662073856
- Hardware Type: <BLANK>
- [x] run
make testto ensure correctness - [x] run
make checkstyleto ensure code style - [x] run
make test-convergenceto ensure convergence