Zhihu
Zhihu copied to clipboard
TimeDistributed困惑
tks,fastText模型在Embedding层之后有个TimeDistributed,不太清楚这层的作用。而且发现你的三个模型里都会有这一步。pytorch小白,针对这个点可以详细解答下吗(这块维度变化也不太懂) 不胜感激~~
self.tdfc1 = nn.Linear(D, 512) self.td1 = TimeDistributed(self.tdfc1) self.tdbn1 = nn.BatchNorm2d(1)
self.tdfc2 = nn.Linear(D, 512)
self.td2 = TimeDistributed(self.tdfc2)
self.tdbn2 = nn.BatchNorm2d(1)
self.fc1 = nn.Linear(1024, 512)
self.bn1 = nn.BatchNorm1d(512)
self.fc2 = nn.Linear(512, C)
.... def forward(self, x, y): if self.opt['use_char_word']: x = self.embed_char(x.long()) y = self.embed_word(y.long()) elif self.opt['use_word_char']: x = self.embed_word(x.long()) y = self.embed_char(y.long()) else: x = self.embed(x.long()) y = self.embed(y.long())
if self.opt['static']:
x = x.detach()
x = F.relu(self.tdbn1(self.td1(x).unsqueeze(1))).squeeze(1)
if self.opt['static']:
y = y.detach()
y = F.relu(self.tdbn2(self.td2(y).unsqueeze(1))).squeeze(1)
x = x.mean(1).squeeze(1)
y = y.mean(1).squeeze(1)
x = torch.cat((x, y), 1)
x = F.relu(self.bn1(self.fc1(x)))
logit = self.fc2(x)
return logit
......
class TimeDistributed(nn.Module): def init(self, module): super(TimeDistributed, self).init() self.module = module
def forward(self, x):
if len(x.size()) <= 2:
return self.module(x)
n, t = x.size(0), x.size(1)
# merge batch and seq dimensions
x_reshape = x.contiguous().view(t * n, x.size(2))
y = self.module(x_reshape)
# we have to reshape Y
y = y.contiguous().view(n, t, y.size()[1])
return y
可以理解为对原始的embedding做一个transform,是一层timestep的全连接,会提升模型效果。纬度原先是256,后面可以transfer到其他维度,比如这里的512
可以理解为对原始的embedding做一个transform,是一层timestep的全连接,会提升模型效果。纬度原先是256,后面可以transfer到其他维度,比如这里的512