byaldi icon indicating copy to clipboard operation
byaldi copied to clipboard

Is there a way to access similarity maps?

Open FlorianSaby opened this issue 1 year ago • 3 comments

Hey, Thank you so much for your library! I’m nearly done preparing a practical session for my master's students using it. I was wondering if there's an easy way to display similarity maps between a query and an image? have a nice day, Flo

FlorianSaby avatar Jan 04 '25 11:01 FlorianSaby

Hacky thing I spliced for this from d6349d29 It is using old version "colpali-engine==0.2.2"

Something like this

+class CustomEvaluatorPatchScore(CustomEvaluator):
+    def evaluate(self, qs, ps) -> Tuple[torch.Tensor, torch.Tensor]:
+        scores, patch_scores = self.evaluate_colbert(qs, ps)
+        return scores, patch_scores
+
+    def evaluate_colbert(self, qs, ps, batch_size=128) -> Tuple[torch.Tensor, torch.Tensor]:
+        print(f"evaluate_colbert called with qs shape: {qs[0].shape}, ps length: {len(ps)}")
+        scores = []
+        patch_scores = []
+        for i in range(0, len(qs), batch_size):
+            print(f"Processing batch {i//batch_size + 1}")
+            scores_batch = []
+            patch_scores_batch = []
+            qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).to(
+                self.device
+            )
+            print(f"qs_batch shape: {qs_batch.shape}")
+            for j in range(0, len(ps), batch_size):
+                ps_batch = torch.nn.utils.rnn.pad_sequence(
+                    ps[j : j + batch_size], batch_first=True, padding_value=0
+                ).to(self.device)
+                print(f"ps_batch shape: {ps_batch.shape}")
+                batch_patch_scores = torch.einsum("bnd,csd->bcns", qs_batch, ps_batch)
+                print(f"batch_patch_scores shape: {batch_patch_scores.shape}")
+                scores_batch.append(batch_patch_scores.max(dim=3)[0].sum(dim=2))
+                patch_scores_batch.append(batch_patch_scores.sum(dim=2, keepdim = True))
+
+            # Add check to ensure scores_batch is not empty
+            if scores_batch:
+                scores_batch = torch.cat(scores_batch, dim=1).cpu()
+                patch_scores_batch = torch.cat(patch_scores_batch, dim=1).cpu()
+                print(f"scores_batch shape: {scores_batch.shape}")
+                print(f"patch_scores_batch shape: {patch_scores_batch.shape}")
+                scores.append(scores_batch)
+                patch_scores.append(patch_scores_batch)
+            else:
+                print("Warning: scores_batch is empty. Skipping this batch.")
+
+        # Add check to ensure scores is not empty
+        if scores:
+            scores = torch.cat(scores, dim=0)
+            patch_scores = torch.cat(patch_scores, dim=0)
+            print(f"Final scores shape: {scores.shape}")
+            print(f"Final patch_scores shape: {patch_scores.shape}")
+        else:
+            print("Error: No valid scores were generated. Returning empty tensors.")
+            scores = torch.tensor([])
+            patch_scores = torch.tensor([])
+
+        return scores, patch_scores

reidsanders avatar Mar 31 '25 18:03 reidsanders

Thanks a lot for the answer! However, I’m not quite sure how to apply this to my Byaldi code. What would the qs and ps be in this case?

FlorianSaby avatar Apr 01 '25 12:04 FlorianSaby

I did this on a very early version and the code structure has changed a lot so unfortunately it isn't going to be directly compatible. There used to be a CustomEvaluator class that did the score calculation, and this is a simple modification of it. The qs is the query embeddings and the ps the patch embeddings. The regular scores just take a max(dim=3) first, while the patch scores keep that information.

batch_patch_scores = torch.einsum("bnd,csd->bcns", qs_batch, ps_batch)
scores_batch.append(batch_patch_scores.max(dim=3)[0].sum(dim=2))
patch_scores_batch.append(batch_patch_scores.sum(dim=2, keepdim = True))

I notice that colpali-engine now has a util that is likely easier to use. https://github.com/illuin-tech/colpali/blob/main/colpali_engine/interpretability/similarity_maps.py

reidsanders avatar Apr 01 '25 14:04 reidsanders