torchexplorer
torchexplorer copied to clipboard
Add support for module input of type list
Wonderful project! I tried some baseline models and it worked well. However, it seems that it does not support modules which take only one input argument of type list. For instance:
def forward(self, feats: list[Tensor]):
assert len(feats) == len(self.in_channels)
proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)]
....
And if I put this specific model to torchexplorer, I get the following error:
File "/root/miniconda3/envs/rtdetr/lib/python3.10/site-packages/torchexplorer/hook/hook.py", line 329, in process_tensor
return tensor + dummy_tensor if torch.is_floating_point(tensor) else tensor
TypeError: is_floating_point(): argument 'input' (position 1) must be Tensor, not list
I tried to modify the hooks but it did not work out.
Any suggestions? @spfrommer
Btw, if I tear the input apart:
def forward(self, feat1, feat2, feat3):
feats = [feat1, feat2, feat3]
proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)]
....
I get the following assert error:
File "/root/miniconda3/envs/rtdetr/lib/python3.10/site-packages/torchexplorer/structure/structure.py", line 201, in <listcomp>
self._inner_recurse(current_struct, n[0], n[1]) for n in next_functions
File "/root/miniconda3/envs/rtdetr/lib/python3.10/site-packages/torchexplorer/structure/structure.py", line 200, in _inner_recurse
upstreams = _flatten([
File "/root/miniconda3/envs/rtdetr/lib/python3.10/site-packages/torchexplorer/structure/structure.py", line 201, in <listcomp>
self._inner_recurse(current_struct, n[0], n[1]) for n in next_functions
File "/root/miniconda3/envs/rtdetr/lib/python3.10/site-packages/torchexplorer/structure/structure.py", line 200, in _inner_recurse
upstreams = _flatten([
File "/root/miniconda3/envs/rtdetr/lib/python3.10/site-packages/torchexplorer/structure/structure.py", line 201, in <listcomp>
self._inner_recurse(current_struct, n[0], n[1]) for n in next_functions
File "/root/miniconda3/envs/rtdetr/lib/python3.10/site-packages/torchexplorer/structure/structure.py", line 200, in _inner_recurse
upstreams = _flatten([
File "/root/miniconda3/envs/rtdetr/lib/python3.10/site-packages/torchexplorer/structure/structure.py", line 201, in <listcomp>
self._inner_recurse(current_struct, n[0], n[1]) for n in next_functions
File "/root/miniconda3/envs/rtdetr/lib/python3.10/site-packages/torchexplorer/structure/structure.py", line 166, in _inner_recurse
assert metadata['module'] == current_module
AssertionError
^CException ignored in: <module 'threading' from '/root/miniconda3/envs/rtdetr/lib/python3.10/threading.py'>
Traceback (most recent call last):
File "/root/miniconda3/envs/rtdetr/lib/python3.10/threading.py", line 1567, in _shutdown
lock.acquire()
KeyboardInterrupt:
ps: This error disappeared after I remove the transformer encoder
Modules accepting a list of arguments isn't currently supported -- right now, it's expected that the input & output tensor shapes / count are consistent.
The second example you gave should work though. Would you mind making a MWE?