pytorch
pytorch copied to clipboard
Allow tensor subclasses and add `torch.serialization.mark_safe_globals` that allows users to allowlist classes for `weights_only` load
Conditions for allowlisting tensor subclasses
We allowlist tensor subclasses that
(1) Do not override __setstate__
, __getattr__
, __setattr__
(2) Use the generic tp_alloc
(3) Are in a module that has been imported by the user
*Note that we use inspect.getattr_static(sys.modules[module], name)
to get the class as this method claims to have no code execution.
The rationale for these two conditions is as follows:
The rebuild func provided by Tensor.__reduce_ex__
is torch._tensor._rebuild_from_type_v2
, which is defined as such (note the call to getattr
, Tensor.__setstate__
and the call to as_subclass
as well as the call to _set_obj_state
which calls setattr
)
https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/_tensor.py#L57-L71
as_subclass
is implemented with a call to THPVariable_NewWithVar
that will eventually call tp_alloc
here
https://github.com/pytorch/pytorch/blob/4e66aaa01092ddc8822bbca315b673329c76f4cd/torch/csrc/autograd/python_variable.cpp#L2053
The func
arg to _rebuild_from_type_v2
for wrapper subclasses is Tensor.rebuild_wrapper_subclass
, which will similarly call into THPVariable_NewWithVar
and hit the above tp_alloc
Note that we do not call tp_init
or tp_new
(i.e. cls.__init__
or cls.__new__
) when unpickling
How do we check something is a tensor subclass/constraints around imports
In order to check whether bla
is a tensor subclass in the bytecode GLOBAL module.name
, we need to do an issubclass
check, which entails converting the global string to the appropriate type. We do not arbitrarily import modules but will perform this check as long as the given subclass (given by module.name
) has already been imported by the user (i.e. module in sys.modules
and issubclass(getattr(sys[modules], name), torch.Tensor)
This PR also allowlisted torch._utils._rebuild_wrapper_subclass
and torch.device
(used by _rebuild_wrapper_subclass
)
API for allow listing
This PR also added torch.serialization.mark_safe_globals
that enables user to allowlist globals they have deemed safe (for example they could allowlist a tensor subclass with a custom __setstate__
if they have checked that this is safe.
Next steps:
- Add testing and allowlist required classes for all in-core tensor subclasses (e.g.
DTensor
,FakeTensor
etc.)
Stack from ghstack (oldest at bottom):
- -> #124331
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/124331
- :page_facing_up: Preview Python docs built from this PR
- :page_facing_up: Preview C++ docs built from this PR
- :question: Need help or want to give feedback on the CI? Visit the bot commands wiki or our office hours
Note: Links to docs will display an error until the docs builds have been completed.
:white_check_mark: No Failures
As of commit e91a1de143a89b77aed03473f786194ea12955a4 with merge base 8619fe6214cd8f31345ae73c5b90024a0233dc40 ():
:green_heart: Looks good so far! There are no failures yet. :green_heart:
This comment was automatically generated by Dr. CI and updates every 15 minutes.
@pytorchbot merge
Merge started
Your change will be merged once all checks pass (ETA 0-4 Hours).
Learn more about merging in the wiki.
Questions? Feedback? Please reach out to the PyTorch DevX TeamAdvanced Debugging
Check the merge workflow status
here