imitation icon indicating copy to clipboard operation
imitation copied to clipboard

[Preference Comparison] L2 regularization with dynamic regularization coefficient

Open AdamGleave opened this issue 3 years ago • 5 comments

The DRLHP paper states in section 2.2.3 that:

A fraction of 1/e of the data is held out to be used as a validation set for each predictor. We use `2 regularization and adjust the regularization coefficient to keep the validation loss between 1.1 and 1.5 times the training loss. In some domains we also apply dropout for regularization.

A similar approach was used in follow-up work by Ibarz et al, described in a bit more detail in section A.3:

A fraction of 1/e of the data is held out to be used as a validation set. We use L2- regularization of network weights with the adaptive scheme described in Christiano et al. (2017): the L2-regularization weight increases if the average validation loss is more than 50% higher than the average training loss, and decreases if it is less than 10% higher (initial weight 0.0001, multiplicative rate of change 0.001 per learning step).

Currently preference comparisons does not support any form of adaptive L2 regularization, although it does support a constant weight decay in CrossEntropyRewardTrainer. This has a similar effect to constant L2 regularization.

We should add support since it shows a modest performance improvement in ablations (e.g. section 3.3 of DRLHP paper showed significant improvements in swimmer, qbert and seaquest though little difference in other environments). Moreover as it was part of the original paper it is good to have support to be able to replicate it and perform our own ablations. Furthermore, I suspect regularization may also improve the robustness of learned reward functions aiding in transfer, although they don't evaluate that.

Concretely, I think the tasks would be:

  • Support L2 coefficient (or weight decay coefficient) being a variable rather than a constant.
  • Add train/validation split of dataset.
  • Compute validation loss.
  • Add logic to dynamically adjust the L2 coefficient based on validation loss.

AdamGleave avatar Jul 07 '22 19:07 AdamGleave

Thanks for flagging this out! I support very much adding this feature.

yawen-d avatar Jul 09 '22 23:07 yawen-d

Hi both, I have some questions regarding the technical details of the implementation. @AdamGleave @yawen-d (re-posted as I did some significant re-writing after I thought about it for longer).

  • Since L2 regularization is not the same as weight decay, should we implement L2 penalty or a weight decay?
  • If L2 regularization, should this just be a penalty to the cross-entropy loss? This is the most obvious approach, but want to make sure.
  • Presumably there's already a test/train split that is done when training models on preference_comparisons. A bit more context regarding the workflow/pipeline for these models would be valuable so that I can use a similar method for the train/validation split.
  • What logic exactly are you looking for? (see below)

It is not super clear what the "multiplicative rate of change" means in the Ibarz et al paper. I presume that they update the penalty coefficient to $\lambda_t=\lambda_{t-1} * (1\pm 0.001)$ depending on whether it's inside or outside the 10-50% range. However, this change seems too small, given that one is normally uncertain about the order of magnitude of the optimal penalty, and this only changes it by 1/100th of an order of magnitude every time, so maybe they mean something else.

Further, I am not convinced that the increase/decrease rate parameter should be constant during the training process. Intuitively, a higher coefficient will bring both losses closer together, as the "importance" of the training examples is overpowered by the L2 norm of the weights. However, the update should probably be relative to how different the losses are from each other. Otherwise, a single rate parameter is not able to adapt to the two opposing edge cases: either losses that are just in the boundary will switch between cycles of under- and over-regularization, causing a lot of variance and slowing down training, or losses that are too close or too far away take very long to advance to the right level of regularization.

On a more theoretical note, I am generally uncertain about the mathematical background behind dynamic regularization. If you have any resources that discuss this clearly, I would like to take a look. By applying this algorithm, we are biasing the training direction by using validation data, thus leaking info of the validation set into the training process. Normally, the model is an ubiased estimator on the validation set during training. But in every training step we are already leaking info, and it is not clear to me that bringing the penalty to zero will "erase the memory" if we have already learned/overfitted on the validation set.

Rocamonde avatar Jul 20 '22 22:07 Rocamonde

Since L2 regularization is not the same as weight decay, should we implement L2 penalty or a weight decay?

Good question, unfortunately the answer seems a bit unclear.

The decoupled weight decay regularization paper claims (page 1) that "L2 regularization is not effective in Adam", but "weight decay is equally effective in both SGD and Adam". This would advocate in favor of using weight decay.

However, this seems tricky from an implementation perspective. PyTorch's AdamW expects weight_decay to be a float. So there doesn't seem to be a supported way of changing the weight decay over time.

Now I don't think it'd be that add hard to hack up our own weight decay: we'd basically just decay the gradients in between loss.backward() and self.optim.step(), similar to the method discussed by fast.ai under "Implementing AdamW" which is a pretty close match to Algorithm 2 from the decoupled weight decay paper.

Given the purported benefits of weight decay over L2 with Adam (and their equivalence in SGD) and a dynamic version being implementable in ~5 lines of code, I think I lean towards this hand-rolled method. But it definitely feels a bit hacky... and I suspect prior work just used L2 regularization, so we have some evidence that works at least tolerably well. What do you think?

If L2 regularization, should this just be a penalty to the cross-entropy loss? This is the most obvious approach, but want to make sure.

If we went for L2 regularization then yeah just add it as a penalty, although we probably want the computation to not be entangled with cross-entropy loss so we can add it to alternative loss functions. Note https://github.com/HumanCompatibleAI/imitation/pull/460 is introducing a new abstract class representing loss https://github.com/HumanCompatibleAI/imitation/blob/fcee608c520bb64847c229de1a8da6bb124f8389/src/imitation/algorithms/preference_comparisons.py#L664 so we're no longer going to be hard-coded to cross entropy.

Presumably there's already a test/train split that is done when training models on preference_comparisons. A bit more context regarding the workflow/pipeline for these models would be valuable so that I can use a similar method for the train/validation split.

I don't think we have a train/test split right now unfortunately. Note we currently only support synthetic comparisons (based on ground-truth reward function), and we can just sample more of those any time we want, so there's no need to have a dedicated test dataset. That said, we'd definitely want this in the future when working with human data. And it seems useful even in the synthetic case so we can check for overfitting.

What logic exactly are you looking for? (see below)

I think your interpretation of the Ibarz et al paper is correct. My main uncertainty is what they mean by "learning step". They say in A.4 they train for 500 iterations, with each iteration consisting of training for 6250 batches. If by learning step they're referring to each gradient step, then there's 6250*500=3125000 learning steps, which would be plenty for even a 1.001 learning rate to make many orders of magnitude difference. But computing the training and validation loss every batch seems like a lot of overhead. But 1.001**500=1.648 which is way too small. My best guess is they kept a moving average of train/validation loss and kept doing the multiplicative update while the moving average was above/below threshold, that gets you to a fast enough cumulative update without massive overheads. But this is pure conjecture...

Frustratingly https://github.com/HumanCompatibleAI/learning-from-human-preferences/blob/3b6b645ace2741b879a20e757f6f2ee4c672046c/README.md does not implement L2 regularization nor as far as I can tell does the "reference" implementation: https://github.com/nottombrown/rl-teacher/blob/b2c2201e9d2457b13185424a19da7209364f23df/rl_teacher/teach.py#L123

Given this I think we should just do what seems sensible to us and not stick too rigidly to any prior work, and accept we'll need to tune it in experiments.

Further, I am not convinced that the increase/decrease rate parameter should be constant during the training process. Intuitively, a higher coefficient will bring both losses closer together, as the "importance" of the training examples is overpowered by the L2 norm of the weights.

Good point. We should probably compare the training and validation cross-entropy loss (without the L2 penalty) to adjust the coefficient. What do you think?

On a more theoretical note, I am generally uncertain about the mathematical background behind dynamic regularization. If you have any resources that discuss this clearly, I would like to take a look. By applying this algorithm, we are biasing the training direction by using validation data, thus leaking info of the validation set into the training process. Normally, the model is an ubiased estimator on the validation set during training. But in every training step we are already leaking info, and it is not clear to me that bringing the penalty to zero will "erase the memory" if we have already learned/overfitted on the validation set.

Cross-validation methods in general certainly leak info from validation set during training, but I'm not sure I see the problem with this, the test set is there to let us get an unbiased estimate of performance and we can judge then how much overfitting has occurred. One way of looking at it is adaptive regularization may reduce overfitting to the training set, at the cost of a small amount of overfitting to the validation set. This may be a worthwhile tradeoff, and we can evaluate whether or not it's beneficial by looking at the test set.

Hope this helps, happy to discuss this more!

AdamGleave avatar Jul 21 '22 21:07 AdamGleave

Hi @AdamGleave , thanks for taking the time to reply to my points above. What do you think about the following approach:

  1. If we only plan to support Adam as an optimizer, we can write a custom optimizer class that wraps Adam and 'cleans up' the hackiness and separates the logic of the adaptive weight decay / L2 (we could support both) from the actual usage of this optimizer in any given network. Even if we plan to support multiple optimizers, we could abstract this logic down the line, prioritizing getting Adam to work at the moment.
  2. The above custom optimizer class could require a function to be passed that takes in the train and validation losses, and returns the scaling factor to apply to the L2/weight decay param. A basic implementation could be a fixed scaling if the val/train losses fall in some range (the Ibarz implementation). More complex implementations could adjust the scaling factor as an arbitrary function of the losses. These functions could be parametrized, and the hyperparameters can be tuned in the regular way. I can come up with one or two stock parametrizations, and we can observe the results. Stock functions could be added or removed later down the line.

The above makes regularization modular and transferrable to other problems, and easy to switch between different approaches/implementations. We can create a new module somewhere in the project (any suggestions?)

Let me know if you roughly agree with the above, and I can get started with some basic skeleton of the approach. We can refine it as we go along.

Small technical question: could you advise if it is preferable to update custom tensors (i.e. the L2 coef) in-place or not in-place?

Best wishes

Rocamonde avatar Jul 23 '22 17:07 Rocamonde

  1. If we only plan to support Adam as an optimizer, we can write a custom optimizer class that wraps Adam and 'cleans up' the hackiness and separates the logic of the adaptive weight decay / L2 (we could support both) from the actual usage of this optimizer in any given network. Even if we plan to support multiple optimizers, we could abstract this logic down the line, prioritizing getting Adam to work at the moment.

I like the idea of putting the variable weight decay logic somewhere outside the preference comparison trainer. It's not specific to preference comparisons, after all.

I'm fine with it being a wrapper for Adam, but I think the method I outlined is fairly agnostic as to the underlying optimizer. If that turns out to be true when implementing it, then perhaps a mix-in is the best method? I don't care too much though -- fairly easy to abstract it later if needed.

  1. The above custom optimizer class could require a function to be passed that takes in the train and validation losses, and returns the scaling factor to apply to the L2/weight decay param. A basic implementation could be a fixed scaling if the val/train losses fall in some range (the Ibarz implementation). More complex implementations could adjust the scaling factor as an arbitrary function of the losses. These functions could be parametrized, and the hyperparameters can be tuned in the regular way. I can come up with one or two stock parametrizations, and we can observe the results. Stock functions could be added or removed later down the line.

Where does the train and validation loss come from? That's not something the optimizer necessarily has access to (especially validation loss). I'd lean towards just taking in a zero-argument callable that returns the current scaling factor. This zero-argument callable could be e.g. a class instance that has access to the latest computed train/validation loss. This also supports other use cases, like annealing the decay over time.

Alternatively it could just take a PyTorch variable that we update, but that's probably less flexible.

The above makes regularization modular and transferrable to other problems, and easy to switch between different approaches/implementations. We can create a new module somewhere in the project (any suggestions?)

I like the idea of making adaptive regularization something we can easily apply to other problems. So I think makes sense to try and factor that logic into a separate module. Computing the train/validation loss is something that may end up being a bit algorithm specific, so we'll probably just have to feed those into the adaptive regularization component.

Let me know if you roughly agree with the above, and I can get started with some basic skeleton of the approach. We can refine it as we go along.

I think at the high-level we're on the same page, and my suggestions may well just be misunderstandings, so by all means start on a skeleton.

Small technical question: could you advise if it is preferable to update custom tensors (i.e. the L2 coef) in-place or not in-place?

I'm not sure it even needs to be a tensor. PyTorch isn't a computation graph, we can do get_l2_coef() returning a float and multiply that with the loss on the fly. But if we do have a PyTorch variable storing the L2 coef, I think it makes sense to update it in place -- that's what variables are there for.

AdamGleave avatar Jul 24 '22 03:07 AdamGleave