unilm icon indicating copy to clipboard operation
unilm copied to clipboard

Differential Transformer loss spikes while training.

Open fasil-saidalavi opened this issue 7 months ago • 5 comments

I have trained a 1.3B model using both the Differential Transformer and the standard Transformer. I observed a slight improvement in LLM evaluation scores for the Differential Transformer variant, and its loss was consistently lower. However, when I tried the 7B model comparison, training showed loss spikes, and the loss started increasing after a certain point. I also noticed that the gradient norm increased at the same points where the loss spiked. In terms of implementation details, I only used the differential attention part from this repository; I did not include the SwiGLU layer. Instead, I used a standard FFN layer. @YTianZHU

fasil-saidalavi avatar May 20 '25 12:05 fasil-saidalavi

Hi @fasil-saidalavi , sorry for the late response. Do you use same FFN implementation for both Diff and Transformer? Would you post a code snippet of attention implementation of both Diff and baseline Transformer?

YTianZHU avatar May 30 '25 03:05 YTianZHU

Hai @YTianZHU,

  • I used same FFN implementation for both Diff and Transformer
  • attention implementation of Diff Transformer

class TEDotProductDiffAttention(te.pytorch.DotProductAttention): cp_stream: torch.cuda.Stream = None def init( self, config: TransformerConfig, layer_number: int, attn_mask_type: AttnMaskType, attention_type: str, attention_dropout: float = None, ): self.config = config self.te_forward_mask_type = False self.qkv_format: str = 'sbhd'

    if self.config.apply_query_key_layer_scaling != bool(
        int(os.getenv('NVTE_APPLY_QK_LAYER_SCALING', '0'))
    ):
        raise ValueError(
            f"apply_query_key_layer_scaling is {self.config.apply_query_key_layer_scaling} "
            f"but environment variable NVTE_APPLY_QK_LAYER_SCALING is "
            f"{os.getenv('NVTE_APPLY_QK_LAYER_SCALING')}. Transformer Engine does not support "
            f"setting query key layer scaling via argument, so these two must match."
        )

    extra_kwargs = {}
    if _te_version >= packaging.version.Version("0.11.0"):
        extra_kwargs["num_gqa_groups"] = self.config.num_query_groups
    elif self.config.num_query_groups != self.config.num_attention_heads:
        raise ValueError(
            f"Transformer Engine v{_te_version} does not support Grouped Query Attention, "
            f"use a newer version of Transformer Engine. "
            f"(num_query_groups ({self.config.num_query_groups}) != "
            f"num_attention_heads ({self.config.num_attention_heads}))"
        )

    if _te_version >= packaging.version.Version("0.10.0"):
        extra_kwargs["attention_type"] = attention_type
        # older version don't need attention_type

    if _te_version > packaging.version.Version("0.12.0"):
        self.te_forward_mask_type = True

    # Only Transformer-Engine version >= 1.0.0 supports context parallelism
    if _te_version >= packaging.version.Version("1.0.0"):
        if getattr(TEDotProductDiffAttention, "cp_stream") is None:
            TEDotProductDiffAttention.cp_stream = torch.cuda.Stream()
        extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False)
        extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks(
            check_initialized=False
        )
        extra_kwargs["cp_stream"] = TEDotProductDiffAttention.cp_stream
    else:
        assert (
            self.config.context_parallel_size == 1
        ), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!"

    if config.window_size is not None:
        # Check version
        assert _te_version >= packaging.version.Version(
            "1.2.0"
        ), f"Transformer-Engine version ({str(_te_version)}) must be >= 1.2.0 to support sliding window attention."
        extra_kwargs['window_size'] = config.window_size
    self.diff_kv_channels = divide(self.config.kv_channels, 2)
    super().__init__(
        num_attention_heads=self.config.num_attention_heads,
        kv_channels=self.diff_kv_channels,
        attention_dropout=self.config.attention_dropout
        if attention_dropout is None
        else attention_dropout,
        attn_mask_type=attn_mask_type.name,
        sequence_parallel=self.config.sequence_parallel,
        tp_size=self.config.tensor_model_parallel_size,
        get_rng_state_tracker=get_cuda_rng_tracker
        if get_cuda_rng_tracker().is_initialized()
        else None,
        tp_group=get_tensor_model_parallel_group(check_initialized=False),
        layer_number=layer_number,
        **extra_kwargs,
    )
    # torch.manual_seed(40)
    self.lambda_init = lambda_init_fn(layer_number)
    self.lambda_q1 = nn.Parameter(torch.zeros(self.diff_kv_channels, dtype=torch.float32).normal_(mean=0,std=0.1))
    self.lambda_k1 = nn.Parameter(torch.zeros(self.diff_kv_channels, dtype=torch.float32).normal_(mean=0,std=0.1))
    self.lambda_q2 = nn.Parameter(torch.zeros(self.diff_kv_channels, dtype=torch.float32).normal_(mean=0,std=0.1))
    self.lambda_k2 = nn.Parameter(torch.zeros(self.diff_kv_channels, dtype=torch.float32).normal_(mean=0,std=0.1))

    # torch.manual_seed(42)
    self.subln = RMSNorm(2 * self.diff_kv_channels, eps=1e-5, elementwise_affine=True)


