Any plan to release TAPNext training code?
Thank you for your recent work, TAPNext! I want to fine-tune TAPNext. But it seems that training code is not yet released. When are you going to release the code?
Also, when will the TrecViT code be released, too?
Hi @YJ-142150 !
We don't plan to release the training code but we instead released the pytorch code to compute the tapnext loss and compute it's gradient. the code is pushed here:
https://github.com/google-deepmind/tapnet/blob/main/tapnet/tapnext/torch_losses.py
@artemZholus Do you plan to fine-tune more than 48 frames?
Hi @YJ-142150 , TrecViT code is now released at https://github.com/google-deepmind/trecvit , please take a look. Thanks.
@yangyi02 Thank you!
Hi @bhack, I plan to work on this, but not as part of Google. This will likely take a couple of weeks.
Thanks @artemZholus! i'd be curious to see a jax fine-tuning example too
Hello, and thanks for sharing this great project!
I’m trying to run the TAPNext training code, but I’m hitting the following error when my batch size is 2 (it works fine with batch size = 1). Is there anything I should change in the code?
import torch
from tapnet.tapnext.tapnext_torch import TAPNext
image_size = (128, 128)
device = "cuda:0"
model = TAPNext(
image_size=image_size,
patch_size=(8, 8),
width=768,
depth=12,
num_heads=12,
lru_width=768,
use_checkpointing=False
).to(device)
query_points = torch.randn(2, 10, 3).to(device)
video = torch.randn(2, 16, 128, 128, 3).to(device)
model(video, query_points)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[2], [line 18](vscode-notebook-cell:?execution_count=2&line=18)
[16](vscode-notebook-cell:?execution_count=2&line=16) query_points = torch.randn(2, 10, 3).to(device)
[17](vscode-notebook-cell:?execution_count=2&line=17) video = torch.randn(2, 16, 128, 128, 3).to(device)
---> [18](vscode-notebook-cell:?execution_count=2&line=18) model(video, query_points)
File ~/miniconda3/envs/tapvid/lib/python3.10/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
[1749](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/miniconda3/envs/tapvid/lib/python3.10/site-packages/torch/nn/modules/module.py:1749) return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
[1750](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/miniconda3/envs/tapvid/lib/python3.10/site-packages/torch/nn/modules/module.py:1750) else:
-> [1751](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/miniconda3/envs/tapvid/lib/python3.10/site-packages/torch/nn/modules/module.py:1751) return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/tapvid/lib/python3.10/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
[1757](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/miniconda3/envs/tapvid/lib/python3.10/site-packages/torch/nn/modules/module.py:1757) # If we don't have any hooks, we want to skip the rest of the logic in
[1758](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/miniconda3/envs/tapvid/lib/python3.10/site-packages/torch/nn/modules/module.py:1758) # this function, and just call forward.
[1759](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/miniconda3/envs/tapvid/lib/python3.10/site-packages/torch/nn/modules/module.py:1759) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
[1760](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/miniconda3/envs/tapvid/lib/python3.10/site-packages/torch/nn/modules/module.py:1760) or _global_backward_pre_hooks or _global_backward_hooks
[1761](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/miniconda3/envs/tapvid/lib/python3.10/site-packages/torch/nn/modules/module.py:1761) or _global_forward_hooks or _global_forward_pre_hooks):
-> [1762](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/miniconda3/envs/tapvid/lib/python3.10/site-packages/torch/nn/modules/module.py:1762) return forward_call(*args, **kwargs)
[1764](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/miniconda3/envs/tapvid/lib/python3.10/site-packages/torch/nn/modules/module.py:1764) result = None
[1765](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/miniconda3/envs/tapvid/lib/python3.10/site-packages/torch/nn/modules/module.py:1765) called_always_called_hooks = set()
File ~/work/tapnet/tapnet/tapnext/tapnext_torch.py:273, in TAPNext.forward(self, video, query_points, state)
[271](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:271) step = 0
[272](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:272) print('t, query_points.shape', t, query_points.shape)
--> [273](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:273) point_tokens = self.embed_queries(t, query_points) # [b t Q c]
[274](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:274) x = torch.cat([video_tokens, point_tokens], dim=2) # [b t (h * w + Q) c]
[275](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:275) ssm_cache = []
File ~/work/tapnet/tapnet/tapnext/tapnext_torch.py:208, in TAPNext.embed_queries(self, timesteps, query_points)
[206](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:206) queries_are_late = query_timesteps >= t
[207](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:207) queries_are_early = query_timesteps < 0
--> [208](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:208) mask_and_query_tokens = mask_tokens.scatter(
[209](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:209) dim=1,
[210](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:210) index=query_timesteps.long().clamp(0, t - 1).repeat(1, 1, 1, c),
[211](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:211) src=point_query_tokens,
[212](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:212) )
[213](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:213) mask_and_query_tokens = torch.where(
[214](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:214) (queries_are_late | queries_are_early).unsqueeze(1),
[215](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:215) mask_tokens,
[216](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:216) mask_and_query_tokens,
[217](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:217) )
[218](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:218) is_unknown_token = torch.arange(t, device=query_points.device)[
[219](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:219) None, :, None, None
[220](https://vscode-remote+ssh-002dremote-002bsnubhcvc-002edemo.vscode-resource.vscode-cdn.net/home/user/work/tapnet/~/work/tapnet/tapnet/tapnext/tapnext_torch.py:220) ] < query_timesteps.unsqueeze(1)
RuntimeError: Expected index [1, 2, 10, 768] to be no larger than self [2, 16, 10, 768] apart from dimension 1 and to be no larger size than src [2, 1, 10, 768]
Any suggestions on how to fix this would be greatly appreciated. Thanks!
hi @whikwon , this is a known issue and it is fixed already. Just pull the latest version of the code and it should work well with batch size > 1.
@artemZholus Ah! sorry to bother you. Thank you for quick response!
Hi @bhack, I plan to work on this, but not as part of Google. This will likely take a couple of weeks.
@artemZholus Any news on this? I've tried few workaround for this with the state and query management but we are going to create really a lot of complexities with not very appealing results without a proper finetuning.
Hi @bhack , yes I am still working on this. The main challenge right now is that we need to use a special type of data parallelism called sequence parallelism, when the data is split along the sequence length dimension over multiple GPUs. This way we can scale the sequence length during training arbitrarily with more GPUs (e.g. finetuning on length of 1000+ frames). I already made most of the progress on this and hopefully will finish this week. After that I will switch to finetuning.
@artemZholus Thanks, keep us updated.
I think it is totally required to have a minimum of usability in TapNExt. State-query workarounds only at inference time are really overcomplicated, with a lot of edge cases and with questionable performances .
The main challenge right now is that we need to use a special type of data parallelism called sequence parallelism, when the data is split along the sequence length dimension over multiple GPUs.
Are you doing this in pytorch or jax? If it is torch other then your custom impl it could be a general interesting use case for: https://github.com/pytorch/data It would be nice to have ticket.
@bhack I am doing this in pytorch. FYI, the main challenge isn't to properly load the data. it is to do distributed backpropagation. I.e. some activations/gradients need to be communicated forth and back between gpus.
Oh ok so it is more on the tensor parallelism side (Dtensor)
@artemZholus Will the fine-tuned model be released in near future? Any plan to release the model?
Hi @bhack @whikwon @jscholz @YJ-142150 !
Just a quick sync that this (the finetuning code for TAPNext on long videos) is still in the work.
TAPNext got accepted at ICCV 25 (yay) and the camera ready deadline is soon. I plan to release everything around that time so please stay tuned.
@artemZholus Congratulations for ICCV25! Will "TAPNext long video model" be TAPNextv2? Any planned date for release? Looking forward to it!!