outlines
outlines copied to clipboard
Exllamav2 Integration
This fixes https://github.com/outlines-dev/outlines/issues/1009 Also fixes https://github.com/outlines-dev/outlines/issues/807
The tests I did were:
For loading:
from outlines.integrations.exllamav2 import RegexFilter, TextFilter, JSONFilter, ChoiceFilter
import json
import torch
from exllamav2.generator.filters import ExLlamaV2PrefixFilter
from pydantic import BaseModel
from typing import Literal
from exllamav2 import(
ExLlamaV2,
ExLlamaV2Config,
ExLlamaV2Cache,
ExLlamaV2Cache_8bit,
ExLlamaV2Cache_Q4,
ExLlamaV2Tokenizer,
)
from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2Sampler, ExLlamaV2DynamicJob
from transformers import AutoTokenizer
import uuid
repo_id = "../Phi-3-mini-128k-instruct-exl2"
paged = False
model_dir = repo_id
total_context = 8192
max_context = 1024
max_batch_size = 4 if paged else 1
max_chunk_size = 1024
max_new_tokens = 1024
healing = True
draft_model = None
draft_cache = None
use_ngram_draft = None
use_ngram = None
config = ExLlamaV2Config(model_dir)
config.max_input_len = max_chunk_size
config.max_attention_size = max_chunk_size ** 2
config.max_seq_len = max_context
model = ExLlamaV2(config)
cache = ExLlamaV2Cache_Q4(
model,
max_seq_len = total_context,
lazy = True
)
tokenizer = ExLlamaV2Tokenizer(config)
hf_tokenizer_kwargs = {}
hf_tokenizer_kwargs.setdefault("padding_side", "left")
hf_tokenizer = AutoTokenizer.from_pretrained(model_dir, **hf_tokenizer_kwargs)
model.load_autosplit(cache, progress = True)
generator = ExLlamaV2DynamicGenerator(
model = model,
cache = cache,
draft_model = draft_model,
draft_cache = draft_cache,
tokenizer = tokenizer,
max_batch_size = max_batch_size,
use_ngram_draft = use_ngram,
max_chunk_size = max_chunk_size,
paged = paged,
)
Choices test:
filters = [
ChoiceFilter(["bob", "fred"], hf_tokenizer)
]
context_ids = torch.empty((1, 0), dtype = torch.long)
instruction = "Who is better bob or fred?"
print()
print("Assistant:", end = "")
instruction_ids = tokenizer.encode(f"[INST] {instruction} [/INST]", add_bos = True)
context_ids = torch.cat([context_ids, instruction_ids], dim = -1)
generator.enqueue(
ExLlamaV2DynamicJob(
input_ids = context_ids,
max_new_tokens = 1024,
stop_conditions = [],
filters=filters
)
)
eos = False
while not eos:
results = generator.iterate()
for result in results:
if result["stage"] == "streaming":
eos = result["eos"]
if "text" in result:
print(result["text"], end="")
sys.stdout.flush()
if "token_ids" in result:
context_ids = torch.cat([context_ids, result["token_ids"]], dim = -1)
print()
Returns
Assistant:bob
Json test
class JSONResponse(BaseModel):
response: str
confidence: Literal["low", "medium", "high"]
is_subjective: Literal["no", "yes", "possibly"]
filters = [
JSONFilter(JSONResponse, hf_tokenizer)
]
context_ids = torch.empty((1, 0), dtype = torch.long)
instruction = f"Give a sample response in the format of {JSONResponse.schema()} on a movie review of love actually"
print()
print("Assistant: ", end = "")
instruction_ids = tokenizer.encode(f"[INST] {instruction} [/INST]", add_bos = True)
context_ids = torch.cat([context_ids, instruction_ids], dim = -1)
generator.enqueue(
ExLlamaV2DynamicJob(
input_ids = context_ids,
max_new_tokens = 1024,
stop_conditions = [tokenizer.eos_token_id],
filters=filters
)
)
eos = False
while not eos:
results = generator.iterate()
for result in results:
if result["stage"] == "streaming":
eos = result["eos"]
if "text" in result:
print(result["text"], end="")
sys.stdout.flush()
if "token_ids" in result:
context_ids = torch.cat([context_ids, result["token_ids"]], dim = -1)
print()
Returns
Assistant: {"response": "Love Actually is a charming and heartwarming romantic comedy that delivers a delightful experience. The performances by the lead actors, especially Drew Barrymore and Gael García Bernal, are genuinely commendable. The film beautifully blends humor with heart-tugging moments, making it an ideal watch for those in search of a feel-good cinematic experience. Despite some predictable plot trends, the overall impact of the film remains largely positive. Rating: 7/10", "confidence": "medium", "is_subjective": "no"}
Some questions I had for maintainers were
- Should we do the prefix logic here? I noticed that in some exllamav2 filters in their repo the prefix is ignored but for one they are used.
- Do we want to return the stop tokens? It requires us to check through all the allowed tokens to see which one gives a final state. This may be a bit slower
@lapp0 make sense! Let me try doing this tomorrow
Thanks so much, please let me know if you have any questions!
@lapp0 sry for delay! Two questions Background: The current exllamav2 model in outlines(with ExllamaV2 class) as can be seen here, doesn't support filters(which is exllamav2's logitsprocessor). The filters are mainly used in exllamav2's custom generators like ExLlamaV2DynamicGenerator, ExLlamaV2DynamicGeneratorAsync etc So my questions are
- Do we want to use the logits_processor in SequenceGeneratorAdapter to be converted to an exllamav2 filter like in their library(like logits_processor in llamacpp)? In this case this might involve changing the logic here to use one of the generators. Another option is using ExLlamaV2Sampler but we will be redoing the generator logic in this case.
- This depends on the previous question but in this case, do you have a recommended generator? I have mainly used the ExLlamaV2DynamicGenerator which is mainly used for handling multiple asynchronous requests and responses which is not necessarily what I think outlines is going for ex (one request, one response at a time). But it seems like the most well-supported generator in exllamav2.
Sry for the delayed response and let me know if I'm going in the right direction!
Great questions!
Converting it to a filter is a bit hacky IMO, but may be the simplest solution and doesn't require an upstream change.
Alternatively we could apply logits processing directly. The way exllamav2s library is structured makes this a tricky. ExLlamaV2Sampler.sample() is a staticmethod, and gen_settings: ExLlamaV2Sampler.Settings is a generate(...) argument, however the sampler itself is not. I think the only clean way to handle this is an upstream PR:
- Option 1: update
ExLlamaV2Sampler.Settingsto accept alogits_processorargument, and updatedef sample()to apply thesettings.logits_processorif it exists. - Option 2: update
ExLlamaV2DynamicGenerator.generate()to accept asamplerargument, defaulting toExLlamaV2Sampler, allowing us to inject our own sampler class.
The first option makes more sense to me, it is generator-class agnostic.
This depends on the previous question but in this case, do you have a recommended generator? I have mainly used the ExLlamaV2DynamicGenerator which is mainly used for handling multiple asynchronous requests and responses which is not necessarily what I think outlines is going for ex (one request, one response at a time). But it seems like the most well-supported generator in exllamav2.
Tbh, I'm not sure how well outlines.models works with asyncronous structured generation. It is a reasonable use case though, and necessary for https://github.com/outlines-dev/outlines/issues/655
@lapp0 Sounds good! I think I'll go with option 1. For this, I think the steps needed are
- [ ] Make a new OutlinesLogitsProcessor inherited class like structured but do not return mask and only return the next available tokens. I might even override the call since conversion to pytorch might not be necessary(similar to filters). Happy to talk about this further. The main issue with this is it changes the logic of outlinepreprocessors to be similar to filters but the alternative is passing like torch.ones as logits and using torch.where on the mask
- [ ] To models/exllamav2 add in class ExllamaV2SamplerOutlines(ExLlamaV2Sampler) with optional logits processor
- [ ] See if I can convert the current exllamav2 base model forward logic etc to ExllamaV2DynamicGenerator
- [ ] Add to the unified dispatcher
- [ ] Test code Let me know if it looks good. I'll try finishing this within a week or two
Rather than implementing a new logits processor, I'm awaiting correspondence with the ExLlamaV2 maintainer, turboderp, regarding whether a logits_processor argument would be acceptable within their sampler.
@lapp0 interesting! The main reason I was thinking of a new logits processor is because we do some redundant steps in terms of exllamav2's code base I thought. In that for them, they first
- Get the passed tokens(the next allowed tokens) and then apply that filter using cuda code like
ext_c.logit_filter_exclusive(logit_filter, [sorted(list(pass_tokens))]) - Then finally the logits are computed in cuda code
while in our case, we start with the assumption of the logits getting computed then construct mask etc.
So I thought some of the steps here overlap with our current logits processor. But yeah very much happy to get advice here since this is just making the exllamav2 filter. And also happy to hear what turboderp thinks.
So I thought some of the steps here overlap with our current logits processor. But yeah very much happy to get advice here since this is just making the exllamav2 filter. And also happy to hear what turboderp thinks.
Yes, they will have multiple methods of filtering, but given Outlines singular logits processor implementation, which is tested against all inference engines, it's likely better to follow the same pattern with ExLlamaV2. This will ensure bug fixes, optimizations, enhancement, and new features present in one integration are available to all integrations!
I spoke with turboderp on their discord server, he is open to having a logits_processor argument in ExLlamaV2Sampler.Settings.
Here's the steps I think we should take, let me know what you think:
-
- Update
outlines.generate.*generators so they use the default dispatcher forExLlamaV2(simply delete the ExLlamaV2 dispatcher in eachoutlines.generatemodule) This will ensure the default method ofSequenceGeneratorAdapterandoutlines.processorsis used.
- Update
-
- Implement a
turboderp/exllamav2fork with alogits_processorargument inExLlamaV2Sampler.Settingswhich is applied inExLlamaV2Sampler.sample()(let me know if you'd like to take this over, or if you'd like me to take a shot at it)
- Implement a
-
- Implement a new model
outlines.models.exllamav2which is compatible with the fork
- Implement a new model
-
- Test it against
outlines.models.exllamav2by adding an exllamav2 fixture totests/generate/test_generate.pyand runningpytest -s tests/generate/test_generate.py -k exllamav2
- Test it against
Let me know if you think this is the right path.
Thanks so much for your great work on this PR. The users in the ExLlamaV2 discord were excited to hear about this PR!
@lapp0 wow, didn't know exllamav2 had a discord server! And makes perfect sense. If you can do ii that'll be awesome since I was thinking of this and I couldn't think of a clean way to do it atm. For iii. sounds good. I'll try converting it to the dynamic generator
@isamu-isozaki can you please take a look at this changeset and the provided example json_schema_outlines.py?
https://github.com/lapp0/exllamav2/pull/1
I believe it should provide a sufficient basis for implementing outlines.models.exllamav2.
Let me know if you see anything that should be changed in my implementation. If you have any questions, please do not hesitate! Good luck!
Edit: Also please add "Fixes https://github.com/outlines-dev/outlines/issues/807" to the PR description.
@lapp0 sounds good. And sorry got a bit side tracked by some work. I'll try get to this at least by the weekend. Sorry for delay!
Sorry for the delay, I finally got the exllamav2 fork built and I was able to run the current pr's code with below which worked!
import sys
sys.path.append("../outlines-dev")
import outlines
from enum import Enum
from pydantic import BaseModel, constr
model = outlines.models.exl2(
model_path="turboderp/TinyLlama-1B-32k-exl2",
cache_q4=True,
paged=False
)
prompt = """You are a sentiment-labelling assistant.
Is the following review positive or negative?
Review: This restaurant is just awesome!
"""
generator = outlines.generate.choice(model, ["Positive", "Negative"])
answer = generator(prompt)
print(answer)
prompt = "<s>result of 9 + 9 = 18</s><s>result of 1 + 2 = "
answer = outlines.generate.format(model, int)(prompt, max_tokens=1)
print(answer)
generator = outlines.generate.format(model, float)
answer = generator(prompt, max_tokens=10)
print(answer)
generator = outlines.generate.text(model)
unstructured = generator(prompt, max_tokens=30)
generator = outlines.generate.regex(
model,
r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)",
)
structured = generator(prompt, max_tokens=30)
print(unstructured)
# What is the IP address of the Google DNS servers?
#
# Passive DNS servers are at DNS servers that are private.
# In other words, both IP servers are private. The database
# does not contain Chelsea Manning
print(structured)
class Weapon(str, Enum):
sword = "sword"
axe = "axe"
mace = "mace"
spear = "spear"
bow = "bow"
crossbow = "crossbow"
class Armor(str, Enum):
leather = "leather"
chainmail = "chainmail"
plate = "plate"
class Character(BaseModel):
name: constr(max_length=10)
age: int
armor: Armor
weapon: Weapon
strength: int
# Construct structured sequence generator
generator = outlines.generate.json(model, Character)
# Draw a sample
seed = 789001
character = generator("Give me a character description", seed=seed)
print(repr(character))
# Character(name='Anderson', age=28, armor=<Armor.chainmail: 'chainmail'>, weapon=<Weapon.sword: 'sword'>, strength=8)
character = generator("Give me an interesting character description", seed=seed)
print(repr(character))
# Character(name='Vivian Thr', age=44, armor=<Armor.plate: 'plate'>, weapon=<Weapon.crossbow: 'crossbow'>, strength=125)
The current main issue is that I can't seem to run the tests due to some error with the pyairports. @lapp0 do you have some advice on how to fix this?
pytest -s tests/generate/test_generate.py -k exllamav2
======================== test session starts =========================
platform linux -- Python 3.10.12, pytest-8.3.2, pluggy-1.5.0
rootdir: /mnt/d/personal_projects/whiterabbitneo-pentestgpt/outlines-dev
configfile: pyproject.toml
plugins: anyio-3.6.2
collected 0 items / 1 error
=============================== ERRORS ===============================
__________ ERROR collecting tests/generate/test_generate.py __________
tests/generate/test_generate.py:6: in <module>
import outlines.generate as generate
outlines/__init__.py:6: in <module>
import outlines.types
outlines/types/__init__.py:1: in <module>
from . import airports, countries
outlines/types/airports.py:4: in <module>
from pyairports.airports import AIRPORT_LIST
/home/isamu/miniconda3/lib/python3.10/site-packages/pyairports/airports.py:1: in <module>
from pkg_resources import resource_string
/home/isamu/miniconda3/lib/python3.10/site-packages/pkg_resources/__init__.py:3663: in <module>
def _initialize_master_working_set():
/home/isamu/miniconda3/lib/python3.10/site-packages/pkg_resources/__init__.py:3646: in _call_aside
f(*args, **kwargs)
/home/isamu/miniconda3/lib/python3.10/site-packages/pkg_resources/__init__.py:3687: in _initialize_master_working_set
tuple(dist.activate(replace=False) for dist in working_set)
/home/isamu/miniconda3/lib/python3.10/site-packages/pkg_resources/__init__.py:3687: in <genexpr>
tuple(dist.activate(replace=False) for dist in working_set)
/home/isamu/miniconda3/lib/python3.10/site-packages/pkg_resources/__init__.py:3144: in activate
declare_namespace(pkg)
/home/isamu/miniconda3/lib/python3.10/site-packages/pkg_resources/__init__.py:2542: in declare_namespace
warnings.warn(msg, DeprecationWarning, stacklevel=2)
E DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google')`.
E Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
====================== short test summary info =======================
ERROR tests/generate/test_generate.py - DeprecationWarning: Deprecated call to `pkg_resources.declare_nam...
!!!!!!!!!!!!!!! Interrupted: 1 error during collection !!!!!!!!!!!!!!!
========================= 1 error in 17.91s ==========================
@isamu-isozaki sorry for the delayed response.
pyairports is an annoying library which has caused a lot of issues for me as well. And the only thing we use the library for is loading the 3 letter airport code list from https://github.com/ozeliger/pyairports/blob/f611ee5a5a82b4e98b22641bb99693d862c802e4/pyairports/data/airport_list.json
A quick and easy hack is to remove the import and run tests again.
@lapp0 Thanks! I think I'm pretty much done locally: tests
outlines-dev# pytest -s tests/generate/test_generate.py -k exllamav2 -x
============================================================= test session starts ==============================================================
platform linux -- Python 3.10.12, pytest-8.3.2, pluggy-1.5.0
rootdir: /mnt/d/personal_projects/whiterabbitneo-pentestgpt/outlines-dev
configfile: pyproject.toml
plugins: anyio-3.6.2
collected 320 items / 288 deselected / 32 selected
Loading: blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4.6-exl2 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:48 0:00:00
Loading tokenizer...
Compiling FSM index for all state transitions: 100%|████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 41.80it/s]
Compiling FSM index for all state transitions: 100%|███████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 116.97it/s]
Compiling FSM index for all state transitions: 100%|████████████████████████████████████████████████████████████| 21/21 [00:00<00:00, 83.39it/s]
.s......................
========================================== 31 passed, 1 skipped, 288 deselected in 102.91s (0:01:42) ===========================================
and pre-commit:
outlines-dev> pre-commit run --all-files
check for merge conflicts................................................Passed
debug statements (python)................................................Passed
fix end of files.........................................................Passed
trim trailing whitespace.................................................Passed
isort....................................................................Passed
pyupgrade................................................................Passed
flake8...................................................................Passed
black....................................................................Passed
mypy.....................................................................Passed
@lapp0 sry 2 questions
- Should I modify the pyproject.toml to install exllamav2 or do you recommend skipping the tests(like diffusers)?
- How do I add models/is that done by the org side maybe?
Should I modify the pyproject.toml to install exllamav2 or do you recommend skipping the tests(like diffusers)?
We should allow it in the test dependencies for any platform which supports exllamav2
https://github.com/outlines-dev/outlines/blob/main/pyproject.toml#L59-L63
How do I add models/is that done by the org side maybe?
Sorry, could you please clarify?
@lapp0 Thanks for your reply! I think I was a bit confused on how to add models but I'm guessing blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2 is already in the test environment? Also, sorry currently figuring out how to run exllamav2 on the CPU for the tests as currently the tests seem to be failing because of that
I think I was a bit confused on how to add models but I'm guessing blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2 is already in the test environment?
Yes it's already in the test environment.
If you're interested, after this change, a good upstream change would be allowing model loading via hub uri with hf_hub_download.
Also, sorry currently figuring out how to run exllamav2 on the CPU for the tests as currently the tests seem to be failing because of that
Don't bother. The hardware isn't supported. Please just add exllamav2 to https://github.com/outlines-dev/outlines/blob/main/tests/generate/conftest.py#L23-L33
This will skip the tests on any machine without CUDA.
@lapp0 Got it and thanks! I think I'm only missing coverage which I'll try making tests for once I get time
Hi, just want to pop by and see how it is going. Will this feature be released soon? If there is some dev branch i can try it as well.
@lapp0 Got it and thanks! I think I'm only missing coverage which I'll try making tests for once I get time
Great, please let me know when you're ready for review!
Hi, just want to pop by and see how it is going. Will this feature be released soon? If there is some dev branch i can try it as well.
You might be able to get it working with the installation commands below. Please report back with any issues or feedback, it will help with this PR!
pip install git+https://github.com/isamu-isozaki/outlines@exllamav2_filter
pip install git+https://github.com/lapp0/exllamav2@sampler-logits-processor
@remichu-ai Hi! If you had an issue building exllamav2 like me you can just install outlines with my initial commit to this pr and you can use the code examples and it should work. However, I did hear some issues with the speed of inference if you have a bad CPU in this case. I'm not sure how much more performant the current latest commit is. You can def use this branch to test it out since the main thing left is just writing tests etc and not much for functionality
@lapp0 hi! Sorry for more qs. I did write some tests to attempt to fill up the exllamav2.py. The coverage is 100% locally for exllamav2.py. But it seems like if the tests are skipped they don't count towards coverage(which is the case for this pipeline). Do you happen to know a simple way to fix this by any chance? Other than this I think I'm ready for review!
@lapp0 Thanks for review! Let me check it out tomorrow
@lapp0 Thanks for the review. I did all the changes and all my tests passed locally(including pre-commit)
(base) outlines-dev$ pytest -s tests/generate/test_integration_exllamav2.py --cov=outlines.models
============================================ test session starts =============================================
platform linux -- Python 3.10.12, pytest-8.3.2, pluggy-1.5.0
rootdir: /mnt/d/personal_projects/whiterabbitneo-pentestgpt/outlines-dev
configfile: pyproject.toml
plugins: anyio-3.6.2, cov-5.0.0
collected 19 items
Loading: blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:31 0:00:00
Loading tokenizer...
Loading: blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 0:00:00
Loading tokenizer...
Loading: blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:01 0:00:00
Loading tokenizer...
Loading: blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 0:00:00
Loading tokenizer...
Loading: blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:01 0:00:00
Loading tokenizer...
Loading: blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 0:00:00
Loading tokenizer...
.
---------- coverage: platform linux, python 3.10.12-final-0 ----------
Name Stmts Miss Branch BrPart Cover Missing
------------------------------------------------------------------------------------
outlines/models/__init__.py 9 0 0 0 100%
outlines/models/exllamav2.py 140 0 62 0 100%
outlines/models/llamacpp.py 154 110 60 0 21% 27-53, 56-57, 62-73, 76-84, 87-89, 92-94, 98, 107, 142, 146, 160-239, 277-293, 332-355, 358-362, 386-407
outlines/models/mlxlm.py 81 72 30 0 8% 25-27, 38-41, 70-122, 147-196, 230-247
outlines/models/openai.py 176 134 58 0 19% 97-105, 138-155, 158, 183-251, 255, 258, 261, 292-313, 318-322, 349-364, 381-388, 394-415, 420, 429-452, 461-484
outlines/models/tokenizer.py 12 0 0 0 100%
outlines/models/transformers.py 168 140 52 0 13% 28-56, 68-82, 87-90, 93-94, 97-106, 109-116, 119, 122-123, 126, 137-138, 163-184, 192-195, 225-253, 268-297, 309-340, 349-368, 371-381, 415-435, 444-452
outlines/models/transformers_vision.py 38 30 14 0 15% 12-13, 46-63, 73, 109-138
outlines/models/vllm.py 78 66 42 0 10% 24-27, 30-42, 87-149, 159, 164-169, 184-188, 208-226
------------------------------------------------------------------------------------
TOTAL 856 552 318 0 31%
================================== 18 passed, 1 skipped in 72.95s (0:01:12) ==================================
(base) outlines-dev$ pytest -s tests/generate/test_generate.py -k exllamav2
============================================ test session starts =============================================
platform linux -- Python 3.10.12, pytest-8.3.2, pluggy-1.5.0
rootdir: /mnt/d/personal_projects/whiterabbitneo-pentestgpt/outlines-dev
configfile: pyproject.toml
plugins: anyio-3.6.2, cov-5.0.0
collected 320 items / 288 deselected / 32 selected
Loading: blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:37 0:00:00
Loading tokenizer...
Compiling FSM index for all state transitions: 100%|██████████████████████████| 10/10 [00:00<00:00, 45.03it/s]
Compiling FSM index for all state transitions: 100%|██████████████████████████| 25/25 [00:00<00:00, 95.85it/s]
Compiling FSM index for all state transitions: 100%|██████████████████████████| 21/21 [00:00<00:00, 95.23it/s]
Compiling FSM index for all state transitions: 100%|██████████████████████████| 10/10 [00:00<00:00, 96.69it/s]
Compiling FSM index for all state transitions: 100%|█████████████████████████| 25/25 [00:00<00:00, 139.23it/s]
Compiling FSM index for all state transitions: 100%|██████████████████████████| 21/21 [00:00<00:00, 95.51it/s]
Compiling FSM index for all state transitions: 100%|████████████████████████████| 6/6 [00:00<00:00, 73.53it/s]
Compiling FSM index for all state transitions: 100%|████████████████████████████| 8/8 [00:00<00:00, 92.24it/s]
Compiling FSM index for all state transitions: 100%|██████████████████████████| 10/10 [00:00<00:00, 92.73it/s]
...................
========================== 31 passed, 1 skipped, 288 deselected in 85.01s (0:01:25) ==========================
outlines-dev> pre-commit run --all-files
check for merge conflicts................................................Passed
debug statements (python)................................................Passed
fix end of files.........................................................Passed
trim trailing whitespace.................................................Passed
isort....................................................................Passed
pyupgrade................................................................Passed
flake8...................................................................Passed
black....................................................................Passed
mypy.....................................................................Passed
Great job @isamu-isozaki !
I've opened the EXL2 PR for logits processors
https://github.com/turboderp/exllamav2/pull/634
@lapp0 awesome!
@isamu-isozaki I'm not sure whether the ExLlamaV2 PR will be merged soon, it's been a week without comment. To get this out the door could you please Outlines ExLlamaV2 documentation to make the following clear:
- ExLlamaV2 doesn't have logits processor support yet.
- There is a third party fork which supports logits processors and is compatible with outlines
- The install command is
pip install git+https://github.com/lapp0/exllamav2@sampler-logits-processor
- The install command is
Could you also let me know what build issues you experienced? I didn't run into any but I'd like to ensure the install-from-git command doesn't result in additional confusion.
We can revert the documentation to reference the main ExLlamaV2 branch once the PR is merged.