scholar
scholar copied to clipboard
Hierarchical clustering improvements
- [ ] Ward linkage (currently broken)
- [ ] Optimize single linkage with slink (https://sites.cs.ucsb.edu/~veronika/MAE/summary_SLINK_Sibson72.pdf)
- [ ] Support median and centroid linkage (https://github.com/elixir-nx/scholar/pull/187#issuecomment-1794254332)
Here is a small patch for the first one, but unfortunately it is not enough, so perhaps something else is wrong:
diff --git a/lib/scholar/cluster/hierarchical.ex b/lib/scholar/cluster/hierarchical.ex
index 481b6cd..562367f 100644
--- a/lib/scholar/cluster/hierarchical.ex
+++ b/lib/scholar/cluster/hierarchical.ex
@@ -196,10 +196,11 @@ defmodule Scholar.Cluster.Hierarchical do
clades = Nx.broadcast(-1, {n - 1, 2})
sizes = Nx.broadcast(1, {2 * n - 1})
pointers = Nx.broadcast(-1, {2 * n - 2})
+ n_sizes = Nx.broadcast(1, {n})
diss = Nx.tensor(:infinity, type: Nx.type(pairwise)) |> Nx.broadcast({n - 1})
- {{clades, diss, sizes}, _} =
- while {{clades, diss, sizes}, {count = 0, pointers, pairwise}}, count < n - 1 do
+ {{clades, diss, sizes, n_sizes}, _} =
+ while {{clades, diss, sizes, n_sizes}, {count = 0, pointers, pairwise}}, count < n - 1 do
# Indexes of who I am nearest to
nearest = Nx.argmin(pairwise, axis: 1)
@@ -213,10 +214,21 @@ defmodule Scholar.Cluster.Hierarchical do
# They are bidirectional but let's keep only one side.
links = Nx.select(clades_selector and nearest > nearest_of_nearest, nearest, n)
- {clades, count, pointers, pairwise, diss, sizes} =
- merge_clades(clades, count, pointers, pairwise, diss, sizes, links, n, update_fun)
-
- {{clades, diss, sizes}, {count, pointers, pairwise}}
+ {clades, count, pointers, pairwise, diss, sizes, n_sizes} =
+ merge_clades(
+ clades,
+ count,
+ pointers,
+ pairwise,
+ diss,
+ sizes,
+ n_sizes,
+ links,
+ n,
+ update_fun
+ )
+
+ {{clades, diss, sizes, n_sizes}, {count, pointers, pairwise}}
end
sizes = sizes[n..(2 * n - 2)]
@@ -224,16 +236,27 @@ defmodule Scholar.Cluster.Hierarchical do
{clades[perm], diss[perm], sizes[perm]}
end
- defnp merge_clades(clades, count, pointers, pairwise, diss, sizes, links, n, update_fun) do
- {{clades, count, pointers, pairwise, diss, sizes}, _} =
- while {{clades, count, pointers, pairwise, diss, sizes}, links},
+ defnp merge_clades(
+ clades,
+ count,
+ pointers,
+ pairwise,
+ diss,
+ sizes,
+ n_sizes,
+ links,
+ n,
+ update_fun
+ ) do
+ {{clades, count, pointers, pairwise, diss, sizes, n_sizes}, _} =
+ while {{clades, count, pointers, pairwise, diss, sizes, n_sizes}, links},
i <- 0..(Nx.size(links) - 1) do
# i < j because of how links is formed.
# i will become the new clade index and we "infinity-out" j.
j = links[i]
if j == n do
- {{clades, count, pointers, pairwise, diss, sizes}, links}
+ {{clades, count, pointers, pairwise, diss, sizes, n_sizes}, links}
else
# Clades a and b (i and j of pairwise) are being merged into c.
indices = [i, j] |> Nx.stack() |> Nx.new_axis(-1)
@@ -251,6 +274,9 @@ defmodule Scholar.Cluster.Hierarchical do
sc = sa + sb
sizes = Nx.indexed_put(sizes, Nx.stack([i, c]) |> Nx.new_axis(-1), Nx.stack([sc, sc]))
+ n_sizes =
+ Nx.indexed_put(n_sizes, Nx.stack([i, j]) |> Nx.new_axis(-1), Nx.stack([sc, sc]))
+
# Update dissimilarities
diss = Nx.indexed_put(diss, Nx.stack([count]), pairwise[i][j])
@@ -259,7 +285,7 @@ defmodule Scholar.Cluster.Hierarchical do
# Update pairwise
updates =
- update_fun.(pairwise[i], pairwise[j], pairwise[i][j], sa, sb, sc)
+ update_fun.(pairwise[i], pairwise[j], pairwise[i][j], sa, sb, n_sizes)
|> Nx.indexed_put(indices, Nx.broadcast(:infinity, {2}))
pairwise =
@@ -269,11 +295,11 @@ defmodule Scholar.Cluster.Hierarchical do
|> Nx.put_slice([j, 0], Nx.broadcast(:infinity, {1, n}))
|> Nx.put_slice([0, j], Nx.broadcast(:infinity, {n, 1}))
- {{clades, count + 1, pointers, pairwise, diss, sizes}, links}
+ {{clades, count + 1, pointers, pairwise, diss, sizes, n_sizes}, links}
end
end
- {clades, count, pointers, pairwise, diss, sizes}
+ {clades, count, pointers, pairwise, diss, sizes, n_sizes}
end
defnp find_clade(pointers, i) do
diff --git a/test/scholar/cluster/hierarchical_test.exs b/test/scholar/cluster/hierarchical_test.exs
index 6c4e5d5..4511252 100644
--- a/test/scholar/cluster/hierarchical_test.exs
+++ b/test/scholar/cluster/hierarchical_test.exs
@@ -127,7 +127,6 @@ defmodule Scholar.Cluster.HierarchicalTest do
assert model.dissimilarities == Nx.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0])
end
- @tag :skip
test "ward", %{data: data} do
model = Hierarchical.fit(data, linkage: :ward)
I have commented Ward for now, see 6845727ee7889d085a9d79cec948dcf3c94ed2bc.