vllm icon indicating copy to clipboard operation
vllm copied to clipboard

[Bug]: prefix-caching: inconsistent completions

Open hibukipanim opened this issue 1 year ago • 22 comments

Your current environment

vLLM version 0.5.0.post1

šŸ› Describe the bug

Hi,

Seems that there is a dirty cache issue with --enable-prefix-caching. We noticed it as we saw internal eval scores significantly degrade when running with --enable-prefix-caching and here I'll show how to reproduce it with a short snippet.

Running 2 vLLM servers with:

without prefix caching:

python -m vllm.entrypoints.openai.api_server --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 --port 8001

and another with prefix caching:

python -m vllm.entrypoints.openai.api_server --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 --port 8002 --enable-prefix-caching

Then running this snippet:

import string 
import random

import openai

vllms = {
    "no-prefix-caching": "http://localhost:8001/v1",
    "with-prefix-caching": "http://localhost:8002/v1",
}

random.seed(0)
prompts = []
for i in range(16):
    prompts.append(''.join(random.choices(string.ascii_lowercase + string.digits, k=512)))

runs = []
for run in range(2):
    print(f"\nšŸƒ run #{run+1}")

    completions = {k: [] for k in vllms.keys()}
    runs.append(completions)
    for name, endpoint in vllms.items():
        print(f"vLLM {name=}, {endpoint=}")
        client = openai.OpenAI(
            base_url=endpoint,
            api_key="foo"
        )

        for prompt in prompts:
            response = client.completions.create(
                    model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
                    prompt=prompt,
                    temperature=0,
                    max_tokens=4,
            )
            completion = response.choices[0].text
            completions[name].append(completion)

        print(f"completions: {completions[name]}")

        if run > 0 and runs[run][name] != runs[run-1][name]:
            print(f"āŒ completions for vLLM {name=} differs from previous run!")
    
    if completions["with-prefix-caching"] != completions["no-prefix-caching"]:
        print("šŸ›‘ completions differ between with & without prefix")
        

prints:

šŸƒ run #1
vLLM name='no-prefix-caching', endpoint='http://localhost:8001/v1'
completions: ['6x2w', 'zwg9v', 'xjuwf', 'hu5qw', 'jg0m', '1tzkb', '4w0q', '5zx5', 'zxqj', '7v16', '0ty57', 'vk0j', 'jjnj', 'xw95', 'vxjj', 't6x7']
vLLM name='with-prefix-caching', endpoint='http://localhost:8002/v1'
completions: ['6x2w', 'zwg9v', 'xjuwf', 'hu5qw', 'jg0m', '1tzkb', '4w0q', '5zx5', 'zxqj', '7v16', '0ty57', 'vk0j', 'jjnj', 'xw95', 'vxjj', 't6x7']

šŸƒ run #2
vLLM name='no-prefix-caching', endpoint='http://localhost:8001/v1'
completions: ['6x2w', 'zwg9v', 'xjuwf', 'hu5qw', 'jg0m', '1tzkb', '4w0q', '5zx5', 'zxqj', '7v16', '0ty57', 'vk0j', 'jjnj', 'xw95', 'vxjj', 't6x7']
vLLM name='with-prefix-caching', endpoint='http://localhost:8002/v1'
completions: ['6x2w', 'zwma71', '37wk', 'hu5qw', 'jg0m', '1tzkb', '4h7a', '5zq7', 'zxqj', '7k4n', '0ty57', 'vk0j', 'jjnj', 'xw95', 'vxjj', 't6x7']
āŒ completions for vLLM name='with-prefix-caching' differs from previous run!
šŸ›‘ completions differ between with & without prefix

This happens also with 0.4.3. With 0.4.2 this snippet crashes the server with prefix-caching enabled.

Hopefully one of these PR resolves the issue šŸ¤ž :

  • https://github.com/vllm-project/vllm/pull/5188
  • https://github.com/vllm-project/vllm/pull/5364 (will be able to try building these branches and reproducing only in few days, hope tagging the PRs can help till then) Edit: built and tried both PRs and they don't resolve the issue

hibukipanim avatar Jun 14 '24 14:06 hibukipanim

We have an improved block manager which has better test coverage for prefix caching. We have tests which compare equality of prefix caching vs non-prefix caching -- so this case shouldn't happen // if it is happening, we can more easily diagnose the failure. Note the v2 block manager is not yet optimized for performance.

Can you see if it occurs with --use-block-manager-v2?

cadedaniel avatar Jun 15 '24 00:06 cadedaniel

Thanks for the reply @cadedaniel. I tried now with --use-v2-block-manager (version 0.5.0.post1) and it still happens unfortunately.

Edit: Tried also building current main branch (commit e2b85cf86a522e734a38b1d0314cfe9625003ef9) where https://github.com/vllm-project/vllm/pull/5364 is already merged, and the issue still happens (also with --use-v2-block-manager)

hibukipanim avatar Jun 17 '24 06:06 hibukipanim

Built also the branch of https://github.com/vllm-project/vllm/pull/5188 and it doesn't resolve the issue

hibukipanim avatar Jun 17 '24 10:06 hibukipanim

