vllm icon indicating copy to clipboard operation
vllm copied to clipboard

[Core][Distributed] add fast broadcast for tensor dict

Open youkaichao opened this issue 1 year ago • 11 comments

An ongoing effort of https://github.com/vllm-project/vllm/issues/4440 .

Reduce the number of broadcast from 2 to 1.

Broadcast time (before): 0.38772106170654297ms Broadcast time (after): 0.128173828125ms

TODO:

  • [ ] improve the broadcast for prepare input in the same way.

youkaichao avatar May 11 '24 07:05 youkaichao

The TensorMetadata is not good for serialization:

from vllm.distributed.communication_op import TensorMetadata
import torch
d = TensorMetadata("cuda", torch.float32, torch.Size([]))
import pickle
s = pickle.dumps(d)
len(s) # 120
import pickletools
pickletools.dis(s)

output:

    0: \x80 PROTO      4
    2: \x95 FRAME      109
   11: \x8c SHORT_BINUNICODE 'vllm.distributed.communication_op'
   46: \x94 MEMOIZE    (as 0)
   47: \x8c SHORT_BINUNICODE 'TensorMetadata'
   63: \x94 MEMOIZE    (as 1)
   64: \x93 STACK_GLOBAL
   65: \x94 MEMOIZE    (as 2)
   66: \x8c SHORT_BINUNICODE 'cuda'
   72: \x94 MEMOIZE    (as 3)
   73: \x8c SHORT_BINUNICODE 'torch'
   80: \x94 MEMOIZE    (as 4)
   81: \x8c SHORT_BINUNICODE 'float32'
   90: \x94 MEMOIZE    (as 5)
   91: \x93 STACK_GLOBAL
   92: \x94 MEMOIZE    (as 6)
   93: \x8c SHORT_BINUNICODE 'torch'
  100: \x94 MEMOIZE    (as 7)
  101: \x8c SHORT_BINUNICODE 'Size'
  107: \x94 MEMOIZE    (as 8)
  108: \x93 STACK_GLOBAL
  109: \x94 MEMOIZE    (as 9)
  110: )    EMPTY_TUPLE
  111: \x85 TUPLE1
  112: \x94 MEMOIZE    (as 10)
  113: R    REDUCE
  114: \x94 MEMOIZE    (as 11)
  115: \x87 TUPLE3
  116: \x94 MEMOIZE    (as 12)
  117: \x81 NEWOBJ
  118: \x94 MEMOIZE    (as 13)
  119: .    STOP
highest protocol among opcodes = 4

Each single TensorMetadata takes 120 bytes.

youkaichao avatar May 11 '24 22:05 youkaichao

After a8d1d3a, the serialization size is reduced by more than a half (120 bytes to 52 bytes):

from vllm import TensorMeta
import torch
d = TensorMeta("cuda", torch.float32, torch.Size([]))
import pickle
s = pickle.dumps(d)
len(s) # 52
import pickletools
pickletools.dis(s)

output:

    0: \x80 PROTO      4
    2: \x95 FRAME      41
   11: \x8c SHORT_BINUNICODE 'vllm'
   17: \x94 MEMOIZE    (as 0)
   18: \x8c SHORT_BINUNICODE 'TensorMeta'
   30: \x94 MEMOIZE    (as 1)
   31: \x93 STACK_GLOBAL
   32: \x94 MEMOIZE    (as 2)
   33: )    EMPTY_TUPLE
   34: \x81 NEWOBJ
   35: \x94 MEMOIZE    (as 3)
   36: ]    EMPTY_LIST
   37: \x94 MEMOIZE    (as 4)
   38: (    MARK
   39: \x8c     SHORT_BINUNICODE 'cuda'
   45: \x94     MEMOIZE    (as 5)
   46: K        BININT1    17
   48: )        EMPTY_TUPLE
   49: e        APPENDS    (MARK at 38)
   50: b    BUILD
   51: .    STOP

youkaichao avatar May 11 '24 23:05 youkaichao

With all above optimization, the bytes to broadcast BlockMetaData can be reduced from 260 bytes to 107 bytes.

This benefit will become more significant when we apply the technique to prepare input related data stucture.

youkaichao avatar May 11 '24 23:05 youkaichao

"improve the broadcast for prepare input in the same way."

