ijepa icon indicating copy to clipboard operation
ijepa copied to clipboard

is batch size per gpu important to re-implement the accuracy reported in the paper?

Open Gus-Guo opened this issue 1 year ago • 4 comments

Hi, I recently has run the vit-14-ep300 config on 16 a100 gpus. But since the gpu would run out of memory in the middle of training, I decrease the batch size from 128 into 112 per gpu. But I obtain a lower linear probe accuracy (like -2%). Is it important to preserve the batch size of 128 per gpu? Thank you very much!

Gus-Guo avatar Jul 03 '23 07:07 Gus-Guo

Hi @Gus-Guo, batch size is not important for the method, since there is no batch-component to the loss. However, I expect that changing the batch size might require you to change other parameters accordingly, such as the learning-rate or the ema range. However, running out of memory in the middle of training is not great. Would you mind sharing some logs including:

  • std-out
  • config

MidoAssran avatar Jul 03 '23 21:07 MidoAssran

Thank you

Hi @Gus-Guo, batch size is not important for the method, since there is no batch-component to the loss. However, I expect that changing the batch size might require you to change other parameters accordingly, such as the learning-rate or the ema range. However, running out of memory in the middle of training is not great. Would you mind sharing some logs including:

  • std-out
  • config

Mido, thank you very much for your reply. My training config is as follows:


data: batch_size: 112 #bs is 128 in the paper color_jitter_strength: 0.0 crop_scale:

  • 0.3
  • 1.0 crop_size: 224 image_folder: imagenet_full_size/061417/ num_workers: 2 pin_mem: true use_color_distortion: false use_gaussian_blur: false use_horizontal_flip: false use_fake_data: false hdfs_data_path: /workspace/Dataset/imagenet1k/train train_dataset_size: 1281167 logging: folder: experiments/ijepa/in1k_vith14_epo300_bs112_16gpu_a100_amp_fp16 write_tag: jepa mask: allow_overlap: false aspect_ratio:
  • 0.75
  • 1.5 enc_mask_scale:
  • 0.85
  • 1.0 min_keep: 10 num_enc_masks: 1 num_pred_masks: 4 patch_size: 14 pred_mask_scale:
  • 0.15
  • 0.2 meta: copy_data: false load_checkpoint: false model_name: vit_huge pred_depth: 12 pred_emb_dim: 384 read_checkpoint: null use_bfloat16: true optimization: ema:
  • 0.996
  • 1.0 epochs: 300 final_lr: 1.0e-06 final_weight_decay: 0.4 ipe_scale: 1.0 lr: 0.001 start_lr: 0.0002 warmup: 40 weight_decay: 0.04 num_accumulate_iter: 1

part of my training logs where running out of memory:

INFO:root:rank [0],[Mon Jul 3 14:54:28 2023]: [2, 190] loss: 0.104 masks: 68.8 44.1 [wd: 4.00e-02] [lr: 2.26e-04] [mem: 7.07e+04] (5109.5 ms) INFO:root:[2, 190] grad_stats: [6.88e-02 2.97e-04] (2.59e-04, 9.89e-02) INFO:root:rank [0],[Mon Jul 3 14:55:18 2023]: [2, 200] loss: 0.103 masks: 68.7 44.2 [wd: 4.00e-02] [lr: 2.26e-04] [mem: 7.07e+04] (5105.9 ms) INFO:root:[2, 200] grad_stats: [4.58e-02 1.88e-04] (1.67e-04, 7.15e-02) INFO:root:rank [0],[Mon Jul 3 14:56:08 2023]: [2, 210] loss: 0.102 masks: 68.6 44.2 [wd: 4.00e-02] [lr: 2.27e-04] [mem: 7.07e+04] (5097.9 ms) INFO:root:[2, 210] grad_stats: [1.13e-02 7.24e-05] (6.50e-05, 1.72e-02) INFO:root:rank [0],[Mon Jul 3 14:56:59 2023]: [2, 220] loss: 0.101 masks: 68.7 44.2 [wd: 4.00e-02] [lr: 2.27e-04] [mem: 7.07e+04] (5096.2 ms) INFO:root:[2, 220] grad_stats: [2.06e-02 1.49e-04] (1.11e-04, 3.02e-02) INFO:root:rank [0],[Mon Jul 3 14:57:50 2023]: [2, 230] loss: 0.100 masks: 68.6 44.1 [wd: 4.00e-02] [lr: 2.27e-04] [mem: 7.07e+04] (5100.3 ms) INFO:root:[2, 230] grad_stats: [1.29e-01 5.10e-04] (4.46e-04, 1.94e-01) INFO:root:rank [0],[Mon Jul 3 14:58:38 2023]: [2, 240] loss: 0.099 masks: 68.6 44.1 [wd: 4.00e-02] [lr: 2.28e-04] [mem: 7.07e+04] (5086.8 ms) INFO:root:[2, 240] grad_stats: [5.05e-02 2.08e-04] (1.88e-04, 6.79e-02) INFO:root:rank [0],[Mon Jul 3 14:59:30 2023]: [2, 250] loss: 0.099 masks: 68.6 44.2 [wd: 4.00e-02] [lr: 2.28e-04] [mem: 7.07e+04] (5090.7 ms) INFO:root:[2, 250] grad_stats: [4.58e-02 2.69e-04] (1.77e-04, 8.33e-02) INFO:root:rank [0],[Mon Jul 3 15:00:16 2023]: [2, 260] loss: 0.098 masks: 68.5 44.2 [wd: 4.00e-02] [lr: 2.28e-04] [mem: 7.07e+04] (5071.0 ms) INFO:root:[2, 260] grad_stats: [2.65e-02 2.61e-04] (1.72e-04, 4.26e-02) INFO:root:rank [0],[Mon Jul 3 15:01:05 2023]: [2, 270] loss: 0.097 masks: 68.5 44.2 [wd: 4.00e-02] [lr: 2.29e-04] [mem: 7.07e+04] (5062.7 ms) INFO:root:[2, 270] grad_stats: [2.16e-02 1.40e-04] (1.14e-04, 3.57e-02) INFO:root:rank [0],[Mon Jul 3 15:01:56 2023]: [2, 280] loss: 0.096 masks: 68.4 44.3 [wd: 4.00e-02] [lr: 2.29e-04] [mem: 7.07e+04] (5064.2 ms) INFO:root:[2, 280] grad_stats: [2.36e-02 2.30e-04] (1.44e-04, 2.88e-02) INFO:root:rank [0],[Mon Jul 3 15:02:51 2023]: [2, 290] loss: 0.095 masks: 68.3 44.3 [wd: 4.00e-02] [lr: 2.29e-04] [mem: 7.07e+04] (5080.0 ms) INFO:root:[2, 290] grad_stats: [1.22e-02 1.14e-04] (9.00e-05, 2.19e-02)

File "/opt/tiger/test_merlin_demo/src/train_iter.py", line 419, in forward_context z = predictor(z, masks_enc, masks_pred) File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1118, in _call_impl return forward_call(*input, **kwargs) File "/usr/local/lib/python3.7/dist-packages/torch/nn/parallel/distributed.py", line 892, in forward output = self.module(*inputs[0], **kwargs[0]) File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1118, in _call_impl return forward_call(*input, **kwargs) File "/opt/tiger/test_merlin_demo/src/models/vision_transformer.py", line 319, in forward x = blk(x) File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1118, in _call_impl return forward_call(*input, **kwargs) File "/opt/tiger/test_merlin_demo/src/models/vision_transformer.py", line 170, in forward x = x + self.drop_path(self.mlp(self.norm2(x))) File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1118, in _call_impl return forward_call(*input, **kwargs) File "/opt/tiger/test_merlin_demo/src/models/vision_transformer.py", line 118, in forward x = self.fc1(x) File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1118, in _call_impl return forward_call(*input, **kwargs) File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/linear.py", line 103, in forward return F.linear(input, self.weight, self.bias) File "/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py", line 1848, in linear return torch._C._nn.linear(input, weight, bias) RuntimeError: CUDA out of memory. Tried to allocate 204.00 MiB (GPU 7; 79.35 GiB total capacity; 70.82 GiB already allocated; 180.19 MiB free; 76.69 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Gus-Guo avatar Jul 04 '23 02:07 Gus-Guo

It's surprising that it's happening already more than an epoch into training. Not sure if there are other processes running on your GPU, but could you try changing enc_mask_scale to the range (0.7, 0.8)? This change should result in smaller context blocks, which will use less memory. I'm somewhat hopeful that this wouldn't require any hyperparam changes.

MidoAssran avatar Jul 04 '23 20:07 MidoAssran

Thank you very much. I will have a try.

Gus-Guo avatar Jul 05 '23 03:07 Gus-Guo