possible workaround https://github.com/vllm-project/vllm/issues/5376#issuecomment-2179257676

colefranks avatar Jun 19 '24 21:06 colefranks

Thanks @colefranks I tried and seems that the workaround doesn't seem to help but it does change the behavior, tried several combinations (all with version 0.5.0.post1).

On first iteration, there is difference in outputs between VLLM_ATTENTION_BACKEND=XFORMERS and without. And if we assume that's ok, anyway when --enable-prefix-caching is used, than second iteration with --enable-prefix-caching differs from the first one.

hibukipanim avatar Jun 23 '24 07:06 hibukipanim

is this issuse solved ? i meet the same problem, inconsistent completions .

kuangdao avatar Jul 11 '24 02:07 kuangdao

The same thing happened when I replaced the model with Opt-125m and inferred offline. However, when I inserted torch.mannual_seed () (not random.seed) before generate, the result was correct.

SaltFish11 avatar Jul 12 '24 02:07 SaltFish11

@hibukipanim @kuangdao @SaltFish11 I sloved the problem by change the triton code. in this file ../triton/common/build.py cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", add the "-std=c99", after the lines,like this if is_hip(): ret = subprocess.check_call([ cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC","-std=c99", f"-L{hip_lib_dir}", "-lamdhip64", "-o", so ]) else: cc_cmd = [ cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda","-std=c99", "-o", so ] cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs] ret = subprocess.check_call(cc_cmd)

bsll avatar Jul 12 '24 08:07 bsll

thanks @bsll, but I struggle to understand what triton you mean? there is no such folder in vLLM, do you mean in https://github.com/triton-lang/triton ? https://github.com/triton-inference-server/server? don't see a common/build.py in either?

hibukipanim avatar Jul 14 '24 12:07 hibukipanim

thanks @bsll, but I struggle to understand what triton you mean? there is no such folder in vLLM, do you mean in https://github.com/triton-lang/triton ? https://github.com/triton-inference-server/server? don't see a common/build.py in either?

thanks @bsll workaround. @hibukipanim the location is like /path/to/miniconda3/envs/vllm/lib/python3.9/site-packages/triton/common/build.py

LLouice avatar Jul 15 '24 03:07 LLouice

Thanks @bsll & @LLouice I tried to make the update you suggested in triton but unfortunately the issue is still reproduces (with 0.5.2) for me, with the exact snippet as in the first message.

To be more detailed with what I did: I'm running vLLM in a virtualenv. Inside it I edited the file at: .venv/lib/python3.10/site-packages/triton/common/build.py and changed these lines:

    if is_hip():
        ret = subprocess.check_call([
            cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC",
            f"-L{hip_lib_dir}", "-lamdhip64", "-o", so
        ])
    else:
        cc_cmd = [
            cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda",
            "-o", so
        ]
        cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
        ret = subprocess.check_call(cc_cmd)

to these lines:

    if is_hip():
        ret = subprocess.check_call([
            cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC","-std=c99",
            f"-L{hip_lib_dir}", "-lamdhip64", "-o", so
         ])
    else:
        cc_cmd = [
            cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda","-std=c99",
            "-o", so
        ]
        cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
        ret = subprocess.check_call(cc_cmd)

i.e.:

97c97
<             cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC",
---
>             cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC","-std=c99",
99c99
<         ])
---
>          ])
102c102
<             cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda",
---
>             cc, src, "-O3", f"-I{

I also deleted ~/.triton which had cache (but it wasn't created again after running, so maybe it's not really used in this flow?). Then I re-ran the server.

and must say it's quite surprising that changing the gcc dialect to std99 would change behavior (but cool if it does ..)

hibukipanim avatar Jul 19 '24 14:07 hibukipanim

@SaltFish11 thanks for the comment. However, I tried adding:

import torch
torch.manual_seed(42)

at the top of vllm/entrypoints/openai/api_server.py and the issue still reproduces

hibukipanim avatar Jul 19 '24 15:07 hibukipanim

Same here, adding "-std=c99" still produces strange output, repeating with 3 to 4 spaces. i.e. 1. I need to use 2 use 2 use 2 use 2 use 2

Looking forward to further solutions.

roger0426 avatar Jul 19 '24 16:07 roger0426

@zachzzc Thanks for https://github.com/vllm-project/vllm/pull/7018

FYI - I tried it now with the commit which merged your PR (fb2c1c86c196aa1531435d0c445fbea4c9dd4aa5) and also with the current HEAD (9fadc7b7a03f798036d0e8710587870e13bae759) and unfortunately the snippet from the issue still fails with both. (and I double-checked it by running same versions without --enable-prefix-caching and it was ok. So prefix-caching still has correctness issues)

hibukipanim avatar Aug 04 '24 12:08 hibukipanim

Have you tried the meaningful inputs (like real sentences) instead of random number? Wonder if it is just caused by the minor kernel execution difference after the cache is hit.

zachzzc avatar Aug 05 '24 18:08 zachzzc

@zachzzc Originally I opened this issue after seeing degradation in some internal evals which used real inputs. (Altough it would happen more often under concurrent requests). Why would kernel execution be different in case of cache hit?

