COMET
COMET copied to clipboard
Explicitly set `weights_only` to `False` in `torch.load()`
Running the example command with PyTorch >= 2.6 gives the following error:
[rank0]: Traceback (most recent call last):
[rank0]: File "/mnt/users_home/cpii.local/yxing/miniconda3/envs/comet/bin/comet-score", line 7, in <module>
[rank0]: sys.exit(score_command())
[rank0]: ^^^^^^^^^^^^^^^
[rank0]: File "/mnt/users_home/cpii.local/yxing/miniconda3/envs/comet/lib/python3.11/site-packages/comet/cli/score.py", line 203, in score_command
[rank0]: outputs = model.predict(
[rank0]: ^^^^^^^^^^^^^^
[rank0]: File "/mnt/users_home/cpii.local/yxing/miniconda3/envs/comet/lib/python3.11/site-packages/comet/models/base.py", line 663, in predict
[rank0]: predictions = pred_writer.gather_all_predictions()
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/mnt/users_home/cpii.local/yxing/miniconda3/envs/comet/lib/python3.11/site-packages/comet/models/predict_writer.py", line 99, in gather_all_predictions
[rank0]: [
[rank0]: File "/mnt/users_home/cpii.local/yxing/miniconda3/envs/comet/lib/python3.11/site-packages/comet/models/predict_writer.py", line 100, in <listcomp>
[rank0]: flatten_predictions(torch.load(os.path.join(self.output_dir, f)))
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/mnt/users_home/cpii.local/yxing/miniconda3/envs/comet/lib/python3.11/site-packages/torch/serialization.py", line 1529, in load
[rank0]: raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
[rank0]: _pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, do those steps only if you trust the source of the checkpoint.
[rank0]: (1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
[rank0]: (2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
[rank0]: WeightsUnpickler error: Unsupported global: GLOBAL comet.models.utils.Prediction was not an allowed global by default. Please use `torch.serialization.add_safe_globals([comet.models.utils.Prediction])` or the `torch.serialization.safe_globals([comet.models.utils.Prediction])` context manager to allowlist this global if you trust this class/function.
[rank0]: Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.
Explicitly setting weights_only to False fixed it for me.