def forward(
    self,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    attention_mask: Tensor,
    attn_mask_type: AttnMaskType,
    packed_seq_params: PackedSeqParams = None,
):
    packed_seq_kwargs = (
        dataclasses.asdict(packed_seq_params) if packed_seq_params is not None else {}
    )
    # overwrite self.qkv_format depending on self.config.apply_rope_fusion, which can be set after init
    if self.config.apply_rope_fusion and _te_version > packaging.version.Version("0.13.0"):
        self.qkv_format = 'bshd'

    qkv_format = packed_seq_kwargs.get('qkv_format', self.qkv_format)

    if _te_version < packaging.version.Version("1.3.0"):
        # TE 1.3.0 introduces precomputing max_seqlen to remove unnecessary kernels and D2H copies (#555)
        # These two arguments did not exist prior to 1.3.0
        packed_seq_kwargs.pop("max_seqlen_q", None)
        packed_seq_kwargs.pop("max_seqlen_kv", None)

    if self.config.apply_rope_fusion and qkv_format == 'bshd':
        query, key, value = [x.transpose(0, 1).contiguous() for x in (query, key, value)]
        # In PyTorch, the following two tensors are in fact the same:
        #   Tensor with shape (1, S, H, D) and stride (S*H*D, H*D, D, 1)
        #   Tensor with shape (1, S, H, D) and stride (H*D, H*D, D, 1)
        # Stride for a dimension that is 1 has no meaning, so tensors created two different ways
        # can have same shape but different strides.
        # We unify them to the first one to pass the stride check in TE
        if value.shape == key.shape and value.shape[0] == 1 and value.stride() != key.stride():
            value = value.as_strided(value.shape, key.stride())
            
    src_len, tgt_len, bsz, q_num_heads , kv_num_heads = key.size(0), query.size(0), query.size(1), query.size(2), key.size(2)
    # query = query.view(tgt_len, bsz, 2 * q_num_heads, self.diff_kv_channels)
    query = query.reshape(tgt_len, bsz, q_num_heads, 2, self.diff_kv_channels)
    
    # key = key.view(tgt_len, bsz, 2 * kv_num_heads, self.diff_kv_channels)
    key = key.reshape(src_len, bsz, kv_num_heads, 2, self.diff_kv_channels)
    
    # value = value.view(tgt_len, bsz, 2 * kv_num_heads, self.diff_kv_channels)
    value = value.reshape(src_len, bsz, kv_num_heads, 2, self.diff_kv_channels)
    
    q1, q2 = query[:, :, :, 0], query[:, :, :, 1]
    k1, k2 = key[:, :, :, 0], key[:, :, :, 1]
    v1, v2 = value[:, :, :, 0], value[:, :, :, 1]
    
    if self.te_forward_mask_type:
        core_attn_out11 = super().forward(
            q1,
            k1,
            v1,
            attention_mask,
            attn_mask_type=attn_mask_type.name,
            **packed_seq_kwargs,
        )
    else:
        core_attn_out11 = super().forward(q1, k1, v1, attention_mask, **packed_seq_kwargs,)
    core_attn_out11 = core_attn_out11.view(tgt_len, bsz, kv_num_heads , self.diff_kv_channels)
    
    if self.te_forward_mask_type:
        core_attn_out12 = super().forward(
            q1,
            k1,
            v2,
            attention_mask,
            attn_mask_type=attn_mask_type.name,
            **packed_seq_kwargs,
        )
    else:
        core_attn_out12 = super().forward(q1, k1, v2, attention_mask, **packed_seq_kwargs,)
    core_attn_out12 = core_attn_out12.view(tgt_len, bsz, kv_num_heads , self.diff_kv_channels)

    attn1 = torch.cat([core_attn_out11, core_attn_out12], dim=-1)
    
    if self.te_forward_mask_type:
        core_attn_out21 = super().forward(
            q2,
            k2,
            v1,
            attention_mask,
            attn_mask_type=attn_mask_type.name,
            **packed_seq_kwargs,
        )
    else:
        core_attn_out21 = super().forward(q2, k2, v1, attention_mask, **packed_seq_kwargs,)
    core_attn_out21 = core_attn_out21.view(tgt_len, bsz, kv_num_heads , self.diff_kv_channels)
    
    if self.te_forward_mask_type:
        core_attn_out22 = super().forward(
            q2,
            k2,
            v2,
            attention_mask,
            attn_mask_type=attn_mask_type.name,
            **packed_seq_kwargs,
        )
    else:
        core_attn_out22 = super().forward(q2, k2, v2, attention_mask, **packed_seq_kwargs,)
    core_attn_out22 = core_attn_out22.view(tgt_len, bsz, kv_num_heads , self.diff_kv_channels)
    
    attn2 = torch.cat([core_attn_out21, core_attn_out22], dim=-1)


    lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(query)
    lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(query)
    lambda_full = lambda_1 - lambda_2 + self.lambda_init
    # lambda_full = 1
    core_attn_out = attn1 - lambda_full * attn2
    
    core_attn_out = self.subln(core_attn_out)
    
    core_attn_out = core_attn_out * (1 - self.lambda_init)
    core_attn_out = core_attn_out.view(tgt_len, bsz , q_num_heads * self.config.kv_channels) 

    if self.config.apply_rope_fusion and qkv_format == 'bshd':
        return core_attn_out.transpose(0, 1)
    else:
        return core_attn_out

