text-generation-inference
text-generation-inference copied to clipboard
Adding Llava-Next (Llava 1.6) with full support.
What does this PR do?
- Changed all models to extract
embed_tokens
in order to enable llava to separately call the embeddings and the core model layers. - Added VlmCausalLM to inherit from FlashMistral in order to be maximally supported. The only added logics sits on top and parses images into pixel values, preallocates input_ids space for the image embeddings, and passes them for the model.
- Added Clip for the vision tower.
- Didn't add flash for the vision tower since there's no padding anyway.
- Added heuristic (potentially incomplete) to calculate number of features before calculating the clip patches (allows for easier logic reuse of the LLM under the hood).
Still needs to be done:
- [ ] Implement the image parsing in the controller side, to avoid downloading n times per TP shard and also refusing requests too large early and avoid issues where the truncation actually truncates the image.
- [ ] Make sure it works with quantization properly.
- [x] Make sure it works with TP>1
Fixes # (issue)
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Did you read the contributor guideline, Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [ ] Did you write any new necessary tests?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
Awesome! Documenting that this PR should fix #1689 - right?
Awesome! Documenting that this PR should fix #1689 - right?
Indeed it would ! Not super generally quite yet (we need to transfer part of the token counting logic to rust, which means more code but we'll do the transfer slowly but surely). The not generally means it can potentially trigger OOMs still because the rust part (in charge of scheduling) might make wrong assumptions on query regarding memory therefore it can potentially schedule more than the hardware can withstand. (If we make the scheduler too strict, we might disallow some legit requests, so we really need to transfer the whole logic in due time)
Hmm Docker is more tight on RAM and is leading to OOM, probably need to fix the scheduling before merging then.
I encountered a bug where most image inputs cause the model to crash with the following error:
RuntimeError: shape mismatch: value tensor of shape [2352, 7168] cannot be broadcast to indexing result of shape [3712, 7168]
What are the expected image input dimensions for the llava model? Do the dimensions [2352, 7168] and [3712, 7168] have any special meaning?
QQ: the docs specifically point to vicuna 13b 1.6 model, but what about all the other llava next models, including latest ones:
e.g.
liuhaotian/llava-v1.6-34b lmms-lab/llama3-llava-next-8b lmms-lab/llava-next-72b lmms-lab/llava-next-110b
Also, when I try sharded, I get:
> File "/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py", line 220, in serve_inner
model = get_model(
File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/__init__.py", line 857, in get_model
raise NotImplementedError("sharded is not supported for AutoModel")
NotImplementedError: sharded is not supported for AutoModel
But I have to use sharded for the 72b model.
Also, for liuhaotian/llava-v1.6-34b
I get:
docker run -d --restart=always --gpus '"device=7"' \
--shm-size 12g \
-v $HOME/.cache/huggingface/hub/:/data \
-p 30030:80 \
--name next34b \
ghcr.io/huggingface/text-generation-inference:2.0.4 \
--model-id liuhaotian/llava-v1.6-34b --trust-remote-code --max-stop-sequences=10 \
--max-batch-prefill-tokens=32768 --max-input-length 4096 --max-total-tokens 8192
2024-05-28T02:53:53.846667Z ERROR text_generation_launcher: Error when initializing model
Traceback (most recent call last):
File "/opt/conda/bin/text-generation-server", line 8, in <module>
sys.exit(app())
File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 311, in __call__
return get_command(self)(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1157, in __call__
return self.main(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 778, in main
return _main(
File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 216, in _main
rv = self.invoke(ctx)
File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1688, in invoke
return _process_result(sub_ctx.command.invoke(sub_ctx))
File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1434, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 783, in invoke
return __callback(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 683, in wrapper
return callback(**use_params) # type: ignore
File "/opt/conda/lib/python3.10/site-packages/text_generation_server/cli.py", line 90, in serve
server.serve(
File "/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py", line 257, in serve
asyncio.run(
File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
return loop.run_until_complete(main)
File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 636, in run_until_complete
self.run_forever()
File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
self._run_once()
File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
handle._run()
File "/opt/conda/lib/python3.10/asyncio/events.py", line 80, in _run
self._context.run(self._callback, *self._args)
> File "/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py", line 220, in serve_inner
model = get_model(
File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/__init__.py", line 908, in get_model
raise ValueError(f"Unsupported model type {model_type}")
ValueError: Unsupported model type llava
2024-05-28T02:53:54.271514Z ERROR shard-manager: text_generation_launcher: Shard complete standard error output:
The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
/opt/conda/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
warnings.warn(
Traceback (most recent call last):
File "/opt/conda/bin/text-generation-server", line 8, in <module>
sys.exit(app())
File "/opt/conda/lib/python3.10/site-packages/text_generation_server/cli.py", line 90, in serve
server.serve(
File "/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py", line 257, in serve
asyncio.run(
File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
return loop.run_until_complete(main)
File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
return future.result()
File "/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py", line 220, in serve_inner
model = get_model(
File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/__init__.py", line 908, in get_model
raise ValueError(f"Unsupported model type {model_type}")
ValueError: Unsupported model type llava
rank=0
2024-05-28T02:53:54.370653Z ERROR text_generation_launcher: Shard 0 failed to start
2024-05-28T02:53:54.370671Z INFO text_generation_launcher: Shutting down shards
Error: ShardCannotStart
Even lmms-lab/llama3-llava-next-8b fails same way
This pull request comments out the truncate function and raises MaxNewTokens error when prompt_length + max_new_tokens is larger than max_total_token. Any plan to fix it?