RUL
RUL copied to clipboard
AttMoE代码问题——未提供修改后的MoE模型代码
这里是MoE是用的https://github.com/XiuzeZhou/mixture-of-experts
这个代码吗?也就是https://github.com/davidmrau/mixture-of-experts
。但是你在AttMoE-NASA.ipynb
和AttMoE-CALCE.ipynb
中定义了下面模型。其中MoE的参数不是原始库的参数,不知道后续修改了哪些部分,且您的库中没有提供修改后的代码和原始MoE代码的来源。
from mixture_of_experts import MoE
class AttMoE(nn.Module):
def __init__(self, feature_size=16, hidden_dim=8, num_layers=1, nhead=4, dropout=0., dropout_rate=0.2,
num_experts=8, device='cpu'):
super(AttMoE, self).__init__()
self.feature_size, self.hidden_dim = feature_size, hidden_dim
self.dropout = nn.Dropout(dropout_rate)
self.cell = Attention(feature_size=feature_size, hidden_dim=hidden_dim, nhead=nhead, dropout=dropout)
self.linear = nn.Linear(hidden_dim, 1)
experts = nn.Linear(hidden_dim, hidden_dim)
# create moe layers based on the number of experts
self.moe = MoE(dim=hidden_dim, num_experts=num_experts, experts=experts)
self.moe = self.moe.to(device)
def forward(self, x):
out = self.dropout(x)
out = self.cell(x) # cell 输出 shape (batch_size, seq_len=1, feature_size)
out,_ = self.moe(out)
out = out.reshape(-1, self.hidden_dim) # (batch_size, hidden_dim)
out = self.linear(out) # shape: (batch_size, 1)
return out
mixture-of-experts库的github: https://github.com/lucidrains/mixture-of-experts