vllm
vllm copied to clipboard
[Core][Distributed] add fast broadcast for tensor dict
An ongoing effort of https://github.com/vllm-project/vllm/issues/4440 .
Reduce the number of broadcast from 2 to 1.
Broadcast time (before): 0.38772106170654297ms Broadcast time (after): 0.128173828125ms
TODO:
- [ ] improve the broadcast for prepare input in the same way.
The TensorMetadata is not good for serialization:
from vllm.distributed.communication_op import TensorMetadata
import torch
d = TensorMetadata("cuda", torch.float32, torch.Size([]))
import pickle
s = pickle.dumps(d)
len(s) # 120
import pickletools
pickletools.dis(s)
output:
0: \x80 PROTO 4
2: \x95 FRAME 109
11: \x8c SHORT_BINUNICODE 'vllm.distributed.communication_op'
46: \x94 MEMOIZE (as 0)
47: \x8c SHORT_BINUNICODE 'TensorMetadata'
63: \x94 MEMOIZE (as 1)
64: \x93 STACK_GLOBAL
65: \x94 MEMOIZE (as 2)
66: \x8c SHORT_BINUNICODE 'cuda'
72: \x94 MEMOIZE (as 3)
73: \x8c SHORT_BINUNICODE 'torch'
80: \x94 MEMOIZE (as 4)
81: \x8c SHORT_BINUNICODE 'float32'
90: \x94 MEMOIZE (as 5)
91: \x93 STACK_GLOBAL
92: \x94 MEMOIZE (as 6)
93: \x8c SHORT_BINUNICODE 'torch'
100: \x94 MEMOIZE (as 7)
101: \x8c SHORT_BINUNICODE 'Size'
107: \x94 MEMOIZE (as 8)
108: \x93 STACK_GLOBAL
109: \x94 MEMOIZE (as 9)
110: ) EMPTY_TUPLE
111: \x85 TUPLE1
112: \x94 MEMOIZE (as 10)
113: R REDUCE
114: \x94 MEMOIZE (as 11)
115: \x87 TUPLE3
116: \x94 MEMOIZE (as 12)
117: \x81 NEWOBJ
118: \x94 MEMOIZE (as 13)
119: . STOP
highest protocol among opcodes = 4
Each single TensorMetadata takes 120 bytes.
After a8d1d3a, the serialization size is reduced by more than a half (120 bytes to 52 bytes):
from vllm import TensorMeta
import torch
d = TensorMeta("cuda", torch.float32, torch.Size([]))
import pickle
s = pickle.dumps(d)
len(s) # 52
import pickletools
pickletools.dis(s)
output:
0: \x80 PROTO 4
2: \x95 FRAME 41
11: \x8c SHORT_BINUNICODE 'vllm'
17: \x94 MEMOIZE (as 0)
18: \x8c SHORT_BINUNICODE 'TensorMeta'
30: \x94 MEMOIZE (as 1)
31: \x93 STACK_GLOBAL
32: \x94 MEMOIZE (as 2)
33: ) EMPTY_TUPLE
34: \x81 NEWOBJ
35: \x94 MEMOIZE (as 3)
36: ] EMPTY_LIST
37: \x94 MEMOIZE (as 4)
38: ( MARK
39: \x8c SHORT_BINUNICODE 'cuda'
45: \x94 MEMOIZE (as 5)
46: K BININT1 17
48: ) EMPTY_TUPLE
49: e APPENDS (MARK at 38)
50: b BUILD
51: . STOP
With all above optimization, the bytes to broadcast BlockMetaData can be reduced from 260 bytes to 107 bytes.
This benefit will become more significant when we apply the technique to prepare input related data stucture.
"improve the broadcast for prepare input in the same way."
It will require another PR.
Also can you tell me the perf improvement from it?
For broadcasting blocks to swap/copy, the benefit is:
Broadcast time (before): 0.38772106170654297ms Broadcast time (after): 0.128173828125ms
I don't have an end-to-end benchmarking.
update test_swap
It requires quite a large modification to the test procedure (separate the test into distributed tests) . Meanwhile, the correctness is already checked in https://github.com/vllm-project/vllm/pull/4757/files#diff-cba46ef2b8ff23834781fa3b43794a3f19ffc6b4f1ec2353a8d13d1cdc2d0588R110 .
@rkooo567 can you help take a look at https://buildkite.com/vllm/ci/builds/7258#018f732a-46ad-4e69-a35b-25f5200d0e19 ? The failure looks like a ray issue, the function cannot access the name List, although it is imported in the top.
@youkaichao it would be good to check whether there's non-negligible performance difference in end-to-end tests before introducing the additional complexity, it's not always easy to infer this from a microbenchmark. A simple before/after test generating a decent number of tokens with a TP deployment would be sufficient I think?
Do you know how much of the latency benefit comes from compressing the number of bytes with the new TensorMeta class vs eliminating one of the broadcasts?
The two broadcast_tensor_dicts (this one and prepare_input_tensors) are done immediately after each other, and it does not look like the second depends on the first, could we combine them?
Especially if they're combined, I'm wondering whether we can avoid the quite convoluted abstractions for what is just a single case.
Implementation-wise, instead of requiring the custom classes, what do you think about this:
- Within broadcast_tensor_dict, maintain a single static tensor buffer (instead of per class). Its initial size can be like 16 bytes.
- Use the first byte of the broadcast tensor to indicate the kind of message that follows, either
- a regular pickled object
- a buffer resize, where the following 4 or 8 bytes are the encoded int size that the buffer should be increased to. After this, the broadcast is repeated.
Then there's no need to maintain special classes. I also don't think there's any need to have special handling for the keys, we can just pass lists instead of dicts?
@youkaichao another reason the above approach might be better - IIUC the get_example_metadata_list approach won't work if the size varies much at runtime (not sure whether that might be the case for prepare_input_tensors)?
First of all, this PR is the first step for later optimization. Itself is a pure benefit because it reduces the broadcast from twice to once.
The followup for applying the optimization in prepare input needs to come after the refactor https://github.com/vllm-project/vllm/pull/4681 .
Within broadcast_tensor_dict, maintain a single static tensor buffer (instead of per class). Its initial size can be like 16 bytes. Use the first byte of the broadcast tensor to indicate the kind of message that follows, either a regular pickled object a buffer resize, where the following 4 or 8 bytes are the encoded int size that the buffer should be increased to. After this, the broadcast is repeated.
This does not reduce the broadcast. It still requires two broadcast even if we don't have any tensor data to broadcast.
I think the nice benchmark to back up is;
- how much is input broadcast overhead for e2e latency? -> In our internal bechmhark, we found this overhead is "very big". https://docs.google.com/spreadsheets/d/1GMyebF9XwlLJzpkpRxZrzUNcQTSibHlQ7zifldaDPtI/edit#gid=0, almost as big as model fwd at high tp.
- I think there are 2 parts we can optimize. 1. reduce overhead of braodcast_object_list. 2. reduce the # of tensor broadcast (we do it per tensor). I think this tackles 1, and it'd be great to know how much is 1 in e2e broadcasting overhead (I believe @youkaichao already has the number).
@njhill has a proposal to cache the max length of metadata based on callsite, I will wait and see how it works.
@youkaichao I've opened #4844 to show the idea, PTAL!
close as https://github.com/vllm-project/vllm/pull/5399 will be a better solution.