inseq icon indicating copy to clipboard operation
inseq copied to clipboard

Save tensors in lower precision

Open LuukSuurmeijer opened this issue 1 month ago • 1 comments

Description

Added support for saving attributions in a lower tensor precision.

Upon saving, tensors are transformed to hugginface safetensors. Then they are optionally quantized to float16, int8 or uint8 (if there are no negative values) using zeropoint quantization. The quantization parameters are stored in the safetensor object to recover the float32 values upon loading. Safetensors are bytes objects, so they need to be base64 encoded to be written to JSON.

List of changes:

  • save has an extra parameter scores_precision with default value float32
  • FeatureAttributionSequenceOutput has two new private methods: _convert_to_safetensors and _recover_from_safetensors in order to convert the object's tensors from torch tensors to safetensors and viceversa. They are used in saving / loading respectively.
  • Two util functions in torch_utils, convert_to_safetensor and dequantize_safetensor that converts a tensor both ways respectively
  • Two new unit tests for saving and loading in float16 /float8 in test_attribution.py

This is my first PR on this project and first time properly diving into inseq, so please be critical and help me improve the feature! There are several points where I am not sure about the implementation:

  • Have to deepcopy the objects while saving, is that really necessary?
  • Is it a good idea to introduce new private methods on FeatureAttributionSequenceOutput
  • Does the output JSON preserve enough readibility?
  • Should I include unit tests for the new torch_utils functions? I saw that most of them do not have unit tests, but am happy to add them

All tests run clean with no errors.

Related Issue

issue 202

Type of Change

  • 🥂 Improvement (non-breaking change which improves an existing feature)
  • 🚀 New feature (non-breaking change which adds functionality)

Checklist

  • [x] I've read the CODE_OF_CONDUCT.md document.
  • [x] I've read the CONTRIBUTING.md guide.
  • [x] I've successfully run the style checks using make fix-style.
  • [x] I've written tests for all new methods and classes that I created and successfully ran make test.
  • [x] I've written the docstring in Google format for all the methods and classes that I used.

LuukSuurmeijer avatar May 10 '24 13:05 LuukSuurmeijer

Hey @LuukSuurmeijer, thanks a lot for this PR!

I had a look and added some very minor fixes (add a Literal type for the allowed precision strings, added a docstring for the new parameter in save). I also made sure the code works fine when compress=False, but a different precision is specified. In one of my tests, however, I had a weird issue. If you run the following code:


import torch
from inseq import load_model, FeatureAttributionOutput

saliency_mt_model = load_model("Helsinki-NLP/opus-mt-en-it", "attention")

out_path = "tmp_attr_8bit.json"
out = saliency_mt_model.attribute("This is a test.", device="cpu", show_progress=False)
out.save(out_path, scores_precision="float8", overwrite=True)
loaded_out = FeatureAttributionOutput.load(out_path)
assert torch.allclose(
        out.sequence_attributions[0].source_attributions,
        loaded_out.sequence_attributions[0].source_attributions,
        atol=1e-02,
)

You get an error in the parsing of the JSON metadata header. From a very quick exploration, it seems like this is caused by the selection of the header json.loads(safetensor[8 : (7 + header_length)])["__metadata__"] in dequantize_safetensor, which cuts the json 1 character too short. This is puzzling because for other precisions the same code works fine, and I confirm that it matches the one from the Hugging Face example you referred to. If we do not find out what's the issue, we might want to set up some error handling to either 1) brute-force extraction character-by-character until a valid JSON is formed (not ideal) or at least 2) Raise an informative error about the problem, mentioning to try another precision.

gsarti avatar May 11 '24 11:05 gsarti