ijepa
ijepa copied to clipboard
is batch size per gpu important to re-implement the accuracy reported in the paper?
Hi, I recently has run the vit-14-ep300 config on 16 a100 gpus. But since the gpu would run out of memory in the middle of training, I decrease the batch size from 128 into 112 per gpu. But I obtain a lower linear probe accuracy (like -2%). Is it important to preserve the batch size of 128 per gpu? Thank you very much!
Hi @Gus-Guo, batch size is not important for the method, since there is no batch-component to the loss. However, I expect that changing the batch size might require you to change other parameters accordingly, such as the learning-rate or the ema range. However, running out of memory in the middle of training is not great. Would you mind sharing some logs including:
- std-out
- config
Thank you
Hi @Gus-Guo, batch size is not important for the method, since there is no batch-component to the loss. However, I expect that changing the batch size might require you to change other parameters accordingly, such as the learning-rate or the ema range. However, running out of memory in the middle of training is not great. Would you mind sharing some logs including:
- std-out
- config
Mido, thank you very much for your reply. My training config is as follows:
data: batch_size: 112 #bs is 128 in the paper color_jitter_strength: 0.0 crop_scale:
- 0.3
- 1.0 crop_size: 224 image_folder: imagenet_full_size/061417/ num_workers: 2 pin_mem: true use_color_distortion: false use_gaussian_blur: false use_horizontal_flip: false use_fake_data: false hdfs_data_path: /workspace/Dataset/imagenet1k/train train_dataset_size: 1281167 logging: folder: experiments/ijepa/in1k_vith14_epo300_bs112_16gpu_a100_amp_fp16 write_tag: jepa mask: allow_overlap: false aspect_ratio:
- 0.75
- 1.5 enc_mask_scale:
- 0.85
- 1.0 min_keep: 10 num_enc_masks: 1 num_pred_masks: 4 patch_size: 14 pred_mask_scale:
- 0.15
- 0.2 meta: copy_data: false load_checkpoint: false model_name: vit_huge pred_depth: 12 pred_emb_dim: 384 read_checkpoint: null use_bfloat16: true optimization: ema:
- 0.996
- 1.0 epochs: 300 final_lr: 1.0e-06 final_weight_decay: 0.4 ipe_scale: 1.0 lr: 0.001 start_lr: 0.0002 warmup: 40 weight_decay: 0.04 num_accumulate_iter: 1
part of my training logs where running out of memory:
INFO:root:rank [0],[Mon Jul 3 14:54:28 2023]: [2, 190] loss: 0.104 masks: 68.8 44.1 [wd: 4.00e-02] [lr: 2.26e-04] [mem: 7.07e+04] (5109.5 ms) INFO:root:[2, 190] grad_stats: [6.88e-02 2.97e-04] (2.59e-04, 9.89e-02) INFO:root:rank [0],[Mon Jul 3 14:55:18 2023]: [2, 200] loss: 0.103 masks: 68.7 44.2 [wd: 4.00e-02] [lr: 2.26e-04] [mem: 7.07e+04] (5105.9 ms) INFO:root:[2, 200] grad_stats: [4.58e-02 1.88e-04] (1.67e-04, 7.15e-02) INFO:root:rank [0],[Mon Jul 3 14:56:08 2023]: [2, 210] loss: 0.102 masks: 68.6 44.2 [wd: 4.00e-02] [lr: 2.27e-04] [mem: 7.07e+04] (5097.9 ms) INFO:root:[2, 210] grad_stats: [1.13e-02 7.24e-05] (6.50e-05, 1.72e-02) INFO:root:rank [0],[Mon Jul 3 14:56:59 2023]: [2, 220] loss: 0.101 masks: 68.7 44.2 [wd: 4.00e-02] [lr: 2.27e-04] [mem: 7.07e+04] (5096.2 ms) INFO:root:[2, 220] grad_stats: [2.06e-02 1.49e-04] (1.11e-04, 3.02e-02) INFO:root:rank [0],[Mon Jul 3 14:57:50 2023]: [2, 230] loss: 0.100 masks: 68.6 44.1 [wd: 4.00e-02] [lr: 2.27e-04] [mem: 7.07e+04] (5100.3 ms) INFO:root:[2, 230] grad_stats: [1.29e-01 5.10e-04] (4.46e-04, 1.94e-01) INFO:root:rank [0],[Mon Jul 3 14:58:38 2023]: [2, 240] loss: 0.099 masks: 68.6 44.1 [wd: 4.00e-02] [lr: 2.28e-04] [mem: 7.07e+04] (5086.8 ms) INFO:root:[2, 240] grad_stats: [5.05e-02 2.08e-04] (1.88e-04, 6.79e-02) INFO:root:rank [0],[Mon Jul 3 14:59:30 2023]: [2, 250] loss: 0.099 masks: 68.6 44.2 [wd: 4.00e-02] [lr: 2.28e-04] [mem: 7.07e+04] (5090.7 ms) INFO:root:[2, 250] grad_stats: [4.58e-02 2.69e-04] (1.77e-04, 8.33e-02) INFO:root:rank [0],[Mon Jul 3 15:00:16 2023]: [2, 260] loss: 0.098 masks: 68.5 44.2 [wd: 4.00e-02] [lr: 2.28e-04] [mem: 7.07e+04] (5071.0 ms) INFO:root:[2, 260] grad_stats: [2.65e-02 2.61e-04] (1.72e-04, 4.26e-02) INFO:root:rank [0],[Mon Jul 3 15:01:05 2023]: [2, 270] loss: 0.097 masks: 68.5 44.2 [wd: 4.00e-02] [lr: 2.29e-04] [mem: 7.07e+04] (5062.7 ms) INFO:root:[2, 270] grad_stats: [2.16e-02 1.40e-04] (1.14e-04, 3.57e-02) INFO:root:rank [0],[Mon Jul 3 15:01:56 2023]: [2, 280] loss: 0.096 masks: 68.4 44.3 [wd: 4.00e-02] [lr: 2.29e-04] [mem: 7.07e+04] (5064.2 ms) INFO:root:[2, 280] grad_stats: [2.36e-02 2.30e-04] (1.44e-04, 2.88e-02) INFO:root:rank [0],[Mon Jul 3 15:02:51 2023]: [2, 290] loss: 0.095 masks: 68.3 44.3 [wd: 4.00e-02] [lr: 2.29e-04] [mem: 7.07e+04] (5080.0 ms) INFO:root:[2, 290] grad_stats: [1.22e-02 1.14e-04] (9.00e-05, 2.19e-02)
File "/opt/tiger/test_merlin_demo/src/train_iter.py", line 419, in forward_context z = predictor(z, masks_enc, masks_pred) File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1118, in _call_impl return forward_call(*input, **kwargs) File "/usr/local/lib/python3.7/dist-packages/torch/nn/parallel/distributed.py", line 892, in forward output = self.module(*inputs[0], **kwargs[0]) File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1118, in _call_impl return forward_call(*input, **kwargs) File "/opt/tiger/test_merlin_demo/src/models/vision_transformer.py", line 319, in forward x = blk(x) File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1118, in _call_impl return forward_call(*input, **kwargs) File "/opt/tiger/test_merlin_demo/src/models/vision_transformer.py", line 170, in forward x = x + self.drop_path(self.mlp(self.norm2(x))) File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1118, in _call_impl return forward_call(*input, **kwargs) File "/opt/tiger/test_merlin_demo/src/models/vision_transformer.py", line 118, in forward x = self.fc1(x) File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1118, in _call_impl return forward_call(*input, **kwargs) File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/linear.py", line 103, in forward return F.linear(input, self.weight, self.bias) File "/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py", line 1848, in linear return torch._C._nn.linear(input, weight, bias) RuntimeError: CUDA out of memory. Tried to allocate 204.00 MiB (GPU 7; 79.35 GiB total capacity; 70.82 GiB already allocated; 180.19 MiB free; 76.69 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
It's surprising that it's happening already more than an epoch into training. Not sure if there are other processes running on your GPU, but could you try changing enc_mask_scale
to the range (0.7, 0.8)
? This change should result in smaller context blocks, which will use less memory. I'm somewhat hopeful that this wouldn't require any hyperparam changes.
Thank you very much. I will have a try.