litgpt icon indicating copy to clipboard operation
litgpt copied to clipboard

Batched inference on a single node with multiple GPUs

Open antareson opened this issue 1 year ago • 9 comments

How to infer a batch of encoded tensors (shape = (B, T)) on 4 GPUs, getting 3~4x tokens/s through put compared to on single GPU? (it's for a small model which can be fit into a GPU's mem)

I've tried launching fabric with strategy='dp', 'ddp', 'fsdp' as in commit 7130a36 (2023/12/14 #818). But failed for various reasons.

Meanwhile, generate/sequentially.py is slower than single GPU and tp.py doesn't work for bached inputs out of the box.

antareson avatar Jun 09 '24 03:06 antareson

Hi there,

thanks for the suggestions. Currently, batched inference is not supported in LitGPT yet (batching is only supported in training currently). But it's one of the things on our list that we want to add.

rasbt avatar Jun 09 '24 11:06 rasbt

Thank you for the quick response.

I'd like to try implementing batched inferecne. Could you provide some guidance and suggest a starting point?

e.g. Should I start from one commit before 13fa12c (which dropped FSDP support #813) or should I copy and trim the batched training code?

Thanks

antareson avatar Jun 09 '24 13:06 antareson

Thanks for your interest and offering help to contribute. I would honestly start with the most recent code because it's changed quite a bit over time. I would probably start with the generate base function in https://github.com/Lightning-AI/litgpt/tree/main/litgpt/generate (and then maybe later the same for the chat function: https://github.com/Lightning-AI/litgpt/blob/main/litgpt/chat/base.py)

rasbt avatar Jun 09 '24 13:06 rasbt

There is a PR #886 that started work on implementing batched inference. If you want, you can proceed that work.

Andrei-Aksionov avatar Jun 09 '24 14:06 Andrei-Aksionov

Got it. To make batched inference work on multiple GPUs, would it be recommended to begin with DDP instead of FSDP?

It will be of great help if you could point me to any relevant documentation or code examples. Thanks.

antareson avatar Jun 10 '24 13:06 antareson

I would even start with single GPU, and then we could think about implementing data or model parallelism later.

rasbt avatar Jun 10 '24 13:06 rasbt

Single device --> DDP --> FSDP.

Unfortunately I'm not familiar with the problem, so I cannot provide any docs. But, I've planned to do this anyway. Hopefully in a couple of weeks I'll be back at my computer, so be able to assist. In the meantime, try to do as much as you can on your own. The task should be interesting. Have fun 😊

Andrei-Aksionov avatar Jun 10 '24 13:06 Andrei-Aksionov

Just wanted to chime in with some support! (sadly I have been absolutely swamped busy and haven't had the time to return to my original PR, which is probably only useful for starting hints at this point due to how much has changed).

Helpful tip: Ignore my build_mask_cache function

The implementation I used originally was mostly correct, but something is fishy about the mask cache. I was using my edited version to mask out padding tokens by building a custom (B,1,T,T) sized cache, as opposed to the default (1, 1, B, B) sized triangular mask cache that just gets broadcasted onto all the inputs. Apparently this screws with something in the scaled_dot_attention function (still sorting that out). At present i'm just using the original build_mask_cache function and letting the unmasked padding tokens do their thing (which llama 3 very much prefers it seems)

Batched inference is more and more desirable as synthetic data generation becomes more commonplace; looking forward to official support!

Edit: Realized my mistake while taking a shower:

I was inserting the padding mask into the triangular mask in a way that overwrote False values in the upper right triangle (which should all be False) so the attention values coming out for really messed up. What needs to happen is inserting False values into the lower part of the triangle appropriately, while not inserting True's into the upper... Doh!

It's not pretty, but it works:

def build_mask_cache(
        max_seq_length: int, 
        device: Optional[torch.device] = None,
        padding_mask : Optional[torch.Tensor] = None
        ) -> torch.Tensor:

    
    # (B, max_seq_length, max_seq_length) sized tensor of True's
    ones = torch.ones(
        (max_seq_length, max_seq_length), 
        device=device, 
        dtype=torch.bool).unsqueeze(0).repeat(padding_mask.size(0), 1, 1)
    
    # insert the padding mask into the ones tensor
    ones[:, :, :padding_mask.size(1)] = padding_mask[:, :].unsqueeze(1)
    
    # insert False/0 into the upper triangle of the tensor, and add a dimension for the head
    mask = torch.tril(ones).unsqueeze(1)

    return mask

FlimFlamm avatar Jun 14 '24 01:06 FlimFlamm

Thanks @FlimFlamm for the info.

Andrei-Aksionov avatar Jun 14 '24 04:06 Andrei-Aksionov