exllama icon indicating copy to clipboard operation
exllama copied to clipboard

Adds the possibility to influence prediction with bias

Open paolorechia opened this issue 2 years ago • 3 comments
trafficstars

Related issue (created by me): https://github.com/turboderp/exllama/issues/103

paolorechia avatar Jun 26 '23 06:06 paolorechia

This PR was squashed today. Added a couple of modifications that made easier to integrate into guidance. Here's the PR in the guidance repository: https://github.com/microsoft/guidance/pull/298

I wouldn't be surprised if the PR is blocked by Microsoft Research, given that I'm not entirely sure what I'm doing :joy: But, worth trying :smile:

paolorechia avatar Jul 02 '23 13:07 paolorechia

Getting Guidance to work with Exllama is a pretty significant step, since getting structured content out of an LLM is one of the biggest hurdles to application development. (That i know of, at least; the other being context length, which it sounds like may be very nearly solved)

tensiondriven avatar Jul 02 '23 15:07 tensiondriven

@tensiondriven Agreed, hopefully someone more knowledgeable than me will pick up the work so far and help completing the integration. Let’s see 🙂

paolorechia avatar Jul 02 '23 15:07 paolorechia

Just seen this as I sat down to do the same thing. Amazing, thanks @paolorechia.

What specifically needs more work?

qeternity avatar Jul 06 '23 17:07 qeternity

Hey, @qeternity, my first attempt to support the guidance ‘select’ command was quite hacky. It uses a second trie instead of the inference logprobs, which is different from the transformers implementation. That’s probably the biggest thing to fix in the open PR.

Other than that, there’s a lot of work to do to cover all parameters, token healing and general clean up of the code.

I was hoping to hear some words back from Microsoft before investing more time but nothing so far.

paolorechia avatar Jul 06 '23 17:07 paolorechia

I don't think Microsoft feedback is necessary, this will be incredibly useful to a lot of people even if it lives on as a fork.

I think this issue alone already has a handful of people who are likely going to duplicate each other's work otherwise.

qeternity avatar Jul 06 '23 17:07 qeternity

I don't think Microsoft feedback is necessary, this will be incredibly useful to a lot of people even if it lives on as a fork.

I think this issue alone already has a handful of people who are likely going to duplicate each other's work otherwise.

@qeternity Yes, you’re absolutely right. I suppose either way I’m not very confident in the results of my PR, so this is still very experimental work and needs more testing / refinement.

paolorechia avatar Jul 06 '23 17:07 paolorechia

Anyone else make any progress on this? Was about to start looking at it where you left off, @paolorechia :)

fblissjr avatar Jul 12 '23 18:07 fblissjr

@fblissjr thanks for the interest!

I believe @qeternity made some good progress in his fork.

I’d recommend you read the discussion on https://github.com/microsoft/guidance/pull/298 to take some ideas out and understand our current problems.

IMHO you could explore try continuing his fork or try to experiment with writing an adapter model class directly in the transformers library.

Let me know if there’s something specific you would like to discuss.

paolorechia avatar Jul 12 '23 22:07 paolorechia

this will be incredibly useful to a lot of people even if it lives on as a fork.

I second this. This functionality is crucial for me since I am working on a C# library that uses the same tricks as guidance. So it's a feature that is generally useful. And it's a small change. @paolorechia @turboderp any chance we can check this in?

(I probably only need the gen_single_token change, currently experimenting with it.)

zmarty avatar Aug 10 '23 17:08 zmarty

@zmarty I can’t answer on it as I’m not a maintainer of this repo. Guidance integration ended up being harder than we expected, we all kinda of gave up for the time being, but the logit bias feature is rather simple IMO.

Guidance actually uses a custom token sampling algorithm, so it’s not just the bias like I imagined.

would love to see your library code if it’s open source, please share :)

paolorechia avatar Aug 11 '23 07:08 paolorechia

WARNING: I don't know what I'm doing.

@paolorechia Here's what I found so far.

Inside generator.py I added:

    def generate_token_with_bias(self, prefix, logit_bias, startNewRequest = True):
        self.end_beam_search()

        if prefix and len(prefix) > 0:
            ids, mask = self.tokenizer.encode(prefix, return_mask = True, max_seq_len = self.model.config.max_seq_len)

            if (startNewRequest):
                self.gen_begin(ids, mask = mask)
            else:
                self.gen_accept_token(ids)

        token = self.gen_single_token(logit_bias=logit_bias)

        text = self.tokenizer.decode(token)
        return text[0]

The idea here is that from an external client we make an initial request with prefix set to the preliminary prompt, startNewRequest set to True to initialize with gen_begin, and logit_bias set to the token(s) we want to boost.

This initial call works and if I continue to call it without any prefix or bias, it continues to generate coherent text.

However, it breaks down in the subsequent cases where I want to set startNewRequest to False.

Here is what I want to achieve:

  • First call:
    • prefix: "Once upon a time, "
    • logit_bias: set to letter 't'
    • startNewRequest: True
    • Result: t
  • Second call:
    • prefix: "rolls lived under the "
    • logit_bias: set to letter 'b'
    • startNewRequest: False -> Now this part is important because the intent is that if this is False, we append the prefix to the ongoing prompt using self.gen_accept_token(ids)
    • Result: b
  • Subsequent calls:
    • prefix: None
    • logit_bias: None
    • startNewRequest: False
    • Result: garbage starts, because it kind of forgets what happened before

