Add continuous generation
This PR is for solving https://github.com/outlines-dev/outlines/issues/667.
API
import outlines
from outlines.generate import continuous
generator = outlines.generate.text(model)
generator_c = continuous(generator)
response = generator_c(prompt, max_tokens=30)
continuous
continuous wraps any SequenceGenerator object, it could be:
-
outlines.generate.choice -
outlines.generate.text -
outlines.generate.json -
...
The continuous wrapper allows the generator to save the state of a Sequence, it means that, if you continuously generate a sequence as shown:
import outlines
from outlines.generate import continuous
generator = outlines.generate.text(model)
generator_c = continuous(generator)
response_1 = generator_c(prompt, max_tokens=100)
response_2 = generator_c(response)
KV Cache (under some conditions) will be saved. Algorithms such as beam search could be used to optimize the whole sequence rather than separately.
import outlines
from outlines.generate import continuous
generator = outlines.generate.text(model, sampler=BeamSearchSampler(3))
generator_c = continuous(generator)
response_1 = generator_c(prompt, max_tokens=100)
response_2 = generator_c(response_1)
You can mix different types of SequenceGenerator objects:
import outlines
from outlines.generate import continuous
generator_text = outlines.generate.text(model)
generator_choice = outlines.generate.choice(model, ["Positive", "Negative"])
generator_text_c = continuous(generator_text)
generator_choice_c = continuous(generator_choice)
response_1 = generator_text_c(prompt, max_tokens=100)
response_2 = generator_choice_c(response_1)
Once a prompt is given to the continuous wrapper, it becomes a SequenceState object.
class SequenceState:
token_ids: torch.Tensor
weights: torch.Tensor
attention_masks: torch.Tensor
kv_cache: torch.Tensor
tokenizer: "Tokenizer"
SequenceState
Indexing
Each SequenceState has three dimensions SequenceState[batch_key: Union[int, slice], sample_key: Union[int, slice], ids_size_key: Union[int, slice]].
However, there are three cases where this is handled differently:
-
batch_size == 1andsample_size == 1SequenceState[ids_size_key: Union[int, slice]], instead ofSequenceState[0, 0, ids_size_key: Union[int, slice]]. -
batch_size == 1SequenceState[sample_key: Union[int, slice], ids_size_key: Union[int, slice]], instead ofSequenceState[0, sample_key: Union[int, slice], ids_size_key: Union[int, slice]]. -
sample_size == 1SequenceState[batch_key: Union[int, slice], ids_size_key: Union[int, slice]], instead ofSequenceState[batch_key: Union[int, slice], 0, ids_size_key: Union[int, slice]].
Operations
You can apply two operations on a SequenceState:
-
Slicing
-
Adding (
SequenceStateto aSequenceStateandSequenceStateto a prompt)
Adding
-
SequenceStateto aSequenceState
This won't save the first part of the KV Cache for the moment, but it does accumulate the weights between both sequences.
I don't have an idea how to implement it, the KV Cache implementation from HuggingFace accepts either (1) a None value or (2) a KV Cache with a context size less than one than the one for the token_ids.
I've just done an experiment where I use the model to compute (or complete) the KV Cache for the second sequence using the model to satisfy (2). The function is called complete_kv_cache_from_token_ids, it's not implemented because it's slow.
-
SequenceStateto a prompt
This will reinitialize everything.
Slicing
Conditions under which KV Cache is saved:
-
The slice considers only one element
(batch_size_after_the_slice == 1, sample_size_after_the_slice == 1), slicing more than one element will reset the KV Cache. The condition includes the base case where(batch_size == 1, sample_size == 1). -
The slice starts from the first index for the prompt
(SequenceState[..., :M], SequenceState[..., 0:M]).
There are some technical intricacies that don't allow saving KV Cache even under 1. and 2., see [NOTE] [SPECIAL CASE] flags in token_level_slice_from_string_level_slice utility.
It's also one of the reasons to not go wander to get KV Cache work if batch_size > 1 and num_samples > 1. The tradeoff complexity-usefulness seems just way off to me.
Using list(SequenceState)
list(SequenceState) allows to convert the SequenceState object into a list of strings.
Exceptions
Three types of exceptions could be raised while using the continuous wrapper.
-
SampleMismatch: This is (1) raised when the sequence's samples are sliced, then thrown to the wrapper (a mismatch between the number of samples in the sequence and the one in the generator) and (2) two sequences with different number of samples are added.
-
BatchMismatch: This is raised when two sequences of different batch sizes are added.
-
SlicingError: This is raised when the slice doesn't allow the KV Cache to be saved, it is handled through resetting the KV Cache.
FLAGs
You will see multiple flags that I've put in the code comments:
-
[NOTE]: Those are general notes, explaining how I approached some problems.
-
[QUESTION]: Those are questions that I had when I was coding certain mechanisms.
-
[POTENTIAL BUG]: Those are lines of code that could potentially trigger bugs.
Modifications
-
GenerationStatereturnsattention_masksas well. -
sequence_generatortakeskv_cacheas a keyword argument with the valueNoneas a default.
test_continuous.py
Those are some tests I added for the different parts of the code.
PS: I use the name SequenceState instead of Sequence just because it made the coding more obvious to me, tell me if you want to switch it back to Sequence.
This is cool, I really like the continuous function to a point I'd consider doing the same thing for stream. I'll review the PR soon, please bear with me.