tapnet icon indicating copy to clipboard operation
tapnet copied to clipboard

Any plan to release TAPNext training code?

Open YJ-142150 opened this issue 8 months ago • 17 comments

Thank you for your recent work, TAPNext! I want to fine-tune TAPNext. But it seems that training code is not yet released. When are you going to release the code?

Also, when will the TrecViT code be released, too?

YJ-142150 avatar Apr 15 '25 06:04 YJ-142150

Hi @YJ-142150 !

We don't plan to release the training code but we instead released the pytorch code to compute the tapnext loss and compute it's gradient. the code is pushed here:

https://github.com/google-deepmind/tapnet/blob/main/tapnet/tapnext/torch_losses.py

artemZholus avatar Apr 21 '25 18:04 artemZholus

@artemZholus Do you plan to fine-tune more than 48 frames?

bhack avatar May 04 '25 07:05 bhack

Hi @YJ-142150 , TrecViT code is now released at https://github.com/google-deepmind/trecvit , please take a look. Thanks.

yangyi02 avatar May 05 '25 16:05 yangyi02

@yangyi02 Thank you!

YJ-142150 avatar May 07 '25 03:05 YJ-142150

Hi @bhack, I plan to work on this, but not as part of Google. This will likely take a couple of weeks.

artemZholus avatar May 07 '25 15:05 artemZholus

Thanks @artemZholus! i'd be curious to see a jax fine-tuning example too

jscholz avatar May 08 '25 13:05 jscholz

Hello, and thanks for sharing this great project!

I’m trying to run the TAPNext training code, but I’m hitting the following error when my batch size is 2 (it works fine with batch size = 1). Is there anything I should change in the code?

import torch
from tapnet.tapnext.tapnext_torch import TAPNext

image_size = (128, 128)
device = "cuda:0"
model = TAPNext(
    image_size=image_size,
    patch_size=(8, 8),
    width=768,
    depth=12,
    num_heads=12,
    lru_width=768,
    use_checkpointing=False
).to(device)
query_points = torch.randn(2, 10, 3).to(device)
video = torch.randn(2, 16, 128, 128, 3).to(device)
model(video, query_points)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[2], [line 18](vscode-notebook-cell:?execution_count=2&line=18)
     [16](vscode-notebook-cell:?execution_count=2&line=16) query_points = torch.randn(2, 10, 3).to(device)
     [17](vscode-notebook-cell:?execution_count=2&line=17) video = torch.randn(2, 16, 128, 128, 3).to(device)
---> [18](vscode-notebook-cell:?execution_count=2&line=18) model(video, query_points)

File ~/miniconda3/envs/tapvid/lib/python3.10/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   [1749](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/miniconda3/envs/tapvid/lib/python3.10/site-packages/torch/nn/modules/module.py:1749)     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   [1750](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/miniconda3/envs/tapvid/lib/python3.10/site-packages/torch/nn/modules/module.py:1750) else:
-> [1751](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/miniconda3/envs/tapvid/lib/python3.10/site-packages/torch/nn/modules/module.py:1751)     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/tapvid/lib/python3.10/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
   [1757](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/miniconda3/envs/tapvid/lib/python3.10/site-packages/torch/nn/modules/module.py:1757) # If we don't have any hooks, we want to skip the rest of the logic in
   [1758](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/miniconda3/envs/tapvid/lib/python3.10/site-packages/torch/nn/modules/module.py:1758) # this function, and just call forward.
   [1759](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/miniconda3/envs/tapvid/lib/python3.10/site-packages/torch/nn/modules/module.py:1759) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   [1760](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/miniconda3/envs/tapvid/lib/python3.10/site-packages/torch/nn/modules/module.py:1760)         or _global_backward_pre_hooks or _global_backward_hooks
   [1761](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/miniconda3/envs/tapvid/lib/python3.10/site-packages/torch/nn/modules/module.py:1761)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1762](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/miniconda3/envs/tapvid/lib/python3.10/site-packages/torch/nn/modules/module.py:1762)     return forward_call(*args, **kwargs)
   [1764](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/miniconda3/envs/tapvid/lib/python3.10/site-packages/torch/nn/modules/module.py:1764) result = None
   [1765](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/miniconda3/envs/tapvid/lib/python3.10/site-packages/torch/nn/modules/module.py:1765) called_always_called_hooks = set()

