TokenFlow
TokenFlow copied to clipboard
a problem about the code,thanks
it seems that you change all the basictransformerblock in both down_blocks, mid_blocks and up_blocks. why still change the up_blocks in the unet again?
def register_extended_attention(model):
for _, module in model.unet.named_modules():
if isinstance_str(module, "BasicTransformerBlock"):
module.attn1.forward = sa_forward(module.attn1)
res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]}
# we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
for res in res_dict:
for block in res_dict[res]:
module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
module.forward = sa_forward(module)