lmql
lmql copied to clipboard
[support Qwen models] AssertionError: Cannot intersect further patterns if '*' has already been handled.
I'm trying to load a Qwen model with lmql, but I keep encountering this error. I've also tried Qwen-14B-Chat/Qwen-14B but encountered the same error. My code is as follows:
import lmql
@lmql.query(
model=lmql.model(
"Qwen/Qwen-72B-Chat",
tokenizer="Qwen/Qwen-72B-Chat",
trust_remote_code=True
)
)
def prompt():
'''lmql
argmax
"What is the capital of France? [RESPONSE]"
where
len(TOKENS(RESPONSE)) < 20
'''
if __name__ == '__main__':
print(prompt())
Error:
Traceback (most recent call last):
File "/home/name/user1/lmql/lmql_test.py", line 21, in <module>
print(prompt())
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/api/queries.py", line 148, in lmql_query_wrapper
return module.query(*args, **kwargs)
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/lmql_runtime.py", line 204, in __call__
return call_sync(self, *args, **kwargs)
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/loop.py", line 37, in call_sync
res = loop.run_until_complete(task)
File "/home/name/miniconda3/envs/lmql/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
return future.result()
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/lmql_runtime.py", line 230, in __acall__
results = await interpreter.run(self.fct, **query_kwargs)
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/tracing/tracer.py", line 240, in wrapper
return await fct(*args, **kwargs)
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/interpreter.py", line 1070, in run
async for _ in decoder_fct(prompt, **decoder_args):
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/decoders.py", line 21, in argmax
h = h.extend(await model.argmax(h, noscore=True))
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_cache.py", line 277, in argmax
return await arr.aelement_wise(op_argmax)
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_array.py", line 318, in aelement_wise
result_items = await asyncio.gather(*[op_with_path(path, seqs, *args, **kwargs) for path, seqs in self.sequences.items()])
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_array.py", line 317, in op_with_path
return path, await op(element, *args, **kwargs)
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_cache.py", line 249, in op_argmax
cache_entries = [await self.get_cache(s, 'top-1', user_data=True, **kwargs) for s in seqs]
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_cache.py", line 249, in <listcomp>
cache_entries = [await self.get_cache(s, 'top-1', user_data=True, **kwargs) for s in seqs]
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_cache.py", line 196, in get_cache
keys = await self.get_keys(s, edge_type, **kwargs)
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_cache.py", line 171, in get_keys
mask = (await self.get_mask(s, **kwargs)).logits_mask[0]
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_cache.py", line 142, in get_mask
logits_mask_result = await self.delegate.compute_logits_mask(s.input_ids.reshape(1, -1), [s.user_data], constrained_seqs, [s], **kwargs, required=True)
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_model.py", line 87, in compute_logits_mask
mask = await processor(seqs, additional_logits_processor_mask=is_constrained, **kwargs)
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/interpreter.py", line 671, in where_processor
results = [(mask, user_data, max_tokens_hint) for mask, user_data, max_tokens_hint in await asyncio.gather(*token_mask_tasks)]
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/interpreter.py", line 487, in where_for_sequence
mask, logit_mask, state, max_tokens_hint = await self.where_step_for_sequence(s, needs_masking, seqidx, return_follow_map=return_follow_map, **kwargs)
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/interpreter.py", line 561, in where_step_for_sequence
valid, is_final, trace, follow_trace = ops.digest(where,
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/ops/node.py", line 236, in digest
op_follow_map = follow_apply(intm, op, value, context=follow_context)
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/ops/follow_map.py", line 226, in follow_apply
result_map = result_map.intersect(pattern)
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/ops/follow_map.py", line 71, in intersect
assert handled != "*", "Cannot intersect further patterns if '*' has already been handled."
AssertionError: Cannot intersect further patterns if '*' has already been handled.
Hi there, we have not tested LMQL with Qwen models yet, so this may be an issue with supporting its tokenizer. I will have to investigate a bit further.
You can fix the assertion by adding if p2 == "*": break after this line:
https://github.com/eth-sri/lmql/blob/ab02526ddf9883aff4acda4e76c5a2a1cc136bf1/src/lmql/ops/follow_map.py#L241
However, I could not get the model to do inference on my machine, since it never seems to finish a forward pass. Maybe you can try running with the change above, and report back with further results?
You can fix the assertion by adding
if p2 == "*": breakafter this line:https://github.com/eth-sri/lmql/blob/ab02526ddf9883aff4acda4e76c5a2a1cc136bf1/src/lmql/ops/follow_map.py#L241
However, I could not get the model to do inference on my machine, since it never seems to finish a forward pass. Maybe you can try running with the change above, and report back with further results?
Thank you for your reply! I still couldn't fix the error by adding the code at the location you mentioned. But I can directly add this code in front of the assertion which is throwing the error to allow the program to continue running. Now, my new error message is as follows:
[Loading Qwen/Qwen-72B-Chat with AutoModelForCausalLM.from_pretrained("Qwen/Qwen-72B-Chat", trust_remote_code=True)]]
The model is automatically converting to bf16 for faster inference. If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".
Try importing flash-attention for faster inference...
Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary
Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm
Warning: import flash_attn fail, please install FlashAttention to get higher efficiency https://github.com/Dao-AILab/flash-attention
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 82/82 [00:11<00:00, 7.04it/s]
[Qwen/Qwen-72B-Chat ready on device cpu]
/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:394: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.8` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.
warnings.warn(
/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:404: UserWarning: `do_sample` is set to `False`. However, `top_k` is set to `0` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_k`.
warnings.warn(
[Error during generate()] expected scalar type c10::BFloat16 but found double
Traceback (most recent call last):
File "/home/name/user1/lmql/lmql_test.py", line 19, in <module>
print(prompt())
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/api/queries.py", line 148, in lmql_query_wrapper
return module.query(*args, **kwargs)
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/lmql_runtime.py", line 204, in __call__
return call_sync(self, *args, **kwargs)
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/loop.py", line 37, in call_sync
res = loop.run_until_complete(task)
File "/home/name/miniconda3/envs/lmql/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
return future.result()
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/lmql_runtime.py", line 230, in __acall__
results = await interpreter.run(self.fct, **query_kwargs)
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/tracing/tracer.py", line 240, in wrapper
return await fct(*args, **kwargs)
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/interpreter.py", line 1070, in run
async for _ in decoder_fct(prompt, **decoder_args):
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/decoders.py", line 21, in argmax
h = h.extend(await model.argmax(h, noscore=True))
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_cache.py", line 277, in argmax
return await arr.aelement_wise(op_argmax)
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_array.py", line 318, in aelement_wise
result_items = await asyncio.gather(*[op_with_path(path, seqs, *args, **kwargs) for path, seqs in self.sequences.items()])
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_array.py", line 317, in op_with_path
return path, await op(element, *args, **kwargs)
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_cache.py", line 256, in op_argmax
non_cached_argmax = iter((await self.delegate.argmax(DataArray(non_cached), **kwargs)).items())
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/models/lmtp/lmtp_dcmodel.py", line 307, in argmax
return await self.sample(sequences, temperature=0.0, **kwargs)
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/models/lmtp/lmtp_dcmodel.py", line 350, in sample
return await sequences.aelement_wise(op_sample)
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_array.py", line 318, in aelement_wise
result_items = await asyncio.gather(*[op_with_path(path, seqs, *args, **kwargs) for path, seqs in self.sequences.items()])
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/runtime/dclib/dclib_array.py", line 317, in op_with_path
return path, await op(element, *args, **kwargs)
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/models/lmtp/lmtp_dcmodel.py", line 340, in op_sample
tokens = await asyncio.gather(*[self.stream_and_return_first(s, await self.generate(s, temperature=temperature, **kwargs), mode) for s,mode in zip(seqs, unique_sampling_mode)])
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/models/lmtp/lmtp_dcmodel.py", line 147, in stream_and_return_first
buffer += [await anext(iterator)]
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/models/lmtp/lmtp_multiprocessing.py", line 188, in generate
async for token in self.stream_iterator(self.stream_id):
File "/home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/models/lmtp/lmtp_multiprocessing.py", line 217, in stream_iterator
raise LMTPStreamError(item["error"])
lmql.models.lmtp.errors.LMTPStreamError: failed to generate tokens 'expected scalar type c10::BFloat16 but found double'
Task was destroyed but it is pending!
task: <Task cancelling name='lmtp_inprocess_client_loop' coro=<LMTPDcModel.inprocess_client_loop() running at /home/name/miniconda3/envs/lmql/lib/python3.10/site-packages/lmql/models/lmtp/lmtp_dcmodel.py:76> wait_for=<Future finished result=True>>