What I think the problem is: self.gen_accept_token(ids). This function does not seem to update the cache, and sounds to me like cache is important?

zmarty avatar Aug 11 '23 18:08 zmarty

Just for completeness, below is the SocketIO API that calls the code above.

I call this API from a separate C# project.

from model import ExLlama, ExLlamaCache, ExLlamaConfig
from tokenizer import ExLlamaTokenizer
from generator import ExLlamaGenerator
import os, glob
import torch

from flask import Flask, render_template
from flask_socketio import SocketIO

# Directory containing model, tokenizer, generator

model_directory =  "/ai/ooba/text-generation-webui/models/TheBloke_Llama-2-13B-chat-GPTQ/"

# Locate files we need within that directory

tokenizer_path = os.path.join(model_directory, "tokenizer.model")
model_config_path = os.path.join(model_directory, "config.json")
st_pattern = os.path.join(model_directory, "*.safetensors")
model_path = glob.glob(st_pattern)[0]

# Create config, model, tokenizer and generator

config = ExLlamaConfig(model_config_path)               # create config from config.json
config.model_path = model_path                          # supply path to model weights file

model = ExLlama(config)                                 # create ExLlama instance and load the weights
tokenizer = ExLlamaTokenizer(tokenizer_path)            # create tokenizer from tokenizer model file

cache = ExLlamaCache(model)                             # create cache for inference
generator = ExLlamaGenerator(model, tokenizer, cache)   # create generator

# Configure generator

generator.disallow_tokens([tokenizer.eos_token_id])

generator.settings.token_repetition_penalty_max = 1.2
generator.settings.temperature = 0.95
generator.settings.top_p = 0.65
generator.settings.top_k = 100
generator.settings.typical = 0.5

app = Flask(__name__)
socketio = SocketIO(app)

@socketio.on('from_client')
def handle_message(msg):
    startNewRequest = bool(msg['startNewRequest'])
    prompt = msg.get('prompt')
    boostedTokens = msg.get('boostedTokens')

    logit_bias = None

    if boostedTokens:
        logit_bias = torch.zeros([1, 1, config.vocab_size])

        for key, value in boostedTokens.items():
            logit_bias[:, :, int(key)] += value

    output = generator.generate_token_with_bias(prompt, logit_bias=logit_bias, startNewRequest=startNewRequest)

    send_message_to_client(output)

def send_message_to_client(msg):
    socketio.emit('from_server', msg)

if __name__ == '__main__':
    socketio.run(app, host='0.0.0.0')

zmarty avatar Aug 11 '23 18:08 zmarty

Ideas I have to continue fixing the cache:

  • Use gen_begin_reuse
  • Or use constraints var in gen_single_token to force it to ingest "rolls lived under the " into the cache token by token.
  • Use gen_feed_tokens

zmarty avatar Aug 11 '23 18:08 zmarty

Ok the last proposal about gen_feed_tokens worked. Phew.

Here's the updated code that generates coherent output so far:

    def generate_token_with_bias(self, prefix, logit_bias, startNewRequest = True):
        self.end_beam_search()

        if prefix and len(prefix) > 0:
            ids, mask = self.tokenizer.encode(prefix, return_mask = True, max_seq_len = self.model.config.max_seq_len)

            if (startNewRequest):
                self.gen_begin(ids, mask = mask)
            else:
                self.gen_feed_tokens(ids, mask)

        token = self.gen_single_token(logit_bias=logit_bias)

        text = self.tokenizer.decode(token)
        return text[0]

So what this theoretically allows us to do is:

  • Bias the next token in some direction(s)
  • Optionally feed it tokens that we know it would need to generate anyway and speed up the process. But since I use gen_feed_tokens and that one calls self.model.forward I have no idea if this saves any computation or not. I see in the call to forward it sets preprocess_only to True, and inside model.py I see that if this is True, it skips some stuff, but not sure how much it saves.

zmarty avatar Aug 11 '23 18:08 zmarty

I see your problem, @zmarty

I had not thought about your use case when I wrote this bias example, interesting.

Unfortunately I’m also not familiar with how the exllama cache works internally, so I don’t think I’ll be of much help here, though it seems like you found a solution.

One caveat I think worth mentioning is that putting a bias just on the first letter does not achieve the same quality of the generated output as the original guidance library would.

If I’m not mistaken, llama uses sentencepiece tokenizer, which splits the input into subword units.

In guidance there’s some additional glue between the tokenizer and the bias generation that I didn’t figure out back then, but I think for optimal performance the bias has to be applied to the longest subword units that belong to the tokenizer vocabulary.

I think the trick is maybe tokenizing what you want to boost, taking the first token out of that, and then applying the bias on the token (rather than applying the bias on the first character of the word).

Of course, if you’re trying to figure out between two options, they might share substrings, and then handling this can get complex if you want to implement it efficiently (skipping common tokens).

paolorechia avatar Aug 11 '23 22:08 paolorechia

@paolorechia Yes, presumably with the latest update above I fixed my cache problem, will test more. Also forgot to mention that on the C# side I am using the actual Llama tokenizer, not just on a per character basis. So hopefully I can achieve all the tricks :)

zmarty avatar Aug 11 '23 22:08 zmarty

Stale PR, closing.

paolorechia avatar Jan 31 '24 21:01 paolorechia