torchtitan
torchtitan copied to clipboard
A PyTorch native library for large-scale model training
It can make model able to train a big model where GPU can not even fit batchsize =1?
Summary: Use the stateful_dataloader from torchdata (https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader) for storing the token buffer and iteration data order. It requires a dependency on the nightly build of torchdata >= 20240426. Also make...
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #280 Per suggestion in #274: This PR removes embedding from number of parameters calculation, because embedding op doesn't do matmul. This PR...
This way we could temporarily enable 2-D parallel compile, and it might make sense to do transformer block compile in the future with PP (which we'll see). We should figure...
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #250
lr scheduler currently maintains two global states to implement the full lr warmup and decay. We want to remove these: "nit: we can make these two arguments still as function...
The issue comes from the backward computation of `aten.mul` of two complex numbers from DTensors: the result will be b + a`i` when it should be a + b`i`. Not...
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #180 * #285 * #161 * __->__ #172 Adds new command ./create_seed_checkpoint.sh which largely reuses code inside train.py to create the model and...
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #308 * __->__ #161 ---- - uses pipeline tracer frontend to extract a graph and partition it into chunks per stage - hardcodes...
Purpose of this PR is to show: 1. One line change needed -- remove this line: ``` self.freqs_cis = self.freqs_cis.to(h.device) ``` Reason 1: compile does not support in-place attribute mutation....