Distributing ao tensor subclasses in .safetensors checkpoints
Context
The current status quo for distributing ao weights on huggingface is as checkpoints produced by torch.save (examples here). The reasoning is that by default, safetensors only supports saving state dictionaries of plain tensors, whereas ao weights tend to be tensor subclasses.
@drisspg and I had a conversation last week about what would be necessary for ao weights be distributed as .safetensors files rather than .pt files on huggingface. My understanding from him is this might be necessary for checkpoints to be marked as official on huggingface
Since ao tensor subclasses are wrapper tensor subclasses that have plain tensors + metadata, it should theoretically be possible to decompose subclasses into tensors and metadata --> save tensors + metadata to a safetensors file (with special handling for non-json serializable metadata) with safetensors.torch.save_file --> have a helper on top of safetensors.torch.load_file that reconstructs the subclass.
I wrote a simple prototype of what this would looks like and am looking
- To stimulate discussion on what saving ao subclasses in safetensors format might look like/whether this is needed
- For feedback on whether a solution like below covers the important cases/is a viable solution.
Simple example
We took the more straightforward case of a LinearActivationQuantizedTensor with dynamic quantization to fbgemmfp8tensor, which I understand from Driss might be the main case to target.
My understanding here is that a LinearActivationQuantizedTensor would take
-
original_weight_tensor: a plain tensor -
input_quant_func: in this caseto_fbgemm_fp8, which maps toFbgemmFp8Tensor.from_float -
quant_kwargs: in this case,activation_scale_ub
Given this, I generated a rudimentary script that would
- Extract tensor and non-tensor attributes from
LinearActivationQuantizedTensor - Pass all information necessary to reconstruct the subclass to the
metadataargument tosafetensors.torch.save_filei. In particular the functiontorchao.dtypes.fbgemm_fp8_tensor.to_fbgemm_fp8is serialized as a string "{__module__}.{__qualname__}" and the actual function object is accessed from a dict during loading. - Load tensors via
safetensors.torch.load_fileand manually read metadata from the .safetensors file in order to reconstruct the subclasses
Click me for script
import torch
import json
import inspect
from typing import Dict, Any, Callable, Optional, Union, Tuple, List
from safetensors.torch import save_file, load_file
import torchao
ALLOWED_QUANT_FUNCTIONS = {
"torchao.dtypes.fbgemm_fp8_tensor.to_fbgemm_fp8" : torchao.dtypes.fbgemm_fp8_tensor.to_fbgemm_fp8,
"torchao.dtypes.fbgemm_fp8_tensor.FbgemmFp8Tensor.from_float" : torchao.dtypes.fbgemm_fp8_tensor.FbgemmFp8Tensor.from_float
# add to me
}
def get_function_path(func: Callable) -> str:
"""Get the import path for a function."""
fullpath = f"{func.__module__}.{func.__qualname__}"
assert fullpath in ALLOWED_QUANT_FUNCTIONS
return fullpath
def create_metadata_for_tensor_subclass(tensor: torch.Tensor) -> Tuple[Dict[str, str], Dict[str, torch.Tensor]]:
"""
Create metadata for tensor subclasses from torchao.
Args:
tensor: A tensor subclass (e.g., LinearActivationQuantizedTensor)
Returns:
Tuple of (metadata, tensors_dict) where:
- metadata: Dictionary with metadata needed to reconstruct the tensor
- tensors_dict: Dictionary with tensors to save
"""
metadata = {}
tensors_dict = {}
if tensor.__class__.__name__ == "LinearActivationQuantizedTensor":
metadata["tensor_type"] = "LinearActivationQuantizedTensor"
quant_func_path = get_function_path(tensor.input_quant_func)
metadata["input_quant_func"] = quant_func_path
if hasattr(tensor, "quant_kwargs"):
metadata["quant_kwargs"] = json.dumps(tensor.quant_kwargs)
tensors_dict["original_weight"] = tensor.original_weight_tensor
else:
raise ValueError(f"Unsupported tensor type: {tensor.__class__.__name__}")
return metadata, tensors_dict
def save_tensor_subclass(tensor: torch.Tensor, file_path: str, additional_metadata: Optional[Dict[str, str]] = None):
"""
Save a tensor subclass with appropriate metadata.
Args:
tensor: The tensor subclass to save
file_path: Path where to save the tensor
additional_metadata: Optional additional metadata to include
"""
metadata, tensors_dict = create_metadata_for_tensor_subclass(tensor)
if additional_metadata:
metadata.update(additional_metadata)
save_file(tensors_dict, file_path, metadata=metadata)
print(f"Saved tensor subclass to {file_path} with metadata")
def save_tensor_subclass_dict(tensor_dict: Dict[str, torch.Tensor], file_path: str,
additional_metadata: Optional[Dict[str, str]] = None):
"""
Save a dictionary of tensor subclasses with appropriate metadata.
Args:
tensor_dict: Dictionary of tensor subclasses to save, with keys as tensor names
file_path: Path where to save the tensors
additional_metadata: Optional additional metadata to include
"""
combined_metadata = {}
combined_tensors_dict = {}
for tensor_name, tensor in tensor_dict.items():
# TODO: handle case where tensor is a plain tensor
metadata, tensors_dict = create_metadata_for_tensor_subclass(tensor)
prefixed_tensors_dict = {f"{tensor_name}:{key}": value for key, value in tensors_dict.items()}
for key, value in metadata.items():
combined_metadata[f"{tensor_name}:{key}"] = value
combined_tensors_dict.update(prefixed_tensors_dict)
combined_metadata["tensor_names"] = json.dumps(list(tensor_dict.keys()))
if additional_metadata:
combined_metadata.update(additional_metadata)
save_file(combined_tensors_dict, file_path, metadata=combined_metadata)
print(f"Saved {len(tensor_dict)} tensor subclasses to {file_path} with metadata")
def load_tensor_subclass(file_path: str) -> torch.Tensor:
"""
Load a tensor subclass from a safetensors file.
Args:
file_path: Path to the safetensors file
Returns:
The reconstructed tensor subclass
"""
loaded_tensors = load_file(file_path)
with open(file_path, "rb") as f:
import struct
header_size = struct.unpack("<Q", f.read(8))[0]
header_bytes = f.read(header_size)
header = json.loads(header_bytes)
metadata = header.get("__metadata__", {})
assert "tensor_names" not in metadata
tensor_type = metadata.get("tensor_type")
if tensor_type == "LinearActivationQuantizedTensor":
original_weight = loaded_tensors["original_weight"]
quant_func_path = metadata.get("input_quant_func")
if quant_func_path not in ALLOWED_QUANT_FUNCTIONS:
raise ValueError(f"Security error: Quantization function '{quant_func_path}' is not in the allowed list")
quant_func = ALLOWED_QUANT_FUNCTIONS.get(quant_func_path)
quant_kwargs = json.loads(metadata.get("quant_kwargs", "{}"))
from torchao.quantization.linear_activation_quantized_tensor import to_linear_activation_quantized
return to_linear_activation_quantized(
original_weight,
input_quant_func=quant_func,
quant_kwargs=quant_kwargs
)
else:
return loaded_tensors
def load_tensor_subclass_dict(file_path: str) -> Dict[str, torch.Tensor]:
"""
Load a dictionary of tensor subclasses from a safetensors file.
Args:
file_path: Path to the safetensors file
Returns:
Dictionary of reconstructed tensor subclasses
"""
loaded_tensors = load_file(file_path)
with open(file_path, "rb") as f:
import struct
header_size = struct.unpack("<Q", f.read(8))[0]
header_bytes = f.read(header_size)
header = json.loads(header_bytes)
metadata = header.get("__metadata__", {})
if "tensor_names" not in metadata:
tensor = load_tensor_subclass(file_path)
return {"tensor": tensor}
tensor_names = json.loads(metadata["tensor_names"])
result = {}
for tensor_name in tensor_names:
tensor_metadata = {}
for key, value in metadata.items():
if key.startswith(f"{tensor_name}:"):
# Remove the prefix
tensor_metadata[key[len(tensor_name)+1:]] = value
tensor_tensors = {}
for key, value in loaded_tensors.items():
if key.startswith(f"{tensor_name}:"):
# Remove the prefix
tensor_tensors[key[len(tensor_name)+1:]] = value
tensor_type = tensor_metadata.get("tensor_type")
if tensor_type == "LinearActivationQuantizedTensor":
original_weight = tensor_tensors["original_weight"]
quant_func_path = tensor_metadata.get("input_quant_func")
if quant_func_path not in ALLOWED_QUANT_FUNCTIONS:
raise ValueError(f"Security error: Quantization function '{quant_func_path}' is not in the allowed list")
quant_func = ALLOWED_QUANT_FUNCTIONS.get(quant_func_path)
quant_kwargs = json.loads(tensor_metadata.get("quant_kwargs", "{}"))
from torchao.quantization.linear_activation_quantized_tensor import to_linear_activation_quantized
result[tensor_name] = to_linear_activation_quantized(
original_weight,
input_quant_func=quant_func,
quant_kwargs=quant_kwargs
)
else:
result[tensor_name] = tensor_tensors
return result
if __name__ == "__main__":
from torchao.dtypes.fbgemm_fp8_tensor import to_fbgemm_fp8
from torchao.quantization.linear_activation_quantized_tensor import to_linear_activation_quantized
weight1 = torch.randn(32, 64, dtype=torch.float32)
weight2 = torch.randn(64, 128, dtype=torch.float32)
weight3 = torch.randn(128, 256, dtype=torch.float32)
fp8_weight1 = to_linear_activation_quantized(
weight1,
input_quant_func=to_fbgemm_fp8,
quant_kwargs={"activation_scale_ub": 0.3}
)
fp8_weight2 = to_linear_activation_quantized(
weight2,
input_quant_func=to_fbgemm_fp8,
quant_kwargs={"activation_scale_ub": 0.5}
)
fp8_weight3 = to_linear_activation_quantized(
weight3,
input_quant_func=to_fbgemm_fp8,
quant_kwargs={"activation_scale_ub": 0.7}
)
tensor_dict = {
"layer1.weight": fp8_weight1,
"layer2.weight": fp8_weight2,
"layer3.weight": fp8_weight3
}
print("Saving tensor subclasses...")
print(tensor_dict)
save_tensor_subclass_dict(tensor_dict, "fp8_weights_multi.safetensors")
reconstructed_dict = load_tensor_subclass_dict("fp8_weights_multi.safetensors")
print(f"Loaded {len(reconstructed_dict)} tensors:")
for name, tensor in reconstructed_dict.items():
print(name, tensor)
Questions
- Does a solution like the above sufficiently handle the main cases ao cares about?
a. Does this approach scale to subclasses other than
LinearQuantizedActivationTensorin ao or are there more non-json serializable attributes that we might need to handle/might not be possible to handle? - Do we have control over all points where these checkpoints would be loaded (as there would need to be additional helper code on top of safetensors.load_file to reconstruct the subclasses).
cc @jerryzh168 @drisspg
- This is probably kind of true today, but not sure if it will always be
- For the most part we do, I think if people plan to use AO for quant they have the module installed and at least can call into it
Not sure if this is "safe" in the safe tensors sense though
Not sure if this is "safe" in the safe tensors sense though
If I read this correctly, you are saying that this is not as "safe" as safetensors because the constructors of tensor subclasses (e.g. LinearAffineQuantizedTensor) are being called in the "helper above safetensors.safe_load". If the goal is to load custom tensor subclasses whose definitions are constantly in flux, I would argue that it is never possible to be as safe as safetensors (which limits the surface to just plain tensors).
In my view, I think the difference between torch.load(weights_only=True) and the above would be mainly that the above makes the code that is being used to rebuild the subclass very explicit, whereas torch.load(weights_only=True) is executing more opaque bytecode, but we restrict what functions the unpickler can call.
Is what you're looking for here something like neuralmagic/compressed-tensors? I don't have a complete understanding of that but my read is that it would also need to instantiate a compressor that needs to be called based on metadata in the config file
Ohh sorry I totally meant that as a question, tbh I dont totally understand the value of safe-tensors vs weights_only = True. My understanding is that you don't want pickle since you can have arbitrary code execution. I think in this case both don't have pickle so thats good?
from discussing with @mikaylagawarecki offline, I think safetensor is safer than weights_only=True because of the reasons she mentioned in https://github.com/pytorch/ao/issues/2338#issuecomment-2963715988. Plus now we have a request from unsloth to support this, so I feel it's time to build this.
Is it compatible with multiple cards? For example, if the model was quantized using 2 cards and saved as safetensor, can we load and run inference on 2 or even 4 cards?