text-generation-webui
text-generation-webui copied to clipboard
Add support for custom stopping strings in RWKV models
Addresses a comment in #903
Breaks from the generation if any stopping string is in the output. It should also work if an extension modifies the reply where the stopping string is in the modified reply but not in the original reply, or vice versa.
Works in all modes (e.g. --chat, --notebook). Testing it with --chat is kind of weird because the mode already handles stopping.
Tested with RWKV-4-Pile-1B5-Instruct-test2-20230209.pth (instruct-test model)
When using api-stream-example.py
I get this error printed to log:
File "/home/shane/Documents/LLM-UI/modules/text_generation.py", line 172, in generate_reply
found_stop_string = next((s for s in stopping_strings + evaluated_stopping_strings if s in original_reply), None)
File "/home/shane/Documents/LLM-UI/modules/text_generation.py", line 172, in <genexpr>
found_stop_string = next((s for s in stopping_strings + evaluated_stopping_strings if s in original_reply), None)
TypeError: 'in <string>' requires string as left operand, not list
Which also results in no token generation
Otherwise, perfectly functional! Nice work
Changing from 'custom_stopping_strings': [],
to 'custom_stopping_strings': '',
in api-example-stream.py fixes the issue for me.
custom_stopping_strings
is a represented as a string in shared.py.
I haven't been able to get custom_stopping_strings to work in api-example-stream for other models, so I want to make an assumption that it's not actually implemented, but I am probably just doing it wrong.
These arguments to custom_stopping_strings using a Vicuna model don't work, but they don't throw an exception either.
[], '', '""', 'a', ['a'], ['""']
Related to #805?
I probably should've noted how I got it to work properly, you can put it in this kind of format: '"token","\n","###","\\n###","Alice:"'
If you're familiar with developer tools, you can also just take a peek at what data is sent out to the server and copy that.
{
"fn_index": 39,
"data": [
10,
-1,
0.7,
0.1,
40,
1,
1.18,
1,
0,
0,
true,
0,
1,
1,
false,
true,
false,
2048,
"\"\\n\",\"###\"",
true,
0,
false,
false,
false,
false,
true,
"None",
"None",
"None",
0,
0
],
"event_data": null,
"session_hash": "0tgyi49xkqeh"
}