imagen-pytorch
imagen-pytorch copied to clipboard
multi-gpus with max_batch_size > 1 with ddp
@lucidrains hi, when I train the imagen with multi-gpus, an warning occured “UserWarning: Grad strides do not match bucket view strides. ”.
warning
data:image/s3,"s3://crabby-images/c583e/c583e1ed1f7588de78435b029436df4bed41637c" alt="image"
/home/tqd/anaconda3/envs/imagen_pyenv3/lib/python3.9/site-packages/torch/autograd/__init__.py:173: UserWarning: Grad strides do not match bucket view strides. This may indicate grad was not created according to the gradient layout contract, or that the param's strides changed since DDP was constructed. This is not an error, but may impair performance.
grad.sizes() = [2048, 1024, 1, 1], strides() = [1024, 1, 1024, 1024]
bucket_view.sizes() = [2048, 1024, 1, 1], strides() = [1024, 1, 1, 1] (Triggered internally at /opt/conda/conda-bld/pytorch_1656352660876/work/torch/csrc/distributed/c10d/reducer.cpp:312.)
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
/home/tqd/anaconda3/envs/imagen_pyenv3/lib/python3.9/site-packages/torch/autograd/__init__.py:173: UserWarning: Grad strides do not match bucket view strides. This may indicate grad was not created according to the gradient layout contract, or that the param's strides changed since DDP was constructed. This is not an error, but may impair performance.
grad.sizes() = [2048, 1024, 1, 1], strides() = [1024, 1, 1024, 1024]
bucket_view.sizes() = [2048, 1024, 1, 1], strides() = [1024, 1, 1, 1] (Triggered internally at /opt/conda/conda-bld/pytorch_1656352660876/work/torch/csrc/distributed/c10d/reducer.cpp:312.)
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
/home/tqd/anaconda3/envs/imagen_pyenv3/lib/python3.9/site-packages/torch/autograd/__init__.py:173: UserWarning: Grad strides do not match bucket view strides. This may indicate grad was not created according to the gradient layout contract, or that the param's strides changed since DDP was constructed. This is not an error, but may impair performance.
grad.sizes() = [2048, 1024, 1, 1], strides() = [1024, 1, 1024, 1024]
bucket_view.sizes() = [2048, 1024, 1, 1], strides() = [1024, 1, 1, 1] (Triggered internally at /opt/conda/conda-bld/pytorch_1656352660876/work/torch/csrc/distributed/c10d/reducer.cpp:312.)
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
train script
def main(
config,
unet,
epoches,
text,
valid
):
# check config path
config_path = Path(config)
full_config_path = str(config_path.resolve())
assert config_path.exists(), f'config not found at {full_config_path}'
with open(config_path, 'r') as f:
config_data = json.loads(f.read())
assert 'checkpoint_path' in config_data, 'checkpoint path not found in config'
model_path = Path(config_data['checkpoint_path'])
full_model_path = str(model_path.resolve())
# setup imagen config
imagen_config_klass = ElucidatedImagenConfig if config_data['type'] == 'elucidated' else ImagenConfig
imagen = imagen_config_klass(**config_data['imagen']).create()
trainer = ImagenTrainer(
imagen = imagen,
dl_tuple_output_keywords_names = ('images', 'text_embeds', 'text_masks'),
**config_data['trainer']
)
# load pt
if model_path.exists():
loaded = torch.load(str(model_path))
version = safeget(loaded, 'version')
print(f'loading Imagen from {full_model_path}, saved at version {version} - current package version is {__version__}')
trainer.load(model_path)
if torch.cuda.is_available():
trainer = trainer.cuda()
size = config_data['imagen']['image_sizes'][unet-1]
max_batch_size = config_data['max_batch_size'] if 'max_batch_size' in config_data else 1
channels = 'RGB'
if 'channels' in config_data['imagen']:
assert config_data['imagen']['channels'] > 0 and config_data['imagen']['channels'] < 5, 'Imagen only support 1 to 4 channels L, LA, RGB, RGBA'
if config_data['imagen']['channels'] == 4:
channels = 'RGBA' # Color with alpha
elif config_data['imagen']['channels'] == 2:
channels == 'LA' # Luminance (Greyscale) with alpha
elif config_data['imagen']['channels'] == 1:
channels = 'L' # Luminance (Greyscale)
assert 'batch_size' in config_data['dataset'], 'A batch_size is required in the config file'
# load and add train dataset and valid dataset
ds = load_dataset(config_data['dataset_name'])
print(ds)
# ds = load_dataset("wukong50k")
trainer.add_train_dataset(
ds = ds['train'],
collate_fn = MyCollator(
image_size = size,
image_label = config_data['image_label'],
text_label = config_data['text_label'],
url_label = config_data['url_label'],
name = imagen.text_encoder_name,
channels = channels
),
**config_data['dataset']
)
if not trainer.split_valid_from_train and valid != 0:
assert 'validation' in ds, 'There is no validation split in the dataset'
trainer.add_valid_dataset(
ds = ds['validation'],
collate_fn = Collator(
image_size = size,
image_label = config_data['image_label'],
text_label= config_data['text_label'],
url_label = config_data['url_label'],
name = imagen.text_encoder_name,
channels = channels
),
**config_data['dataset']
)
for i in range(epoches):
loss = trainer.train_step(unet_number = unet, max_batch_size = max_batch_size)
print(f'step {i}, loss: {loss}')
if valid != 0 and not (i % valid) and i > 0:
valid_loss = trainer.valid_step(unet_number = unet, max_batch_size = max_batch_size)
print(f'step {i}, valid loss: {valid_loss}')
if not (i % 322) and i > 0 and trainer.is_main and text is not None:
images = trainer.sample(texts = [text], batch_size = 1, return_pil_images = True, stop_at_unet_number = unet)
images[0].save(f'./sample-{i // 322}.png')
trainer.save(model_path)
if __name__ == "__main__":
main("./myconfig.json", 1, 400, "a cute bird", 0)
config
{
"type": "original",
"imagen": {
"video": false,
"timesteps": [1024, 512, 512],
"image_sizes": [64, 256, 1024],
"random_crop_sizes": [null, 64, 256],
"condition_on_text": true,
"cond_drop_prob": 0.1,
"text_encoder_name": "IDEA-CCNL/Randeng-T5-Char-700M-Chinese",
"unets": [
{
"dim": 512,
"dim_mults": [1, 2, 3, 4],
"num_resnet_blocks": 3,
"layer_attns": [false, true, true, true],
"layer_cross_attns": [false, true, true, true],
"attn_heads": 8
},
{
"dim": 128,
"dim_mults": [1, 2, 4, 8],
"num_resnet_blocks": [2, 4, 8, 8],
"layer_attns": [false, false, false, true],
"layer_cross_attns": [false, false, false, true],
"attn_heads": 8
},
{
"dim": 128,
"dim_mults": [1, 2, 4, 8],
"num_resnet_blocks": [2, 4, 8, 8],
"layer_attns": false,
"layer_cross_attns": [false, false, false, true],
"attn_heads": 8
}
]
},
"trainer": {
"lr": 1e-4,
"fp16": true,
"use_ema": false
},
"dataset_name": "conceptual_captions",
"dataset": {
"batch_size": 512,
"shuffle": true
},
"max_batch_size": 32,
"image_label": "img_path",
"url_label": null,
"text_label": "caption",
"checkpoint_path": "./imagen.pt"
}
I still have this problem. Have you solved it?
这是来自QQ邮箱的自动回复邮件。您好~您发送的邮件我已收到。谢谢您的邮件~