Easy-Transformer
Easy-Transformer copied to clipboard
[Proposal] Memory efficient causal mask implementation
Proposal
[Relatively minor proposal - considered making it a bug, but it's not really a bug.]
In the initialization of each Attention
module, we register a causal_mask
buffer. This buffer is a boolean tensor of shape (self.cfg.n_ctx, self.cfg.n_ctx)
.
This is quite inefficient for 2 reasons:
- Most times, the prompt context length is much smaller than
self.cfg.n_ctx
(which represents the maximum context length). - The same buffer is stored for every layer.
This hasn't really been a visible issue so far with models with smallish context lengths. But consider a model like Qwen 72B, which has max context length of 32768 and 80 layers. With the current implementation, there will be a boolean tensor of shape (32768, 32768)
initialized for each layer, resulting in 32768 * 32768 * 1 byte * 80 layers ~= 86 GB of overhead.
As a temporary fix, we can just cap n_ctx
on models to be less than some reasonable value (2048
or 4096
). But I think the ideal solution is just to compute the attention mask on the fly, and have it be the size of the particular context length.
Note that the same inefficiency exists with rotary embeddings (we precompute sin and cos tensors of length n_ctx
). But it's not nearly as bad since they grow O(n_ctx)
, whereas the mask grows O(n_ctx^2)
.
Checklist
- [x] I have checked that there is no similar issue in the repo (required)