Recursion depth exceeded with custom `__getattr__` on `torch.nn.module`
Bug description
Python runtime throws an exception when the Fabric wraps the torch.nn.Module whose getattr is overriden.
What version are you seeing the problem on?
v2.1
How to reproduce the bug
from tests_fabric.helpers.models import BoringFabric
from typing import Union, Any
def wrapped__getattr__(self,name:str) -> Union[torch.Tensor, torch.nn.Module]:
result = self.original__get_attr__(name)
try:
print(f"Inside custom getattr on torch module for : {name}")
self._custom_attr ="Custom"
except Exception as e:
print(f"Exception occured: {e}")
return result
torch.nn.modules.Module.original__get_attr__ = torch.nn.modules.Module.__getattr__
torch.nn.modules.Module.__getattr__ = wrapped__getattr__
def test_wrapper():
fabric = BoringFabric()
fabric.expected_dtype="bf16-mixed"
fabric.run()
Error messages and logs
strategies/test_single_device.py::test_wrapper FAILED
=============================================================================== FAILURES ===============================================================================
_____________________________________________________________________________ test_wrapper _____________________________________________________________________________
def test_wrapper():
fabric = BoringFabric()
fabric.expected_dtype="bf16-mixed"
> fabric.run()
strategies/test_single_device.py:230:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../../venv/lib/python3.10/site-packages/lightning/fabric/fabric.py:925: in _wrap_and_launch
return to_run(*args, **kwargs)
../../../../venv/lib/python3.10/site-packages/lightning/fabric/fabric.py:930: in _wrap_with_setup
return to_run(*args, **kwargs)
helpers/models.py:56: in run
model = self.get_model()
helpers/models.py:36: in get_model
return nn.Linear(32, 2)
../../../../venv/lib/python3.10/site-packages/torch/nn/modules/linear.py:96: in __init__
self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs))
../../../../venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1715: in __setattr__
self.register_parameter(name, value)
../../../../venv/lib/python3.10/site-packages/torch/nn/modules/module.py:577: in register_parameter
elif hasattr(self, name) and name not in self._parameters:
strategies/test_single_device.py:214: in wrapped__getattr__
result = self.original__get_attr__(name)
strategies/test_single_device.py:173: in wrapped__getattr__
result = self.original__get_attr__(name)
strategies/test_single_device.py:173: in wrapped__getattr__
result = self.original__get_attr__(name)
E RecursionError: maximum recursion depth exceeded
!!! Recursion detected (same locals & position)
======================================================================= short test summary info ========================================================================
FAILED strategies/test_single_device.py::test_wrapper - RecursionError: maximum recursion depth exceeded
Environment
* CUDA:
- GPU: None
- available: False
- version: None
* Lightning:
- lightning: 2.1.3
- lightning-cloud: 0.5.57
- lightning-habana: 1.3.0
- lightning-utilities: 0.10.0
- pytorch-lightning: 2.1.3
cc @carmocca @justusschock @awaelchli
@jyothisambolu This is because when you do
torch.nn.modules.Module.original__get_attr__ = torch.nn.modules.Module.__getattr__
torch.nn.modules.Module.__getattr__ = wrapped__getattr__
You are overriding the getattr on every module, including the FabricModule. In your custom getattr, when you do self._custom_attr ="Custom" it is calling __setattr__ in FabricModule, and this calls __getattr__ again, which then creates this loop.
My suggestion is avoid overriding getattr for all nn.Modules. Override it only on the models you care about. Example:
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Linear(2, 2)
...
MyModel.original__get_attr__ = MyModel.__getattr__
MyModel.__getattr__ = wrapped__getattr__
Hey @jyothisambolu can you take a look at my reply?
Hey @jyothisambolu can you take a look at my reply?
Hi @awaelchli, Thanks for the solution. It may work for model-specific customizations. But if we want to use custom attr across all modules( for module debug/analysis/customization) we will still hit the issue.
I don't know a solution to this at the moment. The implementation of the getattr and setattr on the FabricModule are quite essential. I don't know how to change them to support your use case unfortunately.
I think the right way to go here is to check isinstance(self, nn.Module) and not isinstance(self, FabricModule) when you override @jyothisambolu, and if so apply the instrumentation otherwise fallback to the standard getattr.