`MultiSortingComparison` and `MultiTemplateComparison` optimal assignment
Hi!
I have been looking at the literature on Multidimensional Assignment Problems / Entity Matching to understand multiple sorting or template assignments, and I realized that the current method does not seem to always return the optimal matching.
To check my hypothesis, I modified the BaseMultiComparison class to create a minimal working example with the potential issue (example from https://arxiv.org/pdf/2112.03346 page 3), it can be run as a script:
from collections import OrderedDict
from copy import deepcopy
import numpy as np
class BaseMultiComparison():
"""
Base class for graph-based multi comparison classes.
It handles graph operations, comparisons, and agreements.
"""
def __init__(self):
import networkx as nx
# BaseComparison.__init__(
# self,
# object_list=object_list,
# name_list=name_list,
# match_score=match_score,
# chance_score=chance_score,
# verbose=verbose,
# )
# self.match_score = 0.3
self.name_list = ['a', 'b', 'c']
self.object_list = ['1', '2', '3']
self._verbose = True
self.graph = None
self.subgraphs = None
self.clean_graph = None
def _compute_all(self):
self._do_comparison()
self._do_graph()
self._clean_graph()
self._do_agreement()
def _populate_nodes(self):
for name in self.name_list:
for unit_id in self.object_list:
self.graph.add_node((name, unit_id))
@property
def units(self):
return deepcopy(self._new_units)
def compute_subgraphs(self):
"""
Computes subgraphs of connected components.
Returns
-------
sg_object_names: list
List of sorter names for each node in the connected component subgraph
sg_units: list
List of unit ids for each node in the connected component subgraph
"""
if self.clean_graph is not None:
g = self.clean_graph
else:
g = self.graph
import networkx as nx
subgraphs = (g.subgraph(c).copy() for c in nx.connected_components(g))
sg_object_names = []
sg_units = []
for sg in subgraphs:
object_names = []
unit_names = []
for node in sg.nodes:
object_names.append(node[0])
unit_names.append(node[1])
sg_object_names.append(object_names)
sg_units.append(unit_names)
return sg_object_names, sg_units
def _do_comparison(
self,
):
# do pairwise matching
if self._verbose:
print("Multicomparison step 1: pairwise comparison")
self.comparisons = {
('a', 'b'): {
'1': ('2', 0.6), '2': ('1', 0.6), '3': ('3', 1.0)
},
('b', 'c'): {
'1': ('1', 1.0), '2': ('2', 1.0), '3': ('3', 1.0)
},
('a', 'c'): {
'1': ('1', 1.0), '2': ('2', 1.0), '3': ('3', 1.0)
}
}
def _do_graph(self):
if self._verbose:
print("Multicomparison step 2: make graph")
import networkx as nx
self.graph = nx.Graph()
# nodes
self._populate_nodes()
# edges
for comp_name, comp in self.comparisons.items():
for u1 in comp.keys():
u2 = comp[u1][0]
if u2 != -1:
name_1, name_2 = comp_name
node1 = name_1, u1
node2 = name_2, u2
score = comp[u1][1]
self.graph.add_edge(node1, node2, weight=score)
# the graph is symmetrical
self.graph = self.graph.to_undirected()
def _clean_graph(self):
if self._verbose:
print("Multicomparison step 3: clean graph")
clean_graph = self.graph.copy()
import networkx as nx
subgraphs = (clean_graph.subgraph(c).copy() for c in nx.connected_components(clean_graph))
removed_nodes = 0
for sg in subgraphs:
object_names = []
for node in sg.nodes:
object_names.append(node[0])
sorters, counts = np.unique(object_names, return_counts=True)
if np.any(counts > 1):
for sorter in sorters[counts > 1]:
nodes_duplicate = [n for n in sg.nodes if sorter in n]
# get edges
edges_duplicates = []
weights_duplicates = []
for n in nodes_duplicate:
edges = sg.edges(n, data=True)
for e in edges:
edges_duplicates.append(e)
weights_duplicates.append(e[2]["weight"])
# remove extra edges
n_edges_to_remove = len(nodes_duplicate) - 1
remove_idxs = np.argsort(weights_duplicates)[:n_edges_to_remove]
edges_to_remove = np.array(edges_duplicates, dtype=object)[remove_idxs]
for edge_to_remove in edges_to_remove:
clean_graph.remove_edge(edge_to_remove[0], edge_to_remove[1])
sg.remove_edge(edge_to_remove[0], edge_to_remove[1])
if self._verbose:
print(f"Removed edge: {edge_to_remove}")
# remove extra nodes (as a second step to not affect edge removal)
for edge_to_remove in edges_to_remove:
if edge_to_remove[0] in nodes_duplicate:
node_to_remove = edge_to_remove[0]
else:
node_to_remove = edge_to_remove[1]
if node_to_remove in sg.nodes:
sg.remove_node(node_to_remove)
print(f"Removed node: {node_to_remove}")
removed_nodes += 1
if self._verbose:
print(f"Removed {removed_nodes} duplicate nodes")
self.clean_graph = clean_graph
def _do_agreement(self):
# extract agreement from graph
if self._verbose:
print("Multicomparison step 4: extract agreement from graph")
self._new_units = {}
# save new units
import networkx as nx
self.subgraphs = [self.clean_graph.subgraph(c).copy() for c in nx.connected_components(self.clean_graph)]
for new_unit, sg in enumerate(self.subgraphs):
edges = list(sg.edges(data=True))
if len(edges) > 0:
avg_agr = np.mean([d["weight"] for u, v, d in edges])
else:
avg_agr = 0
object_unit_ids = {}
for node in sg.nodes:
object_name, unit_name = node
object_unit_ids[object_name] = unit_name
# sort dict based on name list
sorted_object_unit_ids = OrderedDict()
for name in self.name_list:
if name in object_unit_ids:
sorted_object_unit_ids[name] = object_unit_ids[name]
self._new_units[new_unit] = {
"avg_agreement": avg_agr,
"unit_ids": sorted_object_unit_ids,
"agreement_number": len(sg.nodes),
}
b = BaseMultiComparison()
b._compute_all()
print(b._new_units)
Therefore, according to you, is my MWE correctly adapted from the literature to the spikeinterface framework? If so, have you envisioned other methods so far or should we think more about it to solve this issue please?
Thanks!
Florent
@florian6973,
we will take a look at this soon. We are in the middle of a spikeinterface hackathon, but super curious about this. It is a little hard for me to read the code (without having a nice diff view). Could you also post the same code with comments on the lines you changed to make comparison a bit easier. If we haven't responded by next week please ping us again!
Thanks for your reply!
Sure, here are some more details:
- my goal is to replicate the example from the paper to see if spikeinterface correctly solves the multiple assignment problem. If we consider A, B, C as three different sessions or sorters, (a1, a2, a3, ...) as the corresponding templates / units, and sim the similarity measure (cosine or any other), we would like to know which is the best matching between the templates/units across sessions/sorters.
We can rewrite the table as three agreement matrices computed by 2-by-2 comparisons. The matches from the current Hungarian method in spikeinterface are shown in bold.
$$\begin{array}{c|ccc} & b_1 & b_2 & b_3 \ \hline a_1 & 0.4 & \textbf{0.6} & 0.6 \ a_2 & \textbf{0.6} & 0.6 & 0.6 \ a_3 & 0.6 & 0.6 & \textbf{1} \end{array}$$
$$\begin{array}{c|ccc} & c_1 & c_2 & c_3 \ \hline b_1 & \textbf{1} & 0.1 & 0.1 \ b_2 & 0.1 & \textbf{1} & 0.1 \ b_3 & 0.1 & 0.1 & \textbf{1} \end{array}$$
$$\begin{array}{c|ccc} & c_1 & c_2 & c_3 \ \hline a_1 & \textbf{1} & 0.1 & 0.1 \ a_2 & 0.1 & \textbf{1} & 0.1 \ a_3 & 0.1 & 0.1 & \textbf{1} \end{array}$$
- from there, I adapted the
BaseMultiComparisonto reflect this particular situation, and check if we obtain in the end the true optimal matching $(a_1, b_1, c_1)$, $(a_2, b_2, c_2)$, $(a_3, b_3, c_3)$. Please find the diff below. Note that I do not need the whole comparison matrix given the way the graph is built, and I assume the match_score is low enough:
def __init__(self):
self.name_list = ['a', 'b', 'c']
self.object_list = ['1', '2', '3']
# def _compare_ij(self, i, j):
# raise NotImplementedError
# def _populate_nodes(self):
# raise NotImplementedError
def _populate_nodes(self):
for name in self.name_list:
for unit_id in self.object_list:
self.graph.add_node((name, unit_id))
# def _do_comparison(
# self,
# ):
# # do pairwise matching
# if self._verbose:
# print("Multicomparison step 1: pairwise comparison")
# self.comparisons = {}
# for i in range(len(self.object_list)):
# for j in range(i + 1, len(self.object_list)):
# if self.name_list is not None:
# name_i = self.name_list[i]
# name_j = self.name_list[j]
# else:
# name_i = "object i"
# name_j = "object j"
# if self._verbose:
# print(f" Comparing: {name_i} and {name_j}")
# comp = self._compare_ij(i, j)
# self.comparisons[(name_i, name_j)] = comp
def _do_comparison(
self,
):
# do pairwise matching
if self._verbose:
print("Multicomparison step 1: pairwise comparison")
self.comparisons = {
('a', 'b'): {
'1': ('2', 0.6), '2': ('1', 0.6), '3': ('3', 1.0)
},
('b', 'c'): {
'1': ('1', 1.0), '2': ('2', 1.0), '3': ('3', 1.0)
},
('a', 'c'): {
'1': ('1', 1.0), '2': ('2', 1.0), '3': ('3', 1.0)
}
}
def _do_graph(self):
# ...
# for comp_name, comp in self.comparisons.items():
# for u1 in comp.hungarian_match_12.index.values:
# u2 = comp.hungarian_match_12[u1]
# if u2 != -1:
# name_1, name_2 = comp_name
# node1 = name_1, u1
# node2 = name_2, u2
# score = comp.agreement_scores.loc[u1, u2]
# self.graph.add_edge(node1, node2, weight=score)
for comp_name, comp in self.comparisons.items():
for u1 in comp.keys():
u2 = comp[u1][0]
if u2 != -1:
name_1, name_2 = comp_name
node1 = name_1, u1
node2 = name_2, u2
score = comp[u1][1]
self.graph.add_edge(node1, node2, weight=score)
- but we obtain $(a_1, b_1, c_1)$, $(a_2)$, $(b_2, c_2)$ and $(a_3, b_3, c_3)$ with the spikeinterface code. It is not even what would be expected $(a_1, b_2, c_2)$, $(a_2, b_1, c_1)$ and $(a_3, b_3, c_3)$.
I hope this is clearer. I am not sure if I am fully correct, but I was trying to properly understand the multiple comparison module, so that's why I am asking.
Have a good hackathon :)
By the way, if you are in Boston at some point we could discuss it in person if needed :)
Hey @florian6973,
thanks for the well wishes. We could definitely meet at some point. If you're on the slack just send me a message. But I think @alejoe91 is better for looking over this one. I didn't work on the initial code so he would know it way better.
Thank a lot for this @florian6973, super interesting and detailed investigation. Will definitely look into this while working on #2626, please feel free to give any feedback and thoughts on the plan I posted there.