torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

[RFC] Supporting KV-cache toggling

Open SalmanMohammadi opened this issue 5 months ago • 9 comments

Problem

Currently, when we use model.setup_caches(), KV-caches are always updated for every subsequent forward pass on the model. We have valid use cases for using model.setup_caches(), but then not updating the KV-cache on forward passes. The most immediate use-case we have for this is in the eval recipe (see https://github.com/pytorch/torchtune/issues/1621). When specifying multiple tasks in the recipe:

  • If one task is a generation task and another a log-likelihood task and
  • The generation task is evaluated before the log-likelihood task and
  • KV-cacheing is enabled

Then, KV-cacheing will still be enabled for the log-likelihood task which is incorrect behaviour for a number of reasons. In this example, we have two forward passes occurring, one with KV-cacheing enabled, one with KV-cacheing disabled.

Another use case is for my work on improving our RLHF offerings.

  1. For the current PPO recipe, the overall structure looks something like:
batch = batch.to(device)
with torch.no_grad():
    ...
    completions = policy_model.generate(...)

loss = self._ppo_step(completions, policy_model, ...)
loss.backward()
optimizer.step()

This isn't 1:1 but you get the gist. Here, we have two kinds of forward passes: one under torch.no_grad() with KV-cacheing, and one with requires_grad=True and without KV-cacheing (infact, we should never really have the case where requires_grad=True + KV-cacheing, but this part is kind of relevant for compile).

  1. For a LoRA PPO recipe I've been working on, the structure is a bit different:
batch = batch.to(device)
with torch.no_grad():
    ...
    completions = policy_model.generate(...)
    with torchtune.modules.peft.disable_adapter(policy_model):
        ref_logits = policy_model(completions)

loss = self._ppo_step(completions, policy_model, ...)
loss.backward()
optimizer.step()

So, now three kinds of forward passes: two under torch.no_grad(): one with KV-cacheing and one without, and one with requires_grad=True and without KV-cacheing.

We want to support these use-cases in a compile friendly manner.

Solution

1) Context managing KV-cacheing

To me this feels like the most user-friendly solution. We define a context manager like:

@contextlib.contextmanager
def enable_kv_cache(model):
    if not model.caches_are_enabled():
        raise ValueError()
    for layer in model.layers:
        layer.attn.cache_enabled = True
    try:
        yield
    finally:
        for layer in model.layers:
            layer.attn.cache_enabled = False

Which relies on a small modification to MultiHeadAttention:


class MultiHeadAttention(...):
    def __init__(...):
        ...
        self.cache_enabled = False

    def forward(...):
        ...
        if self.kv_cache is not None and self.cache_enabled:
            k, v = self.kv_cache.update(k, v)

The UX for this change would rely on users always having to use with enable_kv_cache(model): if they wanted to use KV-cacheing. This is nice because the behaviour is quite explicit. However, specifically for the eval recipe, KV-cacheing is configurable so we maybe don't want to be erroring out. Two options:

  1. inspired by torch.inference_mode(mode: bool)
@contextlib.contextmanager
def kv_cache_mode(model, mode=True):
    if not model.caches_are_enabled():
        raise ValueError()
    for layer in model.layers:
        layer.attn.cache_enabled = mode
    try:
        yield
    finally:
        for layer in model.layers:
            layer.attn.cache_enabled = not mode 


with kv_cache_mode(self._model, mode=self._enable_kv_cache)
    toks, _ = generation.generate(
        self._model,
        maybe_padded_context,
        max_generated_tokens=self.max_gen_toks,
        temperature=temperature,
        top_k=None,  # do_sample is not supported currently
        stop_tokens=self._tokenizer.stop_tokens,
    )

I'm not 100% sure about this one because I can't imagine many use cases when model.caches_are_enabled=True but mode=False. Maybe if your default is inference/KV-cacheing and you want to disable it for a single pass?

  1. Don't error out on model.caches_are_enabled. If you try using this without setup_caches, this is a no-op and nothing will happen:
@contextlib.contextmanager
def kv_cache_mode(model):
    if not model.caches_are_enabled():
        # maybe warn here?
        yield
    for layer in model.layers:
        layer.attn.cache_enabled = True
    try:
        yield
    finally:
        for layer in model.layers:
            layer.attn.cache_enabled = False 

I can confirm the above solutions work with compile (I reserve the right to retract this claim in light of any future knowledge).

2) Rely on inference_mode

We could differentiate forward passes which require KV-cacheing by using torch.inference_mode() (like we currently decorate _generation.generate with). Then, the change would simply be:


class MultiHeadAttention(...):
    def __init__(...):
        ...


    def forward(x, ...):
        ...
        if self.kv_cache is not None and x.is_inference():
            k, v = self.kv_cache.update(k, v)

However, this will only work with compile with nightlies/until next release (see https://github.com/pytorch/pytorch/pull/136450).

This is a very minimal change but very non-obvious, so I don't really like it. Open to thoughts though.

SalmanMohammadi avatar Sep 25 '24 13:09 SalmanMohammadi