accelerate icon indicating copy to clipboard operation
accelerate copied to clipboard

“Stop Halving My Batch!” · Default back-off 0.5 → 0.9

Open AmitMY opened this issue 5 months ago • 9 comments

What does this PR do?

1 / Problem

  • find_executable_batch_size() currently halves the batch after every OOM. Example: request 128, real limit 100 → old loop jumps straight to 64 -- 36 % throughput lost.

  • On heterogeneous clusters (T4 ↔ V100 ↔ A100 ↔ H100) users can’t predict the limit in advance.

  • The escape hatch, reduce_batch_size_fn, is practically unusable:

    • Not exposed by 🤗 Transformers Trainer, so you can’t pass a custom function.
    • Signature is fn() → int — it receives no context (not even the failing batch size), so writing a smart policy is impossible.

2 / Fix (this PR)

  • Replace the ×0.5 fallback with a gentler ×0.9 loop. Keep shaving 10% until it fits.
  • No public-API change; only an internal constant and the log string change.
GPU VRAM Requested BS Real limit (fits) Old algo (×0.5) New algo (×0.9) Waste ↓
H100 80 GB 128 100 64 92 28 pp
A100 40 GB 128 50 32 46 28 pp
V100 32 GB 128 40 32 36 10 pp
T4 16 GB 128 20 16 19 15 pp

3 / Why not rely on reduce_batch_size_fn?

Issue Impact
Invisible Transformers never forwards a user-supplied callable.
Blind Callable gets zero information about what just failed.

4 / Future hook proposal

@accelerate.register_callable("reduce_batch_size_fn")
def my_decay(prev_bs: int, oom_exc: BaseException) -> int:
    return max(1, int(prev_bs * 0.95))
  • Gives the function the actual failing batch size + any OOM detail.
  • Lets libraries or users plug in binary-search, throughput models, etc., without subclassing Accelerator.

5 / Risk

  • Worst-case a handful of extra trials (+≈4) before convergence — negligible in a training scenario.
  • Behaviour when even bs=1 OOMs is unchanged (fast failure).

6 / Bottom line

Changing one constant rescues 15–45 % of wasted GPU capacity for heterogeneous clusters users today, while paving the way for a clean plug-in story tomorrow.

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [x] Did you read the contributor guideline, Pull Request section?
  • [ ] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
  • [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • [ ] Did you write any new necessary tests?

Who can review?

  • Core parts of the library: @BenjaminBossan @SunMarc @zach-huggingface

AmitMY avatar Jun 22 '25 13:06 AmitMY

Indeed that might be a better default even it might take longer to find the wanted batch size. It would be best also to allow users to pass their own function either with the hook you proposed or fix reduce_batch_size_fn in the future.

SunMarc avatar Jun 23 '25 10:06 SunMarc

This sounds okay, however don't we want to enforce some limits on the batch size? I.e. be a power of 2, resp. be divisible by 2. Having random batch size isn't very good from a performance standpoint and GPUs in general might struggle with this on some tasks.

True I had that concern as well but looking at different posts, i'm not sure if this is always is the case. Also, the user can also specify a batch size that is not a power of 2 even if this is not generally the case. https://wandb.ai/datenzauberai/Batch-Size-Testing/reports/Do-Batch-Sizes-Actually-Need-To-Be-Powers-of-2---VmlldzoyMDkwNDQx#:~:text=Some%20kinds%20of%20hardware%20achieve,by%20one%20order%20of%20magnitude.

SunMarc avatar Jun 23 '25 13:06 SunMarc

I'm on board with the discussion, and agree that further considerations might be necessary. However, I believe this function is meant to prevent OOMs, not to find the fastest training strategy.

I will be all for:

  1. find_executable_batch_size changing the default strategy + allowing the user to register a custom backoff
  2. a new find_fast_batch_size that could look at hardware specific heuristics and adjust the batch size accordingly, or actually test different batch sizes and based on on-the-fly benchmarking decide on a batch size for the entire run (increase batch size, change to multiple of 8 on A100 etc...)

As they say in that wandb blog:

Measuring the actual effect on training speed, accuracy and memory consumption when choosing a batch size should be preferred instead of focusing on powers of 2.

I believe that both could speed up initial model experimentation, but that probably for a real production training run, both should be avoided.

AmitMY avatar Jun 23 '25 14:06 AmitMY

I'm on board with the discussion, and agree that further considerations might be necessary. However, I believe this function is meant to prevent OOMs, not to find the fastest training strategy.

I will be all for:

  1. find_executable_batch_size changing the default strategy + allowing the user to register a custom backoff
  2. a new find_fast_batch_size that could look at hardware specific heuristics and adjust the batch size accordingly, or actually test different batch sizes and based on on-the-fly benchmarking decide on a batch size for the entire run (increase batch size, change to multiple of 8 on A100 etc...)

As they say in that wandb blog:

Measuring the actual effect on training speed, accuracy and memory consumption when choosing a batch size should be preferred instead of focusing on powers of 2.

I believe that both could speed up initial model experimentation, but that probably for a real production training run, both should be avoided.

I do agree with what you say, however I think main purpose of this, is to have a sensible default. Though given some research I've done and link Marc sent, I haven't found strong evidence for power-of-2/even batch size. With this, I think it's okay to merge and keep this as you suggest. Would you also be interested in creating the mechanism for own heuristics/functions since you started this already?

S1ro1 avatar Jun 23 '25 14:06 S1ro1

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Can you fix the tests also since the default changed ? @AmitMY

SunMarc avatar Jun 23 '25 14:06 SunMarc

Fixed the tests that seem relevant. One test failed because of an old torch version it seems, and works on my machine. Can you approve again so we check it passes on github actions?

AmitMY avatar Jun 23 '25 15:06 AmitMY

@bot /style

S1ro1 avatar Jun 23 '25 16:06 S1ro1

Style fixes have been applied. View the workflow run here.

github-actions[bot] avatar Jun 23 '25 16:06 github-actions[bot]

I'll recreate the PR as the github ci is broken on this PR

SunMarc avatar Jul 16 '25 10:07 SunMarc