ZenCtrl icon indicating copy to clipboard operation
ZenCtrl copied to clipboard

torch.load fails due to weights_only=True in PyTorch >=2.6 when loading quantized model

Open umerkayvyro opened this issue 7 months ago • 3 comments

When running app/gradio_app.py, loading the model via FluxTransformer2DModel.from_pretrained(...) fails with a pickle.UnpicklingError due to weights_only=True being the new default in PyTorch 2.6+.

Error:

_pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options...

In particular, the model loading fails due to:

transformer_model = FluxTransformer2DModel.from_pretrained(
    "sayakpaul/flux.1-schell-int8wo-improved",
    torch_dtype=torch.bfloat16,
    use_safetensors=False,
)

Temporary Fix:

I patched torch.load globally as follows to resolve the issue:

original_torch_load = torch.load
def patched_torch_load(*args, **kwargs):
    kwargs['weights_only'] = False
    return original_torch_load(*args, **kwargs)
torch.load = patched_torch_load

This forces weights_only=False, which restores compatibility with model files containing pickled classes like torchao.dtypes.affine_quantized_tensor.AffineQuantizedTensor.

Suggestion:

Consider adding this workaround (or a conditional variant) in gradio_app.py.

umerkayvyro avatar May 12 '25 14:05 umerkayvyro

Thanks , this might help the community while we are unifying everything

saliou-k avatar May 12 '25 16:05 saliou-k

AttributeError: Can't get attribute 'PlainAQTLayout' on <module 'torchao.dtypes.affine_quantized_tensor' from 'D:\conda\envs\zenctrl\lib\site-packages\torchao\dtypes\affine_quantized_tensor.py'>

Jandown avatar May 13 '25 03:05 Jandown

@Jandown that is not an issue I was able to fix.

umerkayvyro avatar May 15 '25 13:05 umerkayvyro