outlines icon indicating copy to clipboard operation
outlines copied to clipboard

`Write` seems to select one token rather than write them all

Open prhbrt opened this issue 1 year ago • 1 comments

Describe the issue as clearly as possible:

The Write instruction suggests in the documentation that several tokens will be written at a certain stage, e.g. these tokens are forced onto the decoder's/LLM's input without being generated in the first place. However, in practice as per the example below, one token is selected using a sampler.

This is backed by the class-agnostic use of of Write and Generate, e.g. just their token fields are used, like here.

Steps/code to reproduce the bug:

device = "cuda:2"

import outlines
import torch

from outlines.fsm.guide import Guide, Write, Generate, Instruction
from outlines.generate.api import SequenceGenerator
from outlines.samplers import multinomial
from outlines.models import Transformers
from transformers import AutoTokenizer, AutoModelForCausalLM


class PedroGuide(Guide):
  def __init__(self, **kwargs):
    if len({'text', 'tokenizer'} - set(kwargs)) == 0:
      self.text = kwargs['text']
      self.tokenizer = kwargs['tokenizer']
      self.tokens = tokenizer(self.text).input_ids[1:]


  def get_next_instruction(self, state: int) -> Instruction:
    return Write(self.tokens)

  def get_next_state(self, state: int, token_id: int) -> int:
    return state + 1

  def is_final_state(self, state: int) -> bool:
    return state > 0

  def copy(self) -> 'PedroGuide':
    guide = PedroGuide()
    guide.tokenizer = self.tokenizer
    guide.tokens = self.tokens
    guide.text = self.text
    return guide


if 'model' not in globals():
  model_name = "mistralai/Mistral-7B-Instruct-v0.2"
  model = AutoModelForCausalLM.from_pretrained(model_name)

  if device is not None:
    model = model.to(device)


if 'tokenizer' not in globals():
  tokenizer = AutoTokenizer.from_pretrained(model_name)


if 'outlines_model' not in globals():
  outlines_model = Transformers(model, tokenizer)


guide = PedroGuide(tokenizer=tokenizer, text="Say Pedro! And then provide a recipe on muffins.")

sampler = multinomial(temperature=2.)
generator = SequenceGenerator(guide, outlines_model, sampler, device)


response = generator("""
Just say something:
"""[1:-1])


response

Expected result:

"Say Pedro! And then provide a recipe on muffins."

Error message:

Only one token is returned.

Outlines/Python version information:

Version information

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
0.0.41
Python 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)

Context for the issue:

I'm trying to write my own guide that sometimes steers the 'train of thought' but also the syntax by forcing particular tokens to be generated.

prhbrt avatar Jun 05 '24 09:06 prhbrt

Reconfirmed issue for:

