[HANDS-ON BUG] Unit 8 part 2, sample-factory 2.1.1 not compatible with pytorch 2.6+
Describe the bug
Sample-factory utilizes a pytorch function:
checkpoint_dict = torch.load(latest_checkpoint, map_location=device)
Pytorch 2.6+ requires an additional weights_only=False variable for the function.
Material
How to Fix
pip install the latest sample-factory:
!pip install sample-factory==2.1.3
PR: https://github.com/huggingface/deep-rl-class/pull/627
I fixed this by running this code:
import torch.serialization
import _codecs
import numpy as np
torch.serialization.add_safe_globals([np.core.multiarray.scalar, np.dtype, np.dtypes.Float64DType, _codecs.encode])
before the cell with: from sample_factory.enjoy import enjoy cfg = parse_vizdoom_cfg....
I fixed this by running this code:
import torch.serialization import _codecs import numpy as np torch.serialization.add_safe_globals([np.core.multiarray.scalar, np.dtype, np.dtypes.Float64DType, _codecs.encode])before the cell with: from sample_factory.enjoy import enjoy cfg = parse_vizdoom_cfg....
It worked, thank you!
🐛 Issue: Can't load Sample Factory checkpoints in PyTorch 2.6+
Problem:
PyTorch 2.6 changed the default of torch.load() to weights_only=True, breaking compatibility with older sample-factory checkpoints saved using full pickle. Even with weights_only=False, loading fails unless specific types are allowlisted.
Error:
_pickle.UnpicklingError: Can only build Tensor, Parameter, OrderedDict or types allowlisted... but got <class 'numpy.dtypes.Float64DType'>
Fix:
To load old checkpoints with weights_only=False, add this before torch.load():
import torch
import numpy as np
torch.serialization.add_safe_globals([
np.core.multiarray.scalar,
np.dtype,
np.dtype('float64').__class__,
])
Suggestion:
Sample Factory should provide a script to re-save checkpoints in weights_only=True format for PyTorch 2.6+. PyTorch could also improve error messaging or ease migration.
For me after install sample-factory I install torch=2.5.0