rl
rl copied to clipboard
[BUG] torch geometric layers not working in policy network
Describe the bug
I am trying to integrate torch geometric layers into my policy network and I think I am running into a variant of https://github.com/pytorch/rl/issues/1613
To Reproduce
Include any torch geometric layer (i used torch_geometric.nn.Linear) in a policy and try to collect data from it with a SyncDataCollector.
from torch_geometric.nn import Linear
from torch import nn
class GNNMessageNet(nn.Module):
def __init__(self):
nn.Module.__init__(self)
self.conv = Linear(10, 32)
Traceback (most recent call last):
File "/home/rerz/.local/share/JetBrains/IntelliJIdea2024.2/python-ce/helpers/pydev/pydevd.py", line 1570, in _exec
pydev_imports.execfile(file, globals, locals) # execute the script
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rerz/.local/share/JetBrains/IntelliJIdea2024.2/python-ce/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "<project>/packages/gridworks/src/gridworks/task/train.py", line 101, in <module>
collector = SyncDataCollector(
^^^^^^^^^^^^^^^^^^
File "<project>/.venv/lib/python3.12/site-packages/torchrl/collectors/collectors.py", line 654, in __init__
(self.policy, self.get_weights_fn,) = self._get_policy_and_device(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<project>/.venv/lib/python3.12/site-packages/torchrl/collectors/collectors.py", line 236, in _get_policy_and_device
policy = deepcopy(policy)
^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 162, in deepcopy
y = _reconstruct(x, memo, *rv)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 259, in _reconstruct
state = deepcopy(state, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
y = copier(x, memo)
^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 162, in deepcopy
y = _reconstruct(x, memo, *rv)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 259, in _reconstruct
state = deepcopy(state, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
y = copier(x, memo)
^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
y = copier(x, memo)
^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 162, in deepcopy
y = _reconstruct(x, memo, *rv)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 259, in _reconstruct
state = deepcopy(state, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
y = copier(x, memo)
^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
y = copier(x, memo)
^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 162, in deepcopy
y = _reconstruct(x, memo, *rv)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 259, in _reconstruct
state = deepcopy(state, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
y = copier(x, memo)
^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
y = copier(x, memo)
^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 162, in deepcopy
y = _reconstruct(x, memo, *rv)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 259, in _reconstruct
state = deepcopy(state, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
y = copier(x, memo)
^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
y = copier(x, memo)
^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 162, in deepcopy
y = _reconstruct(x, memo, *rv)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 259, in _reconstruct
state = deepcopy(state, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
y = copier(x, memo)
^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
y = copier(x, memo)
^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 162, in deepcopy
y = _reconstruct(x, memo, *rv)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 259, in _reconstruct
state = deepcopy(state, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
y = copier(x, memo)
^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
y = copier(x, memo)
^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 162, in deepcopy
y = _reconstruct(x, memo, *rv)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 259, in _reconstruct
state = deepcopy(state, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
y = copier(x, memo)
^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
y = copier(x, memo)
^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 162, in deepcopy
y = _reconstruct(x, memo, *rv)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 259, in _reconstruct
state = deepcopy(state, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
y = copier(x, memo)
^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
y = copier(x, memo)
^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/copy.py", line 143, in deepcopy
y = copier(memo)
^^^^^^^^^^^^
File "<project>/.venv/lib/python3.12/site-packages/torch_geometric/nn/dense/linear.py", line 129, in __deepcopy__
out.weight = copy.deepcopy(self.weight, memo)
^^^^^^^^^^
File "<project>/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1959, in __setattr__
raise TypeError(
TypeError: cannot assign 'torch.meta.FloatTensor' as parameter 'weight' (torch.nn.Parameter or None expected)
File ...
Expected behavior
I can use torch geometric layers in my policy as any other torch layer.
System info
Describe the characteristic of your environment:
- Python 3.12
- torchrl/tensordict main branch
- torch_geometric 2.6.1
Additional context
Issue seems to be happening while deepcopying the torch geometric layer.
Checklist
- [x ] I have checked that there is no similar issue in the repo (required)
- [x ] I have read the documentation (required)
- [x ] I have provided a minimal working example to reproduce the bug (required)