pytorch icon indicating copy to clipboard operation
pytorch copied to clipboard

Allow tensor subclasses and add `torch.serialization.mark_safe_globals` that allows users to allowlist classes for `weights_only` load

Open mikaylagawarecki opened this issue 10 months ago • 1 comments

Conditions for allowlisting tensor subclasses

We allowlist tensor subclasses that (1) Do not override __setstate__, __getattr__, __setattr__ (2) Use the generic tp_alloc (3) Are in a module that has been imported by the user

*Note that we use inspect.getattr_static(sys.modules[module], name) to get the class as this method claims to have no code execution.

The rationale for these two conditions is as follows:

The rebuild func provided by Tensor.__reduce_ex__ is torch._tensor._rebuild_from_type_v2, which is defined as such (note the call to getattr, Tensor.__setstate__ and the call to as_subclass as well as the call to _set_obj_state which calls setattr)

https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/_tensor.py#L57-L71

as_subclass is implemented with a call to THPVariable_NewWithVar

that will eventually call tp_alloc here https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/csrc/autograd/python_variable.cpp#L2053

The func arg to _rebuild_from_type_v2 for wrapper subclasses is Tensor.rebuild_wrapper_subclass, which will similarly call into THPVariable_NewWithVar and hit the above tp_alloc

Note that we do not call tp_init or tp_new (i.e. cls.__init__ or cls.__new__) when unpickling

How do we check something is a tensor subclass/constraints around imports

In order to check whether bla is a tensor subclass in the bytecode GLOBAL module.name, we need to do an issubclass check, which entails converting the global string to the appropriate type. We do not arbitrarily import modules but will perform this check as long as the given subclass (given by module.name) has already been imported by the user (i.e. module in sys.modules and issubclass(getattr(sys[modules], name), torch.Tensor)

This PR also allowlisted torch._utils._rebuild_wrapper_subclass and torch.device (used by _rebuild_wrapper_subclass)

API for allow listing

This PR also added torch.serialization.mark_safe_globals that enables user to allowlist globals they have deemed safe (for example they could allowlist a tensor subclass with a custom __setstate__ if they have checked that this is safe.

Next steps:

  • Add testing and allowlist required classes for all in-core tensor subclasses (e.g. DTensor, FakeTensor etc.)

Stack from ghstack (oldest at bottom):

  • -> #124331

mikaylagawarecki avatar Apr 17 '24 22:04 mikaylagawarecki

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/124331

Note: Links to docs will display an error until the docs builds have been completed.

:white_check_mark: No Failures

As of commit e91a1de143a89b77aed03473f786194ea12955a4 with merge base 8619fe6214cd8f31345ae73c5b90024a0233dc40 (image): :green_heart: Looks good so far! There are no failures yet. :green_heart:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar Apr 17 '24 22:04 pytorch-bot[bot]

@pytorchbot merge

mikaylagawarecki avatar May 17 '24 14:05 mikaylagawarecki

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging Check the merge workflow status here

pytorchmergebot avatar May 17 '24 14:05 pytorchmergebot