Andrew Gu
Andrew Gu
For `type == "right"`, we do not need to compute `A`, and for `type == "left"`, we do not need to compute `B`. (The variables are unused.) We can avoid...
This PR: - In forward, this saves 4 `aten::slice`. - In backward, this saves 4 `aten::fill` and 4 `aten::copy_` kernels with shape `(bs, seq_len, n_kv_heads, head_dim)`. See https://github.com/pytorch/torchtitan/pull/418 for details...
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #567 * __->__ #551 Compiling the loss improves performance. Moving the `.float()` upcast to inside this compiled loss further improves performance.
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #533 * #532 Credit: @felipemello1 for the previous token chunked cross entropy Credit: @Chillee for the new token chunked cross entropy Running...
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #533 * __->__ #532
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #467 **Current UX** - We use a `saved_tensors_hooks` context manager, which should be wrapped around `module.forward`. The context lets us override pack...
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #459 * #382
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #488 * #487
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #459 * __->__ #382 This requires https://github.com/pytorch/pytorch/pull/127786. **Experiment** - Llama3-8B on 8xH100, 1D FSDP, local batch size 2, selective op AC, `compiled_rmsnorm`, `torch.compile`...
For memory snapshot, we usually only need to take a snapshot on one of the first few iterations (e.g. step 2 or 3) after the optimizer step has run on...