peft icon indicating copy to clipboard operation
peft copied to clipboard

ENH: Improve torch.compile support in MetaMath

Open BenjaminBossan opened this issue 1 month ago • 1 comments

The MetaMathQA benchmark already had support to enable torch.compile but it was not very well implemented. The new changes are:

  • call compile after applying PEFT, not before
  • compile with dynamic=True
  • avoid model.eval() + model.train() calls

These changes prevent graph breaks and recompiles. A context manager is now used to ensure that those don't happen.

Some unrelated changes:

  • improve some type annotations
  • use dtype argument instead of deprecated torch_dtype

BenjaminBossan avatar Nov 06 '25 11:11 BenjaminBossan

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Note to self: How to deal with dropout? It's not deactivated by torch.inference_mode().

BenjaminBossan avatar Nov 20 '25 14:11 BenjaminBossan

After some testing: When running evaluation, we would generally want to put the model into eval mode (dropout). However, this triggers a re-compile when the model is put back into train mode (i.e. a total of two compiles happen). We could skip this train/eval toggle during training to avoid the re-compile. This would mean that the model is in train mode when evaluating, but arguably that is not a very big deal. Obviously, when it comes to the test set, we do put the model in eval mode first.

Here are the numbers for no compile, compile with train/eval switch, and compile without train/eval switch:

metric no compile compile w/ train/eval switch compile w/o switch
reserved mem max 22.3 GB 16.4 GB 16.4 GB
reserved mem avg 14.4 GB 11.2 GB 11.2 GB
reserved mem 99th 20.1 GB 14.6 GB 14.6 GB
number of compiles 0 2 1
train time / step ~29 sec ~24 sec ~24 sec
final train loss (sanity check) 0.60717 0.60690 0.60695
total time 994 sec 943 sec

Validation accuracy varies a bit, but that's to be expected with a rather small validation set size and measuring on generations.

I tried a couple of mitigations, like running eval with torch.compiler.disable or compiling the eval function separately (which wouldn't really save time, as we replace a recompile with another compile), but nothing I tried helps.

@githubnemo What would you prefer: Live with the recompilation or avoid the train/eval switch?

BenjaminBossan avatar Nov 24 '25 12:11 BenjaminBossan

Thanks for investigating!

If I understand correctly there's almost no time penalty for using the more correct (recompilation) variant, so I'd opt for that since dropout is only one potential candidate for train/eval mismatches.

githubnemo avatar Nov 25 '25 11:11 githubnemo

If I understand correctly there's almost no time penalty for using the more correct (recompilation) variant, so I'd opt for that since dropout is only one potential candidate for train/eval mismatches.

Yes, we could do that. It means, however, that we have to remove the error_on_recompile context, which could prevent us from detecting other recompilation issues. LMK if that sounds acceptable.

BenjaminBossan avatar Nov 25 '25 11:11 BenjaminBossan

If I understand correctly there's almost no time penalty for using the more correct (recompilation) variant, so I'd opt for that since dropout is only one potential candidate for train/eval mismatches.

Done, please review again.

BenjaminBossan avatar Dec 02 '25 14:12 BenjaminBossan