TrailBlazer
TrailBlazer copied to clipboard
Cross Attention maps
Hello,
Thank you so much for your great work and codebase!
I would appreciate your clarifications on a few items.
- From within
TextToVideoSDPipelineCall.py
, at this line, the attention maps from the temporal layers seem to be empty, by approximately using this code block
for name, module in self.unet.named_modules():
module_name = type(module).__name__
if module_name == "Attention" and "attn2" in name:
# --- First set
if "temp_attentions" in name:
print(name) # replace .0 with [0]
extracted_attention_map = module.processor.cross_attention_map
if extracted_attention_map!=None:
print(extracted_attention_map.shape)
else:
# --- Second set
...
- First set
down_blocks[0].temp_attentions[0].transformer_blocks[0].attn2 down_blocks[0].temp_attentions[1].transformer_blocks[0].attn2 down_blocks[1].temp_attentions[0].transformer_blocks[0].attn2 down_blocks[1].temp_attentions[1].transformer_blocks[0].attn2 down_blocks[2].temp_attentions[0].transformer_blocks[0].attn2 down_blocks[2].temp_attentions[1].transformer_blocks[0].attn2 up_blocks[1].temp_attentions[0].transformer_blocks[0].attn2 up_blocks[1].temp_attentions[1].transformer_blocks[0].attn2 up_blocks[1].temp_attentions[2].transformer_blocks[0].attn2 up_blocks[2].temp_attentions[0].transformer_blocks[0].attn2 up_blocks[2].temp_attentions[1].transformer_blocks[0].attn2 up_blocks[2].temp_attentions[2].transformer_blocks[0].attn2 up_blocks[3].temp_attentions[0].transformer_blocks[0].attn2 up_blocks[3].temp_attentions[1].transformer_blocks[0].attn2 up_blocks[3].temp_attentions[2].transformer_blocks[0].attn2 mid_block.temp_attentions[0].transformer_blocks[0].attn2
while only .attentions
layers and the transformer_in
layer in the second set have cross attention maps.
- Second set
transformer_in.transformer_blocks[0].attn2 torch.Size([64, 64, 24, 24]) down_blocks[0].attentions[0].transformer_blocks[0].attn2 torch.Size([120, 64, 64, 77]) down_blocks[0].attentions[1].transformer_blocks[0].attn2 torch.Size([120, 64, 64, 77]) down_blocks[1].attentions[0].transformer_blocks[0].attn2 torch.Size([240, 32, 32, 77]) down_blocks[1].attentions[1].transformer_blocks[0].attn2 torch.Size([240, 32, 32, 77]) down_blocks[2].attentions[0].transformer_blocks[0].attn2 torch.Size([480, 16, 16, 77]) down_blocks[2].attentions[1].transformer_blocks[0].attn2 torch.Size([480, 16, 16, 77]) up_blocks[1].attentions[0].transformer_blocks[0].attn2 torch.Size([480, 16, 16, 77]) up_blocks[1].attentions[1].transformer_blocks[0].attn2 torch.Size([480, 16, 16, 77]) up_blocks[1].attentions[2].transformer_blocks[0].attn2 torch.Size([480, 16, 16, 77]) up_blocks[2].attentions[0].transformer_blocks[0].attn2 torch.Size([240, 32, 32, 77]) up_blocks[2].attentions[1].transformer_blocks[0].attn2 torch.Size([240, 32, 32, 77]) up_blocks[2].attentions[2].transformer_blocks[0].attn2 torch.Size([240, 32, 32, 77]) up_blocks[3].attentions[0].transformer_blocks[0].attn2 torch.Size([120, 64, 64, 77]) up_blocks[3].attentions[1].transformer_blocks[0].attn2 torch.Size([120, 64, 64, 77]) up_blocks[3].attentions[2].transformer_blocks[0].attn2 torch.Size([120, 64, 64, 77]) mid_block.attentions[0].transformer_blocks[0].attn2 torch.Size([480, 8, 8, 77])
- If one should assume that the second set is the spatial attention maps, it does not align with modules listed in the supplemental document (page 1, screenshot included), particularly the
transformer_in.transformer_blocks[0].attn2
with size 64, 64, 24, 24 suggesting its temporal (not spatial as mentioned in supplemental) with 24 frames andmid_block.attentions[0].transformer_blocks[0].attn2
with size 480, 8, 8, 77 suggesting its the spatial attention map (not temporal) with 77 tokens.
Your kind clarification would be very helpful. Thanks