torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

VRAM Usage / Training Time in comparison to Huggingface

Open bdytx5 opened this issue 1 year ago • 5 comments

Sorry to bother again.. I was able to use batch size of 17 for Llama3 on a single A6000, whereas I could only reach batch size=1 with huggingface (with lora, and all parameters equal). However the difference in training speed was only about 25% faster for torchtune. Just wondering if this makes any sense? -Brett

bdytx5 avatar Apr 21 '24 19:04 bdytx5

@bdytx5 thanks so much for giving this a try! Glad to hear that torchtune is showing some impressive memory wins here.

Do you mind sharing more information about your training? Is BS the only change you're making? And can you also share what's the training speed you're seeing? A simple iter/sec or sec/iter from the tqdm logs would be enough. I imagine you have similar information from HF as well?

Generally, there are a number of things that go into trading-off memory for perf. Setting the maximum possible BS might not always be efficient because under the hood the allocator might be constantly hitting into malloc retries. So you'll have to play with around with few different batch_sizes. The default config also enables activation checkpointing by default. So you might want to disable it using enable_activation_checkpointing=False. AC will help you fit a bigger BS but will involve recomputing the activations in every backward pass which slows down training. So just a lot going on. I can help collect some of these numbers on an A6000, but would love to learn more about the information above.

kartikayk avatar Apr 21 '24 20:04 kartikayk

Thanks for the quick response! I am writing an article for wandb comparing the two. I disabled activation checkpointing like you mentioned, and also the only real change between the two runs was I used 17 gradient accumulation steps for huggingface in order to make it a bit more even in terms of batch size. 1 epoch on the stock alpaca datasets finished in about 105mins for torchtune and 132mins for HF. I'll share the article here once it's finished (if you want to leave this issue open until then)! I pretty much followed the stock training run for llama3 lora on a single A6000.

bdytx5 avatar Apr 21 '24 20:04 bdytx5

Sounds great! Looking forward to it.

BTW there's an open PR for adding selective checkpointing to torchtune which will likely impact throughput. But this probably won't land till Monday.

kartikayk avatar Apr 21 '24 20:04 kartikayk

gotcha. I'll make sure and make note of this!

bdytx5 avatar Apr 21 '24 20:04 bdytx5

https://wandb.ai/byyoung3/mlnews2/reports/Fine-Tuning-Llama-3-with-LoRA-TorchTune-vs-Huggingface--Vmlldzo3NjE3NzAz?utm_campaign=Fine-Tuning+Llama-3&utm_source=twitter&utm_medium=social&utm_content=Llama3

Heres the article. Hope you enjoy!

bdytx5 avatar May 01 '24 01:05 bdytx5