tapnet icon indicating copy to clipboard operation
tapnet copied to clipboard

Torchscript compatibility

Open ssandler-cat opened this issue 1 year ago • 7 comments

While making the torch TAPIR model compatible with Torchscript tracing is easy by changing TAPIR.forward() in https://github.com/google-deepmind/tapnet/blob/main/torch/tapir_model.py#L196-L209 from

    out = dict(
        occlusion=torch.mean(
            torch.stack(trajectories['occlusion'][p::p]), dim=0
        ),
        tracks=torch.mean(torch.stack(trajectories['tracks'][p::p]), dim=0),
        expected_dist=torch.mean(
            torch.stack(trajectories['expected_dist'][p::p]), dim=0
        ),
        unrefined_occlusion=trajectories['occlusion'][:-1],
        unrefined_tracks=trajectories['tracks'][:-1],
        unrefined_expected_dist=trajectories['expected_dist'][:-1],
    )

    return out

to

    class Output(NamedTuple):
        occlusion: torch.tensor
        tracks: torch.tensor
        expected_dist: torch.tensor

    out = Output(torch.mean(torch.stack(trajectories['occlusion'][p::p]), dim=0),
                 torch.mean(torch.stack(trajectories['tracks'][p::p]), dim=0),
                 torch.mean(torch.stack(trajectories['expected_dist'][p::p]), dim=0)
                )

    return out

(assuming it is OK to eliminate unrefined_ from the output), so that

model = tapir_model.TAPIR(pyramid_level=1)
model.load_state_dict(torch.load('bootstapir_checkpoint.pt'))
model = model.to(torch.device('cpu'))
model.eval()
dummy_input_frames = torch.randn(1, 32, 256, 256, 3, dtype=torch.float32, device = torch.device('cpu'))
dummy_input_query_points = torch.randn(1, 20, 3, dtype=torch.float32, device = torch.device('cpu'))    
scriptModule = torch.jit.trace(model, (dummy_input_frames, dummy_input_query_points))
torch.jit.save(scriptModule, 'bootstapir_checkpoint.ptc')

succeeds, it is not so easy to make it Torchscript scripting compatible.

scriptModule = torch.jit.script(model)

fails with

Module 'BlockV2' has no attribute 'proj_conv' :
  File "C:\tapnet\tapnet\torch\nets.py", line 278
    x = torch.relu(x)
    if self.use_projection:
      shortcut = self.proj_conv(x)
                 ~~~~~~~~~~~~~~ <--- HERE

How to make the model Torchscript scripting compatible?

ssandler-cat avatar Mar 01 '24 04:03 ssandler-cat

It seems to be possible to overcome the error reported above by modifying BlockV2.__init__ by adding an else clause after

    if self.use_projection:
      self.proj_conv = nn.Conv2d(
          in_channels=channels_in,
          out_channels=channels_out,
          kernel_size=1,
          stride=stride,
          padding=0,
          bias=False,
      )

in https://github.com/google-deepmind/tapnet/blob/main/torch/nets.py#L225-L233, so it looks like

   if self.use_projection:
      self.proj_conv = nn.Conv2d(...)
   else:
      self.proj_conv = DummyModel()

where DummyModel is dummy:

class DummyModel:

    def __init__(self):
        pass
        
    def forward(self):
        return torch.tensor(0)
        
    def __call__(self, input):
        return self.forward()

But then torch.jit.script(model) fails with

Arguments for call are not valid.
The following variants are available:
  
  aten::cat(Tensor[] tensors, int dim=0) -> Tensor:
  Keyword argument axis unknown.
  
  aten::cat.names(Tensor[] tensors, str dim) -> Tensor:
  Argument dim not provided.
  
  aten::cat.names_out(Tensor[] tensors, str dim, *, Tensor(a!) out) -> Tensor(a!):
  Argument dim not provided.
  
  aten::cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!):
  Argument out not provided.

The original call is:
  File "C:\tapnet\tapnet\torch\nets.py", line 61
    prev_frame = torch.cat([x[0:1], x[:-1]], dim=0)
    next_frame = torch.cat([x[1:], x[-1:]], dim=0)
    resid = torch.cat([x, prev_frame, next_frame], axis=1) 
            ~~~~~~~~~ <--- HERE

