torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

Perf / accuracy metrics comparison with Nemo for SFT / reasoning distillation scenario

Open vadimkantorov opened this issue 7 months ago • 2 comments

Do you by chance have any perf / accuracy metrics comparison with Nemo for basic fine-tuning?

Thanks :)

vadimkantorov avatar May 22 '25 15:05 vadimkantorov

We don't have specific benchmark numbers for any other library. You can see some of our numbers here and try to put together a comparison yourself though.

@ebsmothers do you have any additional thoughts?

pbontrager avatar May 22 '25 20:05 pbontrager

Basically wondering if torchtune is a good reproducible research tool for some challenges like https://github.com/huggingface/open-r1 which does SFT on math reasoning traces (or otherwise traces from https://github.com/Open-Reasoner-Zero/Open-Reasoner-Zero/tree/main/data), and then grpo...

If there are reproducible logs (logs and wandb curves published) on common hardware like 8xH100 machines - it would be extremely valuable, including as simple, pytorch-native baseline

vadimkantorov avatar May 22 '25 22:05 vadimkantorov

@vadimkantorov sorry I missed this until now. We do provide perf and memory numbers running SFT on an instruct dataset on a single A100 in our readme here. But I think it would be good for us to add similar numbers for a full node as well. Regarding reproducible logs, is there a certain set of runs you are interested in seeing here?

ebsmothers avatar May 27 '25 20:05 ebsmothers

@RedTachyon may be able to provide more info specifically about the GRPO recipe. But we haven't done comprehensive comparisons/benchmarks for recipes in our dev folder.

pbontrager avatar May 30 '25 16:05 pbontrager

Hi, I don't think there's any comprehensive benchmarking effort, but I have a bunch of curves from a heavily modified* version of the GRPO recipe.

Note that all of this is mostly intermediate benchmarks for other research, so there are some crashes, some weird noisy runs etc. But hopefully it's better than nothing. Each model was trained on 8xA100.

*The algorithm should be substantially the same, but we have an internal version that uses vLLM for data generation, so it's probably much faster than the current public version. I think we also changed advantage normalization and the length normalization of the loss. At some point I can get around to cleaning up and publishing the vLLM stuff, especially if there's some demand from the community.

NB: Llama models are whichever Llama 3 variants with that size Qwens are Qwen 2.5

GSM8k with {llama3b, llama8b, qwen3b, qwen7b}:

Image Image

MATH with {llama3b, llama8b, qwen3b, qwen7b}:

Image Image

DAPO training, MATH-500 evaluation with qwen7b

Image

RedTachyon avatar May 30 '25 20:05 RedTachyon