How do SD's text_tokenizer and Unet work when the input prompt is too long?
Question: How do SD's text_tokenizer and Unet work when the input prompt is too long?
Description: Hello, esteemed expert! I have a question recently. When I use AUTOMATIC1111/stable-diffusion-webui, I found that I can input prompts longer than 77 characters, and these prompts' texts are valid for generating images. I don't understand how it works. For example:
prompt = "a photograph of an astronaut riding a horse"
text_input_ids = text_tokenizer(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors="pt"
).input_ids
text_embeddings = text_encoder(text_input_ids.to("cuda"))[0]
The output is torch.Size([1, 77, 768]). I don't understand how text_tokenizer supports such long text prompt inputs or how these excessively long text prompts work in the Unet's cross-attention. I have looked at the code in your repository, but I still haven't found the answer. Forgive my ignorance, and I humbly ask for your guidance.
First, U-Net can consume batch of output of text-encoder like [n, 77, 768]. So, training scripts utilize this property to extend length of tokens 75, 150, 225, and so on.
Why not 77 is the first id of input is id of begin token, <bos>, and the last token if of input is id of end token, <eos>, Therefore, scripts focuses the middle useful tokens, pure text.
Utilizing both properties into one magic thing, for example token length 225, [3, 77, 768] is now input of U-Net and pseudo-output of tokenizer. Let [3, 77, 768] be shape of input. i.e. input.shape = (3, 77, 768)
Then,
input[0] = [<bos> + first 75 token of prompt + <eos>]
input[1] = [<bos> + second 75 token of prompt + <eos>]
input[2] = [<bos> + third 75 token of prompt + <eos>]
By doing this manipulation, A1111 or sd-scripts can receive more than 75 token length.