Ripserer.jl
Ripserer.jl copied to clipboard
On keeping track on merges?
Hi Matija,
I've been trying to see how to add the feature of keeping track which component merged with whom. I have some preliminary (messy code) that does this. I guess this would only work for dim=0 homology since only there these merges make sense.
1.) Do you perhaps have any suggestion on how/what to improve on this? 2.) Would it be worth to add support in Ripserer for this, as well as constructing/plotting the corresponding merge tree?
I haven't found anything that supports this out there so perhaps could be useful
The code below is a simple modification of the zeroth_intervals function
function zeroth_intervals_with_merges(filtration, cutoff::T = T(0); verbose=true) where T <: AbstractFloat
F = Mod
reps = false
V = simplex_type(filtration, 0)
CE = chain_element_type(V, F)
dset = DisjointSetsWithBirth(vertices(filtration), births(filtration))
intervals = PersistenceInterval[]
to_skip = simplex_type(filtration, 1)[]
to_reduce = simplex_type(filtration, 1)[]
simplices = sort!(edges(filtration))
merged_vertices = Dict()
persistence_merged_vertices = Dict()
if verbose
progbar = Progress(
length(simplices) + nv(filtration); desc="Computing 0d intervals... "
)
end
for edge in simplices
u, v = vertices(edge)
i = find_root!(dset, u)
j = find_root!(dset, v)
if i ≠ j
# According to the elder rule, the vertex with the higher birth will die first.
last_vertex = birth(dset, i) > birth(dset, j) ? i : j
int = interval(dset, filtration, last_vertex, edge, cutoff, reps)
btime_last, bvertex_last = birth(dset, last_vertex)
if !isnothing(int)
push!(intervals, int)
if last_vertex == i
btime_first, bvertex_first = birth(dset, j)
else
btime_first, bvertex_first = birth(dset, i)
end
merged_vertices[simplex(filtration, Val(0), (bvertex_last,))] = simplex(filtration, Val(0), (bvertex_first,))
persistence_merged_vertices[simplex(filtration, Val(0), (bvertex_last,))] = persistence(int)
end
union!(dset, i, j)
push!(to_skip, edge)
else
push!(to_reduce, edge)
end
verbose && next!(progbar; showvalues=((:n_intervals, length(intervals)),))
end
for v in vertices(filtration)
if find_root!(dset, v) == v && !isnothing(simplex(filtration, Val(0), (v,)))
int = interval(dset, filtration, v, nothing, cutoff, reps)
push!(intervals, int)
end
verbose && next!(progbar; showvalues=((:n_intervals, length(intervals)),))
end
reverse!(to_reduce)
thresh = T(threshold(filtration))
diagram = PersistenceDiagram(
sort!(intervals; by=persistence);
threshold=thresh,
dim=0,
field=F,
filtration=filtration,
)
return merged_vertices, Ripserer.postprocess_diagram(filtration, diagram), persistence_merged_vertices
end
I can also share the code to construct and plot the corresponding merge tree if you think it's worthwhile or necessary
Hey, sorry for the inactivity. I've been very busy lately. Thanks for the implementation! I'll try it out and see if I can incorporate it into the code. If the plotting code can be written as a Plots recipe, I think it would be nice to fit it in somehow.
Hi! Thanks for getting back. And no worries about that. Let me know what you think about the implementation. For the merge tree, I think it would be cool to store the things in a better way than what I did above. I'm currently plotting the tree with PyPlot since I always have issues with plotting in Julia. But will work on a plot recipe for this and share it here. Thanks once again
Hi again. Apologies for the delay in sharing the plot recipe...
This can surely be polished a lot more (efficiency, code design, naming...) but for now I just wanted to share something functional. Let me know if it works out for you if you give it a try
using Ripserer
using PersistenceDiagrams
import Plots: @recipe, @series
# to compute the persistence diagrams with the merges#
function zeroth_intervals_with_merges(
filtration::Ripserer.AbstractFiltration,
cutoff::T;
verbose=true
) where T <: Real
F = Ripserer.Mod
reps = false
V = Ripserer.simplex_type(filtration, 0)
CE = Ripserer.chain_element_type(V, F)
dset = Ripserer.DisjointSetsWithBirth(vertices(filtration), Ripserer.births(filtration))
intervals = Ripserer.PersistenceInterval[]
to_skip = Ripserer.simplex_type(filtration, 1)[]
to_reduce = Ripserer.simplex_type(filtration, 1)[]
simplices = sort!(Ripserer.edges(filtration))
merged_vertices = Dict{V, V}()
persistence_merged_vertices = Dict{V, T}()
if verbose
progbar = Progress(
length(simplices) + Ripserer.nv(filtration); desc="Computing 0d intervals... "
)
end
for edge in simplices
u, v = Ripserer.vertices(edge)
i = Ripserer.find_root!(dset, u)
j = Ripserer.find_root!(dset, v)
if i ≠ j
# According to the elder rule, the vertex with the higer birth will die first.
last_vertex = Ripserer.birth(dset, i) > Ripserer.birth(dset, j) ? i : j
int = Ripserer.interval(dset, filtration, last_vertex, edge, cutoff, reps)
btime, bvertex = Ripserer.birth(dset, last_vertex)
if !isnothing(int)
push!(intervals, int)
if last_vertex == i
btime_first, bvertex_first = Ripserer.birth(dset, j)
else
btime_first, bvertex_first = Ripserer.birth(dset, i)
end
merged_vertices[simplex(filtration, Val(0), (bvertex,))] = simplex(filtration, Val(0), (bvertex_first,))
persistence_merged_vertices[simplex(filtration, Val(0), (bvertex,))] = Ripserer.persistence(int)
end
Ripserer.union!(dset, i, j)
push!(to_skip, edge)
else
push!(to_reduce, edge)
end
verbose && next!(progbar; showvalues=((:n_intervals, length(intervals)),))
end
for v in vertices(filtration)
if Ripserer.find_root!(dset, v) == v && !isnothing(Ripserer.simplex(filtration, Val(0), (v,)))
int = Ripserer.interval(dset, filtration, v, nothing, cutoff, reps)
push!(intervals, int)
end
verbose && next!(progbar; showvalues=((:n_intervals, length(intervals)),))
end
reverse!(to_reduce)
sort!(intervals; by=Ripserer.persistence)
return merged_vertices, PersistenceDiagram(intervals), persistence_merged_vertices
end
# some utilities to sort the bars #
function get_count_of_merges(data::Dict{K, K}) where K
dict_of_merges = Dict{K, Vector{K}}()
data_for_hist = []
@showprogress for k in keys(data)
knew = data[k]
push!(data_for_hist, knew.birth)
if haskey(dict_of_merges, knew)
# a new interval that dies into the interval given by knew
push!(dict_of_merges[knew], k)
else
dict_of_merges[knew] = [k]
end
end
return dict_of_merges, data_for_hist
end
function sort_intervals!(dict::Dict{K, Vector{K}}, k::K, birthsimplex_to_death::Dict{K, T}, data::AbstractVector{K}) where {K, T}
push!(data, k)
temp = dict[k]
death_times = map(x->birthsimplex_to_death[temp[x]], 1:length(temp))
temp = temp[sortperm(death_times)]
for kk ∈ temp
if haskey(dict, kk)
sort_intervals!(dict, kk, birthsimplex_to_death, data)
else
push!(data, kk)
end
end
end
function get_ordered_bars(data_intervals::PersistenceDiagram, data_merges::Dict{K, K}) where K
dict_counts, _ = get_count_of_merges(data_merges)
dict_birth_to_death = Dict(birth_simplex(x) => x.death for x ∈ data_intervals)
infinite_bar = birth_simplex(data_intervals[end])
sorted_bars = K[]
sort_intervals!(dict_counts, infinite_bar, dict_birth_to_death, sorted_bars)
return sorted_bars, dict_birth_to_death
end
mutable struct BarCoord{T}
xcoord::T
with_inf_bar::Bool
xmerge::Union{Nothing, T}
ylim::Tuple{T, T}
end
BarCoord(ylim::Tuple{T, T}) where T = BarCoord{T}(0, nothing, ylim)
function get_bar_coords(
data_bars::Vector{S},
merges::Dict{S, S},
simplex_to_death::Dict{S, T},
max_val::Union{T, Nothing}
) where {S, T}
δx = 0.1
_x_neg = 0.
_x_pos = 0.
pos = [:left, :right]
bar_to_coords = Dict{S, BarCoord{T}}()
if isnothing(max_val)
max_val = filter(isfinite, collect(values(simplex_to_death))) |> maximum
end
for bar::S in data_bars
if simplex_to_death[bar] == T(Inf)
bar_to_coords[bar] = BarCoord{T}(
T(0),
false,
nothing,
(birth(bar), max_val)
)
elseif simplex_to_death[merges[bar]] == Inf # it merges with the infinite bar
if pos[1] == :left
_x_neg -= δx
bar_to_coords[bar] = BarCoord{T}(_x_neg, true, T(0), (birth(bar), simplex_to_death[bar]))
circshift!(pos, 1)
else
_x_pos += δx
bar_to_coords[bar] = BarCoord{T}(_x_pos, true, T(0), (birth(bar), simplex_to_death[bar]))
circshift!(pos, 1)
end
else # it merges with another bar different from the infinite bar
merges_with = merges[bar]
if bar_to_coords[merges_with].xcoord < 0
_x_neg -= δx
bar_to_coords[bar] = BarCoord{T}(
_x_neg,
false,
bar_to_coords[merges_with].xcoord,
(birth(bar), simplex_to_death[bar])
)
else
_x_pos += δx
bar_to_coords[bar] = BarCoord{T}(
_x_pos,
false,
bar_to_coords[merges_with].xcoord,
(birth(bar), simplex_to_death[bar])
)
end
end
end
return bar_to_coords
end
@recipe function f(
diagrams::PersistenceDiagram,
merges::Dict{S, S},
max_death_time::Union{Nothing, T}
) where {T, S<:Ripserer.AbstractCell}
data_bars, from_simplex_to_death = get_ordered_bars(diagrams, merges)
bar_coordinates = get_bar_coords(data_bars, merges, from_simplex_to_death, max_death_time)
color --> :black
xticks --> []
## the infinite bar ##
for k in keys(bar_coordinates)
if isnothing(bar_coordinates[k].xmerge) # the infinite bar #
@series begin
seriestype := :line
linewidth --> 1
label --> nothing
xs = [bar_coordinates[k].xcoord, bar_coordinates[k].xcoord]
ys = collect(bar_coordinates[k].ylim)
xs, ys
end
else
@series begin
seriestype := :line
linewidth --> 1
label --> nothing
xs = [bar_coordinates[k].xcoord, bar_coordinates[k].xcoord]
ys = collect(bar_coordinates[k].ylim)
xs, ys
end
@series begin
seriestype := :line
linewidth --> 1
label --> nothing
xs = sort([bar_coordinates[k].xcoord, bar_coordinates[k].xmerge])
ys = [bar_coordinates[k].ylim[2], bar_coordinates[k].ylim[2]]
xs, ys
end
end
end
end
With this we can plot a disconnectivity graph/merge tree as follows
using Ripserer
using Plots
data = rand(4, 4, 4, 4);
max_val = maximum(data)
merges, diagrams, _ = zeroth_intervals_with_merges(Cubical(data), 0.)
Plots.plot(diagrams, merges, max_val)