x-transformers icon indicating copy to clipboard operation
x-transformers copied to clipboard

Support for a mask during autoregressive generation with Key-Value Caching

Open Oufattole opened this issue 1 year ago • 10 comments

Why isn't a mask supported when key-value caching is enabled here?

Oufattole avatar Nov 12 '24 19:11 Oufattole

@Oufattole in what scenario would you need masking when doing autoregressive decoding?

lucidrains avatar Nov 12 '24 19:11 lucidrains

I'm trying to do sliding window inference, but the lengths of the initial prompts are different in my transformer, so I think I should mask out the padding as that's what we do during autoregressive pretraining.

I'm applying transformers to medical trajectories as a part of this open source project providing ML tooling for modeling patient time-series data (where you tokenize a patient's irregularly sampled time series observations, such as medications, diagnoses, procedures, etc.). I'm interested in generating future trajectories and evaluating them. Here is the relevant code I am currently using for generating trajectories. I currently am just not caching key value pairs, so that I can apply masks, but that is prohibitively slow.

Oufattole avatar Nov 12 '24 20:11 Oufattole

@Oufattole yes I see, so you are off the beaten path

sliding windows isn't supported here yet

lucidrains avatar Nov 12 '24 20:11 lucidrains

@Oufattole you can do away with masking by slicing the cached key values before passing it back in

lucidrains avatar Nov 12 '24 20:11 lucidrains

Ahhh I see thank you, I'll try that! With medical data, unlike in NLP and CV, many patient trajectories are very small and you don't need a long sequence length at all. For example, with my dataset 80% of patients are below the 512 max sequence length, but a small subset of patients are punching over 30k (this is after extreme reductions in the vocabulary -- i.e. which time-series variables we model, prior to which some of these patients hit over 300k).

I naively am trying to use sliding windows, but if there is a better approach you recommend for handling such extreme sequence length variations, I would be happy to try it.

Oufattole avatar Nov 12 '24 21:11 Oufattole

Wait, actually, I think you do support masking the left padded tokens with the seq_start_pos arg here @lucidrains .

Oufattole avatar Nov 13 '24 03:11 Oufattole

@Oufattole so that hyperparameter was actually built for variable prompt lengths iirc. i'll have to take a closer look to really know if it can be repurposed for what you are doing

during sliding window, you'll have to slice the cached key values as you decode out of the window length

lucidrains avatar Nov 13 '24 14:11 lucidrains

@Oufattole what specialty is this and what exactly are you trending in the EMR that hits 300k in length?

lucidrains avatar Nov 13 '24 14:11 lucidrains

Yes, I think you already do this kv-cache slicing during generation here when restricting to the max_seq_length (i.e. in the sliding window setting). Am I correct about this?

I'll send you an email in regard to the broader EHR modeling question, which I realize may be out of scope for this github issue.

Oufattole avatar Nov 13 '24 16:11 Oufattole

@Oufattole it has been a while, let me review it tomorrow morning and see if it can be made to work for your issue

lucidrains avatar Nov 13 '24 23:11 lucidrains