File ~/work/tapnet/tapnet/tapnext/tapnext_torch.py:273, in TAPNext.forward(self, video, query_points, state)
    [271](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:271)   step = 0
    [272](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:272) print('t, query_points.shape', t, query_points.shape)
--> [273](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:273) point_tokens = self.embed_queries(t, query_points)  # [b t Q c]
    [274](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:274) x = torch.cat([video_tokens, point_tokens], dim=2)  # [b t (h * w + Q) c]
    [275](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:275) ssm_cache = []

File ~/work/tapnet/tapnet/tapnext/tapnext_torch.py:208, in TAPNext.embed_queries(self, timesteps, query_points)
    [206](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:206)   queries_are_late = query_timesteps >= t
    [207](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:207)   queries_are_early = query_timesteps < 0
--> [208](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:208)   mask_and_query_tokens = mask_tokens.scatter(
    [209](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:209)       dim=1,
    [210](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:210)       index=query_timesteps.long().clamp(0, t - 1).repeat(1, 1, 1, c),
    [211](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:211)       src=point_query_tokens,
    [212](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:212)   )
    [213](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:213)   mask_and_query_tokens = torch.where(
    [214](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:214)       (queries_are_late | queries_are_early).unsqueeze(1),
    [215](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:215)       mask_tokens,
    [216](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:216)       mask_and_query_tokens,
    [217](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:217)   )
    [218](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:218) is_unknown_token = torch.arange(t, device=query_points.device)[
    [219](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:219)     None, :, None, None
    [220](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:220) ] < query_timesteps.unsqueeze(1)

RuntimeError: Expected index [1, 2, 10, 768] to be no larger than self [2, 16, 10, 768] apart from dimension 1 and to be no larger size than src [2, 1, 10, 768]

Any suggestions on how to fix this would be greatly appreciated. Thanks!

whikwon avatar May 16 '25 01:05 whikwon

hi @whikwon , this is a known issue and it is fixed already. Just pull the latest version of the code and it should work well with batch size > 1.

artemZholus avatar May 16 '25 02:05 artemZholus

@artemZholus Ah! sorry to bother you. Thank you for quick response!

whikwon avatar May 16 '25 03:05 whikwon

Hi @bhack, I plan to work on this, but not as part of Google. This will likely take a couple of weeks.

@artemZholus Any news on this? I've tried few workaround for this with the state and query management but we are going to create really a lot of complexities with not very appealing results without a proper finetuning.

bhack avatar May 20 '25 12:05 bhack

Hi @bhack , yes I am still working on this. The main challenge right now is that we need to use a special type of data parallelism called sequence parallelism, when the data is split along the sequence length dimension over multiple GPUs. This way we can scale the sequence length during training arbitrarily with more GPUs (e.g. finetuning on length of 1000+ frames). I already made most of the progress on this and hopefully will finish this week. After that I will switch to finetuning.

artemZholus avatar May 21 '25 15:05 artemZholus

@artemZholus Thanks, keep us updated.

I think it is totally required to have a minimum of usability in TapNExt. State-query workarounds only at inference time are really overcomplicated, with a lot of edge cases and with questionable performances .

The main challenge right now is that we need to use a special type of data parallelism called sequence parallelism, when the data is split along the sequence length dimension over multiple GPUs.

Are you doing this in pytorch or jax? If it is torch other then your custom impl it could be a general interesting use case for: https://github.com/pytorch/data It would be nice to have ticket.

bhack avatar May 21 '25 15:05 bhack

@bhack I am doing this in pytorch. FYI, the main challenge isn't to properly load the data. it is to do distributed backpropagation. I.e. some activations/gradients need to be communicated forth and back between gpus.

artemZholus avatar May 21 '25 17:05 artemZholus

Oh ok so it is more on the tensor parallelism side (Dtensor)

bhack avatar May 21 '25 18:05 bhack

@artemZholus Will the fine-tuned model be released in near future? Any plan to release the model?

YJ-142150 avatar May 27 '25 05:05 YJ-142150

Hi @bhack @whikwon @jscholz @YJ-142150 !

Just a quick sync that this (the finetuning code for TAPNext on long videos) is still in the work.

TAPNext got accepted at ICCV 25 (yay) and the camera ready deadline is soon. I plan to release everything around that time so please stay tuned.

artemZholus avatar Jul 18 '25 01:07 artemZholus

@artemZholus Congratulations for ICCV25! Will "TAPNext long video model" be TAPNextv2? Any planned date for release? Looking forward to it!!

YJ-142150 avatar Jul 30 '25 11:07 YJ-142150