Adan
Adan copied to clipboard
Embedding tensors/weight update unsupported
Hello!
I think I found a bug in the Adan optimizer, which affects embedding tables.
I implemented Adan optimzier in Tensorflow 2. You could find the implementation here
I wanted to keep the implementation as close to the original code as possible. However, there are different approaches for updating "sparse" tensors in TensorFlow and PyTorch. An example of a "sparse" tensor is an embedding matrix. Pytorch treats "sparse" data as if it was dense. TensorFlow has two functions for making updates - _resource_apply_dense for dense and _resource_apply_sparse for "sparse".
I decided to test the correctness of my implementation using the following logic:
- Define a function to optimize. In case of "dense" optimization, it's simple linear regression, in case of "sparse" - make all embeddings equal to 1 (see
tf_adan/test_adan_*.py) - Generate random input data and initial weights matrix.
- Optimize weights matrix using official and my implementation. Optimziers have same hparams.
- Compare loss history and weights after optimization. If they are equal - my implementation is correct.
I noticed that loss history and weights after optimization is the same for dense parameters. However, my implementation shows a better loss for embedding params weights after optimization isn't the same. It's especially noticeable in cases when the batch contains only a few possible categories. For example, categorical features have 2k unique values, while the batch size equals 100:

I think the source of the bug is the following:
- For "new" gradients, i.e., categorical values gradients, for which we haven't made an update before, we replace the previous gradient with the current gradient. This logic is implemented here:
https://github.com/sail-sg/Adan/blob/d8646475859e98c2eeee5351b6647613bf5bebeb/adan.py#L130
As I understand, prev_grad for all "new" gradients on step>1 won't be replaced with the current gradient.
- The other reason is that gradient params (exp_avg, exp_avg_sq, exp_avg_diff) are updated regardless of the presence of the category in the batch. That means that for categories
I'm unsure if it's a bug in your implementation or in mine. I also tested Adam optimizer in tf and torch, see:
https://github.com/DenisVorotyntsev/Adan/blob/02e66241a98958152315ae5358ee6f364f092f8b/tf_adan/utils.py#L37
Losses for Adam optimizers in tf/torch are almost the same.
What do you think? Looking forward to your thoughts.
Hi, @DenisVorotyntsev, We really thank your contribution and detailed exploration.
- In fact, we do not support sparse parameter updating right now. It is a good starting point to figure out the inconsistency.
- We do not quite understand the meaning of
embedding params. Is it some kind of sparse parameter? - The source you mentioned, 1) we set
pre_grad = cur_gradat the beginning of parameter updating. The logic seems right. I noticed that your implementation is performed in the same logic. So, Are there any further comments to help us figure out where the problem is; 2) the second source seems incomplete. We really need more details for the second source.
Best, Looking forward to your feedback.
Embedding is a lookup table that is used to map indexes to tensors. It's commonly used in tabular tasks with categorical values. tf/torch implementations are - torch tf
Embedding tensors are only updated for values that are present in the batch. The same should also be for all gradient params (e.g., exp_avg). I made this example to illustrate my point - google colab
Here we train simple nn with only one embedding layer. Embedding consists of 6 unique values (0, 1, 2, etc). We train model using input with batch size=3, input: [0,1,2], and then repeating several times [3,4,5]. We print gradients and the state of the optimizer before the step. Pay close attention to gradients and gradient params (exp_avg) of [0, 1, 2] index
We can see that:
- On any step, gradients are non-zero only for indexes of elements present in the batch.
- Gradient params are updated for all indexes. We can see a shrinking of
exp_avgfor indexes [0, 1, 2]. This behavior isn't expected. It negatively affects updates for rare categories, i.e., categories that aren't present in the batch. In the worst-case scenario, if the category is present in only one batch, gradient params are shrunk to 0 (exp_avg = exp_avg * beta_1**n_steps)
Ok, things seem to be clear. The problem is that our implementation does not support embedding tensors yet, for which we should only update the tensor whose indexes are in the lookup table.
If this is the case, could you give some advice or references, such as which optimizer, implemented by Pytorch, can support the update of embedding parameters? We believe your implementation has supported the update of embedding parameters in TF, but our reference (most optimizer implemented in Timm) seems not to support the embedding tensors. We may figure out the solution by ourselves.
Before the next step, we need your confirmation of the problem's source. Is our understanding right?
Again, thanks for your discussion and comments.
Yes, I think it's the correct understanding. I don't know how to fix it in pytorch, but I'd look how it's implemented in Adam optimizer first (https://pytorch.org/docs/stable/_modules/torch/optim/adam.html#Adam). Torch Adam and tf Adam are the same, so embedding updates must be correct there.
So latent vector (torch.nn.Embedding) optimization has had a few issues for a few years now. Apparently it is because if momentum/decay gets applied to all elements in an embedding object, even if some elements were not used in computing the forward/backwards.
See this bug: https://discuss.pytorch.org/t/issues-training-models-with-nn-embeddings-with-sparse-gratients/19189 DeepSDF bug #51 also describes a similar problem, where the work around is to manually set the gradient for all elements of the embedding vector to None after opt.step (not sure if opt.zero_grad(set_to_none=True) is enough?). Glancing at Adan's code, it should work as well, as https://github.com/sail-sg/Adan/blob/main/adan.py#L173 indicates.