baal icon indicating copy to clipboard operation
baal copied to clipboard

Baal seems to ignore eval_batch_size causing gpu memory issues

Open hugocool opened this issue 8 months ago • 6 comments

Describe the bug When setting the batch_size to 2 in BAAL, it appears to be using a batch_size of 16 instead. This is causing a CUDA out of memory error. Despite setting per_device_eval_batch_size and train_batch_size to 2 in TrainingArguments, the predict_on_dataset function seems to be using a batch_size of 16, I am letting BAAL sort 1e6 (1 million) examples, when i run the predict_on_dataset function i see the following in the logs:

0%| | 0/62500 [00:00<?, ?it/s] 0%| | 0/62500 [00:01<?, ?it/s]

meaning it is using a batch_size of 16, instead of the specified 2. A batch size of 8 would also work (if i manually downsample the input dataframe to be 8 inputs).

To Reproduce

    model = patch_module(model)
    from transformers import TrainingArguments

    args = TrainingArguments(output_dir="/", per_device_eval_batch_size=2)
    args = args.set_dataloader(
        train_batch_size=2, eval_batch_size=2, auto_find_batch_size=False
    )

    trainer = BaalTransformersTrainer(
        model=model,
        args=args,
    )

    dataset = Dataset.from_pandas(tokenized_X)
    predictions = trainer.predict_on_dataset(dataset, iterations=30)

which gives:

"CUDA out of memory. Tried to allocate 1.41 GiB (GPU 0; 15.78 GiB total capacity; 14.49 GiB already allocated; 397.75 MiB free; 14.53 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"

Expected behavior The predict_on_dataset function should respect the batch_size specified in TrainingArguments and not cause a CUDA out of memory error.

Version (please complete the following information):

  • OS: ubuntu 22
  • Python 3.10.4:
  • Baal:
    version : 1.9.1
    description : Library to enable Bayesian active learning in your research or labeling work.

dependencies

  • h5py >=3.4.0,<4.0.0
  • matplotlib >=3.4.3,<4.0.0
  • numpy >=1.21.2,<2.0.0
  • Pillow >=6.2.0
  • scikit-learn >=1.0.0,<2.0.0
  • scipy >=1.7.1,<2.0.0
  • structlog >=21.1.0,<22.0.0
  • torch >=1.6.0
  • torchmetrics >=0.9.3,<0.10.0
  • tqdm >=4.62.2,<5.0.0

Additional context I am running this on AWS batch on a p3 instance.

hugocool avatar Oct 20 '23 13:10 hugocool