Differential Transformer loss spikes while training.
I have trained a 1.3B model using both the Differential Transformer and the standard Transformer. I observed a slight improvement in LLM evaluation scores for the Differential Transformer variant, and its loss was consistently lower. However, when I tried the 7B model comparison, training showed loss spikes, and the loss started increasing after a certain point. I also noticed that the gradient norm increased at the same points where the loss spiked. In terms of implementation details, I only used the differential attention part from this repository; I did not include the SwiGLU layer. Instead, I used a standard FFN layer. @YTianZHU
Hi @fasil-saidalavi , sorry for the late response. Do you use same FFN implementation for both Diff and Transformer? Would you post a code snippet of attention implementation of both Diff and baseline Transformer?
Hai @YTianZHU,
- I used same FFN implementation for both Diff and Transformer
- attention implementation of Diff Transformer
class TEDotProductDiffAttention(te.pytorch.DotProductAttention): cp_stream: torch.cuda.Stream = None def init( self, config: TransformerConfig, layer_number: int, attn_mask_type: AttnMaskType, attention_type: str, attention_dropout: float = None, ): self.config = config self.te_forward_mask_type = False self.qkv_format: str = 'sbhd'
if self.config.apply_query_key_layer_scaling != bool(
int(os.getenv('NVTE_APPLY_QK_LAYER_SCALING', '0'))
):
raise ValueError(
f"apply_query_key_layer_scaling is {self.config.apply_query_key_layer_scaling} "
f"but environment variable NVTE_APPLY_QK_LAYER_SCALING is "
f"{os.getenv('NVTE_APPLY_QK_LAYER_SCALING')}. Transformer Engine does not support "
f"setting query key layer scaling via argument, so these two must match."
)
extra_kwargs = {}
if _te_version >= packaging.version.Version("0.11.0"):
extra_kwargs["num_gqa_groups"] = self.config.num_query_groups
elif self.config.num_query_groups != self.config.num_attention_heads:
raise ValueError(
f"Transformer Engine v{_te_version} does not support Grouped Query Attention, "
f"use a newer version of Transformer Engine. "
f"(num_query_groups ({self.config.num_query_groups}) != "
f"num_attention_heads ({self.config.num_attention_heads}))"
)
if _te_version >= packaging.version.Version("0.10.0"):
extra_kwargs["attention_type"] = attention_type
# older version don't need attention_type
if _te_version > packaging.version.Version("0.12.0"):
self.te_forward_mask_type = True
# Only Transformer-Engine version >= 1.0.0 supports context parallelism
if _te_version >= packaging.version.Version("1.0.0"):
if getattr(TEDotProductDiffAttention, "cp_stream") is None:
TEDotProductDiffAttention.cp_stream = torch.cuda.Stream()
extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False)
extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks(
check_initialized=False
)
extra_kwargs["cp_stream"] = TEDotProductDiffAttention.cp_stream
else:
assert (
self.config.context_parallel_size == 1
), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!"
if config.window_size is not None:
# Check version
assert _te_version >= packaging.version.Version(
"1.2.0"
), f"Transformer-Engine version ({str(_te_version)}) must be >= 1.2.0 to support sliding window attention."
extra_kwargs['window_size'] = config.window_size
self.diff_kv_channels = divide(self.config.kv_channels, 2)
super().__init__(
num_attention_heads=self.config.num_attention_heads,
kv_channels=self.diff_kv_channels,
attention_dropout=self.config.attention_dropout
if attention_dropout is None
else attention_dropout,
attn_mask_type=attn_mask_type.name,
sequence_parallel=self.config.sequence_parallel,
tp_size=self.config.tensor_model_parallel_size,
get_rng_state_tracker=get_cuda_rng_tracker
if get_cuda_rng_tracker().is_initialized()
else None,
tp_group=get_tensor_model_parallel_group(check_initialized=False),
layer_number=layer_number,
**extra_kwargs,
)
# torch.manual_seed(40)
self.lambda_init = lambda_init_fn(layer_number)
self.lambda_q1 = nn.Parameter(torch.zeros(self.diff_kv_channels, dtype=torch.float32).normal_(mean=0,std=0.1))
self.lambda_k1 = nn.Parameter(torch.zeros(self.diff_kv_channels, dtype=torch.float32).normal_(mean=0,std=0.1))
self.lambda_q2 = nn.Parameter(torch.zeros(self.diff_kv_channels, dtype=torch.float32).normal_(mean=0,std=0.1))
self.lambda_k2 = nn.Parameter(torch.zeros(self.diff_kv_channels, dtype=torch.float32).normal_(mean=0,std=0.1))
# torch.manual_seed(42)
self.subln = RMSNorm(2 * self.diff_kv_channels, eps=1e-5, elementwise_affine=True)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
attention_mask: Tensor,
attn_mask_type: AttnMaskType,
packed_seq_params: PackedSeqParams = None,
):
packed_seq_kwargs = (
dataclasses.asdict(packed_seq_params) if packed_seq_params is not None else {}
)
# overwrite self.qkv_format depending on self.config.apply_rope_fusion, which can be set after init
if self.config.apply_rope_fusion and _te_version > packaging.version.Version("0.13.0"):
self.qkv_format = 'bshd'
qkv_format = packed_seq_kwargs.get('qkv_format', self.qkv_format)
if _te_version < packaging.version.Version("1.3.0"):
# TE 1.3.0 introduces precomputing max_seqlen to remove unnecessary kernels and D2H copies (#555)
# These two arguments did not exist prior to 1.3.0
packed_seq_kwargs.pop("max_seqlen_q", None)
packed_seq_kwargs.pop("max_seqlen_kv", None)
if self.config.apply_rope_fusion and qkv_format == 'bshd':
query, key, value = [x.transpose(0, 1).contiguous() for x in (query, key, value)]
# In PyTorch, the following two tensors are in fact the same:
# Tensor with shape (1, S, H, D) and stride (S*H*D, H*D, D, 1)
# Tensor with shape (1, S, H, D) and stride (H*D, H*D, D, 1)
# Stride for a dimension that is 1 has no meaning, so tensors created two different ways
# can have same shape but different strides.
# We unify them to the first one to pass the stride check in TE
if value.shape == key.shape and value.shape[0] == 1 and value.stride() != key.stride():
value = value.as_strided(value.shape, key.stride())
src_len, tgt_len, bsz, q_num_heads , kv_num_heads = key.size(0), query.size(0), query.size(1), query.size(2), key.size(2)
# query = query.view(tgt_len, bsz, 2 * q_num_heads, self.diff_kv_channels)
query = query.reshape(tgt_len, bsz, q_num_heads, 2, self.diff_kv_channels)
# key = key.view(tgt_len, bsz, 2 * kv_num_heads, self.diff_kv_channels)
key = key.reshape(src_len, bsz, kv_num_heads, 2, self.diff_kv_channels)
# value = value.view(tgt_len, bsz, 2 * kv_num_heads, self.diff_kv_channels)
value = value.reshape(src_len, bsz, kv_num_heads, 2, self.diff_kv_channels)
q1, q2 = query[:, :, :, 0], query[:, :, :, 1]
k1, k2 = key[:, :, :, 0], key[:, :, :, 1]
v1, v2 = value[:, :, :, 0], value[:, :, :, 1]
if self.te_forward_mask_type:
core_attn_out11 = super().forward(
q1,
k1,
v1,
attention_mask,
attn_mask_type=attn_mask_type.name,
**packed_seq_kwargs,
)
else:
core_attn_out11 = super().forward(q1, k1, v1, attention_mask, **packed_seq_kwargs,)
core_attn_out11 = core_attn_out11.view(tgt_len, bsz, kv_num_heads , self.diff_kv_channels)
if self.te_forward_mask_type:
core_attn_out12 = super().forward(
q1,
k1,
v2,
attention_mask,
attn_mask_type=attn_mask_type.name,
**packed_seq_kwargs,
)
else:
core_attn_out12 = super().forward(q1, k1, v2, attention_mask, **packed_seq_kwargs,)
core_attn_out12 = core_attn_out12.view(tgt_len, bsz, kv_num_heads , self.diff_kv_channels)
attn1 = torch.cat([core_attn_out11, core_attn_out12], dim=-1)
if self.te_forward_mask_type:
core_attn_out21 = super().forward(
q2,
k2,
v1,
attention_mask,
attn_mask_type=attn_mask_type.name,
**packed_seq_kwargs,
)
else:
core_attn_out21 = super().forward(q2, k2, v1, attention_mask, **packed_seq_kwargs,)
core_attn_out21 = core_attn_out21.view(tgt_len, bsz, kv_num_heads , self.diff_kv_channels)
if self.te_forward_mask_type:
core_attn_out22 = super().forward(
q2,
k2,
v2,
attention_mask,
attn_mask_type=attn_mask_type.name,
**packed_seq_kwargs,
)
else:
core_attn_out22 = super().forward(q2, k2, v2, attention_mask, **packed_seq_kwargs,)
core_attn_out22 = core_attn_out22.view(tgt_len, bsz, kv_num_heads , self.diff_kv_channels)
attn2 = torch.cat([core_attn_out21, core_attn_out22], dim=-1)
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(query)
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(query)
lambda_full = lambda_1 - lambda_2 + self.lambda_init
# lambda_full = 1
core_attn_out = attn1 - lambda_full * attn2
core_attn_out = self.subln(core_attn_out)
core_attn_out = core_attn_out * (1 - self.lambda_init)
core_attn_out = core_attn_out.view(tgt_len, bsz , q_num_heads * self.config.kv_channels)
if self.config.apply_rope_fusion and qkv_format == 'bshd':
return core_attn_out.transpose(0, 1)
else:
return core_attn_out
attention implementation of baseline Transformer
class TEDotProductAttention(te.pytorch.DotProductAttention):
cp_stream: torch.cuda.Stream = None
def __init__(
self,
config: TransformerConfig,
layer_number: int,
attn_mask_type: AttnMaskType,
attention_type: str,
attention_dropout: float = None,
):
self.config = config
self.te_forward_mask_type = False
self.qkv_format: str = 'sbhd'
if self.config.apply_query_key_layer_scaling != bool(
int(os.getenv('NVTE_APPLY_QK_LAYER_SCALING', '0'))
):
raise ValueError(
f"apply_query_key_layer_scaling is {self.config.apply_query_key_layer_scaling} "
f"but environment variable NVTE_APPLY_QK_LAYER_SCALING is "
f"{os.getenv('NVTE_APPLY_QK_LAYER_SCALING')}. Transformer Engine does not support "
f"setting query key layer scaling via argument, so these two must match."
)
extra_kwargs = {}
if _te_version >= packaging.version.Version("0.11.0"):
extra_kwargs["num_gqa_groups"] = self.config.num_query_groups
elif self.config.num_query_groups != self.config.num_attention_heads:
raise ValueError(
f"Transformer Engine v{_te_version} does not support Grouped Query Attention, "
f"use a newer version of Transformer Engine. "
f"(num_query_groups ({self.config.num_query_groups}) != "
f"num_attention_heads ({self.config.num_attention_heads}))"
)
if _te_version >= packaging.version.Version("0.10.0"):
extra_kwargs["attention_type"] = attention_type
# older version don't need attention_type
if _te_version > packaging.version.Version("0.12.0"):
self.te_forward_mask_type = True
# Only Transformer-Engine version >= 1.0.0 supports context parallelism
if _te_version >= packaging.version.Version("1.0.0"):
if getattr(TEDotProductAttention, "cp_stream") is None:
TEDotProductAttention.cp_stream = torch.cuda.Stream()
extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False)
extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks(
check_initialized=False
)
extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream
else:
assert (
self.config.context_parallel_size == 1
), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!"
if config.window_size is not None:
# Check version
assert _te_version >= packaging.version.Version(
"1.2.0"
), f"Transformer-Engine version ({str(_te_version)}) must be >= 1.2.0 to support sliding window attention."
extra_kwargs['window_size'] = config.window_size
super().__init__(
num_attention_heads=self.config.num_attention_heads,
kv_channels=self.config.kv_channels,
attention_dropout=self.config.attention_dropout
if attention_dropout is None
else attention_dropout,
attn_mask_type=attn_mask_type.name,
sequence_parallel=self.config.sequence_parallel,
tp_size=self.config.tensor_model_parallel_size,
get_rng_state_tracker=get_cuda_rng_tracker
if get_cuda_rng_tracker().is_initialized()
else None,
tp_group=get_tensor_model_parallel_group(check_initialized=False),
layer_number=layer_number,
**extra_kwargs,
)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
attention_mask: Tensor,
attn_mask_type: AttnMaskType,
packed_seq_params: PackedSeqParams = None,
):
packed_seq_kwargs = (
dataclasses.asdict(packed_seq_params) if packed_seq_params is not None else {}
)
# overwrite self.qkv_format depending on self.config.apply_rope_fusion, which can be set after init
if self.config.apply_rope_fusion and _te_version > packaging.version.Version("0.13.0"):
self.qkv_format = 'bshd'
qkv_format = packed_seq_kwargs.get('qkv_format', self.qkv_format)
if _te_version < packaging.version.Version("1.3.0"):
# TE 1.3.0 introduces precomputing max_seqlen to remove unnecessary kernels and D2H copies (#555)
# These two arguments did not exist prior to 1.3.0
packed_seq_kwargs.pop("max_seqlen_q", None)
packed_seq_kwargs.pop("max_seqlen_kv", None)
if self.config.apply_rope_fusion and qkv_format == 'bshd':
query, key, value = [x.transpose(0, 1).contiguous() for x in (query, key, value)]
# In PyTorch, the following two tensors are in fact the same:
# Tensor with shape (1, S, H, D) and stride (S*H*D, H*D, D, 1)
# Tensor with shape (1, S, H, D) and stride (H*D, H*D, D, 1)
# Stride for a dimension that is 1 has no meaning, so tensors created two different ways
# can have same shape but different strides.
# We unify them to the first one to pass the stride check in TE
if value.shape == key.shape and value.shape[0] == 1 and value.stride() != key.stride():
value = value.as_strided(value.shape, key.stride())
if self.te_forward_mask_type:
core_attn_out = super().forward(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type.name,
**packed_seq_kwargs,
)
else:
core_attn_out = super().forward(query, key, value, attention_mask, **packed_seq_kwargs,)
if self.config.apply_rope_fusion and qkv_format == 'bshd':
return core_attn_out.transpose(0, 1)
else:
return core_attn_out
@fasil-saidalavi Hi, in DIFF, q1k1 and q2k2 share the same value in a DIFF attention head. For e.g., the q1 q2 k1 k2 has 64 headdim, than the v should be 128 headdim. The output is attn(q1, k1, v) - lambda * attn(q2, k2, v). Seems in your code, the calculation is differnent.
@YTianZHU Hai, I just used this code from the official implementation and made changes so that the diff attention work in my training framework. before training i verified that both the implementation is having same input output pair.
class MultiheadFlashDiff2(nn.Module): """ DiffAttn implemented with FlashAttention, for packages that does not support different qk/v dimensions e.g., flash-attention (https://github.com/Dao-AILab/flash-attention) """ def init( self, embed_dim, depth, # current layer index num_heads, num_kv_heads=None, ): super().init() self.embed_dim = embed_dim
# arg num_heads set to half of baseline Transformer's num_heads
# for e.g., to compare with a baseline Transformer with 16 heads, pass in num_heads=8 for DIFF Transformer
self.num_heads = num_heads
# arg num_kv_heads set to half of baseline Transformer's num_kv_heads if use GQA
# for e.g., to compare with a baseline Transformer with 16 heads and 8 kv_heads,
# pass in num_heads=8, num_kv_heads=4 for DIFF Transformer
# if use MHA, pass in num_kv_heads=None
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
self.n_rep = self.num_heads // self.num_kv_heads
self.head_dim = embed_dim // num_heads // 2
self.scaling = self.head_dim ** -0.5
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.k_proj = nn.Linear(embed_dim, embed_dim // self.n_rep, bias=False)
self.v_proj = nn.Linear(embed_dim, embed_dim // self.n_rep, bias=False)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
# depth means current layer index
self.lambda_init = lambda_init_fn(depth)
self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
self.subln = RMSNorm(2 * self.head_dim, eps=1e-5, elementwise_affine=True)
def forward(
self,
x,
rel_pos,
attn_mask=None,
):
bsz, tgt_len, embed_dim = x.size()
src_len = tgt_len
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
q = q.view(bsz, tgt_len, 2 * self.num_heads, self.head_dim)
k = k.view(bsz, src_len, 2 * self.num_kv_heads, self.head_dim)
v = v.view(bsz, src_len, self.num_kv_heads, 2, self.head_dim)
q = apply_rotary_emb(q, *rel_pos, interleaved=True)
k = apply_rotary_emb(k, *rel_pos, interleaved=True)
offset = src_len - tgt_len
q = q.reshape(bsz, tgt_len, self.num_heads, 2, self.head_dim)
k = k.reshape(bsz, src_len, self.num_kv_heads, 2, self.head_dim)
q1, q2 = q[:, :, :, 0], q[:, :, :, 1]
k1, k2 = k[:, :, :, 0], k[:, :, :, 1]
v1, v2 = v[:, :, :, 0], v[:, :, :, 1]
attn11 = flash_attn_func(q1, k1, v1, causal=True)
attn12 = flash_attn_func(q1, k1, v2, causal=True)
attn1 = torch.cat([attn11, attn12], dim=-1)
attn21 = flash_attn_func(q2, k2, v1, causal=True)
attn22 = flash_attn_func(q2, k2, v2, causal=True)
attn2 = torch.cat([attn21, attn22], dim=-1)
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
lambda_full = lambda_1 - lambda_2 + self.lambda_init
attn = attn1 - lambda_full * attn2
attn = self.subln(attn)
attn = attn * (1 - self.lambda_init)
attn = attn.reshape(bsz, tgt_len, self.num_heads * 2 * self.head_dim)
attn = self.out_proj(attn)
return attn
@YTianZHU Hi, what precision did you use for training, especially for the 3B model with 1T tokens?