It will require another PR.

Also can you tell me the perf improvement from it?

For broadcasting blocks to swap/copy, the benefit is:

Broadcast time (before): 0.38772106170654297ms Broadcast time (after): 0.128173828125ms

I don't have an end-to-end benchmarking.

update test_swap

It requires quite a large modification to the test procedure (separate the test into distributed tests) . Meanwhile, the correctness is already checked in https://github.com/vllm-project/vllm/pull/4757/files#diff-cba46ef2b8ff23834781fa3b43794a3f19ffc6b4f1ec2353a8d13d1cdc2d0588R110 .

youkaichao avatar May 13 '24 17:05 youkaichao

@rkooo567 can you help take a look at https://buildkite.com/vllm/ci/builds/7258#018f732a-46ad-4e69-a35b-25f5200d0e19 ? The failure looks like a ray issue, the function cannot access the name List, although it is imported in the top.

youkaichao avatar May 13 '24 19:05 youkaichao

@youkaichao it would be good to check whether there's non-negligible performance difference in end-to-end tests before introducing the additional complexity, it's not always easy to infer this from a microbenchmark. A simple before/after test generating a decent number of tokens with a TP deployment would be sufficient I think?

Do you know how much of the latency benefit comes from compressing the number of bytes with the new TensorMeta class vs eliminating one of the broadcasts?

The two broadcast_tensor_dicts (this one and prepare_input_tensors) are done immediately after each other, and it does not look like the second depends on the first, could we combine them?

Especially if they're combined, I'm wondering whether we can avoid the quite convoluted abstractions for what is just a single case.

Implementation-wise, instead of requiring the custom classes, what do you think about this:

  • Within broadcast_tensor_dict, maintain a single static tensor buffer (instead of per class). Its initial size can be like 16 bytes.
  • Use the first byte of the broadcast tensor to indicate the kind of message that follows, either
    1. a regular pickled object
    2. a buffer resize, where the following 4 or 8 bytes are the encoded int size that the buffer should be increased to. After this, the broadcast is repeated.

Then there's no need to maintain special classes. I also don't think there's any need to have special handling for the keys, we can just pass lists instead of dicts?

njhill avatar May 13 '24 20:05 njhill

@youkaichao another reason the above approach might be better - IIUC the get_example_metadata_list approach won't work if the size varies much at runtime (not sure whether that might be the case for prepare_input_tensors)?

njhill avatar May 13 '24 20:05 njhill

First of all, this PR is the first step for later optimization. Itself is a pure benefit because it reduces the broadcast from twice to once.

The followup for applying the optimization in prepare input needs to come after the refactor https://github.com/vllm-project/vllm/pull/4681 .

Within broadcast_tensor_dict, maintain a single static tensor buffer (instead of per class). Its initial size can be like 16 bytes. Use the first byte of the broadcast tensor to indicate the kind of message that follows, either a regular pickled object a buffer resize, where the following 4 or 8 bytes are the encoded int size that the buffer should be increased to. After this, the broadcast is repeated.

This does not reduce the broadcast. It still requires two broadcast even if we don't have any tensor data to broadcast.

youkaichao avatar May 13 '24 21:05 youkaichao

I think the nice benchmark to back up is;

  • how much is input broadcast overhead for e2e latency? -> In our internal bechmhark, we found this overhead is "very big". https://docs.google.com/spreadsheets/d/1GMyebF9XwlLJzpkpRxZrzUNcQTSibHlQ7zifldaDPtI/edit#gid=0, almost as big as model fwd at high tp.
  • I think there are 2 parts we can optimize. 1. reduce overhead of braodcast_object_list. 2. reduce the # of tensor broadcast (we do it per tensor). I think this tackles 1, and it'd be great to know how much is 1 in e2e broadcasting overhead (I believe @youkaichao already has the number).

rkooo567 avatar May 14 '24 01:05 rkooo567

@njhill has a proposal to cache the max length of metadata based on callsite, I will wait and see how it works.

youkaichao avatar May 14 '24 04:05 youkaichao

@youkaichao I've opened #4844 to show the idea, PTAL!

njhill avatar May 16 '24 19:05 njhill

close as https://github.com/vllm-project/vllm/pull/5399 will be a better solution.

youkaichao avatar Jun 15 '24 18:06 youkaichao