DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

[BUG]Traning multiple model with deepspeed

Open uygnef opened this issue 1 year ago • 11 comments

Describe the bug I am currently attempting to train a txt2img model (both encoder and unet) using deepspeed. I have made some modifications to the code, but I am encountering an error. The error message indicates that there may be an issue with the backward function.

To Reproduce Steps to reproduce the behavior:

    ds_config = {
                  "train_batch_size": 24,
                  "gradient_accumulation_steps": 1,
                  "optimizer": {
                    "type": "Adam",
                    "params": {
                      "lr": 0.01,
                      "betas": [args.adam_beta1, args.adam_beta2],
                      "weight_decay": args.adam_weight_decay,
                      "eps": args.adam_epsilon
                    }
                  },
                  "zero_optimization": {
                    "stage": 3,
                  },
                #    "offload_param": {
                #     "device": "cpu",
                #     "pin_memory": True,
                #     "buffer_count": 5,
                #     "buffer_size": 1e8,
                #     "max_in_cpu": 1e9
                # },
                #   "offload_optimizer": {
                #     "device": "cpu",
                #     "pin_memory": True,
                #     "buffer_count": 4,
                #     "fast_init": False
                # }
                #   "hybrid_engine": {
                #     "enabled": True,
                #     "inference_tp_size": 8,
                #     "release_inference_cache": False,
                #     "pin_parameters": True,
                #     "tp_gather_partition_size": 8,
                # }
                }

text_encoder, text_encoder_optimizer, _, _ = deepspeed.initialize(model=text_encoder, config_params=ds_config)
unet, unet_optimizer, _, _ = deepspeed.initialize(model=unet, config_params=ds_config)

        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(unet):
                text_encoder_optimizer.zero_grad()
                unet_optimizer.zero_grad()
                # Convert images to latent space
                latents = vae.encode(batch["pixel_values"].to(accelerator.device, dtype=weight_dtype)).latent_dist.sample()
                latents = latents * vae.config.scaling_factor

                # Sample noise that we'll add to the latents
                noise = torch.randn_like(latents)
                bsz = latents.shape[0]
                # Sample a random timestep for each image
                timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
                timesteps = timesteps.long()

                # Add noise to the latents according to the noise magnitude at each timestep
                # (this is the forward diffusion process)
                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

                # Get the text embedding for conditioning
                encoder_hidden_states = text_encoder(batch["input_ids"].to(accelerator.device))[0]

                # Get the target for loss depending on the prediction type
                if noise_scheduler.config.prediction_type == "epsilon":
                    target = noise
                elif noise_scheduler.config.prediction_type == "v_prediction":
                    target = noise_scheduler.get_velocity(latents, noise, timesteps)
                else:
                    raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

                # Predict the noise residual and compute loss
                model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
                loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

                # text_encoder.backward(loss, retain_graph=True)
                loss.backward()
                # optimizer.step()
                text_encoder_optimizer.step()
                unet_optimizer.step()
                # text_encoder_optimizer.step()
                # Gather the losses across all processes for logging (if we use distributed training).
                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
                train_loss += avg_loss.item() / args.gradient_accumulation_steps

the error is :

File "/mnt/dolphinfs/hdd_pool/docker/user/abc/src/diffuser/train_.py", line 582, in train unet_optimizer.step() File "/usr/local/conda/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn ret_val = func(*args, **kwargs) File "/usr/local/conda/lib/python3.9/site-packages/deepspeed/runtime/zero/stage3.py", line 1752, in step norm_groups = self._get_norm_groups() File "/usr/local/conda/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn ret_val = func(*args, **kwargs) File "/usr/local/conda/lib/python3.9/site-packages/deepspeed/runtime/zero/stage3.py", line 1568, in _get_norm_groups norm_groups.append(self.get_grad_norm_direct(self.averaged_gradients[i], self.fp16_groups[i])) KeyError: 0

Expected behavior enable traning multiple model with deepspeed

uygnef avatar May 06 '23 09:05 uygnef

