ENH: Improve torch.compile support in MetaMath
The MetaMathQA benchmark already had support to enable torch.compile but it was not very well implemented. The new changes are:
- call compile after applying PEFT, not before
- compile with
dynamic=True - avoid
model.eval()+model.train()calls
These changes prevent graph breaks and recompiles. A context manager is now used to ensure that those don't happen.
Some unrelated changes:
- improve some type annotations
- use
dtypeargument instead of deprecatedtorch_dtype
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
Note to self: How to deal with dropout? It's not deactivated by torch.inference_mode().
After some testing: When running evaluation, we would generally want to put the model into eval mode (dropout). However, this triggers a re-compile when the model is put back into train mode (i.e. a total of two compiles happen). We could skip this train/eval toggle during training to avoid the re-compile. This would mean that the model is in train mode when evaluating, but arguably that is not a very big deal. Obviously, when it comes to the test set, we do put the model in eval mode first.
Here are the numbers for no compile, compile with train/eval switch, and compile without train/eval switch:
| metric | no compile | compile w/ train/eval switch | compile w/o switch |
|---|---|---|---|
| reserved mem max | 22.3 GB | 16.4 GB | 16.4 GB |
| reserved mem avg | 14.4 GB | 11.2 GB | 11.2 GB |
| reserved mem 99th | 20.1 GB | 14.6 GB | 14.6 GB |
| number of compiles | 0 | 2 | 1 |
| train time / step | ~29 sec | ~24 sec | ~24 sec |
| final train loss (sanity check) | 0.60717 | 0.60690 | 0.60695 |
| total time | 994 sec | 943 sec |
Validation accuracy varies a bit, but that's to be expected with a rather small validation set size and measuring on generations.
I tried a couple of mitigations, like running eval with torch.compiler.disable or compiling the eval function separately (which wouldn't really save time, as we replace a recompile with another compile), but nothing I tried helps.
@githubnemo What would you prefer: Live with the recompilation or avoid the train/eval switch?
Thanks for investigating!
If I understand correctly there's almost no time penalty for using the more correct (recompilation) variant, so I'd opt for that since dropout is only one potential candidate for train/eval mismatches.
If I understand correctly there's almost no time penalty for using the more correct (recompilation) variant, so I'd opt for that since dropout is only one potential candidate for train/eval mismatches.
Yes, we could do that. It means, however, that we have to remove the error_on_recompile context, which could prevent us from detecting other recompilation issues. LMK if that sounds acceptable.
If I understand correctly there's almost no time penalty for using the more correct (recompilation) variant, so I'd opt for that since dropout is only one potential candidate for train/eval mismatches.
Done, please review again.