TokenFlow icon indicating copy to clipboard operation
TokenFlow copied to clipboard

a problem about the code,thanks

Open aosong01 opened this issue 9 months ago • 3 comments

it seems that you change all the basictransformerblock in both down_blocks, mid_blocks and up_blocks. why still change the up_blocks in the unet again?

def register_extended_attention(model):
    for _, module in model.unet.named_modules():
        if isinstance_str(module, "BasicTransformerBlock"):
            module.attn1.forward = sa_forward(module.attn1)

    res_dict = {1: [1, 2], 2: [0, 1, 2], 3: [0, 1, 2]}
    # we are injecting attention in blocks 4 - 11 of the decoder, so not in the first block of the lowest resolution
    for res in res_dict:
        for block in res_dict[res]:
            module = model.unet.up_blocks[res].attentions[block].transformer_blocks[0].attn1
            module.forward = sa_forward(module)

aosong01 avatar Sep 26 '23 07:09 aosong01