StyleTTS2 icon indicating copy to clipboard operation
StyleTTS2 copied to clipboard

feat: Improve model checkpoint loading

Open 5Hyeons opened this issue 8 months ago • 0 comments

Summary

This PR fixes the checkpoint loading issue in the second stage of training when using a single GPU. The second stage adds a 'module.' prefix to all parameter names, causing a mismatch with the first stage parameters.

Changes

  • Improved checkpoint loading to handle mismatched state_dict keys.
  • If direct loading fails, a new state_dict with adjusted keys is created and loaded.

Notes

  • Previously, print('%s loaded' % key) suggested parameters were loaded, even though strict=False prevented actual loading if keys did not match. This PR addresses this by ensuring proper parameter loading.

Related Issue

5Hyeons avatar Jun 13 '24 05:06 5Hyeons