DiffSynth-Studio icon indicating copy to clipboard operation
DiffSynth-Studio copied to clipboard

Attention Map Visualization

Open guaini opened this issue 1 year ago • 5 comments

I read the EliGen paper and I'm interested in the attention map visualization results in fig. 13. Which part of the attention map the results correspond to? And how to find the attention map correspond to a specific prompt?

guaini avatar Feb 28 '25 07:02 guaini

The visualization results are derived from the regional attention (RA) layer at the last double-stream transformer block, showing the first ten denoising timesteps. Refer to the source code, we temporally change the forward function of FluxJointAttention class of diffsynth/models/flux_dit.py, where selected = attention_map[:, :, 0:512, 1024:5120] is the code to find the first prompt; 0:512 refer to local prompt tokens, 1024:5120 refer to the lantent tokens. The full codes are:

    def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None, i=None):
        batch_size = hidden_states_a.shape[0]

        # Part A
        qkv_a = self.a_to_qkv(hidden_states_a)
        qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
        q_a, k_a, v_a = qkv_a.chunk(3, dim=1)
        q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)

        # Part B
        qkv_b = self.b_to_qkv(hidden_states_b)
        qkv_b = qkv_b.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
        q_b, k_b, v_b = qkv_b.chunk(3, dim=1)
        q_b, k_b = self.norm_q_b(q_b), self.norm_k_b(k_b)

        q = torch.concat([q_b, q_a], dim=2)
        k = torch.concat([k_b, k_a], dim=2)
        v = torch.concat([v_b, v_a], dim=2)

        q, k = self.apply_rope(q, k, image_rotary_emb)
        if attn_mask is not None and i is not None:
            def compute_attention_map(q, k, mask):
                d_k = q.size(-1)
                scores = torch.matmul(q, k.transpose(-2, -1))  # [1, 24, 5120, 5120]
                scores = scores / (d_k ** 0.5)
                attention_map = torch.nn.functional.softmax(scores, dim=-1)  # [1, 24, 5120, 5120]

                return attention_map
            attention_map = compute_attention_map(q, k, attn_mask)

            selected = attention_map[:, :, 0:512, 1024:5120]
            selected = selected[:,0,:,:].mean(dim=1)
            vis = rearrange(selected, "B (H W) -> B H W", H=64, W=64)
            vis_map = vis.float().squeeze(0).detach().cpu().numpy()
            vis_map_normalized = (vis_map - vis_map.min()) / (vis_map.max() - vis_map.min())
            import seaborn as sns
            import matplotlib.pyplot as plt
            plt.axis('off')
            sns.heatmap(vis_map_normalized, cmap="viridis", annot=False, fmt=".2f", cbar=False, square=True)
            import os
            os.makedirs('workdirs/paper_app/attention_nt', exist_ok=True)
            plt.savefig(f'workdirs/paper_app/attention_nt/attention_{i}.png', bbox_inches='tight', pad_inches=0, dpi=300)

        hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
        hidden_states = hidden_states.to(q.dtype)
        hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]
        if ipadapter_kwargs_list is not None:
            hidden_states_a = interact_with_ipadapter(hidden_states_a, q_a, **ipadapter_kwargs_list)
        hidden_states_a = self.a_to_out(hidden_states_a)
        if self.only_out_a:
            return hidden_states_a
        else:
            hidden_states_b = self.b_to_out(hidden_states_b)
            return hidden_states_a, hidden_states_b

mi804 avatar Feb 28 '25 08:02 mi804

@mi804 Thanks for your reply. I will try the code later.

guaini avatar Feb 28 '25 08:02 guaini

@mi804 selected = attention_map[:, :, 0:512, 1024:5120] selected = selected[:,0,:,:].mean(dim=1)

guaini avatar Feb 28 '25 11:02 guaini

How the above code shows the attention map correspond to the word person?

guaini avatar Feb 28 '25 11:02 guaini

In the implementation of Regional Attention, Local prompt tokens, global prompt tokens and latent tokens are concated together. So for the 5120 tokens, 0:512 are tokens from local prompt person, 512:1024 are tokens from global prompt a person standing by the river, and 1024:5120 are tokens of latent embeddings. so attention_map[:, :, 0:512, 1024:5120] means the attention map between person and latent embeddings

mi804 avatar Feb 28 '25 11:02 mi804