torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

Fix generation for bsz > 1

Open joecummings opened this issue 1 year ago • 3 comments

Our modules only work with generation under two conditions: batch_size = 1 or every single sample in a batch has the same length. The main culprit is this line of code: https://github.com/pytorch/torchtune/blob/288ff4435b0cf17325b5c3b112f6859a6cdf0ea2/torchtune/modules/transformer.py#L167

For a batch that looks like the following:

My, name, is, Joe
Hello, world, <PAD>, <PAD>
Bye, <PAD>, <PAD>, <PAD>

A proper mask would look like:

1 0 0 0
1 1 0 0 
1 1 1 0
1 1 1 1

1 0 0 0
1 1 0 0
0 0 0 0
0 0 0 0

1 0 0 0
0 0 0 0
0 0 0 0
0 0 0 0

of size [b x s x s], which is [3 x 4 x 4]

This will be a fairly involved change that touches several utils and modules. The general changes needed will be:

  • Delete causal mask from KV cache, instead opting for this to come in the mask param
  • Update generate utils to pass a mask into its call to model.forward()
  • Modify eleuther eval recipe to construct proper causal mask to pass to

This was originally found and reported by @iankur

joecummings avatar Aug 02 '24 02:08 joecummings

Was this the kind of thing you had in mind? https://github.com/pytorch/torchtune/blob/1129f9e3a246628c991c246d81dbead62d3437a3/torchtune/modules/rlhf/_generation.py

Granted, there's a couple changes I've been meaning to make (only generating the full mask once, and extending it for each token in the batch, and you'll probably have a more intelligent way of generating the masks themselves).

salmanmohammadi avatar Aug 02 '24 10:08 salmanmohammadi

Was this the kind of thing you had in mind? 1129f9e/torchtune/modules/rlhf/_generation.py

Yep, this is pretty much it! I take it that you're not utilizing the KV Cache for this generation though, right?

joecummings avatar Aug 02 '24 13:08 joecummings

Yep, this is pretty much it! I take it that you're not utilizing the KV Cache for this generation though, right?

Nah. It was also on my TODO list of possible optimizations, and I briefly spoke to Rafi about it, but we agreed it would be kind of a pain in the ass to setup cacheing for custom masks.

salmanmohammadi avatar Aug 02 '24 13:08 salmanmohammadi

Left padded:

My, name, is, Joe
<PAD>, <PAD> Hello, world 
<PAD>, <PAD>, <PAD>, Bye

Left padded mask:

1 0 0 0
1 1 0 0
1 1 1 0
1 1 1 1

1 0 0 0
0 1 0 0
0 0 1 0 
0 0 1 1

1 0 0 0
0 1 0 0
0 0 1 0
0 0 0 1

joecummings avatar Aug 21 '24 23:08 joecummings

Our modules only work with generation under two conditions: batch_size = 1 or every single sample in a batch has the same length. The main culprit is this line of code:

I assume batched generation in the eleuther eval recipe satisfies the latter? I've just got iterative decoding + kv cacheing working for my batched RLHF generation utils - seeing > 10x speedups w/o compile (PPO go brrrr). Can chat about it later today if it's of interest.

salmanmohammadi avatar Aug 22 '24 11:08 salmanmohammadi