apex icon indicating copy to clipboard operation
apex copied to clipboard

Suboptimal implementation of FusedAdam: two unnecessary divisions

Open jxtps opened this issue 3 years ago • 0 comments

The code in the main for-loop of https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu#L92 (ignoring the weight decay portion):

          r_m[ii] = beta1 * r_m[ii] + (1-beta1) * r_g[ii];
          r_v[ii] = beta2 * r_v[ii] + (1-beta2) * r_g[ii] * r_g[ii];
          MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
          MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
          MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
          MATH_T update = next_m_unbiased / denom;
          r_p[ii] = r_p[ii] - (lr * update);

does two unnecessary divisions: r_m[ii] / beta1_correction and r_v[ii] / beta2_correction - the paper explains how you can factor these two divisions out into the learning rate.

You also need to adjust the epsilon, which they don't cover - multiply eps by sqrt(beta2_correction)/sqrt(beta2_correction) then factor out 1/sqrt(beta2_correction) from the denominator. This would produce something like:

MATH_T beta2_correction_sqrt = sqrtf(beta2_correction)
epsilon *= beta2_correction_sqrt
lr *= beta2_correction_sqrt / beta1_correction

followed by a simplified for-loop:

          r_m[ii] = beta1 * r_m[ii] + (1-beta1) * r_g[ii];
          r_v[ii] = beta2 * r_v[ii] + (1-beta2) * r_g[ii] * r_g[ii];
          MATH_T denom = sqrtf(r_v[ii]) + epsilon;
          MATH_T update = r_m[ii] / denom;
          r_p[ii] = r_p[ii] - (lr * update);

(the pre-step can of course be done prior to dispatching & the results incorporated into the AdamFunctor which can then shed a few members)

jxtps avatar Jun 13 '22 22:06 jxtps