StoryDiffusion icon indicating copy to clipboard operation
StoryDiffusion copied to clipboard

what's the purpose of bool_matrix1024, bool_matrix4096 in the return of cal_attn_mask_xl function

Open parryppp opened this issue 1 year ago • 6 comments

The image of bool_matrix1024, bool_matrix4096 are shown as belowed. image image

parryppp avatar May 16 '24 09:05 parryppp

Our method is based on sampling tokens to interact between images through attention operations. To perform the sampling operation, the mask identifies the tokens to be sampled. In order to reduce memory consumption, we switch from using the mask to using indices.

Z-YuPeng avatar May 16 '24 09:05 Z-YuPeng

image The reshaped attention mask is shown above. Do you mean that, for example, if i want to generate 4 consistent images, the yellow zone in the attention map would not be masked, then what does 'randsample' in the paper mean?" image

parryppp avatar May 16 '24 12:05 parryppp

The generated random mask is exactly the means of implementing random sampling. Once we have randomized a mask, it means that only the tokens indicated by mask = 1 will be considered.

Z-YuPeng avatar May 16 '24 13:05 Z-YuPeng

the yellow zone corresponding to the concatenation operation below, we found that we cannot drop a image's own tokens, as this would lead to a significant decline in image quality.

Z-YuPeng avatar May 16 '24 13:05 Z-YuPeng

The mask operation combines random sampling and concatenation into a single step because we initially found that doing so was faster and equivalent to random sampling but also led to greater memory usage. Later, we reverted to the original approach due to concerns about memory consumption raised in issues. https://github.com/HVision-NKU/StoryDiffusion/blob/main/utils/gradio_utils.py#L258

Z-YuPeng avatar May 16 '24 13:05 Z-YuPeng

Thank you for your explanation, I now understand much more clearly. But I still have a question about the shape of attention mask. why does the attention mask ensure that squares on the diagonal remain set to 1 as shown in the figure belowed, is it the purpose of code in L249-L252? image https://github.com/HVision-NKU/StoryDiffusion/blob/b67d205784a3acba6194f13e941f9b5a4dbdc34c/utils/gradio_utils.py#L249 why not just simply generate a random attention mask just like the belowed figure? https://github.com/HVision-NKU/StoryDiffusion/blob/b67d205784a3acba6194f13e941f9b5a4dbdc34c/utils/gradio_utils.py#L261 image

parryppp avatar May 17 '24 04:05 parryppp