DeepSpeed icon indicating copy to clipboard operation
DeepSpeed copied to clipboard

Variable batch size and LR scheduler

Open bm-synth opened this issue 1 year ago • 6 comments

Background and rationale

In many use cases, particularly LLMs, one is faced with inputs (sentences) of variable lengths. A common practice is to pack batches by token count (not a fixed batch size), ie by putting together sentences whose given metric (eg sequence lengths) will add up to an user-provided value. As an example, in Attention is all you need, section 5.1:

Sentence pairs were batched together by approximate sequence length. Each training batch contained a set of sentence pairs containing approximately 25000 source tokens and 25000 target tokens.

Dynamic batch sizes has been requested in DeepSpeed issue 1051, DeepSpeed issue 3455 , Pytorch Lightning issue 16914, huggingface issue 2647 and is available already in many libraries e.g. NVIDIA Triton and Meta FairSeq (implementation here ).

The immediate use case for this is when one needs to maximize GPU utilization. Moreover, this is particularly relevant for curriculum learning where a BxTxE (Batch x Time x Embedding) -shaped input should ideally have high B and low T at the early curriculum steps (many short sentences packed together as a batch), and low B and high T at the late steps (few long sentences in the batch). A dynamic size T is already supported by Deepspeed, e.g. in the documentation for pipeline parallelism's reset_activation_shape():

For curriculum learning that changes the seqlen of each sample, we need to call this whenever the seqlen is going to change.

However, dynamic B is not supported. A dynamic B would require an adequate increase/decrease of learning rate. This technique has been applied previously, and the two most common LR scaling algorithms have been described as:

  1. Linear Scaling Rule: "When the minibatch size is multiplied by k, multiply the learning rate by k", as in Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour, Goyal et al.
  2. Square Root scaling: "when multiplying the batch size by k, multiply the learning rate by √k, to keep the variance in the gradient expectation constant" by One weird trick for parallelizing convolutional neural networks, A. Krizhevsky et al.

In practice, the user picks the total token count per batch as the metric that drives batching, instead of batching by sentence count. During runtime, the variable batch size is computed and the LR is adjusted respectively, based on the LR and batch size provided by the config.

Illustration of dynamic batch size, sequence length and LR

Imagine we picked a limit of 30 tokens per batch, and have set a reference lr=1e-3 for a train_batch_size=2 (in the deepspeed config). The batching algorithm for curriculum may pack the data into batches of short sentences (left) at the early stages, and batches of long sentences (right) as later stages, e.g.:

dynamic_batch_size_and_lr

Above, we collected samples until we filled up the batch with at most 30 tokens. The batch sizes (number of samples) became then 10 and 4 on the left and right examples, respectively. Using the linear scaling rule, the LR for those batches become 5e-3 and 2e-3.

Pipeline parallelism

Pipeline parallelism requires the same batch size and same sequence length across all micro-batches in a batch, as the activation sizes must be fixed between gradient accumulation steps. Between batches, these may change, and long as engine.reset_activation_shape() is called so that the new shapes are communicated on the first gradient accumulation step in the batch. Enforcing similar BxTxE between batches may lead to smaller micro-batches. As an example, below we can see an illustration of a 2-node 2-gradient-accumulation-step (ie 4 micro-batches) batching for the same dataset, when preparing data for the regular DDP (left) and for the pipeline parallelism use cases (right):

dynamic_batch_size_and_lr_microbatching

We can see that the pipeline use case (right) has the same BxTxE shape across all the 4 micro-batches in the same batch, and in order to respect that, it packs less samples in the batch, when compared to the standard use case (left hand size)

Attention Head

For an input of size BxTxE the attention has a shape of TxT for a mask of fixed size across samples of same size, or BxTxT for a different mask per sample (when samples have different sizes, as in the dataset above). This 3D attention matrix can be illustrated for the DDP microbatch 1 (picture above top-left, 4 sentences) as:

dynamic_batch_size_and_lr_attn_matrix

Note the memory savings: the attention head has a size of BxTxT, i.e. a linear memory dependency on the batch size B and quadratic memory dependency on the largest sequence length T in the (micro-) batch. Thus, supporting a dynamic size T allows for an increase of B.

PR overview

