W-Net-PyTorch
W-Net-PyTorch copied to clipboard
推断
你在推断的时候out = wnet(src_tensor, target_tensor),输入了target的tensor,这是不是影响结果了。 正常推断不是只有src么
你在推断的时候out = wnet(src_tensor, target_tensor),输入了target的tensor,这是不是影响结果了。 正常推断不是只有src么
推断的时候,需要提供两个信息,一个是字体,一个是字,所以是两个输入
你的target_tensor是
target_img = generate_img(target_word, target_font_file, font_size=60)
target_tensor = totensor(target_img).unsqueeze(0)
这么得到的,送入到网络forward
def forward(self, lx, rx):
rout = self.right(rx)
...
de_0 = self.deconv1(torch.cat([lout_5, rout_5], dim=1))
...
return de...
这里网络就加载了tgt_tensor
你的target_tensor是
target_img = generate_img(target_word, target_font_file, font_size=60) target_tensor = totensor(target_img).unsqueeze(0)
这么得到的,送入到网络forward
def forward(self, lx, rx): rout = self.right(rx) ... de_0 = self.deconv1(torch.cat([lout_5, rout_5], dim=1)) ... return de...
这里网络就加载了tgt_tensor
对,这个地方确实写错了,谢谢指正