transformers icon indicating copy to clipboard operation
transformers copied to clipboard

load_state_dict doesnt support torch._subclasses.fake_tensor.FakeTensorMode

Open thiagocrepaldi opened this issue 1 year ago • 3 comments

System Info

  • transformers version: 4.36.0
  • Platform: Linux-6.5.0-15-generic-x86_64-with-glibc2.31
  • Python version: 3.11.5
  • Huggingface_hub version: 0.19.4
  • Safetensors version: 0.4.0
  • Accelerate version: 0.24.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.3.0a0+git78a84f1 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?: no

Who can help?

No response

Information

  • [ ] The official example scripts
  • [ ] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [ ] My own task or dataset (give details below)

Reproduction

When PyTorch's FakeTensorMode is active, the underlying storage is changed to be UntypedStorage as a way to not really allocate the memory for the parameters.

As a consequence, transformers's get_tensor failed with ValueError: could not determine the shape of object type 'torch.storage.UntypedStorage'

from torch._subclasses import fake_tensor
import transformers

fake_mode = fake_tensor.FakeTensorMode(allow_non_fake_inputs=False)
with fake_mode:
    fake_model = transformers.AutoModel.from_pretrained("sshleifer/tiny-gpt2") 

Error:

Loading checkpoint shards:   0%|                                           | 0/19 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/opt/pytorch/test_mixtral.py", line 9, in <module>
    model = AutoModelForCausalLM.from_pretrained(model_id)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/ptca/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py", line 566, in from_pretrained
    return model_class.from_pretrained(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/ptca/lib/python3.11/site-packages/transformers/modeling_utils.py", line 3694, in from_pretrained
    ) = cls._load_pretrained_model(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/ptca/lib/python3.11/site-packages/transformers/modeling_utils.py", line 4079, in _load_pretrained_model
    state_dict = load_state_dict(shard_file)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/ptca/lib/python3.11/site-packages/transformers/modeling_utils.py", line 510, in load_state_dict
    return safe_load_file(checkpoint_file)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/ptca/lib/python3.11/site-packages/safetensors/torch.py", line 310, in load_file
    result[k] = f.get_tensor(k)
                ^^^^^^^^^^^^^^^
ValueError: could not determine the shape of object type 'torch.storage.UntypedStorage'

Expected behavior

transformers get_tensor should be able to load fake tensors from a fakefied checkpoint

thiagocrepaldi avatar Feb 13 '24 20:02 thiagocrepaldi

Hi @thiagocrepaldi, thanks for raising this issue!

I'm going to cc in @Narsil, the king of safetensors here.

If you want to be able to create an empty model, you can use accelerate's init_empty_weights utility:

from accelerate import init_empty_weights

with init_empty_weights():
    my_model = ModelClass(...)

amyeroberts avatar Feb 14 '24 15:02 amyeroberts

Hi @thiagocrepaldi, thanks for raising this issue!

I'm going to cc in @Narsil, the king of safetensors here.

If you want to be able to create an empty model, you can use accelerate's init_empty_weights utility:

from accelerate import init_empty_weights

with init_empty_weights():
    my_model = ModelClass(...)

Thanks, I will look into it. This API is specific to transformers, whereas the with torch._subclasses.fake_tensor.FakeTensorMode() should work with any model, transformers or otherwise.

It is because of this generality that we feel this is an important feature to support on transformers. @ezyang from PyTorch project might have an insight on this

thiagocrepaldi avatar Feb 15 '24 15:02 thiagocrepaldi

This is also tracked by https://github.com/pytorch/pytorch/issues/106732 on PyTorch

PyTorch's torch.load fixes FakeTensorMode support with #119990, but the issue above will still repro. Note that the fix for PyTorch side also uses the torch._subclasses.fake_tensor.FakeTensorMode like my proposed/rejected PR for HF https://github.com/huggingface/safetensors/pull/318

thiagocrepaldi avatar Feb 15 '24 21:02 thiagocrepaldi

I wouldn't have accepted https://github.com/huggingface/safetensors/pull/318 lol. But if the relevant library in HF/safetensors is directly manipulating storages, some sort of PR will be necessary

ezyang avatar Feb 19 '24 00:02 ezyang

get_tensor is a Rust implementation which instantiates a Storage of their own

    pub fn get_tensor(&self, name: &str) -> PyResult<PyObject> {
        let info = self.metadata.info(name).ok_or_else(|| {
            SafetensorError::new_err(format!("File does not contain tensor {name}",))
        })?;
        // let info = tensors.get(name).ok_or_else(|| {
        //     SafetensorError::new_err(format!("File does not contain tensor {name}",))
        // })?;

        match &self.storage.as_ref() {
            Storage::Mmap(mmap) => {
                let data =
                    &mmap[info.data_offsets.0 + self.offset..info.data_offsets.1 + self.offset];

                let array: PyObject = Python::with_gil(|py| PyByteArray::new(py, data).into_py(py));

                create_tensor(
                    &self.framework,
                    info.dtype,
                    &info.shape,
                    array,
                    &self.device,
                )
            }
            Storage::TorchStorage(storage) => {
                Python::with_gil(|py| -> PyResult<PyObject> {
                    let torch = get_module(py, &TORCH_MODULE)?;
                    let dtype: PyObject = get_pydtype(torch, info.dtype, false)?;
                    let torch_uint8: PyObject = get_pydtype(torch, Dtype::U8, false)?;
                    let kwargs = [(intern!(py, "dtype"), torch_uint8)].into_py_dict(py);
                    let view_kwargs = [(intern!(py, "dtype"), dtype)].into_py_dict(py);
                    let shape = info.shape.to_vec();
                    let shape: PyObject = shape.into_py(py);

                    let start = (info.data_offsets.0 + self.offset) as isize;
                    let stop = (info.data_offsets.1 + self.offset) as isize;
                    let slice = PySlice::new(py, start, stop, 1);
                    let storage: &PyObject = storage
                        .get(py)
                        .ok_or_else(|| SafetensorError::new_err("Could not find storage"))?;
                    let storage: &PyAny = storage.as_ref(py);
                    let storage_slice = storage
                        .getattr(intern!(py, "__getitem__"))?
                        .call1((slice,))?;

                    let sys = PyModule::import(py, intern!(py, "sys"))?;
                    let byteorder: String = sys.getattr(intern!(py, "byteorder"))?.extract()?;

                    let mut tensor = torch
                        .getattr(intern!(py, "asarray"))?
                        .call((storage_slice,), Some(kwargs))?
                        .getattr(intern!(py, "view"))?
                        .call((), Some(view_kwargs))?;

                    if byteorder == "big" {
                        let inplace_kwargs =
                            [(intern!(py, "inplace"), false.into_py(py))].into_py_dict(py);
                        if info.dtype == Dtype::BF16 {
                            let torch_f16: PyObject = get_pydtype(torch, Dtype::F16, false)?;
                            tensor = tensor.getattr(intern!(py, "to"))?.call(
                                (),
                                Some([(intern!(py, "dtype"), torch_f16)].into_py_dict(py)),
                            )?;
                        }

                        let numpy = tensor
                            .getattr(intern!(py, "numpy"))?
                            .call0()?
                            .getattr("byteswap")?
                            .call((), Some(inplace_kwargs))?;
                        tensor = torch.getattr(intern!(py, "from_numpy"))?.call1((numpy,))?;

                        if info.dtype == Dtype::BF16 {
                            let torch_bf16: PyObject = get_pydtype(torch, Dtype::BF16, false)?;
                            tensor = tensor.getattr(intern!(py, "to"))?.call(
                                (),
                                Some([(intern!(py, "dtype"), torch_bf16)].into_py_dict(py)),
                            )?;
                        }
                    }

                    tensor = tensor.getattr(intern!(py, "reshape"))?.call1((shape,))?;
                    if self.device != Device::Cpu {
                        let device: PyObject = self.device.clone().into_py(py);
                        let kwargs = PyDict::new(py);
                        tensor = tensor
                            .getattr(intern!(py, "to"))?
                            .call((device,), Some(kwargs))?;
                    }
                    Ok(tensor.into_py(py))
                    // torch.asarray(storage[start + n : stop + n], dtype=torch.uint8).view(dtype=dtype).reshape(shape)
                })
            }
        }
    }

thiagocrepaldi avatar Feb 20 '24 17:02 thiagocrepaldi

Well big rip. Then yes, I agree that I would route the code to avoid calling into Rust when fake mode is enabled.

ezyang avatar Feb 20 '24 18:02 ezyang

Hi @thiagocrepaldi, thanks for raising this issue!

I'm going to cc in @Narsil, the king of safetensors here.

If you want to be able to create an empty model, you can use accelerate's init_empty_weights utility:

from accelerate import init_empty_weights

with init_empty_weights():
    my_model = ModelClass(...)

Thank you @amyeroberts. Is this valid for any transformers model? is there a way the model at hand belongs is a transformers model so that we don't use init_empty_weights for a non-transformers model?

@Narsil Did you have a chance to look into this one?

thiagocrepaldi avatar Feb 23 '24 20:02 thiagocrepaldi

Thank you @amyeroberts. Is this valid for any transformers model?

Yes, this is for all transformers models.

is there a way the model at hand belongs is a transformers model so that we don't use init_empty_weights for a non-transformers model?

Sorry, I don't understand the question. Is this about how to identify if a model is a transformers model?

amyeroberts avatar Feb 26 '24 10:02 amyeroberts

Hi @thiagocrepaldi, thanks for raising this issue!

I'm going to cc in @Narsil, the king of safetensors here.

If you want to be able to create an empty model, you can use accelerate's init_empty_weights utility:

from accelerate import init_empty_weights

with init_empty_weights():
    my_model = ModelClass(...)

@amyeroberts would you be able to help me make the stable diffusion xl model to work with the init_empty_weights API?

This is what I got so far

import torch
from diffusers import DiffusionPipeline
from accelerate import init_empty_weights
from accelerate import load_checkpoint_and_dispatch


model_id = "stabilityai/stable-diffusion-xl-base-1.0"
with init_empty_weights():
    model = DiffusionPipeline.from_pretrained(model_id, low_cpu_mem_usage=False, use_safetensors=True)
model = load_checkpoint_and_dispatch(
    model, checkpoint=model_id, device_map="auto"
)

random_input = torch.randn(1, 4, 256, 256)
timestep = torch.tensor([1.0])
encoder_hidden_states = torch.randn(1, 1, 2048)
added_cond_kwargs = {
    "text_embeds": torch.randn(1, 2560),
    "time_ids": torch.tensor([1]),
}
args = (random_input, timestep, encoder_hidden_states, None, None, None, None, added_cond_kwargs)

model(*args)

but it errors out with

Traceback (most recent call last):
  File "/opt/pytorch/test_sdxl_export_hf.py", line 10, in <module>
    model = load_checkpoint_and_dispatch(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/ptca/lib/python3.11/site-packages/accelerate/big_modeling.py", line 567, in load_checkpoint_and_dispatch
    max_memory = get_balanced_memory(
                 ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/ptca/lib/python3.11/site-packages/accelerate/utils/modeling.py", line 946, in get_balanced_memory
    module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/ptca/lib/python3.11/site-packages/accelerate/utils/modeling.py", line 701, in compute_module_sizes
    for name, tensor in named_module_tensors(model, recurse=True):
  File "/opt/conda/envs/ptca/lib/python3.11/site-packages/accelerate/utils/modeling.py", line 475, in named_module_tensors
    for named_parameter in module.named_parameters(recurse=recurse):
                           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/diffusers/src/diffusers/configuration_utils.py", line 142, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'StableDiffusionXLPipeline' object has no attribute 'named_parameters'

thiagocrepaldi avatar Mar 01 '24 21:03 thiagocrepaldi

@thiagocrepaldi This is more of a question for the diffusers library, rather than transformers

amyeroberts avatar Mar 01 '24 21:03 amyeroberts

Thank you @amyeroberts. Is this valid for any transformers model?

Yes, this is for all transformers models.

is there a way the model at hand belongs is a transformers model so that we don't use init_empty_weights for a non-transformers model?

Sorry, I don't understand the question. Is this about how to identify if a model is a transformers model?

Some context might help. I am working with PyTorch ONNX exporter which export any model (not only huggingface's) without allocating memory for the tensors. My question was how could I identify that a model belong to transformers so that we could selectively use accelerate.init_empty_weights. I managed to use inspect.get_module for that, but I was wondering whether huggingface models had an special attribute/method to identify them.

thiagocrepaldi avatar Mar 05 '24 17:03 thiagocrepaldi

@thiagocrepaldi I don't understand - if you're using transformers models at some point in the pipeline it must be explicit that transformers is being used? Can't you just add some metadata on your side to reflect the source of the model?

amyeroberts avatar Mar 05 '24 19:03 amyeroberts

@thiagocrepaldi I don't understand - if you're using transformers models at some point in the pipeline it must be explicit that transformers is being used? Can't you just add some metadata on your side to reflect the source of the model?

"I" in this case is a generic PyTorch API, such as torch.onnx.dynamo_export(model, model_args, model_kwargs. Internally, ``torch.onnx.dynamo_exportwould have to determine whether it should useaccelerate.init_empty_weightsor not, depending on whethermodel` is a HF model.

I can actually solve this problem using inspect.getmodulename(model), but I was wondering whether HF had a special way of detecting without having to use inspect. If there is none, that is ok

thiagocrepaldi avatar Mar 05 '24 19:03 thiagocrepaldi

Hi @thiagocrepaldi, thanks for raising this issue!

I'm going to cc in @Narsil, the king of safetensors here.

If you want to be able to create an empty model, you can use accelerate's init_empty_weights utility:

from accelerate import init_empty_weights

with init_empty_weights():
    my_model = ModelClass(...)

Hi @Narsil did you have a chance to look into the safetensors + FakeTensorMode issue?

thiagocrepaldi avatar Mar 05 '24 19:03 thiagocrepaldi

@thiagocrepaldi I don't think we can guarantee any properties of the models that you can reliably use. The transformers models are just pytorch models. They will have properties associated with PreTrainedModel, but that doesn't mean these are exclusive to transformers models alone.

amyeroberts avatar Mar 07 '24 09:03 amyeroberts

My comment from last year still stands.

from torch._subclasses import fake_tensor is private, therefore I don't think we should do anything about it until it's officially supported.

That being said I do not guarantee that support should require a lot of change. When you are loading a file, you are, well, loading a file (if you are using fake tensors in the first place, maybe not calling load_file in the first place seems more appropriate). safetensors is not using any private part of torch, therefore I could most definitly reproduce the crash using pure python/torch.

Narsil avatar Mar 21 '24 10:03 Narsil