Liger-Kernel icon indicating copy to clipboard operation
Liger-Kernel copied to clipboard

Support Z Loss in CE

Open Tcc0403 opened this issue 1 year ago • 4 comments
trafficstars

Summary

This PR aims to resolve #197

Implemented z loss in LigerCrossEntropy.

note: lse_square_scale not exposed at flce yet, having issues passing the tests.

Details

For loss:

\begin{align}
L_{total} &= L_{ce} + z\_loss\
z\_loss &= lse\_square\_scale \cdot lse^2\
lse &= log \sum e^{X_i}
\end{align}

We can use $m = max(X_i)$ and $d = \sum e^{X_i - m}$, obtained from online softmax algorithm, to calculate $lse$ directly.

\begin{align}
lse &= log \sum e^{X_i}\
     &= log \sum e^{X_i - m + m} = log \sum e^{X_i -m} \cdot e^m\
     &= log\ e^m\sum e^{X_i - m} = m + d
\end{align}

For gradients:

First, we calculate the derivative of lse

\begin{align}
\frac{\partial}{\partial x_i}(lse) &= \frac{\partial}{\partial x_i}(log \sum e^{x_i}) \
                                           &= \frac{1}{\sum e^{x_i}} \cdot  \frac{\partial}{\partial x_i} \sum e^{x_i}\
                                           &= \frac{e^{x_i}}{\sum e^{x_i}} = softmax(x_i).
\end{align}

Then we can obtain the derivative of z_loss by chain rule.

\frac{\partial z\_loss}{\partial x_i} = \frac{\partial}{\partial x_i}\left( lse\_square\_scale \cdot lse^2\right)  = 2\cdot lse\_square\_scale \cdot lse \cdot  softmax(x_i),

and we have the derivative of cross entropy loss with label smoothing

\frac{\partial L_{ce}}{\partial x_i} = softmax(x_i) - (1 - \epsilon)\delta_{k,y} + \frac{\epsilon}{K}= \begin{cases} softmax(x_i) - \frac{\epsilon}{K},                        &  i \neq y \\
                                                   softmax(x_i) - \frac{\epsilon}{K} - (1 - \epsilon) &  i = y \end{cases}

where $\epsilon$ is label_smoothing and $K$ is the number of total classes. Thus, the derivative of total loss is

\begin{align}
\frac{\partial}{\partial x_i}L_{total} &= \frac{\partial}{\partial x_i}L_{ce} + \frac{\partial}{\partial x_i}z\_loss\
                                                     &= softmax(x_i) - \frac{\epsilon}{K} - (1 - \epsilon)\delta_{k,y} +  2\cdot lse\_square\_scale \cdot lse \cdot softmax(x_i)\
                                                     &=\begin{cases} (1 + 2\cdot lse\_square\_scale \cdot lse)\ softmax(x_i) - \frac{\epsilon}{K}, & i \neq y\\
(1 + 2\cdot lse\_square\_scale \cdot lse)\ softmax(x_i) - \frac{\epsilon}{K} -  (1 - \epsilon), & i = y \end{cases} 
\end{align}

Reference

PaLM: Scaling Language Modeling with Pathways Chameleon: Mixed-Modal Early-Fusion Foundation Models

Testing Done

benchmark gist neglectable error in speed benchmark.

This benchmark was done on my machine, which is probably not accurate.

liger ce: 66.123ms
Peak mem:  8.66200832

liger ce with zloss: 65.991ms
Peak mem:  8.66200832

liger ce with zloss with return zloss: 65.951ms
Peak mem:  8.662073856
  • Hardware Type: <BLANK>
  • [x] run make test to ensure correctness
  • [x] run make checkstyle to ensure code style
  • [x] run make test-convergence to ensure convergence

Tcc0403 avatar Sep 10 '24 07:09 Tcc0403