Ripserer.jl icon indicating copy to clipboard operation
Ripserer.jl copied to clipboard

On keeping track on merges?

Open RaimelMedina opened this issue 1 year ago • 3 comments

Hi Matija,

I've been trying to see how to add the feature of keeping track which component merged with whom. I have some preliminary (messy code) that does this. I guess this would only work for dim=0 homology since only there these merges make sense.

1.) Do you perhaps have any suggestion on how/what to improve on this? 2.) Would it be worth to add support in Ripserer for this, as well as constructing/plotting the corresponding merge tree?

I haven't found anything that supports this out there so perhaps could be useful

The code below is a simple modification of the zeroth_intervals function

function zeroth_intervals_with_merges(filtration, cutoff::T = T(0); verbose=true) where T <: AbstractFloat
    F = Mod
    reps = false
    V = simplex_type(filtration, 0)
    CE = chain_element_type(V, F)
    dset = DisjointSetsWithBirth(vertices(filtration), births(filtration))
    intervals = PersistenceInterval[]
    
    to_skip      = simplex_type(filtration, 1)[]
    to_reduce = simplex_type(filtration, 1)[]
    simplices  = sort!(edges(filtration))
    

    merged_vertices = Dict()
    persistence_merged_vertices = Dict()

    if verbose
        progbar = Progress(
            length(simplices) + nv(filtration); desc="Computing 0d intervals... "
        )
    end
    for edge in simplices
        u, v = vertices(edge)
        
        i = find_root!(dset, u)
        j = find_root!(dset, v)
        if i ≠ j
            # According to the elder rule, the vertex with the higher birth will die first.
            last_vertex = birth(dset, i) > birth(dset, j) ? i : j
            int = interval(dset, filtration, last_vertex, edge, cutoff, reps)
            btime_last, bvertex_last = birth(dset, last_vertex)
        
            if !isnothing(int)
                push!(intervals, int)
                if last_vertex == i
                    btime_first, bvertex_first = birth(dset, j)
                else
                    btime_first, bvertex_first = birth(dset, i)
                end
                
                merged_vertices[simplex(filtration, Val(0), (bvertex_last,))] = simplex(filtration, Val(0), (bvertex_first,))
                persistence_merged_vertices[simplex(filtration, Val(0), (bvertex_last,))] = persistence(int)
            end
            union!(dset, i, j)
            push!(to_skip, edge)
        else
            push!(to_reduce, edge)
        end
        verbose && next!(progbar; showvalues=((:n_intervals, length(intervals)),))
    end
    for v in vertices(filtration)
        if find_root!(dset, v) == v && !isnothing(simplex(filtration, Val(0), (v,)))
            int = interval(dset, filtration, v, nothing, cutoff, reps)
            push!(intervals, int)
        end
        verbose && next!(progbar; showvalues=((:n_intervals, length(intervals)),))
    end
    reverse!(to_reduce)

    thresh = T(threshold(filtration))
    diagram = PersistenceDiagram(
        sort!(intervals; by=persistence);
        threshold=thresh,
        dim=0,
        field=F,
        filtration=filtration,
    )
    return merged_vertices, Ripserer.postprocess_diagram(filtration, diagram), persistence_merged_vertices
end

I can also share the code to construct and plot the corresponding merge tree if you think it's worthwhile or necessary

RaimelMedina avatar Apr 19 '24 14:04 RaimelMedina

Hey, sorry for the inactivity. I've been very busy lately. Thanks for the implementation! I'll try it out and see if I can incorporate it into the code. If the plotting code can be written as a Plots recipe, I think it would be nice to fit it in somehow.

mtsch avatar Jun 03 '24 23:06 mtsch

Hi! Thanks for getting back. And no worries about that. Let me know what you think about the implementation. For the merge tree, I think it would be cool to store the things in a better way than what I did above. I'm currently plotting the tree with PyPlot since I always have issues with plotting in Julia. But will work on a plot recipe for this and share it here. Thanks once again

RaimelMedina avatar Jun 07 '24 09:06 RaimelMedina

Hi again. Apologies for the delay in sharing the plot recipe...

This can surely be polished a lot more (efficiency, code design, naming...) but for now I just wanted to share something functional. Let me know if it works out for you if you give it a try

