Easy-Transformer
Easy-Transformer copied to clipboard
Construct causal mask on-the-fly
Description
Previously, we were allocating causal masks of size (n_ctx, n_ctx)
for every instantiation of AbstractAttention
, where n_ctx
corresponds to the maximum context length.
For models with a large maximum context length, this leads to wasteful memory consumption.
This PR
I tried to make the change as conservatively as possible - I took the existing logic for creating the causal mask from AbstractAttention.__init__
, and put it in AbstractAttention.apply_causal_mask
. The causal mask is constructed at inference time, and its shape is (cur_ctx_length, cur_ctx_length)
, which is generally much smaller than (n_ctx, n_ctx)
.
I think the causal mask is light-weight enough that this should not cause performance issues. However, there is opportunity for further optimization: we could have a causal mask buffer tied to each instantiation of AbstractAttention
, initialize it with some shape (maybe (128, 128)
), and then increase it as needed (i.e. if we're doing a forward pass and the ctx_len is 129, then we can regenerate the buffer to be shape (256, 256)
). This solution would avoid constructing new masks each forward pass. I didn't implement this here, but we can explore it if we feel it's necessary.
Fixes #479
Type of change
- [x] Bug fix (non-breaking change which fixes an issue)
Checklist:
- [x] I have commented my code, particularly in hard-to-understand areas
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my feature works
- [x] New and existing unit tests pass locally with my changes
- [x] I have not rewritten tests relating to key interfaces which would affect backward compatibility
Having written out the alternative solution of having a cached attention mask that grows as needed, I'm thinking maybe that's better..
It does have the following drawback: if you run a very long sequence, the model will construct and cache a very large causal mask, and this will be cached for the model's lifetime. But this doesn't seem so bad (thanks @collingray for pointing this out) - if the mask can fit on the device initially, then it's probably fine to cache it for subsequent usage as well.
I ran the following benchmarks to measure perf impact. The difference in perf doesn't seem significant to me, so I think this simple implementation seems ok. Let me know if you think there's some other benchmark that would be good to check.
#%%
import torch
import time
import gc
from tqdm import tqdm
from transformer_lens import HookedTransformer
model = HookedTransformer.from_pretrained(
'qwen-7b-chat',
device='cuda',
fp16=True,
dtype=torch.float16,
fold_ln=False,
center_writing_weights=False,
center_unembed=False,
)
tokenizer = model.tokenizer
torch.set_grad_enabled(False)
#%%
# Testing forward pass perf
runtimes = []
for i in tqdm(range(50)):
rand_toks = torch.randint(0, model.cfg.d_vocab-20, (1, 2048))
start = time.time()
logits = model(rand_toks)
end = time.time()
runtimes.append(end-start)
torch.cuda.empty_cache(); gc.collect()
print(f"Mean: {sum(runtimes)/len(runtimes)} seconds")
print(f"Std: {torch.std(torch.tensor(runtimes))}")
# new impl (on-the-fly):
# Mean: 0.6870033359527588 seconds
# Std: 0.034684497863054276
# old impl (persistent buffer):
# Mean: 0.6738856601715087 seconds
# Std: 0.023556165397167206
#%%
# Testing generate perf
runtimes = []
for i in tqdm(range(20)):
rand_toks = torch.randint(0, model.cfg.d_vocab-20, (1, 1024))
start = time.time()
generation = model.generate(rand_toks, max_new_tokens=32, stop_at_eos=False)
end = time.time()
runtimes.append(end-start)
torch.cuda.empty_cache(); gc.collect()
print(f"Mean: {sum(runtimes)/len(runtimes)} seconds")
print(f"Std: {torch.std(torch.tensor(runtimes))}")
# new impl (on-the-fly):
# Mean: 3.7325199961662294 seconds
# Std: 0.13069213926792145
# old impl (persistent buffer):
# Mean: 3.7241495132446287 seconds
# Std: 0.21790307760238647
Thanks for looking into this!
I guess the most efficient way would be to construct it once per model rather than once per head? However this would potentially break some forms of model parallelism (e.g. with deepspeed)
Also pinged you directly with a potential hacky (buy more efficient) fix using a static property
@andyrdt & @alan-cooney Is there any recollection on where this was? @andyrdt I just merged your branch to the most recent main branch. If you remember the advice Alan gave you, and you have time to implement it, we can get this merged relatively quickly. Otherwise, if you want to convey that information here, I am happy to make the changes, and get this into the main branch.
Hi @bryce13950 - thanks for pinging on this.
The currently-implemented solution in this PR is to construct attention masks for each attention component (i.e. at each layer) on-the-fly. This solution is simple, safe, and doesn't impact perf much (see benchmarks above), but feels a bit hacky since we're reconstructing/reallocating the same mask across many layers.
I think probably the best solution would be to construct a single attention mask at the model level at the beginning of a forward pass, and then pass this attention mask around. This is what I see most recent HuggingFace model implementations do.
For example, the HF Llama implementation constructs a mask at the model level inside of forward()
, and then plumbs it through the forward()
of subcomponents, eventually passing it to the attention component’s forward(..., attention_mask, ...)
.
There is a complication here with certain models that don’t use the same attention mask for each layer (e.g. GPT-neo, which has alternating local/global attention for each layer). One way to deal with this would be to pass the global attention mask around, and then modify it locally in attention components that are configured to use local
attention.
Let me know what you think about this proposed solution. Also please feel free to take a stab at implementing it (either editing this PR or submitting a new one) - I don't think I'll have time to properly revisit this over the next couple of weeks.
Thanks for getting the details in here. I am in the process of getting all PRs up to date with the main branch, and merging whatever is ready or wrapping up some quick final touches on anything. After that, I want to go through issues as well in order to clean up the irrelevant ones, and organize the rest. This PR definitely is looking like it is going to be a bit more time consuming than some of the other PRs so far. When I am done with these upkeep tasks, if this is still pending, then I can definitely get back to it, and wrap it up. If you find time in the next couple weeks, and there still hasn't been any movement on this, then any time you can give to it will definitely help. If you do get it revised, then I will make sure to merge it right away. If not, then it will just have to wait a little bit longer, but not the end of the world.
abandoning this pr