imagen-pytorch icon indicating copy to clipboard operation
imagen-pytorch copied to clipboard

multi-gpus with max_batch_size > 1 with ddp

Open xqun3 opened this issue 2 years ago • 2 comments

@lucidrains hi, when I train the imagen with multi-gpus, an warning occured “UserWarning: Grad strides do not match bucket view strides. ”.

warning

image
/home/tqd/anaconda3/envs/imagen_pyenv3/lib/python3.9/site-packages/torch/autograd/__init__.py:173: UserWarning: Grad strides do not match bucket view strides. This may indicate grad was not created according to the gradient layout contract, or that the param's strides changed since DDP was constructed.  This is not an error, but may impair performance.
grad.sizes() = [2048, 1024, 1, 1], strides() = [1024, 1, 1024, 1024]
bucket_view.sizes() = [2048, 1024, 1, 1], strides() = [1024, 1, 1, 1] (Triggered internally at  /opt/conda/conda-bld/pytorch_1656352660876/work/torch/csrc/distributed/c10d/reducer.cpp:312.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
/home/tqd/anaconda3/envs/imagen_pyenv3/lib/python3.9/site-packages/torch/autograd/__init__.py:173: UserWarning: Grad strides do not match bucket view strides. This may indicate grad was not created according to the gradient layout contract, or that the param's strides changed since DDP was constructed.  This is not an error, but may impair performance.
grad.sizes() = [2048, 1024, 1, 1], strides() = [1024, 1, 1024, 1024]
bucket_view.sizes() = [2048, 1024, 1, 1], strides() = [1024, 1, 1, 1] (Triggered internally at  /opt/conda/conda-bld/pytorch_1656352660876/work/torch/csrc/distributed/c10d/reducer.cpp:312.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
/home/tqd/anaconda3/envs/imagen_pyenv3/lib/python3.9/site-packages/torch/autograd/__init__.py:173: UserWarning: Grad strides do not match bucket view strides. This may indicate grad was not created according to the gradient layout contract, or that the param's strides changed since DDP was constructed.  This is not an error, but may impair performance.
grad.sizes() = [2048, 1024, 1, 1], strides() = [1024, 1, 1024, 1024]
bucket_view.sizes() = [2048, 1024, 1, 1], strides() = [1024, 1, 1, 1] (Triggered internally at  /opt/conda/conda-bld/pytorch_1656352660876/work/torch/csrc/distributed/c10d/reducer.cpp:312.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

train script

def main(
    config,
    unet,
    epoches,
    text,
    valid
):
    # check config path

    config_path = Path(config)
    full_config_path = str(config_path.resolve())
    assert config_path.exists(), f'config not found at {full_config_path}'
    
    with open(config_path, 'r') as f:
        config_data = json.loads(f.read())

    assert 'checkpoint_path' in config_data, 'checkpoint path not found in config'
    
    model_path = Path(config_data['checkpoint_path'])
    full_model_path = str(model_path.resolve())
    
    # setup imagen config

    imagen_config_klass = ElucidatedImagenConfig if config_data['type'] == 'elucidated' else ImagenConfig
    imagen = imagen_config_klass(**config_data['imagen']).create()

    trainer = ImagenTrainer(
        imagen = imagen,
        dl_tuple_output_keywords_names = ('images', 'text_embeds', 'text_masks'),
        **config_data['trainer']
    )

    # load pt
    if model_path.exists():
        loaded = torch.load(str(model_path))
        version = safeget(loaded, 'version')
        print(f'loading Imagen from {full_model_path}, saved at version {version} - current package version is {__version__}')
        trainer.load(model_path)
        
    if torch.cuda.is_available():
        trainer = trainer.cuda()

    size = config_data['imagen']['image_sizes'][unet-1]

    max_batch_size = config_data['max_batch_size'] if 'max_batch_size' in config_data else 1

    channels = 'RGB'
    if 'channels' in config_data['imagen']:
        assert config_data['imagen']['channels'] > 0 and config_data['imagen']['channels'] < 5, 'Imagen only support 1 to 4 channels L, LA, RGB, RGBA'
        if config_data['imagen']['channels'] == 4:
            channels = 'RGBA' # Color with alpha
        elif config_data['imagen']['channels'] == 2:
            channels == 'LA' # Luminance (Greyscale) with alpha
        elif config_data['imagen']['channels'] == 1:
            channels = 'L' # Luminance (Greyscale)


    assert 'batch_size' in config_data['dataset'], 'A batch_size is required in the config file'
    
    # load and add train dataset and valid dataset
    ds = load_dataset(config_data['dataset_name'])
    print(ds)
    # ds = load_dataset("wukong50k")
    trainer.add_train_dataset(
        ds = ds['train'],
        collate_fn = MyCollator(
            image_size = size,
            image_label = config_data['image_label'],
            text_label = config_data['text_label'],
            url_label = config_data['url_label'],
            name = imagen.text_encoder_name,
            channels = channels
        ),
        **config_data['dataset']
    )


    if not trainer.split_valid_from_train and valid != 0:
        assert 'validation' in ds, 'There is no validation split in the dataset'
        trainer.add_valid_dataset(
            ds = ds['validation'],
            collate_fn = Collator(
                image_size = size,
                image_label = config_data['image_label'],
                text_label= config_data['text_label'],
                url_label = config_data['url_label'],
                name = imagen.text_encoder_name,
                channels = channels
            ),
            **config_data['dataset']
        )

    for i in range(epoches):
        loss = trainer.train_step(unet_number = unet, max_batch_size = max_batch_size)
        print(f'step {i}, loss: {loss}')

        if valid != 0 and not (i % valid) and i > 0:
            valid_loss = trainer.valid_step(unet_number = unet, max_batch_size = max_batch_size)
            print(f'step {i}, valid loss: {valid_loss}')

        if not (i % 322) and i > 0 and trainer.is_main and text is not None:
            images = trainer.sample(texts = [text], batch_size = 1, return_pil_images = True, stop_at_unet_number = unet)
            images[0].save(f'./sample-{i // 322}.png')

    trainer.save(model_path)


if __name__ == "__main__":
    main("./myconfig.json", 1, 400, "a cute bird", 0)


config

{
    "type": "original",
    "imagen": {
        "video": false,
        "timesteps": [1024, 512, 512],
        "image_sizes": [64, 256, 1024],
        "random_crop_sizes": [null, 64, 256],
        "condition_on_text": true,
        "cond_drop_prob": 0.1,
        "text_encoder_name": "IDEA-CCNL/Randeng-T5-Char-700M-Chinese",
        "unets": [
            {
                "dim": 512,
                "dim_mults": [1, 2, 3, 4],
                "num_resnet_blocks": 3,
                "layer_attns": [false, true, true, true],
                "layer_cross_attns": [false, true, true, true],
                "attn_heads": 8
            },
            {
                "dim": 128,
                "dim_mults": [1, 2, 4, 8],
                "num_resnet_blocks": [2, 4, 8, 8],
                "layer_attns": [false, false, false, true],
                "layer_cross_attns": [false, false, false, true],
                "attn_heads": 8
            },
            {
                "dim": 128,
                "dim_mults": [1, 2, 4, 8],
                "num_resnet_blocks": [2, 4, 8, 8],
                "layer_attns": false,
                "layer_cross_attns": [false, false, false, true],
                "attn_heads": 8
            }
        ]
    },
    "trainer": {
        "lr": 1e-4,
        "fp16": true,
        "use_ema": false
    },
    "dataset_name": "conceptual_captions",
    "dataset": {
        "batch_size": 512,
        "shuffle": true
    },
    "max_batch_size": 32,
    "image_label": "img_path",
    "url_label": null,
    "text_label": "caption",
    "checkpoint_path": "./imagen.pt"
}

xqun3 avatar Feb 02 '23 09:02 xqun3

I still have this problem. Have you solved it?

clearlyzero avatar Jan 09 '24 12:01 clearlyzero

这是来自QQ邮箱的自动回复邮件。您好~您发送的邮件我已收到。谢谢您的邮件~

xqun3 avatar Jan 09 '24 12:01 xqun3