0.1.14
Python 3.10.12 (main, Jan 17 2025, 14:35:34) [GCC 11.4.0]
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
aiohttp==3.9.5
aiosignal==1.3.1
airportsdata==20241001
annotated-types==0.7.0
anyio==4.4.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.1
async-lru==2.0.4
async-timeout==4.0.3
attrs==23.2.0
Babel==2.15.0
beautifulsoup4==4.12.3
bleach==6.1.0
boxly==0.1.0
certifi==2024.6.2
cffi==1.16.0
charset-normalizer==3.3.2
cloudpickle==3.0.0
comm==0.2.2
contourpy==1.3.1
cycler==0.12.1
datasets==2.19.2
debugpy==1.8.1
decorator==5.1.1
defusedxml==0.7.1
dill==0.3.8
diskcache==5.6.3
exceptiongroup==1.2.1
executing==2.0.1
fastjsonschema==2.19.1
filelock==3.14.0
filetype==1.2.0
fonttools==4.55.3
fqdn==1.5.1
frozenlist==1.4.1
fsspec==2024.3.1
ftfy==6.3.1
genson==1.3.0
h11==0.14.0
httpcore==1.0.5
httpx==0.27.0
huggingface-hub==0.23.3
idna==3.7
interegular==0.3.3
ipykernel==6.29.4
ipython==8.25.0
ipython-genutils==0.2.0
ipywidgets==7.8.1
isoduration==20.11.0
jedi==0.19.1
Jinja2==3.1.4
joblib==1.4.2
json5==0.9.25
jsonpointer==2.4
jsonschema==4.22.0
jsonschema-specifications==2023.12.1
jupyter==1.0.0
jupyter-console==6.6.3
jupyter-events==0.10.0
jupyter-lsp==2.2.5
jupyter_client==8.6.2
jupyter_core==5.7.2
jupyter_server==2.14.1
jupyter_server_terminals==0.5.3
jupyterlab==4.2.1
jupyterlab-widgets==1.1.7
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.2
kiwisolver==1.4.8
lark==1.1.9
llvmlite==0.42.0
MarkupSafe==2.1.5
matplotlib==3.10.0
matplotlib-inline==0.1.7
mistral_common==1.5.2
mistune==3.0.2
mpmath==1.3.0
multidict==6.0.5
multiprocess==0.70.16
nbclient==0.10.0
nbconvert==7.16.4
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.3
notebook==7.2.0
notebook_shim==0.2.4
numba==0.59.1
numpy==1.26.4
nvidia-cublas-cu12==12.4.5.8
nvidia-cuda-cupti-cu12==12.4.127
nvidia-cuda-nvrtc-cu12==12.4.127
nvidia-cuda-runtime-cu12==12.4.127
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.2.1.3
nvidia-curand-cu12==10.3.5.147
nvidia-cusolver-cu12==11.6.1.9
nvidia-cusparse-cu12==12.3.1.170
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.4.127
opencv-python==4.10.0.84
outlines==0.1.14
outlines_core==0.1.26
outlines_ocr_guidance==0.1.1
overrides==7.7.0
packaging==24.0
pandas==2.2.2
pandocfilters==1.5.1
parso==0.8.4
pdftext==0.4.1
pexpect==4.9.0
pillow==10.4.0
platformdirs==4.2.2
prometheus_client==0.20.0
prompt_toolkit==3.0.46
psutil==5.9.8
ptyprocess==0.7.0
pure-eval==0.2.2
pyairports==2.1.1
pyarrow==16.1.0
pyarrow-hotfix==0.6
pycountry==24.6.1
pycparser==2.22
pydantic==2.7.3
pydantic-settings==2.7.1
pydantic_core==2.18.4
Pygments==2.18.0
pyparsing==3.2.1
pypdfium2==4.30.0
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-json-logger==2.0.7
pytz==2024.1
PyYAML==6.0.1
pyzmq==26.0.3
qtconsole==5.5.2
QtPy==2.4.1
referencing==0.35.1
regex==2024.5.15
requests==2.32.3
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rpds-py==0.18.1
safetensors==0.4.3
scikit-learn==1.6.0
scipy==1.15.0
Send2Trash==1.8.3
sentencepiece==0.2.0
six==1.16.0
sniffio==1.3.1
soupsieve==2.5
stack-data==0.6.3
surya-ocr==0.8.3
sympy==1.13.1
tabulate==0.9.0
terminado==0.18.1
threadpoolctl==3.5.0
tiktoken==0.7.0
tinycss2==1.3.0
tokenizers==0.19.1
tomli==2.0.1
torch==2.5.1
tornado==6.4
tqdm==4.66.4
traitlets==5.14.3
transformers==4.41.2
triton==3.1.0
types-python-dateutil==2.9.0.20240316
typing_extensions==4.12.1
tzdata==2024.1
uri-template==1.3.0
urllib3==2.2.1
wcwidth==0.2.13
webcolors==24.6.0
webencodings==0.5.1
websocket-client==1.8.0
widgetsnbextension==3.6.6
xxhash==3.4.1
yarl==1.9.4

prhbrt avatar Feb 03 '25 10:02 prhbrt