xgi icon indicating copy to clipboard operation
xgi copied to clipboard

Returning node positions from `xgi.draw`

Open tlarock opened this issue 1 year ago • 3 comments

I want to use the same node positions to draw a few different hypergraphs on the same set of nodes. My lazy approach was to draw the first hypergraph with xgi.draw, then get node positions from the return and use them for the subsequent draw calls. However, it was surprisingly difficult to get the positions, despite the fact that xgi.draw computes positions and returns something called node_collection. The difficulty is that the return object is a matplotlib.PathCollection, so it doesn't contain pos in the same semantics (at least as far as I can tell). One might be able to use PathCollection.get_offsets() or similar to get them, I'm not sure.

This is not urgent at all because really I should compute the layout first, then use the same positions throughout, which is also what the example in the documentation points towards. Not trying to encourage bad design, but it still might be nice to return pos explicitly for convenience. I think the main question is whether adding another value to the return will somehow break other pieces of code or be otherwise inconvenient.

May relate to #280 in the future.

tlarock avatar Feb 15 '24 15:02 tlarock

Thanks Tim! Yea my first thought is exactly what you wrote: if you need to reuse the positions, it's best to pre-compute them outside of draw and then pass them as an argument. We even have random seed for layout functions, so you could even have the same positions across different scripts.

About potentially returning positions, we can certainly think about it. I'm not sure about it because we are already returning many things (1 axis and 3 collections). We could make it not break things for sure. Let's talk with the others.

maximelucas avatar Feb 16 '24 09:02 maximelucas

I'm against returning positions. I think that it will unnecessarily clutter the code. Maybe we can make a recipe for this? I will note that the code corresponding to my recent paper does exactly what is described here.

nwlandry avatar Apr 03 '24 19:04 nwlandry

What about this recipe lifted from the example I mentioned? My worry is that it is too big for a recipe.



import xgi
import matplotlib.pyplot as plt
from matplotlib import gridspec
import numpy as np
import itertools

links = [[1, 2], [1, 3], [5, 6], [1, 7]]
triangles = [[3, 5, 7], [2, 7, 1], [6, 10, 15]]
squares = [[7, 8, 9, 10]]
pentagons = [[1, 11, 12, 13, 14]]
edges = links + triangles + squares + pentagons

H = xgi.Hypergraph(edges)
pos = xgi.barycenter_spring_layout(H, seed=2)

link_color = "#000000"
triangle_color = "#648FFF"
square_color = "#785EF0"
pentagon_color = "#DC267F"
colors = [link_color, triangle_color, square_color, pentagon_color]

def color_edges(H):
    return [colors[i - 2] for i in H.edges.filterby("order", 1, "gt").size.aslist()]

H = xgi.Hypergraph(edges)

filtering_parameters = np.arange(
    H.edges.size.min(), H.edges.size.max() + 1, 1, dtype=int
)

uniform_filtering = [
    xgi.subhypergraph(H, edges=H.edges.filterby("size", k, "eq")).copy()
    for k in filtering_parameters
]
geq_filtering = [
    xgi.subhypergraph(H, edges=H.edges.filterby("size", k, "geq")).copy()
    for k in filtering_parameters
]
leq_filtering = [
    xgi.subhypergraph(H, edges=H.edges.filterby("size", k, "leq")).copy()
    for k in filtering_parameters
]
exclusion_filtering = [
    xgi.subhypergraph(H, edges=H.edges.filterby("size", k, "neq")).copy()
    for k in filtering_parameters
]
filterings = [uniform_filtering, geq_filtering, leq_filtering, exclusion_filtering]


pos = xgi.pca_transform(xgi.pairwise_spring_layout(H, seed=3))

fig = plt.figure(layout="constrained", figsize=(8, 4))

gs_leftright = gridspec.GridSpec(1, 3, figure=fig, wspace=0.075)

gs_panels = gridspec.GridSpecFromSubplotSpec(4, 4, subplot_spec=gs_leftright[1:])

ax_left = fig.add_subplot(gs_leftright[0])
xgi.draw(
    H, pos=pos, ax=ax_left, edge_fc=color_edges(H), node_size=7, node_lw=0.5, dyad_lw=0.75, alpha=1
)

labels = [r"$H_{(=, k)}$", r"$H_{(\geq, k)}$", r"$H_{(\leq, k)}$", r"$H_{(\neq, k)}$"]

for i, j in itertools.product(range(4), repeat=2):
    ax = fig.add_subplot(gs_panels[i, j])
    ec = color_edges(filterings[i][j])
    xgi.draw(
        filterings[i][j],
        pos=pos,
        ax=ax,
        node_size=4,
        dyad_lw=0.75,
        node_lw=0.5,
        edge_fc=ec,
        alpha=1,
    )
    if i == 0:
        ax.set_title(rf"$k={j + 2}$")

    if j == 0:
        ax.text(-3.5, 0, labels[i], fontsize=16)
plt.show()

nwlandry avatar Apr 09 '24 15:04 nwlandry