amdim-public icon indicating copy to clipboard operation
amdim-public copied to clipboard

Details of Auto-augment to get the best results

Open XingyuJinTI opened this issue 5 years ago • 4 comments

Hi there,

Thanks for release the code for your nice work. To achieve to best results (68.1%), did you directly use the reduced-imagenet augmentation policy provided by fast auto-augmentation repo or you re-search the policy by yourselves?

If it is the former one, could you give the details about it - i.e. did you keep the same basic augmentation in this repo (where there are no horizontal flipping, lighting+PCA but you added random Gray) and then insert their provided augmentation policies? or you make all the basic augmentation the same as theirs?

By the way, is the scheduler for the best model MultiStepLR(optimizer, milestones=[90, 135], gamma=0.2) with 150 epochs?

Many thanks, Eddie

XingyuJinTI avatar Nov 06 '19 18:11 XingyuJinTI

The augmentation that we used for our best results looks like this:

def get_imagenet_transforms():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(128, scale=(0.2, 0.9), ratio=(0.7, 1.4),
                                     interpolation=3),
        transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
        transforms.RandomGrayscale(p=0.25),
        transforms.ToTensor(),
        Lighting(0.1, _IMAGENET_PCA['eigval'], _IMAGENET_PCA['eigvec']),
        normalize
    ])
    transform_test = transforms.Compose([
        transforms.Resize(146, interpolation=3),
        transforms.CenterCrop(128),
        transforms.ToTensor(),
        normalize
    ])
    transform_train.transforms.insert(0, Augmentation(fa_reduced_imagenet()))
    return transform_train, transform_test

The Lighting and Augmentation functions are the same as the fast auto-augment repo. The fa_reduced_imagenet policy is also the same as in the repo. We didn't do any hyperopt on the policy or how to combine it with our existing augmentation, so our numbers could probably be bumped up a bit with some search.

The best model was trained on a schedule like MultiStepLR(optimizer, milestones=[90, 135], gamma=0.2) for 150 epochs. Training for more epochs or further increasing the model size should lead to some gains in the final number. We're currently doing a bit of hyperopt on this since another paper (https://arxiv.org/abs/1906.05849) recently reported a result of 68.4%.

Philip-Bachman avatar Nov 06 '19 22:11 Philip-Bachman

Thanks so much on the answers. Did you meet float division by zero issue as follow, (which only happens for training the large network):

Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 1.69759663277e-313 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 8.487983164e-314 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 4.243991582e-314 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 2.121995791e-314 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 1.0609978955e-314 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 5.304989477e-315 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 2.65249474e-315 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 1.32624737e-315 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 6.63123685e-316 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 3.3156184e-316 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 1.6578092e-316 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 8.289046e-317 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 4.144523e-317 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 2.0722615e-317 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 1.036131e-317 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 5.180654e-318 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 2.590327e-318 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 1.295163e-318 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 6.4758e-319 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 3.2379e-319 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 1.61895e-319 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 8.095e-320 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 4.0474e-320 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 2.0237e-320 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 1.012e-320 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 5.06e-321 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 2.53e-321 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 1.265e-321 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 6.3e-322 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 3.16e-322 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 1.6e-322 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 8e-323 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 4e-323 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 2e-323 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 1e-323 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 5e-324 Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 0.0 Traceback (most recent call last): File "train.py", line 110, in main() File "train.py", line 105, in main test_loader, stat_tracker, checkpointer, args.output_dir, torch_device) File "/cache/amdim/task_self_supervised.py", line 141, in train_self_supervised train_loader, test_loader, stat_tracker, log_dir, device) File "/cache/amdim/task_self_supervised.py", line 69, in _train mixed_precision.backward(loss_opt, optim_inf) # backwards with fp32/fp16 awareness File "/cache/amdim/mixed_precision.py", line 87, in backward scaled_loss.backward() File "/home/work/anaconda3/lib/python3.6/contextlib.py", line 88, in exit next(self.gen) File "/home/work/anaconda3/lib/python3.6/site-packages/apex/amp/handle.py", line 123, in scale_loss optimizer._post_amp_backward(loss_scaler) File "/home/work/anaconda3/lib/python3.6/site-packages/apex/amp/_process_optimizer.py", line 182, in post_backward_with_master_weights models_are_masters=False) File "/home/work/anaconda3/lib/python3.6/site-packages/apex/amp/scaler.py", line 117, in unscale 1./scale) ZeroDivisionError: float division by zero

I suffered this and I can just train large net without mix_precision to avoid this but you know I cannot keep the same batch size (1008). How did you deal with this in your training of the best model. Thanks

XingyuJinTI avatar Nov 07 '19 16:11 XingyuJinTI

I didn't encounter this issue while running these models, but it's possible that some of the pytorch defaults (e.g. layer initialization) have changed since I was running the large models. I encountered similar issues from time-to-time after making major changes to the model, and generally found that, e.g., reducing learning rate or changing the scale of the "fake RKHS" vectors' dot products via initialization or direct rescaling here (https://github.com/Philip-Bachman/amdim-public/blob/master/costs.py#L55) was enough to fix the problem.

The instability always seemed to arise from spikes in the magnitude of these "raw_scores". You could also try reducing the strength of logit regularization from 5e-2 to something like 1e-2. This may change final performance slightly (less than 0.5%), but it will reduce some of the gradient spikes caused by large NCE dot products.

In general, the mixed precision is a bit less automatic/magic than advertised, but it was important for scaling the experiments without access to huge compute.

Did you see this instability after significant training, or was it right at the start? Also, note that I corrected the number of epochs for the best models in my earlier reply... I just loaded one of our pretrained models while testing some other code, and it said it was trained for 150 epochs.

Philip-Bachman avatar Nov 07 '19 21:11 Philip-Bachman

Hi, thanks for your suggestions. I am going to test smaller learning rate and scaler.

The instability happens at a random point, sometimes at Epoch 20, sometimes Epoch 50, sometimes at very early Epoch 2. BUT NOT the first couple of iterations.

XingyuJinTI avatar Nov 12 '19 10:11 XingyuJinTI