The stage 1 was successful using this method, but unfortunately stages 2 and 3 were not successful.

    for epoch in range(first_epoch, args.num_train_epochs):
        unet.train()
        text_encoder.train()
        train_loss = 0.0
        logger.info(f"train dataset shuffle seed: {epoch}")
        train_dataset.shuffle(seed=epoch)
        for step, batch in enumerate(train_dataloader):
            # Skip steps until we reach the resumed step
            # if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
            #     if step % args.gradient_accumulation_steps == 0:
            #         progress_bar.update(1)
            #     continue

            text_encoder.zero_grad()
            unet.zero_grad()
            # Convert images to latent space
            latents = vae.encode(batch["pixel_values"].to(accelerator.device, dtype=weight_dtype)).latent_dist.sample()
            latents = latents * vae.config.scaling_factor

            # Sample noise that we'll add to the latents
            noise = torch.randn_like(latents)
            bsz = latents.shape[0]
            # Sample a random timestep for each image
            timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
            timesteps = timesteps.long()

            # Add noise to the latents according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            # Get the text embedding for conditioning
            encoder_hidden_states = text_encoder(batch["input_ids"].to(accelerator.device))[0]

            # Get the target for loss depending on the prediction type
            if noise_scheduler.config.prediction_type == "epsilon":
                target = noise
            elif noise_scheduler.config.prediction_type == "v_prediction":
                target = noise_scheduler.get_velocity(latents, noise, timesteps)
            else:
                raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

            # Predict the noise residual and compute loss
            model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
            loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

            text_encoder.backward(loss, retain_graph=True)
            unet.backward(loss)
            for n, lp in unet.named_parameters():
                hp_grad = safe_get_full_grad(lp)
                print(n, hp_grad)
            text_encoder.step()
            unet.step()
            # Gather the losses across all processes for logging (if we use distributed training).
            avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
            train_loss += avg_loss.item() / args.gradient_accumulation_steps