hibukipanim avatar Aug 07 '24 07:08 hibukipanim

If you still see the degradation in the real evals then it would be a true bug. It calls the same kernel with different inputs here https://github.com/vllm-project/vllm/blob/5223199e03ac3729eb60043a1ef57156c8af1bc9/vllm/attention/backends/flash_attn.py#L532 depending on how if the cache hits or not. Will update here if I find anything.

zachzzc avatar Aug 07 '24 18:08 zachzzc

+1 in 0.5.2

hiyforever avatar Sep 19 '24 10:09 hiyforever

This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!

github-actions[bot] avatar Dec 19 '24 02:12 github-actions[bot]

re the github-actions bot: the issue still reproduces with the code snippet in the original issue message also with vllm 0.6.6.post1

maybe the issue is less severe though, as when I first opened we also saw actual serious degradation with eval scores with prefix-caching, but in the more recent versions the scores are not exactly the same but close enough...

I think this issue still worth keeping open as AFAIU prefix-caching shouldn't affect the predictions so there still might be a bug hiding.

hibukipanim avatar Dec 27 '24 09:12 hibukipanim

can update that the behavior is different with 0.7.0, not perfect but seems better.

when running without enabling V1, still getting the

šŸ›‘ completions differ between with & without prefix

error from the script above, but no longer getting āŒ completions for vLLM name='with-prefix-caching' differs from previous run!. so there is still diff between running with and without prefix-caching, but at-least each output is consistent with previous generations

and when running with VLLM_USE_V1=1 there was inconsistency with prefix-caching in the second run but after that it matched the non-prefix-caching vllm. Here is the output running the test loop for 4 iterations:


šŸƒ run #1
vLLM name='no-prefix-caching', endpoint='http://localhost:8001/v1'
completions: ['6x2w', 'zwma71', '37wk', 'hu5qw', 'jg0m', '1tzdv', '4h7h', '5zq7', 'zxqj', '7k4n', '0rexw', 'vk0j', 'jjnj', 'xw95', 'vxjj', 't6x7']
vLLM name='with-prefix-caching', endpoint='http://localhost:8002/v1'
completions: ['6x2w', 'zwma71', '37wk', 'hu5qw', 'jg0m', '1tzdv', '4h7h', '5zq7', 'zxqj', '7k4n', '0rexw', 'vk0j', 'jjnj', 'xw95', 'vxjj', 't6x7']

šŸƒ run #2
vLLM name='no-prefix-caching', endpoint='http://localhost:8001/v1'
completions: ['6x2w', 'zwma71', '37wk', 'hu5qw', 'jg0m', '1tzkb', '4h7h', '5zq7', 'zxqj', '7k4n', '0rexw', 'vk0j', 'jjnj', 'xw95', 'vxjj', 't6x7']
āŒ completions for vLLM name='no-prefix-caching' differs from previous run!
vLLM name='with-prefix-caching', endpoint='http://localhost:8002/v1'
completions: ['6x2w', 'zwma71', '37wk', 'hu5qw', 'jg0m', '1tzkb', '4h7h', '5zq7', 'zxqj', '7k4n', '0rexw', 'vk0j', 'jjnj', 'xw95', 'vxjj', 't6x7']
āŒ completions for vLLM name='with-prefix-caching' differs from previous run!

šŸƒ run #3
vLLM name='no-prefix-caching', endpoint='http://localhost:8001/v1'
completions: ['6x2w', 'zwma71', '37wk', 'hu5qw', 'jg0m', '1tzkb', '4h7h', '5zq7', 'zxqj', '7k4n', '0rexw', 'vk0j', 'jjnj', 'xw95', 'vxjj', 't6x7']
vLLM name='with-prefix-caching', endpoint='http://localhost:8002/v1'
completions: ['6x2w', 'zwma71', '37wk', 'hu5qw', 'jg0m', '1tzkb', '4h7h', '5zq7', 'zxqj', '7k4n', '0rexw', 'vk0j', 'jjnj', 'xw95', 'vxjj', 't6x7']

šŸƒ run #4
vLLM name='no-prefix-caching', endpoint='http://localhost:8001/v1'
completions: ['6x2w', 'zwma71', '37wk', 'hu5qw', 'jg0m', '1tzkb', '4h7h', '5zq7', 'zxqj', '7k4n', '0rexw', 'vk0j', 'jjnj', 'xw95', 'vxjj', 't6x7']
vLLM name='with-prefix-caching', endpoint='http://localhost:8002/v1'
completions: ['6x2w', 'zwma71', '37wk', 'hu5qw', 'jg0m', '1tzkb', '4h7h', '5zq7', 'zxqj', '7k4n', '0rexw', 'vk0j', 'jjnj', 'xw95', 'vxjj', 't6x7']

hibukipanim avatar Jan 31 '25 13:01 hibukipanim

This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!

github-actions[bot] avatar May 02 '25 02:05 github-actions[bot]

It seems this issue doesn't get resolved in V1 + APC + FA still

valtab avatar May 17 '25 11:05 valtab