torchtune
torchtune copied to clipboard
[RFC] Supporting KV-cache toggling
Problem
Currently, when we use model.setup_caches()
, KV-caches are always updated for every subsequent forward pass on the model. We have valid use cases for using model.setup_caches()
, but then not updating the KV-cache on forward passes. The most immediate use-case we have for this is in the eval recipe (see https://github.com/pytorch/torchtune/issues/1621). When specifying multiple tasks in the recipe:
- If one task is a generation task and another a log-likelihood task and
- The generation task is evaluated before the log-likelihood task and
- KV-cacheing is enabled
Then, KV-cacheing will still be enabled for the log-likelihood task which is incorrect behaviour for a number of reasons. In this example, we have two forward passes occurring, one with KV-cacheing enabled, one with KV-cacheing disabled.
Another use case is for my work on improving our RLHF offerings.
- For the current PPO recipe, the overall structure looks something like:
batch = batch.to(device)
with torch.no_grad():
...
completions = policy_model.generate(...)
loss = self._ppo_step(completions, policy_model, ...)
loss.backward()
optimizer.step()
This isn't 1:1 but you get the gist. Here, we have two kinds of forward passes: one under torch.no_grad()
with KV-cacheing, and one with requires_grad=True
and without KV-cacheing (infact, we should never really have the case where requires_grad=True
+ KV-cacheing, but this part is kind of relevant for compile).
- For a LoRA PPO recipe I've been working on, the structure is a bit different:
batch = batch.to(device)
with torch.no_grad():
...
completions = policy_model.generate(...)
with torchtune.modules.peft.disable_adapter(policy_model):
ref_logits = policy_model(completions)
loss = self._ppo_step(completions, policy_model, ...)
loss.backward()
optimizer.step()
So, now three kinds of forward passes: two under torch.no_grad()
: one with KV-cacheing and one without, and one with requires_grad=True
and without KV-cacheing.
We want to support these use-cases in a compile friendly manner.
Solution
1) Context managing KV-cacheing
To me this feels like the most user-friendly solution. We define a context manager like:
@contextlib.contextmanager
def enable_kv_cache(model):
if not model.caches_are_enabled():
raise ValueError()
for layer in model.layers:
layer.attn.cache_enabled = True
try:
yield
finally:
for layer in model.layers:
layer.attn.cache_enabled = False
Which relies on a small modification to MultiHeadAttention
:
class MultiHeadAttention(...):
def __init__(...):
...
self.cache_enabled = False
def forward(...):
...
if self.kv_cache is not None and self.cache_enabled:
k, v = self.kv_cache.update(k, v)
The UX for this change would rely on users always having to use with enable_kv_cache(model):
if they wanted to use KV-cacheing. This is nice because the behaviour is quite explicit. However, specifically for the eval recipe, KV-cacheing is configurable so we maybe don't want to be erroring out. Two options:
- inspired by
torch.inference_mode(mode: bool)
@contextlib.contextmanager
def kv_cache_mode(model, mode=True):
if not model.caches_are_enabled():
raise ValueError()
for layer in model.layers:
layer.attn.cache_enabled = mode
try:
yield
finally:
for layer in model.layers:
layer.attn.cache_enabled = not mode
with kv_cache_mode(self._model, mode=self._enable_kv_cache)
toks, _ = generation.generate(
self._model,
maybe_padded_context,
max_generated_tokens=self.max_gen_toks,
temperature=temperature,
top_k=None, # do_sample is not supported currently
stop_tokens=self._tokenizer.stop_tokens,
)
I'm not 100% sure about this one because I can't imagine many use cases when model.caches_are_enabled=True
but mode=False
. Maybe if your default is inference/KV-cacheing and you want to disable it for a single pass?
- Don't error out on
model.caches_are_enabled
. If you try using this withoutsetup_caches
, this is a no-op and nothing will happen:
@contextlib.contextmanager
def kv_cache_mode(model):
if not model.caches_are_enabled():
# maybe warn here?
yield
for layer in model.layers:
layer.attn.cache_enabled = True
try:
yield
finally:
for layer in model.layers:
layer.attn.cache_enabled = False
I can confirm the above solutions work with compile (I reserve the right to retract this claim in light of any future knowledge).
2) Rely on inference_mode
We could differentiate forward passes which require KV-cacheing by using torch.inference_mode()
(like we currently decorate _generation.generate
with). Then, the change would simply be:
class MultiHeadAttention(...):
def __init__(...):
...
def forward(x, ...):
...
if self.kv_cache is not None and x.is_inference():
k, v = self.kv_cache.update(k, v)
However, this will only work with compile with nightlies/until next release (see https://github.com/pytorch/pytorch/pull/136450).
This is a very minimal change but very non-obvious, so I don't really like it. Open to thoughts though.