In training Stage1 after 49th epoch getting RuntimeError: you can only change requires_grad flags of leaf variables, g_loss.requires_grad = True
I am getting the mentioned error in this part of the code: if epoch >= TMA_epoch: # start TMA training loss_s2s = 0 for _s2s_pred, _text_input, _text_length in zip(s2s_pred, texts, input_lengths): loss_s2s += F.cross_entropy(_s2s_pred[:_text_length], _text_input[:_text_length]) loss_s2s /= texts.size(0)
loss_mono = F.l1_loss(s2s_attn, s2s_attn_mono) * 10
loss_gen_all = gl(wav.detach().unsqueeze(1).float(), y_rec).mean()
print(f'the shape of both wav and y_rec respectively {wav.shape} and {y_rec.shape}')
loss_slm = wl(wav.detach(), y_rec.squeeze(1)).mean()
g_loss = loss_params.lambda_mel * loss_mel + \
loss_params.lambda_mono * loss_mono + \
loss_params.lambda_s2s * loss_s2s + \
loss_params.lambda_gen * loss_gen_all + \
loss_params.lambda_slm * loss_slm
print(f'Generator loss is {g_loss}')
running_loss += accelerator.gather(loss_mel).mean().item()
#print(f"g-loss is {type(g_loss)}")
optimizer.zero_grad()
g_loss.requires_grad = True
g_loss.backward()
#accelerator.backward(g_loss)
optimizer.step()
# g_loss.requires_grad = True
# g_loss.backward()
I have the same issue, but for this code snippet:
d_loss = self._dl(wav.detach().unsqueeze(1).float(), y_rec.detach()).mean()
Getting into details, it's the error in the forward method, in WavLMLoss class:
def forward(self, wav, y_rec):
with torch.no_grad():
wav_16 = self.resample(wav)
wav_embeddings = self.wavlm(input_values=wav_16, output_hidden_states=True).hidden_states
y_rec_16 = self.resample(y_rec)
y_rec_embeddings = self.wavlm(input_values=y_rec_16.squeeze(), output_hidden_states=True)
y_rec_embeddings = y_rec_embeddings.hidden_states
floss = 0
for er, eg in zip(wav_embeddings, y_rec_embeddings):
floss += torch.mean(torch.abs(er - eg))
return floss.mean()
self.wavlm(input_values=y_rec_16.squeeze(), output_hidden_states=True) is giving me the exact same error and I don't know why.
What is your dependencies versions for this project?
Found the solution. you need to write self.wavlm.eval() in the start of forward method WavLMLoss class in losses module. Worked for me.