that can be resolved by replacing resid = torch.cat([x, prev_frame, next_frame], axis=1) with resid = torch.cat([x, prev_frame, next_frame], dim=1) . I'd like to know why does not it cause 'axis' an unexpected keyword argument error? The next error that happens is the following:

Unknown type constructor Mapping:
  File "C:\tapnet\tapnet\torch\tapir_model.py", line 145
      get_query_feats: bool = False,
      refinement_resolutions: Optional[List[Tuple[int, int]]] = None,
  ) -> Mapping[str, torch.Tensor]:
       ~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

ssandler-cat avatar Mar 01 '24 06:03 ssandler-cat

Hi,

Thanks for raising the issue, prior to release we were also able to trace the model using the same method you described but after testing it actually showed very little performance increase when used. Can I ask what the use case is for scripting here? Thanks

sgjheywa avatar Mar 06 '24 16:03 sgjheywa

@sgjheywa, scripting (torch.jit.script) helps to save a model with dynamic dimensions, while only static dimensions are supported through tracing. There were many code changes to achieve JIT compatibility, please review https://github.com/google-deepmind/tapnet/pull/85.

ssandler-cat avatar Mar 08 '24 01:03 ssandler-cat

Sorry, I am familiar with scripting, I'm just trying to figure out what the use case is here. Since the model is compatible with torch.compile this seems unnecessary. Thanks

sgjheywa avatar Mar 12 '24 11:03 sgjheywa

@sgjheywa, the use case is LibTorch integration in C++. The model can be compiled with torch.compile, but it does not help since you cannot save it with torch.jit.save. Am I missing something? Thank you.

ssandler-cat avatar Mar 12 '24 23:03 ssandler-cat

While making the torch TAPIR model compatible with Torchscript tracing is easy by changing TAPIR.forward() in https://github.com/google-deepmind/tapnet/blob/main/torch/tapir_model.py#L196-L209 from

    out = dict(
        occlusion=torch.mean(
            torch.stack(trajectories['occlusion'][p::p]), dim=0
        ),
        tracks=torch.mean(torch.stack(trajectories['tracks'][p::p]), dim=0),
        expected_dist=torch.mean(
            torch.stack(trajectories['expected_dist'][p::p]), dim=0
        ),
        unrefined_occlusion=trajectories['occlusion'][:-1],
        unrefined_tracks=trajectories['tracks'][:-1],
        unrefined_expected_dist=trajectories['expected_dist'][:-1],
    )

    return out

to

    class Output(NamedTuple):
        occlusion: torch.tensor
        tracks: torch.tensor
        expected_dist: torch.tensor

    out = Output(torch.mean(torch.stack(trajectories['occlusion'][p::p]), dim=0),
                 torch.mean(torch.stack(trajectories['tracks'][p::p]), dim=0),
                 torch.mean(torch.stack(trajectories['expected_dist'][p::p]), dim=0)
                )

    return out

(assuming it is OK to eliminate unrefined_ from the output), so that

model = tapir_model.TAPIR(pyramid_level=1)
model.load_state_dict(torch.load('bootstapir_checkpoint.pt'))
model = model.to(torch.device('cpu'))
model.eval()
dummy_input_frames = torch.randn(1, 32, 256, 256, 3, dtype=torch.float32, device = torch.device('cpu'))
dummy_input_query_points = torch.randn(1, 20, 3, dtype=torch.float32, device = torch.device('cpu'))    
scriptModule = torch.jit.trace(model, (dummy_input_frames, dummy_input_query_points))
torch.jit.save(scriptModule, 'bootstapir_checkpoint.ptc')

succeeds, it is not so easy to make it Torchscript scripting compatible.

scriptModule = torch.jit.script(model)

fails with

Module 'BlockV2' has no attribute 'proj_conv' :
  File "C:\tapnet\tapnet\torch\nets.py", line 278
    x = torch.relu(x)
    if self.use_projection:
      shortcut = self.proj_conv(x)
                 ~~~~~~~~~~~~~~ <--- HERE

How to make the model Torchscript scripting compatible?

hello! May I ask if you have implemented model training for the Tapir Python version

pubyLu avatar May 20 '24 08:05 pubyLu