deep-rl-class icon indicating copy to clipboard operation
deep-rl-class copied to clipboard

[HANDS-ON BUG] Unit 8 part 2, sample-factory 2.1.1 not compatible with pytorch 2.6+

Open B-ramB opened this issue 6 months ago • 5 comments

Describe the bug

Sample-factory utilizes a pytorch function: checkpoint_dict = torch.load(latest_checkpoint, map_location=device)

Pytorch 2.6+ requires an additional weights_only=False variable for the function.

Material

How to Fix

pip install the latest sample-factory: !pip install sample-factory==2.1.3

B-ramB avatar Jun 24 '25 20:06 B-ramB

PR: https://github.com/huggingface/deep-rl-class/pull/627

B-ramB avatar Jun 24 '25 20:06 B-ramB

I fixed this by running this code:

import torch.serialization
import _codecs
import numpy as np
torch.serialization.add_safe_globals([np.core.multiarray.scalar, np.dtype, np.dtypes.Float64DType, _codecs.encode])

before the cell with: from sample_factory.enjoy import enjoy cfg = parse_vizdoom_cfg....

blanck avatar Jun 29 '25 22:06 blanck

I fixed this by running this code:

import torch.serialization
import _codecs
import numpy as np
torch.serialization.add_safe_globals([np.core.multiarray.scalar, np.dtype, np.dtypes.Float64DType, _codecs.encode])

before the cell with: from sample_factory.enjoy import enjoy cfg = parse_vizdoom_cfg....

It worked, thank you!

0xfabrica avatar Jul 03 '25 19:07 0xfabrica

🐛 Issue: Can't load Sample Factory checkpoints in PyTorch 2.6+

Problem: PyTorch 2.6 changed the default of torch.load() to weights_only=True, breaking compatibility with older sample-factory checkpoints saved using full pickle. Even with weights_only=False, loading fails unless specific types are allowlisted.

Error:

_pickle.UnpicklingError: Can only build Tensor, Parameter, OrderedDict or types allowlisted... but got <class 'numpy.dtypes.Float64DType'>

Fix: To load old checkpoints with weights_only=False, add this before torch.load():

import torch
import numpy as np

torch.serialization.add_safe_globals([
    np.core.multiarray.scalar,
    np.dtype,
    np.dtype('float64').__class__,
])

Suggestion: Sample Factory should provide a script to re-save checkpoints in weights_only=True format for PyTorch 2.6+. PyTorch could also improve error messaging or ease migration.

hosseinkamyab avatar Jul 24 '25 09:07 hosseinkamyab

For me after install sample-factory I install torch=2.5.0

bohlinz avatar Aug 03 '25 09:08 bohlinz