baal
baal copied to clipboard
Baal seems to ignore eval_batch_size causing gpu memory issues
Describe the bug
When setting the batch_size
to 2 in BAAL, it appears to be using a batch_size
of 16 instead. This is causing a CUDA out of memory error. Despite setting per_device_eval_batch_size
and train_batch_size
to 2 in TrainingArguments
, the predict_on_dataset
function seems to be using a batch_size
of 16, I am letting BAAL sort 1e6 (1 million) examples, when i run the predict_on_dataset function i see the following in the logs:
0%| | 0/62500 [00:00<?, ?it/s] 0%| | 0/62500 [00:01<?, ?it/s]
meaning it is using a batch_size of 16, instead of the specified 2. A batch size of 8 would also work (if i manually downsample the input dataframe to be 8 inputs).
To Reproduce
model = patch_module(model)
from transformers import TrainingArguments
args = TrainingArguments(output_dir="/", per_device_eval_batch_size=2)
args = args.set_dataloader(
train_batch_size=2, eval_batch_size=2, auto_find_batch_size=False
)
trainer = BaalTransformersTrainer(
model=model,
args=args,
)
dataset = Dataset.from_pandas(tokenized_X)
predictions = trainer.predict_on_dataset(dataset, iterations=30)
which gives:
"CUDA out of memory. Tried to allocate 1.41 GiB (GPU 0; 15.78 GiB total capacity; 14.49 GiB already allocated; 397.75 MiB free; 14.53 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"
Expected behavior The predict_on_dataset function should respect the batch_size specified in TrainingArguments and not cause a CUDA out of memory error.
Version (please complete the following information):
- OS: ubuntu 22
- Python 3.10.4:
- Baal:
version : 1.9.1
description : Library to enable Bayesian active learning in your research or labeling work.
dependencies
- h5py >=3.4.0,<4.0.0
- matplotlib >=3.4.3,<4.0.0
- numpy >=1.21.2,<2.0.0
- Pillow >=6.2.0
- scikit-learn >=1.0.0,<2.0.0
- scipy >=1.7.1,<2.0.0
- structlog >=21.1.0,<22.0.0
- torch >=1.6.0
- torchmetrics >=0.9.3,<0.10.0
- tqdm >=4.62.2,<5.0.0
Additional context I am running this on AWS batch on a p3 instance.