AdaSpeech
AdaSpeech copied to clipboard
maybe there is a bug in fastspeech.py
hello rishikksh20, thanks for your contribution! I found a problem when training with these code. in line 415, fastspeech.py
if avg_mel is not None:
avg_mel = avg_mel.unsqueeze(0)
# inference
before_outs, outs, d_outs, _ = self._forward(xs, ilens=ilens, ys=ref_mel, avg_mel=avg_mel,
is_inference=True,
phn_level_predictor=phn_level_predictor) # (L, odim)
else:
before_outs, outs, d_outs, _ = self._forward(xs, ilens=ilens, ys=ref_mel, is_inference=True,
phn_level_predictor=phn_level_predictor) # (L, odim)
# inference
_, outs, _, _, _ = self._forward(xs, ilens, is_inference=True) # (L, odim)
return outs[0]
I think the last inference don't need to forward?, like below.
if avg_mel is not None:
avg_mel = avg_mel.unsqueeze(0)
# inference
before_outs, outs, d_outs, _ = self._forward(xs, ilens=ilens, ys=ref_mel, avg_mel=avg_mel,
is_inference=True,
phn_level_predictor=phn_level_predictor) # (L, odim)
else:
before_outs, outs, d_outs, _ = self._forward(xs, ilens=ilens, ys=ref_mel, is_inference=True,
phn_level_predictor=phn_level_predictor) # (L, odim)
# inference
#_, outs, _, _, _ = self._forward(xs, ilens, is_inference=True) # (L, odim)
return outs[0]
also, in line 182, fastspeech.py.
def _forward(
self,
xs: torch.Tensor,
ilens: torch.Tensor,
ys: torch.Tensor = None,
olens: torch.Tensor = None,
ds: torch.Tensor = None,
es: torch.Tensor = None,
ps: torch.Tensor = None,
is_inference: bool = False,
phn_level_predictor: bool = False,
avg_mel: torch.Tensor = None,
) -> Sequence[torch.Tensor]:
# forward encoder
x_masks = self._source_mask(
ilens
) # (B, Tmax, Tmax) -> torch.Size([32, 121, 121])
hs, _ = self.encoder(
xs, x_masks
) # (B, Tmax, adim) -> torch.Size([32, 121, 256])
## AdaSpeech Specific ##
uttr = self.utterance_encoder(ys.transpose(1, 2)).transpose(1, 2)
hs = hs + uttr.repeat(1, hs.size(1), 1)
means ys shouldn't be None, but in line 24, evaluation.py
_, after_outs, d_outs, e_outs, p_outs = model._forward(x_.cuda(), ilens.cuda(), out_length_.cuda(), dur_.cuda(), es=e_.cuda(), ps=p_.cuda(), is_inference=False) # [T, num_mel]
will throw error.