benchmarks icon indicating copy to clipboard operation
benchmarks copied to clipboard

L2 weight decay implementation in TF benchmarks

Open balancap opened this issue 6 years ago • 4 comments

The existing implementation use a L2 regularization on all model variables (including batch norm variables and biases). It's quite different from TF slim models which usually regularizes only conv2d weights, and it hurts quite a lot the training to include all variables.

Maybe we should implement a regularization loss using the common TF api. Something like:

weight_decay = self.params.weight_decay
if rel_device_num == 0 and weight_decay:
    # Regularization losses in the present name scope.
    nm_sc = tf.contrib.framework.get_name_scope()
    reg_losses = tf.losses.get_regularization_losses(nm_sc)
    # TODO: fp16 convertion???
    reg_loss = tf.add_n(reg_losses, name='total_regularization_loss')

More generally, I think it would add a lot of value if this benchmark repo could actually reproduce SOTA training on ImageNet for common architectures.

balancap avatar Dec 21 '17 08:12 balancap

I agree in general we definitely want to reproduce SOTA training for several common architectures, especially resnet. Currently, we can train resnet50 v1 to about ~74%, while the Slim models can get 75%.

We should investigate what variables to include in regularization as part of the effort to get the best convergence. You're right that including the biases is probably not helping (and certainly not helping performance), but I don't want to change it without verifying the effect on convergence, and that should be done as part of a larger effort of reproducing SOTA training implementations.

Assigning to @bignamehyp for now. I should have time to work on convergence in a month or two if he doesn't have time.

reedwm avatar Dec 21 '17 18:12 reedwm

I have experimented on my side on Inception v2 model. The current regularisation was clearly hurting performance, as removing it and adding a few extras (label smoothing + aux logits), I was able to get to 74% top-1 accuracy.

I can launch a server with resnet 50 v1 and my modifications to give it a try.

balancap avatar Dec 21 '17 18:12 balancap

74% accuracy is good to hear. Feel free to send a pull request with the change. If you do, you might want to send the PR before testing it, so if we have any comments, you can address them and do the convergence test afterwards. Alternatively, you can wait for us to implement the regularization change and label smoothing as part of our future convergence effort.

I believe Inception v2 does not have aux logits. Inception v3 does have aux logits according to the paper, but they are currently not enabled in tf_cnn_benchmarks. The constructor has an auxiliary argument, but we only pass False to it. (@bignamehyp, do you know why that is?) As for label smoothing, the Inception paper mentions it, so we should also do it for Inception, but not for resnet and other models.

reedwm avatar Dec 21 '17 18:12 reedwm

True, the aux logits is not part of Inception v2. I guess it only makes a minor difference at the end (0.2 if i remember well the paper). I also forgot to mention that I have been using moving average vars to improve a bit the accuracy.

I'll have to proof-test a bit more the modification before submitting (it's quite dirty hack in the current form!), and see if it is compatible with fp16 mode.

balancap avatar Dec 22 '17 11:12 balancap