This PRs implements dynamic batching and LR scaling. The dataloader and LR scheduler necessary can be retrieved by calling get_dataloader_and_lr_scheduler_for_variable_batch_size. A small explanation of that function follows:

  • The logic behind the algorithms for LR scaling is in scale_lr;
  • The partitioning of samples into batches is done by batch_by_size.
  • For pipeline parallelism, it is required that all micro-batches in a pipeline pass to have the same activation shapes. This is enabled by setting to True the following parameters:
    • required_microbatches_of_same_sizes that will force the B dimension to be the same across all gradient accumulation steps of all dataloaders on a batch;
    • required_microbatches_of_same_lengths that will force the T dimension to be the same across all gradient accumulation steps. Works by calling the user-provided sample_padding_fn(sentence, len) that pads a given sentence to the argument length;
    • batch_by_size returns microbatch_sample_ids (the list of sample ids per micro-batch), batch_sizes (the size of effective batch sizes, and batch_max_seqlens (longest sequence across all microbatches in a batch)
  • dataloader_for_variable_batch_size relies on microbatch_sample_ids and will iterate/collate/pad samples for every batch and return a dataloader that iterates the final (variable-size) batches;
  • lr_scheduler_for_variable_batch_size relies on batch_sizes to compute the learning rate for each effective batch, taking into account the batch size and LR in the config file, and scaling the LR based on the size of each effective batch, and the scaling rule mentioned above (Linear, Square root, etc).
    • Special note to the lr_scheduler returned that will either accept either:
      1. an user-provided Optimizer that will scale the learning rates (in param groups) at every batch, or
      2. an user-defined LRScheduler, that in this case will first get the learning rate from the scheduler and then scale it accordingly.

Example

An example for the use case with and without pipelining is provided in deepspeed/runtime/data_pipeline/data_sampling/variable_batch_size_and_lr_example.py. The example shows an attention head with attention of variable-sized BxTxT per batch, followed by a fixed size feed forward network. These are the main blocks on a Large Language Model. The feed-forward (or linear layer) that follows the attention head requires a constant input size, equivalent to the largest sentence in the whole dataset, so the output of the attention must be padded (see feedforward: needs to convert BxTxE to BxMxE by padding extra tokens in the code).

bm-synth avatar Mar 08 '24 00:03 bm-synth

@bm-synth This is a completely new technique. For this kind of contribution, based on the guideline (https://github.com/microsoft/DeepSpeed/blob/master/CONTRIBUTING.md#new-feature-contribution-guidelines) we would need to first judge the value based on some formal evaluation (ideally an arxiv paper).

conglongli avatar Mar 09 '24 22:03 conglongli

@conglongli a question related to LR scheduling: the LR scheduler documentation says:

if the schedule is supposed to execute at every training step, then the user can pass the scheduler to deepspeed.initialize when initializing the DeepSpeed engine and let DeepSpeed manage it for update or save/restore.

if the schedule is supposed to execute at any other interval (e.g., training epochs), then the user should NOT pass the scheduler to DeepSpeed during initialization and must manage it explicitly.

However this is not the case anymore I believe. If we look at deepspeed/runtime/engine.py function step(self, lr_kwargs=None), then we see:

        # Update the model when we reach gradient accumulation boundaries
        if self.is_gradient_accumulation_boundary():
            [...]
            if (self.eigenvalue_enabled() and not self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution()
                    and self.quantizer.any_precision_switch()):
                self._take_model_step(lr_kwargs, self.block_eigenvalue)
            else:
                self._take_model_step(lr_kwargs)

so lr_scheduler.step() is not called at every iteration/step/microbatch but only at the end of every batch. Can you confirm or comment, please?

bm-synth avatar Mar 13 '24 14:03 bm-synth

@conglongli @mrwyattii I added some information to this PR in line with the new contributions page you sent. The logic for this PR is done, and the example works in the *_example.py. However, I am struggling to integrate this with your curriculum module. This curriculum allows for several metrics and I don't understand how the code iterates and groups samples at every curriculum steps when there are have multiple metrics.

  • As an example: if e.g. "seq_len" is a metric of 10 buckets (values from 0 to 1000, packing every interval of size 100), and "rarity" is a metric with 20 buckets (packing 5% at every step), how do you group the samples in the curriculum?
  • Do you have a simple example of this running? I saw the one in tests/unit/runtime/test_data_efficiency.py but is has "index_to_sample_path": "dummy" and only a single metric.

Finally, this new Data Efficiency curriculum learning is based on files that need to be output beforehand by the DataAnalyzer, and requires all nodes in the network to be able to access the same shared storage, which is not ideal. The legacy curriculum learning module did not require user-defined paths with files, and does not perform a map-reduce a la DataAnalysis:

  • where (in the code) does it get the samples ordered by ascending seqlen, as required for curriculum?
  • and will this be deprecated for good?

bm-synth avatar Mar 15 '24 17:03 bm-synth

great work and waiting for this.

npuichigo avatar Apr 10 '24 09:04 npuichigo

great work and waiting for this.

thank you @npuichigo . For the time being, you can use this as in the example (first initialize deepspeed to get the deepspeed engine, then call get_dataloader_and_lr_scheduler_for_variable_batch_size_deepspeed to get the data loader and LR scheduler). But ultimately, the goal is to have deepspeed.initialize returning the correct data loader and LR scheduler for the variable batch size use case. I'll implement this once I hear from @conglongli @mrwyattii @loadams or other dev.

bm-synth avatar Apr 10 '24 09:04 bm-synth

@conglongli @mrwyattii I added some information to this PR in line with the new contributions page you sent. The logic for this PR is done and the example works in the _example.py file works, however I am struggling to integrate this with your curriculum module. This curriculum allows for several metrics and I don't understand how the code iterates and groups samples at every curriculum steps when there are have multiple metrics.

  • As an example: if e.g. "seq_len" is a metric of 10 buckets (values from 0 to 1000, packing every interval of size 100), and "rarity" is a metric with 20 buckets (packing 5% at every step), how do you group the samples in the curriculum?
  • Do you have a simple example of this running? I saw the one in tests/unit/runtime/test_data_efficiency.py but is has "index_to_sample_path": "dummy" and only a single metric.

Finally, this new Data Efficiency curriculum learning is based on files that need to be output beforehand by the DataAnalyzer, and requires all nodes in the network to be able to access the same shared storage, which is not ideal. The legacy curriculum learning module did not required user-defined paths with files, and does not perform a map-reduce a la DataAnalysis:

  • where (in the code) does it get the samples ordered by ascending seqlen, as required for curriculum?
  • and will this be deprecated for good?

@bm-synth First of all you still didn't provide any evidence of "variable batch size and LR scheduler helps improve model quality". But anyway I understand some users just want to do it so we can accept this PR.

Regarding your question about curriculum learning: (1) Handling of multiple metrics is at https://github.com/microsoft/DeepSpeed/blob/aaaf8bc5e07535e263f83733f8905400bf6f5aca/deepspeed/runtime/data_pipeline/data_sampling/data_sampler.py#L184-L201 (2) Legacy curriculum learning does not reorder data, it works by truncating data to different lengths at user side code.

conglongli avatar Apr 18 '24 20:04 conglongli