torchexplorer icon indicating copy to clipboard operation
torchexplorer copied to clipboard

Add support for module input of type list

Open jacksonsc007 opened this issue 10 months ago • 2 comments
trafficstars

Wonderful project! I tried some baseline models and it worked well. However, it seems that it does not support modules which take only one input argument of type list. For instance:

    def forward(self, feats: list[Tensor]):
        assert len(feats) == len(self.in_channels)
        proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)]
        ....

And if I put this specific model to torchexplorer, I get the following error:

  File "/root/miniconda3/envs/rtdetr/lib/python3.10/site-packages/torchexplorer/hook/hook.py", line 329, in process_tensor
    return tensor + dummy_tensor if torch.is_floating_point(tensor) else tensor
TypeError: is_floating_point(): argument 'input' (position 1) must be Tensor, not list

I tried to modify the hooks but it did not work out.

Any suggestions? @spfrommer

jacksonsc007 avatar Dec 25 '24 06:12 jacksonsc007

Btw, if I tear the input apart:

    def forward(self, feat1, feat2, feat3):
        feats = [feat1, feat2, feat3]
        proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)]
        ....

I get the following assert error:

  File "/root/miniconda3/envs/rtdetr/lib/python3.10/site-packages/torchexplorer/structure/structure.py", line 201, in <listcomp>
    self._inner_recurse(current_struct, n[0], n[1]) for n in next_functions
  File "/root/miniconda3/envs/rtdetr/lib/python3.10/site-packages/torchexplorer/structure/structure.py", line 200, in _inner_recurse
    upstreams = _flatten([
  File "/root/miniconda3/envs/rtdetr/lib/python3.10/site-packages/torchexplorer/structure/structure.py", line 201, in <listcomp>
    self._inner_recurse(current_struct, n[0], n[1]) for n in next_functions
  File "/root/miniconda3/envs/rtdetr/lib/python3.10/site-packages/torchexplorer/structure/structure.py", line 200, in _inner_recurse
    upstreams = _flatten([
  File "/root/miniconda3/envs/rtdetr/lib/python3.10/site-packages/torchexplorer/structure/structure.py", line 201, in <listcomp>
    self._inner_recurse(current_struct, n[0], n[1]) for n in next_functions
  File "/root/miniconda3/envs/rtdetr/lib/python3.10/site-packages/torchexplorer/structure/structure.py", line 200, in _inner_recurse
    upstreams = _flatten([
  File "/root/miniconda3/envs/rtdetr/lib/python3.10/site-packages/torchexplorer/structure/structure.py", line 201, in <listcomp>
    self._inner_recurse(current_struct, n[0], n[1]) for n in next_functions
  File "/root/miniconda3/envs/rtdetr/lib/python3.10/site-packages/torchexplorer/structure/structure.py", line 166, in _inner_recurse
    assert metadata['module'] == current_module
AssertionError
^CException ignored in: <module 'threading' from '/root/miniconda3/envs/rtdetr/lib/python3.10/threading.py'>
Traceback (most recent call last):
  File "/root/miniconda3/envs/rtdetr/lib/python3.10/threading.py", line 1567, in _shutdown
    lock.acquire()
KeyboardInterrupt: 

ps: This error disappeared after I remove the transformer encoder

jacksonsc007 avatar Dec 25 '24 07:12 jacksonsc007

Modules accepting a list of arguments isn't currently supported -- right now, it's expected that the input & output tensor shapes / count are consistent.

The second example you gave should work though. Would you mind making a MWE?

spfrommer avatar Dec 26 '24 13:12 spfrommer