训练过慢,且梯度为NaN
您好,感谢您的工作!我把您的clex layer部分插到我的模型中,实现方式如下:
class Encoder(nn.Module):
def __init__(self, config):
'''省略'''
elif config.my_info_dict.get("algorithm",False)=="clex":
from .clex_layer import CLEXScalingRotaryEmbedding
rope_scaling={"factor": 1,"max_factor": 64,"param_factor": 1,"time_dt": 0.01,"type": "clex","act": "tanh"}
self.clex_layer = CLEXScalingRotaryEmbedding(config.attention_key_size, self.config.my_info_dict["train_len"], rope_scaling)
'''省略'''
def forward(
self,
'''省略'''
):
'''省略'''
if self.config.my_info_dict.get("algorithm", False)=="clex":
sinusoidal_pos = self.clex_layer(seqlen, do_train)
'''省略'''
其中CLEXScalingRotaryEmbedding类仅进行了与核心操作无关的修改:
class ODELinear(nn.Module):
def __init__(
self,
dim: int,
factor,
act,
base=10000,
**kwargs
):
super().__init__()
self.ode_up_proj = nn.Parameter(torch.empty(dim//2, factor*dim))
self.ode_down_proj = nn.Parameter(torch.empty(factor*dim, dim//2))
self.dim = dim
self.base = base
if act == "tanh":
self.act = torch.nn.Tanh()
elif act == "silu":
self.act = torch.nn.SiLU()
else:
raise ValueError(f"act must be one of ['tanh', 'silu'], got {act}")
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.ode_up_proj, a=math.sqrt(5))
nn.init.zeros_(self.ode_down_proj)
def get_time_embedding(self, t, base=10000, device='cuda', dtype=torch.float32):
if t < 1:
alpha = 1
else:
alpha = 2*t-1
ntk_base = base * alpha ** (self.dim / (self.dim-2))
ntk_inv_freq = 1.0 / (ntk_base ** (torch.arange(0, self.dim, 2, dtype=torch.float32).to(device) / self.dim))
index = torch.arange(0, self.dim, 2, dtype=torch.float32).to(device)
delta_ntk_freq = -2*index/(self.dim-2) * 1 / (base ** (index/self.dim) * (alpha ** (index/(self.dim-2) + 1)))
return delta_ntk_freq.to(device, dtype=dtype), ntk_inv_freq.to(device, dtype=dtype)
def forward(self, t, x: torch.Tensor):
device = x.device
delta_time, time = self.get_time_embedding(t.to(device), device=device, dtype=x.dtype)
x = x + torch.log(time)
time_embed = delta_time / time
delta_inv_freq = self.act(x @ self.ode_up_proj.float()) @ self.ode_down_proj.float()
delta_inv_freq = delta_inv_freq + time_embed
return delta_inv_freq
class CLEXScalingRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, rope_scaling=None, base=10000, device=None) -> None:
super().__init__()
self.max_t = rope_scaling["max_factor"]
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq)
self.proj_func = ODELinear(dim, rope_scaling["param_factor"], rope_scaling["act"], base)
self.rope_cached = None
self.max_t_cached = 0
self.freq_cached = None
self.time_dt = rope_scaling["time_dt"]
self.ode_args = {
"method": "rk4",
"options": {"step_size": self.time_dt},
}
def sample_random_times(self, max_t, device):
return torch.randint(1, max_t, (1,), dtype = torch.long, device=device)
def get_random_position_ids(self, n=2048, max=8192):
positions = torch.randperm(max)[:n].sort().values
return positions
def get_continuous_freq(self, time_grid, ex_positions, device):
solution = odeint(
self.proj_func, torch.log(self.inv_freq.to(device, dtype=torch.float32)), time_grid, **self.ode_args
)
if time_grid.size(0) == 2:
scale_inv_freq = torch.exp(solution[1])
freqs = torch.outer(ex_positions.float().squeeze(), scale_inv_freq)
else:
scale_inv_freq = torch.exp(solution)
return scale_inv_freq
embed = torch.cat((freqs,freqs), dim=-1)
return embed
def forward(self, seq_len, do_train=False):
device = self.proj_func.ode_up_proj.device
scale_factor = seq_len // self.max_position_embeddings
if do_train:
t_val = self.sample_random_times(self.max_t+1, device)[0]
if scale_factor < 1.0:
scale_factor = 1
sampled_position_ids = self.get_random_position_ids(n=seq_len-2, max=seq_len*t_val-2).float()
ex_positions = torch.cat([
torch.tensor([0]),
(sampled_position_ids + 1) / scale_factor,
torch.tensor([seq_len*t_val//scale_factor-1])]
).to(device, dtype=torch.float32)
else:
t_val = scale_factor if seq_len%self.max_position_embeddings == 0.0 else scale_factor + 1
t_val = t_val if t_val <= self.max_t else self.max_t
ex_positions = torch.arange(0, self.max_position_embeddings * t_val, dtype=torch.float32).to(device)
if t_val == 1.0:
scale_inv_freq = self.inv_freq.to(device)
freqs = torch.outer(ex_positions.float().squeeze(), scale_inv_freq)
embed = torch.cat((freqs,freqs), dim=-1)
cos, sin = embed.cos(), embed.sin()
elif do_train:
time_grid = torch.tensor([1.0, t_val]).float().to(device)
embed = self.get_continuous_freq(time_grid, ex_positions, device)
cos, sin = embed.cos(), embed.sin()
else:
if self.freq_cached is None:
time_grid = torch.arange(1.0, self.max_t+1.0, dtype=torch.float32).to(device)
self.freq_cached = self.get_continuous_freq(time_grid, ex_positions, device)
if t_val != self.max_t_cached:
scale_inv_freq = self.freq_cached[int(t_val-1.0)]
freqs = torch.outer(ex_positions.float().squeeze(), scale_inv_freq)
embed = torch.cat((freqs,freqs), dim=-1)
self.rope_cached = torch.cat((embed.cos()[None, :, :], embed.sin()[None, :, :]), dim=0)
self.max_t_cached = t_val
cos, sin = self.rope_cached
return cos[None, :seq_len], sin[None, :seq_len]
但是训练速度十分缓慢。按照您论文中的,在模型之上微调一个epoch,我在我的模型上加入如上实现,微调一个epoch需要两个小时,并且出现了loss和梯度为NaN的情况。而去掉CLEX,训练时长表现正常,约三分钟一个epoch。我的实验使用4个Tesla V100-SXM2-32GB GPU。
感谢您的回复!
Hello,
I also experimented with integrating CLEX into my model and observed similarly a significant reduction in training speed and inf gradient norms. Increasing the timestep from 0.01 to 0.1 somewhat alleviated the training slowdown, but unfortunately did not resolve the issue of infinite gradients.
Have you managed to solve the issue ?
Hello,
I also experimented with integrating CLEX into my model and observed similarly a significant reduction in training speed and
infgradient norms. Increasing the timestep from0.01to0.1somewhat alleviated the training slowdown, but unfortunately did not resolve the issue of infinite gradients.Have you managed to solve the issue ?
Sorry I haven't solved this problem yet🥺