libai icon indicating copy to clipboard operation
libai copied to clipboard

Support Repeated Augmentation

Open rentainhe opened this issue 2 years ago • 2 comments

TODO

  • [x] 添加Repeated Augmentation (一种Mini-batch的采样策略)

在Swin-T,DeiT等主流ViT中都用到的一个增强策略,对于大模型涨点比较明显,对于小模型没有太大影响

Reference

  • deit实现: https://github.com/facebookresearch/deit/blob/main/samplers.py
  • 知乎解读: https://zhuanlan.zhihu.com/p/430563265
  • mmcls的实现: https://github.com/open-mmlab/mmclassification/blob/master/mmcls/datasets/samplers/repeat_aug.py

rentainhe avatar Apr 19 '22 04:04 rentainhe

可执行测试的代码段

# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import oneflow as flow
from oneflow.utils.data import Sampler


class RASampler(Sampler):
    """
    Sampler that restricts data loading to a subset of the dataset for distributed,
    with repeated augmentation.

    It ensures that different different process (GPU) will see a
    different augmented version of the same sample.

    Heavily based on flow.utils.data.DistributedSampler

    Arguments:
        dataset: dataset to be sampled.
        micro_batch_size: batch size for per model instance.
        global_batch_size is micro_batch_size times data_parallel_size.
        shuffle: whether to shuffle the dataset.
        data_parallel_rank: local rank for data parallelism.
        data_parallel_size: the size of data parallelism.
        num_repeats: repeat sampling nums for each sample.
        selected_round: determine the number of samples to select per epoch for each rank
        seed: random seed, used for reproducing experiments (default: ``0``).
    """

    def __init__(
        self,
        dataset,
        micro_batch_size,
        shuffle=True,
        consumed_samples=0,
        data_parallel_rank=0,
        data_parallel_size=1,
        num_repeats=3,
        selected_round=256,
        seed=0,
    ):
        self.data_parallel_size = data_parallel_size
        self.dataset = dataset
        self.rank = data_parallel_rank
        self.num_repeats = num_repeats
        self.micro_batch_size = micro_batch_size
        self.actual_batch_size = self.micro_batch_size * self.data_parallel_size

        # samples for each rank: dataset size * repeat nums / rank nums
        self.num_samples = int(
            math.ceil(len(self.dataset) * self.num_repeats / self.data_parallel_size)
        )

        # the total samples after repeat sampling
        self.total_size = self.num_samples * self.data_parallel_size

        # the real samples nums for each rank without repeat samples
        if selected_round:
            self.data_size = int(
                math.floor(
                    len(self.dataset) // selected_round * selected_round / self.data_parallel_size
                )
            )
        else:
            self.data_size = int(math.ceil(len(self.dataset) / self.data_parallel_size))

        self.shuffle = shuffle
        self.consumed_samples = consumed_samples

        self.seed = seed

    def __iter__(self):
        epoch = self.consumed_samples // self.data_size
        batch = []

        while True:
            # deterministically shuffle based on epoch
            if self.shuffle:
                g = flow.Generator()
                g.manual_seed(self.seed + epoch)
                indices = flow.randperm(len(self.dataset), generator=g).tolist()
            else:
                indices = flow.arange(start=0, end=len(self.dataset))

            # add extra samples to make it evenly divisible
            # produce repeats e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2....]
            indices = [ele for ele in indices for i in range(self.num_repeats)]
            padding_size: int = self.total_size - len(indices)
            if padding_size > 0:
                indices += indices[:padding_size]
            assert len(indices) == self.total_size

            # subsample: force the repeated samples being put into different rank
            indices = indices[self.rank : self.total_size : self.data_parallel_size]
            assert len(indices) == self.num_samples

            indices = iter(indices[: self.data_size])

            for idx in indices:
                batch.append(idx)
                if len(batch) == self.micro_batch_size:
                    self.consumed_samples += self.actual_batch_size
                    yield batch
                    batch = []

    def __len__(self):
        return self.data_size

    def set_consumed_samples(self, consumed_samples):
        """you can recover the training iteration by setting `consumed_samples`."""
        self.consumed_samples = consumed_samples

    def set_epoch(self, epoch):
        """used for restoring training status."""
        self.epoch = epoch


sampler = RASampler(list(range(2000)), 4, data_parallel_size=2, data_parallel_rank=0)
for index in sampler:
    print(index)
    break

sampler = RASampler(list(range(2000)), 4, data_parallel_size=2, data_parallel_rank=1)
for index in sampler:
    print(index)
    break

rentainhe avatar Apr 21 '22 01:04 rentainhe

@L1aoXingyu @CPFLAME @Ldpe2G 这个PR可以再帮我一起思考一下,我执行了一个vit_small的完整测试,发现好像用上了这个后似乎也没有用。。可能是我的实现有点问题

rentainhe avatar Apr 24 '22 04:04 rentainhe