attention implementation of baseline Transformer

class TEDotProductAttention(te.pytorch.DotProductAttention):

cp_stream: torch.cuda.Stream = None

def __init__(
    self,
    config: TransformerConfig,
    layer_number: int,
    attn_mask_type: AttnMaskType,
    attention_type: str,
    attention_dropout: float = None,
):
    self.config = config
    self.te_forward_mask_type = False
    self.qkv_format: str = 'sbhd'

    if self.config.apply_query_key_layer_scaling != bool(
        int(os.getenv('NVTE_APPLY_QK_LAYER_SCALING', '0'))
    ):
        raise ValueError(
            f"apply_query_key_layer_scaling is {self.config.apply_query_key_layer_scaling} "
            f"but environment variable NVTE_APPLY_QK_LAYER_SCALING is "
            f"{os.getenv('NVTE_APPLY_QK_LAYER_SCALING')}. Transformer Engine does not support "
            f"setting query key layer scaling via argument, so these two must match."
        )

    extra_kwargs = {}
    if _te_version >= packaging.version.Version("0.11.0"):
        extra_kwargs["num_gqa_groups"] = self.config.num_query_groups
    elif self.config.num_query_groups != self.config.num_attention_heads:
        raise ValueError(
            f"Transformer Engine v{_te_version} does not support Grouped Query Attention, "
            f"use a newer version of Transformer Engine. "
            f"(num_query_groups ({self.config.num_query_groups}) != "
            f"num_attention_heads ({self.config.num_attention_heads}))"
        )

    if _te_version >= packaging.version.Version("0.10.0"):
        extra_kwargs["attention_type"] = attention_type
        # older version don't need attention_type

    if _te_version > packaging.version.Version("0.12.0"):
        self.te_forward_mask_type = True

    # Only Transformer-Engine version >= 1.0.0 supports context parallelism
    if _te_version >= packaging.version.Version("1.0.0"):
        if getattr(TEDotProductAttention, "cp_stream") is None:
            TEDotProductAttention.cp_stream = torch.cuda.Stream()
        extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False)
        extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks(
            check_initialized=False
        )
        extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream
    else:
        assert (
            self.config.context_parallel_size == 1
        ), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!"

    if config.window_size is not None:
        # Check version
        assert _te_version >= packaging.version.Version(
            "1.2.0"
        ), f"Transformer-Engine version ({str(_te_version)}) must be >= 1.2.0 to support sliding window attention."
        extra_kwargs['window_size'] = config.window_size

    super().__init__(
        num_attention_heads=self.config.num_attention_heads,
        kv_channels=self.config.kv_channels,
        attention_dropout=self.config.attention_dropout
        if attention_dropout is None
        else attention_dropout,
        attn_mask_type=attn_mask_type.name,
        sequence_parallel=self.config.sequence_parallel,
        tp_size=self.config.tensor_model_parallel_size,
        get_rng_state_tracker=get_cuda_rng_tracker
        if get_cuda_rng_tracker().is_initialized()
        else None,
        tp_group=get_tensor_model_parallel_group(check_initialized=False),
        layer_number=layer_number,
        **extra_kwargs,
    )

