min-dalle icon indicating copy to clipboard operation
min-dalle copied to clipboard

Parallel forward

Open neverix opened this issue 3 years ago • 4 comments

The model's decoder right now only supports sequential decoding. This is because of the way attn_state is implemented. Parallel ~~generation~~ forward pass can be implemented by setting attn_state to None and handling all cases inside generation code

This would help solve #58

neverix avatar Jul 14 '22 19:07 neverix

I'm not sure what you mean. Are you saying parallel forward over the 256 image tokens? That wouldn't work because each token depends on the previous token. And if you meant parallel over the layers that wouldn't work either since each layer depends on the previous layer's output. Maybe you meant parallel backward?

kuprel avatar Jul 14 '22 19:07 kuprel

Right now the code can't just do forward over all tokens because of the caching implementation. It needs to run through every token instead of just masking the attention

neverix avatar Jul 14 '22 20:07 neverix

Oh I see, it would be for if you wanted to do a forward pass over all tokens at once, instead of sampling one after the other.

kuprel avatar Jul 14 '22 20:07 kuprel

#80 solves this

neverix avatar Jul 19 '22 05:07 neverix