using Ripserer
using PersistenceDiagrams
import Plots: @recipe, @series

# to compute the persistence diagrams with the merges#
function zeroth_intervals_with_merges(
    filtration::Ripserer.AbstractFiltration, 
    cutoff::T; 
    verbose=true
    ) where T <: Real

    F = Ripserer.Mod
    reps = false
    V = Ripserer.simplex_type(filtration, 0)
    CE = Ripserer.chain_element_type(V, F)
    dset = Ripserer.DisjointSetsWithBirth(vertices(filtration), Ripserer.births(filtration))
    intervals = Ripserer.PersistenceInterval[]
    
    to_skip   = Ripserer.simplex_type(filtration, 1)[]
    to_reduce = Ripserer.simplex_type(filtration, 1)[]
    simplices = sort!(Ripserer.edges(filtration))
    
    
    merged_vertices = Dict{V, V}()
    persistence_merged_vertices = Dict{V, T}()

    if verbose
        progbar = Progress(
            length(simplices) + Ripserer.nv(filtration); desc="Computing 0d intervals... "
        )
    end
    for edge in simplices
        u, v = Ripserer.vertices(edge)
        
        i = Ripserer.find_root!(dset, u)
        j = Ripserer.find_root!(dset, v)
        if i ≠ j
            # According to the elder rule, the vertex with the higer birth will die first.
            last_vertex = Ripserer.birth(dset, i) > Ripserer.birth(dset, j) ? i : j
            int = Ripserer.interval(dset, filtration, last_vertex, edge, cutoff, reps)
            btime, bvertex = Ripserer.birth(dset, last_vertex)
        
            if !isnothing(int)
                push!(intervals, int)
                if last_vertex == i
                    btime_first, bvertex_first = Ripserer.birth(dset, j)
                else
                    btime_first, bvertex_first = Ripserer.birth(dset, i)
                end
                
                merged_vertices[simplex(filtration, Val(0), (bvertex,))] = simplex(filtration, Val(0), (bvertex_first,))
                persistence_merged_vertices[simplex(filtration, Val(0), (bvertex,))] = Ripserer.persistence(int)
            end
            Ripserer.union!(dset, i, j)
            push!(to_skip, edge)
        else
            push!(to_reduce, edge)
        end
        verbose && next!(progbar; showvalues=((:n_intervals, length(intervals)),))
    end
    for v in vertices(filtration)
        if Ripserer.find_root!(dset, v) == v && !isnothing(Ripserer.simplex(filtration, Val(0), (v,)))
            int = Ripserer.interval(dset, filtration, v, nothing, cutoff, reps)
            push!(intervals, int)
        end
        verbose && next!(progbar; showvalues=((:n_intervals, length(intervals)),))
    end
    reverse!(to_reduce)

    sort!(intervals; by=Ripserer.persistence)
    return merged_vertices, PersistenceDiagram(intervals), persistence_merged_vertices
end

# some utilities to sort the bars #
function get_count_of_merges(data::Dict{K, K}) where K
    dict_of_merges = Dict{K, Vector{K}}()
    data_for_hist = []
    @showprogress for k in keys(data)
        knew = data[k]
        push!(data_for_hist, knew.birth)
        if haskey(dict_of_merges, knew)
            # a new interval that dies into the interval given by knew
            push!(dict_of_merges[knew], k)
        else
            dict_of_merges[knew] = [k]
        end
    end
    return dict_of_merges, data_for_hist
end

function sort_intervals!(dict::Dict{K, Vector{K}}, k::K, birthsimplex_to_death::Dict{K, T}, data::AbstractVector{K}) where {K, T}
    push!(data, k)
    temp = dict[k]
    
    death_times = map(x->birthsimplex_to_death[temp[x]], 1:length(temp))
    temp = temp[sortperm(death_times)]
    
    for kk ∈ temp
        if haskey(dict, kk)
            sort_intervals!(dict, kk, birthsimplex_to_death, data)
        else
            push!(data, kk)
        end
    end
end

function get_ordered_bars(data_intervals::PersistenceDiagram, data_merges::Dict{K, K}) where K
    dict_counts, _ = get_count_of_merges(data_merges)
    dict_birth_to_death = Dict(birth_simplex(x) => x.death for x ∈ data_intervals)
    
    infinite_bar = birth_simplex(data_intervals[end])
    
    sorted_bars = K[]
    sort_intervals!(dict_counts, infinite_bar, dict_birth_to_death, sorted_bars)
    
    return sorted_bars, dict_birth_to_death
