Why multiply by the amount of gradient accumulation steps?
The AdditionalState consists of per-metric lists which need to be summarised. This makes some sense: for example, if you report metrics inside a forward() and you have N gradient accumulations, then you also get N calls to forward() before advancing the global_step and so you have N measurements to summarise.
You choose to take their mean, i.e. sum(measurements) / len(measurements). If a metric is reported once every forward() call, then you can expect len(measurements) to be the same as gradient_accumulation_steps. If it is reported less or more, then len(measurements) will be less or more than gradient_accumulation_steps, but whether it is equal to this or not is irrelevant. All that matters to take an average is how many measurements there are.
Yet, strangely, the summarisation code does use gradient_accumulation_steps twice. Once to divide values by it:
https://github.com/zipzou/hf-multitask-trainer/blob/8c40c2615d91816ae02a18fbf133ba5107a3f9d4/hf_mtask_trainer/state.py#L63-L65
...and then, later on, to multiply values by it:
https://github.com/zipzou/hf-multitask-trainer/blob/8c40c2615d91816ae02a18fbf133ba5107a3f9d4/hf_mtask_trainer/state.py#L73-L76
Why is this? It seems like an identity transform to me. The only value types for which the division doesn't happen is anything that isn't an int, float, tensor or array:
https://github.com/zipzou/hf-multitask-trainer/blob/8c40c2615d91816ae02a18fbf133ba5107a3f9d4/hf_mtask_trainer/state.py#L69-L70
I don't really know what kinds of values that would be (remember that they still need to be summable...), but assuming they exist, they would be multiplied by gradient_accumulation_steps without being divided. That is: if we report exactly len(v) == 1 value (call it X) in the model whilst doing N > 1 accumulations, then my dashboard would show the value N*X. I don't see any application for this.
During each forward pass in HF’s Trainer, the global loss is divided by accumulation_steps to ensure that the accumulated gradients are averaged when step() called. Most of the time, the value of len(v) and accumulation_steps are equal, but there is a special case: the number of forward passes in the last epoch would be smaller than accumulation_steps, for example, if the last epoch only runs 3 steps and there is no more data from dataloader, but accumulation_steps is set to 4 and the loss is divided by 4, then the summed loss, i.e. sum(v), should be multiplied by 4/3. Otherwise, the reported loss will be underestimated.
I think you may have misunderstood where the identity transform takes place. Like I pointed out in my intro, the value for gradient_accumulation_steps is irrelevant: you have a list of values and you need to take an average. If you substitute line 65 into line 74, you get
sum( v / accs ) / ( len(v) / accs )
== sum(v) / accs × accs / len(v)
== sum(v)/len(v)
I did not mean that len(v) and gradient_accumulation_steps cancel each other out. I meant that dividing by gradient_accumulation_steps and later multiplying by gradient_accumulation_steps does nothing.
During each forward pass in HF’s Trainer, the global loss is divided by
accumulation_steps
I am confused by this statement. Isn't it your code that is doing the dividing on line 65? And also, isn't all of this for reporting metrics that are specifically not the final loss, which HF already tracks?
then the summed loss, i.e. sum(v), should be multiplied by 4/3.
I understand the reasoning, but (1) you won't be reporting global loss with this module, and (2) even if you did call report_metrics(loss) where loss is downscaled by the usual batch size, your code will take an average of those downscaled values.
Let's say the unnormalised losses for 4 accumulations are L1, L2, L3 and L4. Let's say the effective batch size is 1000 (i.e. there should be 1000 examples viewed across the 4 accumulations), then HF would produce v = L1/1000, L2/1000, L3/1000 and L4/1000. The goal would be to get (L1 + L2 + L3 + L4)/N where N ≤ 1000 is the actual batch size, so you want sum(v) × 1000/N. Say N is a perfect multiple of 250 like 750, then you want sum(v) × 4/3. Yet, your code instead computes sum(v / 4) / (3/4) == sum(v) / 4 × (4/3) == sum(v) / 3 == (L1/1000 + L2/1000 + L3/1000)/3 != (L1 + L2 + L3)/750.
I understand what you mean now, and you are right that there is an unnecessary division to calculate the average loss. However, I think the current code still meets my expectations, despite the unnecessary division .
Let's say the unnormalised losses for 4 accumulations are L1, L2, L3 and L4. Let's say the effective batch size is 1000 (i.e. there should be 1000 examples viewed across the 4 accumulations), then HF would produce v = L1/1000, L2/1000, L3/1000 and L4/1000. The goal would be to get (L1 + L2 + L3 + L4)/N where N ≤ 1000 is the actual batch size, so you want sum(v) × 1000/N. Say N is a perfect multiple of 250 like 750, then you want sum(v) × 4/3. Yet, your code instead computes sum(v / 4) / (3/4) == sum(v) / 4 × (4/3) == sum(v) / 3 == (L1/1000 + L2/1000 + L3/1000)/3 != (L1 + L2 + L3)/750.
For this, we can redefine the losses in another clearer way. We set the global batch size to 64, and gradient_accumulation_steps is 4. Then the batch size per step is 16. So we will get 4 loss values, i.e. v=[l1, l2, l3, l4], and l1 = L1 / 16 / 4(gradient_accumulation_steps=4), l2 = L2 / 16 / 4, and so on.
L1, L2, L3, L4 are summed losses in one step, and l1, l2, l3, l4 are average losses per step(16 examples) and scaled down by accumulation_steps
We expect to calculate the mean losses based on v, then we compute v_ = sum(v) / (len(v) / 4) as code. If the dataloader provides 64 examples in this case, so the v_ = (l1 + l2 + l3 + l4) / (4/4) = 1/64 * (L1 + L2 + L3 + L4), as expected. Now, assume the dataloader provides only 48 examples, so the v_ = (l1 + l2 + l3) / (3 / 4) = (l1 + l2 + l3) * (4/3) = (4/3) * (1/64) * (L1 + L2 + L3) = 1/48 * (L1 + L2 + L3), as expected.
So we will get 4 loss values, i.e. v=[l1, l2, l3, l4], and l1 = L1 / 16 / 4 (gradient_accumulation_steps=4), l2 = L2 / 16 / 4, and so on.
Again, the example is irrelevant since HfMultitaskTrainer doesn't report the global loss produced by Trainer, but if what you mean is this line in the Trainer, using that value for the loss means you are dividing once too many.
In your example, l1 ... l4 would be what HF gives you, already divided by gradient_accumulation_steps. You then divide it again by that number, and then you multiply by the number you divide by, cancelling out your division and leaving in place HF's division.
I don't disagree that you need correction at the end of training for an incomplete batch, but (1) for metrics that aren't the loss, the average fixes this -- the correction you desire is exactly the dynamic value given by len(v) -- and (2) for HF's loss, which is irrelevant to report_metrics, you're also computing an average and that would be wrong if it wasn't irrelevant anyway.
What also confused me is that my two loss components add up to n and the total loss to something much bigger (e.g. n*4 when using 4 accumulation steps). - It took me a bit to figure out where the missmatch is coming from.