nncf icon indicating copy to clipboard operation
nncf copied to clipboard

Handle __torch_function__ API properly

Open vshampor opened this issue 5 years ago • 1 comments

In PyTorch 1.7.0 (at least), the PyTorch framework allows users to subclass torch.Tensor and to define a dispatch function called __torch_function__ in order to perform custom behavior when a public PyTorch API function is called with one of the relevant arguments being the aforementioned subclass. As is, this breaks our TracedTensor approach when handling non-NNCF wrapped ops.

Previously, when a TracedTensor was supplied as an input to a non-NNCF wrapped operation, it would output a regular Tensor, which would be handled elsewhere in NNCF leading to a graph becoming disjoint, but no crash occuring. Now PyTorch will convert the type of the output of such operation to TracedTensor, but it won't preserve the crucial attributes of the TracedTensor input (.tensor_meta), so the output object will not be a proper TracedTensor and further processing of this output by NNCF will crash.

Possible solutions:

  1. Check the .tensor_meta attribute in NNCF processing directly instead of checking for the type of an operator input being equal to TracedTensor - dangerous due to possible name conflicts.
  2. Waive support of __torch_function__ API with NNCF by defining TracedTensor as follows, effectively disabling a user's custom dispatch functions along:
class TracedTensor(torch.Tensor):
    __torch_function__ = torch._C._disabled_torch_function_impl
    # ... the rest of the code is identical to current implementation
  1. Switch NNCF to using the __torch_function__ API while tracing - define this dispatch function to do all the tracing and hook-related functionality (that NNCF currently does by directly patching the functions in corresponding torch namespaces), but using the "framework-recommended" way now (https://pytorch.org/docs/stable/notes/extending.html). Possibly could keep the user's own __torch_function__ overrides by calling super().__torch_function__ in TracedTensor.__torch_function__. Increases dependency on the PyTorch framework. Also the control flow graphs that use NNCF patched operations, but do not receive an input TracedTensor exactly, won't appear in the graph at all (current approach will add such operations to the graph, but they won't be connected to the previous graph).

vshampor avatar Oct 29 '20 07:10 vshampor

Option 3 seems preferable but it should cost some substantial effort I guess. Anyway, I vote for it.

AlexKoff88 avatar Oct 29 '20 11:10 AlexKoff88

@vshampor , it's a very old issue without any follow-up. I suggest to close it. In case you want to proceed with it, let's handle it internally.

MaximProshin avatar Jun 20 '23 08:06 MaximProshin