Batched inference on a single node with multiple GPUs
How to infer a batch of encoded tensors (shape = (B, T)) on 4 GPUs, getting 3~4x tokens/s through put compared to on single GPU? (it's for a small model which can be fit into a GPU's mem)
I've tried launching fabric with strategy='dp', 'ddp', 'fsdp' as in commit 7130a36 (2023/12/14 #818). But failed for various reasons.
Meanwhile, generate/sequentially.py is slower than single GPU and tp.py doesn't work for bached inputs out of the box.
Hi there,
thanks for the suggestions. Currently, batched inference is not supported in LitGPT yet (batching is only supported in training currently). But it's one of the things on our list that we want to add.
Thank you for the quick response.
I'd like to try implementing batched inferecne. Could you provide some guidance and suggest a starting point?
e.g. Should I start from one commit before 13fa12c (which dropped FSDP support #813) or should I copy and trim the batched training code?
Thanks
Thanks for your interest and offering help to contribute. I would honestly start with the most recent code because it's changed quite a bit over time. I would probably start with the generate base function in https://github.com/Lightning-AI/litgpt/tree/main/litgpt/generate (and then maybe later the same for the chat function: https://github.com/Lightning-AI/litgpt/blob/main/litgpt/chat/base.py)
There is a PR #886 that started work on implementing batched inference. If you want, you can proceed that work.
Got it. To make batched inference work on multiple GPUs, would it be recommended to begin with DDP instead of FSDP?
It will be of great help if you could point me to any relevant documentation or code examples. Thanks.
I would even start with single GPU, and then we could think about implementing data or model parallelism later.
Single device --> DDP --> FSDP.
Unfortunately I'm not familiar with the problem, so I cannot provide any docs. But, I've planned to do this anyway. Hopefully in a couple of weeks I'll be back at my computer, so be able to assist. In the meantime, try to do as much as you can on your own. The task should be interesting. Have fun 😊
Just wanted to chime in with some support! (sadly I have been absolutely swamped busy and haven't had the time to return to my original PR, which is probably only useful for starting hints at this point due to how much has changed).
Helpful tip: Ignore my build_mask_cache function
The implementation I used originally was mostly correct, but something is fishy about the mask cache. I was using my edited version to mask out padding tokens by building a custom (B,1,T,T) sized cache, as opposed to the default (1, 1, B, B) sized triangular mask cache that just gets broadcasted onto all the inputs. Apparently this screws with something in the scaled_dot_attention function (still sorting that out). At present i'm just using the original build_mask_cache function and letting the unmasked padding tokens do their thing (which llama 3 very much prefers it seems)
Batched inference is more and more desirable as synthetic data generation becomes more commonplace; looking forward to official support!
Edit: Realized my mistake while taking a shower:
I was inserting the padding mask into the triangular mask in a way that overwrote False values in the upper right triangle (which should all be False) so the attention values coming out for really messed up. What needs to happen is inserting False values into the lower part of the triangle appropriately, while not inserting True's into the upper... Doh!
It's not pretty, but it works:
def build_mask_cache(
max_seq_length: int,
device: Optional[torch.device] = None,
padding_mask : Optional[torch.Tensor] = None
) -> torch.Tensor:
# (B, max_seq_length, max_seq_length) sized tensor of True's
ones = torch.ones(
(max_seq_length, max_seq_length),
device=device,
dtype=torch.bool).unsqueeze(0).repeat(padding_mask.size(0), 1, 1)
# insert the padding mask into the ones tensor
ones[:, :, :padding_mask.size(1)] = padding_mask[:, :].unsqueeze(1)
# insert False/0 into the upper triangle of the tensor, and add a dimension for the head
mask = torch.tril(ones).unsqueeze(1)
return mask
Thanks @FlimFlamm for the info.