Megatron-DeepSpeed icon indicating copy to clipboard operation
Megatron-DeepSpeed copied to clipboard

Add Bitfit

Open Muennighoff opened this issue 1 year ago • 0 comments

This PR adds compatibility for BitFit. I'd like to try BitFit + MTF to retain Multilinguality. Empirical evidence from this paper:

Screenshot 2022-07-10 at 19 59 16

Note that adapters also add parameters to the model & increase complexity at inference in Transformers, so BF is the best option imo. Also see this paper though they don't try BitFit.

Automatic Tests: Happy to add one if we decide to merge this 🤗

Manual Tests:

  • 1 Nodes, PP=2, TP=2
  • 2 Nodes, PP=2, TP=2

The below shows how the grad norm decreases as it should, because we have less gradients. I would also expect time to decrease due to less communication, but probably only at more nodes. Memory usage also decreases due to less optimizer states to store.

With BitFit, 2 Nodes, PP=2, TP=2

[default3]: iteration        2/  868457 | consumed samples:          384 | consumed tokens:       786432 | elapsed time per iteration (s): 12.86 | learning rate: 6.291E-07 | global batch size:   192 | lm loss: 1.244176E+01 | loss scale: 4096.0 | grad norm: 0.065 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 | samples per second: 14.925 | TFLOPs: 13.65 |
[default3]: iteration        3/  868457 | consumed samples:          576 | consumed tokens:      1179648 | elapsed time per iteration (s): 12.62 | learning rate: 9.437E-07 | global batch size:   192 | lm loss: 1.244014E+01 | loss scale: 4096.0 | grad norm: 0.062 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 | samples per second: 15.209 | TFLOPs: 13.91 |

Without BitFit

[default3]: iteration        2/  868457 | consumed samples:          384 | consumed tokens:       786432 | elapsed time per iteration (s): 12.62 | learning rate: 6.291E-07 | global batch size:   192 | lm loss: 1.244176E+01 | loss scale: 4096.0 | grad norm: 0.291 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 | samples per second: 15.214 | TFLOPs: 13.91 |
[default3]: iteration        3/  868457 | consumed samples:          576 | consumed tokens:      1179648 | elapsed time per iteration (s): 12.63 | learning rate: 9.437E-07 | global batch size:   192 | lm loss: 1.244006E+01 | loss scale: 4096.0 | grad norm: 0.309 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 | samples per second: 15.201 | TFLOPs: 13.90 |

Muennighoff avatar Jul 10 '22 18:07 Muennighoff