def forward(
    self,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    attention_mask: Tensor,
    attn_mask_type: AttnMaskType,
    packed_seq_params: PackedSeqParams = None,
):
    packed_seq_kwargs = (
        dataclasses.asdict(packed_seq_params) if packed_seq_params is not None else {}
    )
    # overwrite self.qkv_format depending on self.config.apply_rope_fusion, which can be set after init
    if self.config.apply_rope_fusion and _te_version > packaging.version.Version("0.13.0"):
        self.qkv_format = 'bshd'

    qkv_format = packed_seq_kwargs.get('qkv_format', self.qkv_format)

    if _te_version < packaging.version.Version("1.3.0"):
        # TE 1.3.0 introduces precomputing max_seqlen to remove unnecessary kernels and D2H copies (#555)
        # These two arguments did not exist prior to 1.3.0
        packed_seq_kwargs.pop("max_seqlen_q", None)
        packed_seq_kwargs.pop("max_seqlen_kv", None)

    if self.config.apply_rope_fusion and qkv_format == 'bshd':
        query, key, value = [x.transpose(0, 1).contiguous() for x in (query, key, value)]
        # In PyTorch, the following two tensors are in fact the same:
        #   Tensor with shape (1, S, H, D) and stride (S*H*D, H*D, D, 1)
        #   Tensor with shape (1, S, H, D) and stride (H*D, H*D, D, 1)
        # Stride for a dimension that is 1 has no meaning, so tensors created two different ways
        # can have same shape but different strides.
        # We unify them to the first one to pass the stride check in TE
        if value.shape == key.shape and value.shape[0] == 1 and value.stride() != key.stride():
            value = value.as_strided(value.shape, key.stride())

    if self.te_forward_mask_type:
        core_attn_out = super().forward(
            query,
            key,
            value,
            attention_mask,
            attn_mask_type=attn_mask_type.name,
            **packed_seq_kwargs,
        )
    else:
        core_attn_out = super().forward(query, key, value, attention_mask, **packed_seq_kwargs,)

    if self.config.apply_rope_fusion and qkv_format == 'bshd':
        return core_attn_out.transpose(0, 1)
    else:
        return core_attn_out

fasil-saidalavi avatar Jun 04 '25 08:06 fasil-saidalavi

@fasil-saidalavi Hi, in DIFF, q1k1 and q2k2 share the same value in a DIFF attention head. For e.g., the q1 q2 k1 k2 has 64 headdim, than the v should be 128 headdim. The output is attn(q1, k1, v) - lambda * attn(q2, k2, v). Seems in your code, the calculation is differnent.

