tutorials icon indicating copy to clipboard operation
tutorials copied to clipboard

Batch Size Issue in Maissi Generative Model Configuration

Open gexinh opened this issue 1 year ago • 2 comments

Dear Dong Yang (@dongyang0122),

I hope this message finds you well. Thank you in advance for your time and support.

I am currently working with the Maissi generative model and planning to accelerate the training process by increasing the batch size. However, I encountered an issue where, despite modifying the batch size in the configuration file, the DataLoader batch size remains set to 1.

Could you kindly advise on how to resolve this issue?

The log file is as bellow:

image wherein the log is recorded base on the code:

if local_rank == 0:
            logger.info(
                "[{0}] epoch {1}, iter {2}/{3}, loss: {4:.4f}, lr: {5:.12f}.".format(
                    str(datetime.now())[:19], epoch + 1, _iter, len(train_loader), loss.item(), current_lr
                )
            )

Note that the number of itereation is equal to the length of train_loader and the number of training set is 1000. In my understanding, the enlarged batch size should decrease the length of train_loader. However, the length of train_loader is still equal to 1000 (the number of training set), which seems that the batch size is 1.

Additionaly, the corresponding code for data loader is in the scripts.diff_model_train.py:

  def prepare_data(
      train_files: list, device: torch.device, cache_rate: float, num_workers: int = 2, batch_size: int = 1
  ) -> ThreadDataLoader:
      """
      Prepare training data.
  
      Args:
          train_files (list): List of training files.
          device (torch.device): Device to use for training.
          cache_rate (float): Cache rate for dataset.
          num_workers (int): Number of workers for data loading.
          batch_size (int): Mini-batch size.
  
      Returns:
          ThreadDataLoader: Data loader for training.
      """
      train_transforms = Compose(
          [
              monai.transforms.LoadImaged(keys=["image"]),
              monai.transforms.EnsureChannelFirstd(keys=["image"]),
              monai.transforms.Lambdad(
                  keys="top_region_index", func=lambda x: torch.FloatTensor(json.load(open(x))["top_region_index"])
              ),
              monai.transforms.Lambdad(
                  keys="bottom_region_index", func=lambda x: torch.FloatTensor(json.load(open(x))["bottom_region_index"])
              ),
              monai.transforms.Lambdad(keys="spacing", func=lambda x: torch.FloatTensor(json.load(open(x))["spacing"])),
              monai.transforms.Lambdad(keys="top_region_index", func=lambda x: x * 1e2),
              monai.transforms.Lambdad(keys="bottom_region_index", func=lambda x: x * 1e2),
              monai.transforms.Lambdad(keys="spacing", func=lambda x: x * 1e2),
          ]
      )
  
      train_ds = monai.data.CacheDataset(
          data=train_files, transform=train_transforms, cache_rate=cache_rate, num_workers=num_workers
      )
      return ThreadDataLoader(train_ds, num_workers=6, batch_size=batch_size, shuffle=True) 

gexinh avatar Oct 06 '24 06:10 gexinh

I found the problem:

In the scripts.diff_model_train.py, Line 51, the definition is as follows:

def prepare_data(
    train_files: list, device: torch.device, cache_rate: float, num_workers: int = 2, batch_size: int = 1
) -> ThreadDataLoader:

However, on line 359, the function call is missing the 'number_workers' argument, causing the 'batch size' parameter to be incorrectly used for 'number_workers'.

    train_loader = prepare_data(
        train_files, device, args.diffusion_unet_train["cache_rate"], args.diffusion_unet_train["batch_size"]
    )

gexinh avatar Oct 09 '24 11:10 gexinh

Hi @gexinh, thanks for the reporting, fixed it in this PR: https://github.com/Project-MONAI/tutorials/pull/1857.

KumoLiu avatar Oct 09 '24 16:10 KumoLiu