straug icon indicating copy to clipboard operation
straug copied to clipboard

question about training speed

Open littletomatodonkey opened this issue 3 years ago • 0 comments

thanks for your excellent job! it seems that the training is very slow when i use the straug(6x times slower than that without straug). What about the real speed when you test? The following is my aug-code.

class RecStraugRandAug(object):
    def __init__(self, num_aug=2, prob=0.5, **kwargs):
        super().__init__()
        self.num_aug = num_aug
        self.prob = prob
        try:
            from straug.blur import GaussianBlur, DefocusBlur, MotionBlur, GlassBlur
            from straug.camera import Contrast, Brightness, JpegCompression, Pixelate
            from straug.geometry import Perspective, Shrink, Rotate
            from straug.noise import GaussianNoise, ShotNoise, ImpulseNoise, SpeckleNoise
            from straug.pattern import Grid, VGrid, HGrid, RectGrid, EllipseGrid
            from straug.process import Posterize, Solarize, Invert, Equalize, AutoContrast, Sharpness, Color
            from straug.warp import Stretch, Distort, Curve
            from straug.weather import Fog, Snow, Frost, Rain, Shadow
            self.augs = [
                [GaussianBlur(), DefocusBlur(), MotionBlur(), GlassBlur()],
                [Contrast(), Brightness(), JpegCompression(), Pixelate()],
                [Perspective(), Shrink(), Rotate()],
                [GaussianNoise(), ShotNoise(), ImpulseNoise(), SpeckleNoise()],
                [Grid(), VGrid(), HGrid(), RectGrid(), EllipseGrid()],
                [Posterize(), Solarize(), Invert(), Equalize(), AutoContrast(), Sharpness(), Color()],
                [Stretch(), Distort(), Curve()],
                [Fog(), Snow(), Frost(), Rain(), Shadow()],
            ]
        except Exception as ex:
            print(f"exception: {ex}, you can install straug using `pip install straug`")
            exit(-1)
    
    def __call__(self, data):
        img = Image.fromarray(data["image"])
        for idx in range(self.num_aug):
            aug_type_idx = np.random.randint(0, len(self.augs))
            aug_idx = np.random.randint(0, len(self.augs[aug_type_idx]))
            img = self.augs[aug_type_idx][aug_idx](img, mag=random.randint(-1,2), prob=self.prob)
        data["image"] = np.array(img)
        return data

littletomatodonkey avatar Jan 19 '22 00:01 littletomatodonkey