error is

    text_encoder.backward(loss, retain_graph=True)
  File "/home/hadoop-hmart-waimai-rank/.local/lib/python3.9/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/home/hadoop-hmart-waimai-rank/.local/lib/python3.9/site-packages/deepspeed/runtime/engine.py", line 1845, in backward
    self.optimizer.backward(loss, retain_graph=retain_graph)
  File "/home/hadoop-hmart-waimai-rank/.local/lib/python3.9/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1901, in backward
    self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
  File "/home/hadoop-hmart-waimai-rank/.local/lib/python3.9/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 62, in backward
    scaled_loss.backward(retain_graph=retain_graph)
  File "/home/hadoop-hmart-waimai-rank/.local/lib/python3.9/site-packages/torch/_tensor.py", line 488, in backward
    torch.autograd.backward(
  File "/home/hadoop-hmart-waimai-rank/.local/lib/python3.9/site-packages/torch/autograd/__init__.py", line 204, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/hadoop-hmart-waimai-rank/.local/lib/python3.9/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 810, in reduce_partition_and_remove_grads
    self.reduce_ready_partitions_and_remove_grads(param, i)
  File "/home/hadoop-hmart-waimai-rank/.local/lib/python3.9/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 1258, in reduce_ready_partitions_and_remove_grads
    self.reduce_independent_p_g_buckets_and_remove_grads(param, i)
  File "/home/hadoop-hmart-waimai-rank/.local/lib/python3.9/site-packages/deepspeed/runtime/zero/stage_1_and_2.py", line 853, in reduce_independent_p_g_buckets_and_remove_grads
    new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(0, self.elements_in_ipg_bucket, param.numel())
AttributeError: 'DeepSpeedZeroOptimizer' object has no attribute 'ipg_index'

uygnef avatar May 06 '23 09:05 uygnef

@tjruwase Hi, could you please help me to identify the issue?

uygnef avatar May 08 '23 06:05 uygnef

@uygnef, can you please share the full code and command line to help repro? Thanks!

tjruwase avatar May 08 '23 18:05 tjruwase

@tjruwase Hi Tjruwase, I replaced my own model with the Hugging Face model, but all other parts of the code are the same. You can run this to reproduce the error. Thank you very much for your help.

git clone https://github.com/uygnef/deepspeed_test.git
cd deepspeed_test
sh run.sh

uygnef avatar May 09 '23 08:05 uygnef

@uygnef, thanks for sharing this. Unfortunately, I got the following error from here:

Traceback (most recent call last):
  File "main.py", line 112, in <module>
    main()
  File "main.py", line 106, in main
    train(args)
  File "/data/users/olruwase/deepspeed/repro/issue_3472/deepspeed_test/train_text_to_image.py", line 305, in train
    dataset = datasets.load_from_disk(
  File "/opt/conda/lib/python3.8/site-packages/datasets/load.py", line 1886, in load_from_disk
    raise FileNotFoundError(f"Directory {dataset_path} not found")
FileNotFoundError: Directory lambdalabs/pokemon-blip-captions not found

When I replaced with corresponding HF code, I get the following error:

Traceback (most recent call last):
  File "main.py", line 112, in <module>
    main()
  File "main.py", line 106, in main
    train(args)
  File "/data/users/olruwase/deepspeed/repro/issue_3472/deepspeed_test/train_text_to_image.py", line 310, in train
    dataset = load_dataset(
  File "/opt/conda/lib/python3.8/site-packages/datasets/load.py", line 1773, in load_dataset
    builder_instance = load_dataset_builder(
  File "/opt/conda/lib/python3.8/site-packages/datasets/load.py", line 1502, in load_dataset_builder
    dataset_module = dataset_module_factory(
  File "/opt/conda/lib/python3.8/site-packages/datasets/load.py", line 1219, in dataset_module_factory
    raise e1 from None
  File "/opt/conda/lib/python3.8/site-packages/datasets/load.py", line 1186, in dataset_module_factory
    raise e
  File "/opt/conda/lib/python3.8/site-packages/datasets/load.py", line 1160, in dataset_module_factory
    dataset_info = hf_api.dataset_info(
  File "/opt/conda/lib/python3.8/site-packages/huggingface_hub/utils/_validators.py", line 120, in _inner_fn
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/huggingface_hub/hf_api.py", line 1667, in dataset_info
    hf_raise_for_status(r)
  File "/opt/conda/lib/python3.8/site-packages/huggingface_hub/utils/_errors.py", line 301, in hf_raise_for_status
    raise HfHubHTTPError(str(e), response=response) from e
huggingface_hub.utils._errors.HfHubHTTPError: 504 Server Error: Gateway Time-out for url: https://huggingface.co/api/datasets/lambdalabs/pokemon-blip-captions

Any thoughts?

tjruwase avatar May 09 '23 13:05 tjruwase

@tjruwase Your code changes are correct. Can your computer connect to the Internet? I think it's a network problem. Is it the same error after trying a few more times?

uygnef avatar May 09 '23 14:05 uygnef

@uygnef, you are correct the problem was due to network and now resolved. Can you confirm your GPU memory size? I am getting OOM on my V100-32GB. I hope reducing batch size will not affect reproducibility.

tjruwase avatar May 09 '23 14:05 tjruwase

my device is 8*a100 80g. yeah,I think small batch size does not matter

uygnef avatar May 09 '23 14:05 uygnef

@uygnef, you are correct, batch size does not matter. I have repro'd locally. Will update asap.

tjruwase avatar May 09 '23 17:05 tjruwase

@tjruwase hi tjruwase, I was wondering if there has been any progress made on the issue? Thank you for your time and assistance.

uygnef avatar Jun 01 '23 02:06 uygnef

@uygnef, apologies for the delay on this.

The fundamental problem is that the code breaks the gradient partitioning assumptions of zero stage 2/3. In these stages, gradients are partitioned on-the-fly during creation. The assumption is that gradient creation and partitioning of a model is triggered by the backward of the wrapping engine. However, in this case where we have two models (text_encoder and unet) with separate engines, the loss is computed based on both models forward passes. And therefore, the gradients of both models will be created by loss.backward() rather than their respective engine backward. Note that engine.backward() of one model will trigger backward and gradient creation of the other model. My investigation so far suggests that supporting this behavior in zero stage 2/3 is non-trivial effort.

My colleagues suggested some alternatives, hopefully one is suitable for your use case:

  1. Use zero stage 2 only for the memory intensive model, and stage 1 for the other.
  2. Fuse the text-encoder and unet strucutre, and then move deepspeed_test/train_text_to_image.py at 440a8ad8d9678a8f31e2265d94feb7abe231442e · uygnef/deepspeed_test (github.com) to the forward function.

FYI, @minjiaz and @yaozhewei

tjruwase avatar Jun 02 '23 19:06 tjruwase

Closing for now after prescribing workarounds.

tjruwase avatar Aug 10 '23 10:08 tjruwase

any idea about "AttributeError: 'DeepSpeedZeroOptimizer' object has no attribute 'ipg_index'"?

catqaq avatar Sep 18 '23 13:09 catqaq

@catqaq, please see explanation https://github.com/microsoft/DeepSpeed/issues/3472#issuecomment-1574202568

tjruwase avatar Sep 19 '23 13:09 tjruwase

is there any reason I would see this if training a single model? And only occuring with fp16, bf16 and fp32 do not result in this error

ethansmith2000 avatar Jan 07 '24 11:01 ethansmith2000