end

mutable struct BarCoord{T}
    xcoord::T
    with_inf_bar::Bool
    xmerge::Union{Nothing, T}
    ylim::Tuple{T, T}
end

BarCoord(ylim::Tuple{T, T}) where T = BarCoord{T}(0, nothing, ylim)

function get_bar_coords(
    data_bars::Vector{S}, 
    merges::Dict{S, S},
    simplex_to_death::Dict{S, T}, 
    max_val::Union{T, Nothing}
    ) where {S, T}
    
    δx = 0.1
    _x_neg = 0.
    _x_pos = 0.

    pos = [:left, :right]

    bar_to_coords = Dict{S, BarCoord{T}}()
    if isnothing(max_val)
        max_val = filter(isfinite, collect(values(simplex_to_death))) |> maximum
    end

    for bar::S in data_bars
        if simplex_to_death[bar] == T(Inf)
            bar_to_coords[bar] = BarCoord{T}(
                T(0),
                false,
                nothing,
                (birth(bar), max_val)
            )
        elseif simplex_to_death[merges[bar]] == Inf # it merges with the infinite bar
            if pos[1] == :left
                _x_neg -= δx
                bar_to_coords[bar] = BarCoord{T}(_x_neg, true, T(0), (birth(bar), simplex_to_death[bar]))
                circshift!(pos, 1)
            else
                _x_pos += δx
                bar_to_coords[bar] = BarCoord{T}(_x_pos, true, T(0), (birth(bar), simplex_to_death[bar]))
                circshift!(pos, 1)
            end
        else # it merges with another bar different from the infinite bar
            merges_with = merges[bar]
            if bar_to_coords[merges_with].xcoord < 0
                _x_neg -= δx
                bar_to_coords[bar] = BarCoord{T}(
                    _x_neg, 
                    false, 
                    bar_to_coords[merges_with].xcoord,
                    (birth(bar), simplex_to_death[bar])
                )
            else
                _x_pos += δx
                bar_to_coords[bar] = BarCoord{T}(
                    _x_pos, 
                    false, 
                    bar_to_coords[merges_with].xcoord,
                    (birth(bar), simplex_to_death[bar])
                )
            end
        end
    end
    return bar_to_coords
end


@recipe function f(
    diagrams::PersistenceDiagram, 
    merges::Dict{S, S}, 
    max_death_time::Union{Nothing, T}
    ) where {T, S<:Ripserer.AbstractCell}

    data_bars, from_simplex_to_death = get_ordered_bars(diagrams, merges)
    bar_coordinates = get_bar_coords(data_bars, merges, from_simplex_to_death, max_death_time)
    color --> :black
    xticks --> []

    ## the infinite bar ##
    for k in keys(bar_coordinates)
        if isnothing(bar_coordinates[k].xmerge) # the infinite bar #
            @series begin
                seriestype := :line
                linewidth --> 1
                label --> nothing
                xs = [bar_coordinates[k].xcoord, bar_coordinates[k].xcoord]
                ys = collect(bar_coordinates[k].ylim)
                xs, ys
            end
        else
            @series begin
                seriestype := :line
                linewidth --> 1
                label --> nothing
                xs = [bar_coordinates[k].xcoord, bar_coordinates[k].xcoord]
                ys = collect(bar_coordinates[k].ylim)
                xs, ys
            end
            @series begin
                seriestype := :line
                linewidth --> 1
                label --> nothing
                xs = sort([bar_coordinates[k].xcoord, bar_coordinates[k].xmerge])
                ys = [bar_coordinates[k].ylim[2], bar_coordinates[k].ylim[2]]
                xs, ys
            end
        end
    end
end

With this we can plot a disconnectivity graph/merge tree as follows

using Ripserer
using Plots
data = rand(4, 4, 4, 4);
max_val = maximum(data)

merges, diagrams, _ = zeroth_intervals_with_merges(Cubical(data), 0.)
Plots.plot(diagrams, merges, max_val)

RaimelMedina avatar Aug 27 '24 11:08 RaimelMedina