StyleTTS2
StyleTTS2 copied to clipboard
Stage 2 Training Fails with NaN Loss on Single GPU Due to Inconsistent Checkpoint Keys
Description
When trying to start stage 2 training after completing stage 1 using a single A100 80GB GPU with My Korean dataset, I encountered an issue where g_loss
becomes NaN
.
Upon investigation, it was found that the y_rec_gt_pred
output from model.decoder
was NaN
:
ipdb> y_rec_gt_pred
tensor([[[nan, nan, nan, ..., nan, nan, nan]],
[[nan, nan, nan, ..., nan, nan, nan]],
[[nan, nan, nan, ..., nan, nan, nan]],
[[nan, nan, nan, ..., nan, nan, nan]]], device='cuda:0')
The s variable, an input to the decoder, had abnormally large values:
ipdb> p s
tensor([[ 4.1757e+18, -2.2990e+18, -4.4499e+18, -4.2868e+18, 2.8710e+18,
1.4094e+17, -2.4173e+18, -5.7211e+18, 1.2887e+18, 1.2334e+18,
...
-3.6897e+18, 2.0664e+17, -3.9657e+18, 2.1473e+18, 2.9162e+18,
-2.3997e+18, 4.6772e+18, 3.3755e+17, -1.0300e+17, -1.7092e+18,
2.6885e+18, -3.8825e+18, -2.4909e+18]], device='cuda:0',
grad_fn=<GatherBackward>)
The s input is derived from the style encoder:
s = model.style_encoder(st.unsqueeze(1) if multispeaker else gt.unsqueeze(1))
The inputs st and gt were within normal ranges, and i found the style encoder’s weights were not properly loaded from the checkpoint:
print(type(model.style_encoder.shared[0]))
print('params weight:')
print(params['style_encoder']['shared.0.weight_orig'][0])
print('model weight:')
print(model.style_encoder.shared[0].weight[0], model.style_encoder.shared[0].weight[0].device)
print('params bias:')
print(params['style_encoder']['shared.0.bias'][0])
print('model bias:')
print(model.style_encoder.shared[0].bias[0], model.style_encoder.shared[0].bias[0].device)
<class 'torch.nn.modules.conv.Conv2d'>
params weight:
tensor([[[-0.0216, -0.2234, -0.3011],
[ 0.3677, 0.4262, 0.0299],
[ 0.3779, 0.2360, 0.1636]]])
model weight:
tensor([[[ 0.1406, -0.1316, -0.1859],
[-0.2805, -0.0880, -0.2270],
[ 0.1440, -0.0139, 0.1625]]]) cpu
params bias:
tensor(0.1457)
model bias:
tensor(-0.0927, device='cuda:0', grad_fn=<SelectBackward0>) cuda:0
Cause
The issue arises due to inconsistent key names when loading checkpoints between the first and second stages. In the second stage, the MyDataParallel class is used, which prefixes all model keys with ‘module.’. However, if you are using single gpu, the first stage does not apply this prefix when saving checkpoints. --> https://github.com/yl4579/StyleTTS2/issues/120
This inconsistency prevents the proper loading of the model parameters, leading to NaN values in the loss calculation.
Solution
To address this, I’ve updated the load_checkpoint
function to handle cases where the checkpoint keys do not match the model keys by creating a new state_dict with matching keys if direct loading fails.
Updated load_checkpoint Function
def load_checkpoint(model, optimizer, path, load_only_params=True, ignore_modules=[]):
state = torch.load(path, map_location='cpu')
params = state['net']
for key in model:
if key in params and key not in ignore_modules:
print('%s loaded' % key)
try:
model[key].load_state_dict(params[key], strict=True)
except:
from collections import OrderedDict
state_dict = params[key]
new_state_dict = OrderedDict()
print(f'{key} key length: {len(model[key].state_dict().keys())}, state_dict length: {len(state_dict.keys())}')
for (k_m, v_m), (k_c, v_c) in zip(model[key].state_dict().items(), state_dict.items()):
new_state_dict[k_m] = v_c
model[key].load_state_dict(new_state_dict, strict=True)
_ = [model[key].eval() for key in model]
if not load_only_params:
epoch = state["epoch"]
iters = state["iters"]
optimizer.load_state_dict(state["optimizer"])
else:
epoch = 0
iters = 0
return model, optimizer, epoch, iters
Additionally, I have submitted a PR to address this issue: https://github.com/yl4579/StyleTTS2/pull/253