YTianZHU avatar Jun 12 '25 06:06 YTianZHU

@YTianZHU Hai, I just used this code from the official implementation and made changes so that the diff attention work in my training framework. before training i verified that both the implementation is having same input output pair.

class MultiheadFlashDiff2(nn.Module): """ DiffAttn implemented with FlashAttention, for packages that does not support different qk/v dimensions e.g., flash-attention (https://github.com/Dao-AILab/flash-attention) """ def init( self, embed_dim, depth, # current layer index num_heads, num_kv_heads=None, ): super().init() self.embed_dim = embed_dim

    # arg num_heads set to half of baseline Transformer's num_heads
    # for e.g., to compare with a baseline Transformer with 16 heads, pass in num_heads=8 for DIFF Transformer
    self.num_heads = num_heads
    
    # arg num_kv_heads set to half of baseline Transformer's num_kv_heads if use GQA
    # for e.g., to compare with a baseline Transformer with 16 heads and 8 kv_heads, 
    # pass in num_heads=8, num_kv_heads=4 for DIFF Transformer
    # if use MHA, pass in num_kv_heads=None
    self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
    self.n_rep = self.num_heads // self.num_kv_heads
    
    self.head_dim = embed_dim // num_heads // 2
    self.scaling = self.head_dim ** -0.5
    
    self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
    self.k_proj = nn.Linear(embed_dim, embed_dim // self.n_rep, bias=False)
    self.v_proj = nn.Linear(embed_dim, embed_dim // self.n_rep, bias=False)
    self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)

    # depth means current layer index
    self.lambda_init = lambda_init_fn(depth)
    self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
    self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
    self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
    self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))

    self.subln = RMSNorm(2 * self.head_dim, eps=1e-5, elementwise_affine=True)

def forward(
    self,
    x,
    rel_pos,
    attn_mask=None,
):
    bsz, tgt_len, embed_dim = x.size()
    src_len = tgt_len

    q = self.q_proj(x)
    k = self.k_proj(x)
    v = self.v_proj(x)

    q = q.view(bsz, tgt_len, 2 * self.num_heads, self.head_dim)
    k = k.view(bsz, src_len, 2 * self.num_kv_heads, self.head_dim)
    v = v.view(bsz, src_len, self.num_kv_heads, 2, self.head_dim)

    q = apply_rotary_emb(q, *rel_pos, interleaved=True)
    k = apply_rotary_emb(k, *rel_pos, interleaved=True)

    offset = src_len - tgt_len
    q = q.reshape(bsz, tgt_len, self.num_heads, 2, self.head_dim)
    k = k.reshape(bsz, src_len, self.num_kv_heads, 2, self.head_dim)
    q1, q2 = q[:, :, :, 0], q[:, :, :, 1]
    k1, k2 = k[:, :, :, 0], k[:, :, :, 1]
    v1, v2 = v[:, :, :, 0], v[:, :, :, 1]

    attn11 = flash_attn_func(q1, k1, v1, causal=True)
    attn12 = flash_attn_func(q1, k1, v2, causal=True)
    attn1 = torch.cat([attn11, attn12], dim=-1)
    
    attn21 = flash_attn_func(q2, k2, v1, causal=True)
    attn22 = flash_attn_func(q2, k2, v2, causal=True)
    attn2 = torch.cat([attn21, attn22], dim=-1)
    
    lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
    lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
    lambda_full = lambda_1 - lambda_2 + self.lambda_init
    attn = attn1 - lambda_full * attn2

    attn = self.subln(attn)
    attn = attn * (1 - self.lambda_init)
    attn = attn.reshape(bsz, tgt_len, self.num_heads * 2 * self.head_dim)
    
    attn = self.out_proj(attn)
    return attn

fasil-saidalavi avatar Jun 13 '25 05:06 fasil-saidalavi

@YTianZHU Hi, what precision did you use for training, especially for the 3B model with 1T tokens?

fasil-saidalavi avatar Jun 26 '25 08:06 fasil-saidalavi