tapnet icon indicating copy to clipboard operation
tapnet copied to clipboard

`plot_tracks_v2` has bug when plotting with `trackgroup` argument.

Open chandlj opened this issue 10 months ago • 2 comments

I am running this notebook for RoboTAP clustering. After computing the clusters, I am running the following cell:

separation_visibility_trim = clustered['separation_visibility']
separation_tracks_trim = clustered['separation_tracks']

pointtrack_video = viz_utils.plot_tracks_v2(
    (demo_videos[demo_episode_ids[0]]).astype(np.uint8),
    separation_tracks_trim[demo_episode_ids[0]],
    1.0-separation_visibility_trim[demo_episode_ids[0]],
    trackgroup=clustered['classes']
)
media.show_video(pointtrack_video, fps=20)

However, the plot only shows about 10 points no matter how many points I track, and there are really no clusters to be found. I found that if I comment out trackgroup, then the plotting code works correctly and I can see the full range of points (although not colored with cluster ID). I can also verify that clusters are correctly computed by plotting individual frames like so:

separation_visibility_trim = clustered['separation_visibility']
separation_tracks_trim = clustered['separation_tracks']

frame = 35
plt.scatter(
  separation_tracks_trim["dummy_id"][:, frame, 0],
  separation_tracks_trim["dummy_id"][:, frame, 1],
  c=clustered["classes"],
  cmap="viridis",
)
plt.imshow(video[frame])

It's really only when trackgroup is specified that this code does not behave properly. Any ideas of how to fix?

chandlj avatar Apr 24 '24 20:04 chandlj

Now that tapir_clustering.py is fixed, I've run the colab at head and verified that the code will plot more than 20 tracks. Your snippets above look correct to me--I don't see why it wouldn't plot the full set the way that the colab does. Maybe set a breakpoint at https://github.com/google-deepmind/tapnet/blob/main/utils/viz_utils.py#L193 and check what's being passed to plt.scatter?

cdoersch avatar Apr 28 '24 12:04 cdoersch

@cdoersch The most recent code that was pushed for tapir_clustering has a bug and did not work for me in the notebook. Looking at the commit here, it looks on line 574 changing len to np.prod is causing problems. I noticed that jax.tree_map(lambda x: np.prod(x.shape), query_features) actually returns shape 1 for the resolutions array, not 0.

chandlj avatar Apr 29 '24 18:04 chandlj