marimo icon indicating copy to clipboard operation
marimo copied to clipboard

Arviz plots not being displayed

Open lborcard opened this issue 10 months ago • 5 comments

Describe the bug

I am using the arviz library with Pymc and the plots are not being displayed. All I see is the axis infos but not the plots.

Environment

{ "marimo": "0.3.7", "OS": "Darwin", "OS Version": "23.3.0", "Processor": "i386", "Python Version": "3.8.3", "Binaries": { "Browser": "123.0.6312.87", "Node": "--" }, "Requirements": { "click": "8.1.7", "importlib-resources": "6.1.1", "jedi": "0.19.1", "markdown": "3.6", "pymdown-extensions": "10.7.1", "pygments": "2.17.2", "tomlkit": "0.12.4", "uvicorn": "0.29.0", "starlette": "0.37.2", "websocket": "missing", "typing-extensions": "4.8.0", "black": "24.3.0" } }

Code to reproduce

import marimo

__generated_with = "0.3.7"
app = marimo.App(width="medium")


@app.cell
def __():
    import marimo as mo
    return mo,


@app.cell
def __():
    import pymc as pm
    import numpy as np
    import pandas as pd
    return np, pd, pm


@app.cell
def __(np):
    seed = 134
    rng = np.random.default_rng(seed=seed)
    return rng, seed


@app.cell
def __(mo):
    mo.md(
        "We create a series of kmers that could be used a preferential cutting sites"
    )
    return


@app.cell
def __():
    from random import choices

    n = 0
    N = 4**4

    kmers = set()

    while n < N:
        kmer = "".join(choices(k=4, population="ATCG"))
        kmers.add(kmer)
        n = len(kmers)
    return N, choices, kmer, kmers, n


@app.cell
async def __(choices, kmers, np, rng):
    k = 7
    subkmers = choices(population=list(kmers), k=k)

    frac = np.array([0.2, 0.1, 0.1, 0.1, 0.4, 0.05, 0.05])
    if np.sum(frac) == 1 and k == len(frac):
        print("OK")
    else:
        print("Error")

    true_c = 3

    # number of reads
    readsLen = 10
    reads = np.arange(readsLen)

    true_p = rng.dirichlet(size=readsLen, alpha=frac * true_c)

    ## get the counts
    totalcount = 34
    obs_counts = np.vstack(
        [rng.multinomial(n=totalcount, pvals=pi) for pi in true_p]
    )
    obs_counts
    return (
        frac,
        k,
        obs_counts,
        reads,
        readsLen,
        subkmers,
        totalcount,
        true_c,
        true_p,
    )


@app.cell
def __(obs_counts, true_p):
    print(
        f"shape of true_p {true_p.shape} \nshape of obs_count {obs_counts.shape}"
    )
    obs_counts.shape
    return


@app.cell
def __(k, np, obs_counts, pm, reads, subkmers, totalcount):
    coords = {"reads": reads, "subkmers": subkmers}

    with pm.Model(coords=coords) as multidirimarg:
        fraction = pm.Dirichlet("frac", a=np.ones(k), dims="subkmers")
        conc = pm.Lognormal("conc", mu=1, sigma=1)
        counts = pm.DirichletMultinomial(
            "counts",
            n=totalcount,
            a=fraction * conc,
            observed=obs_counts,
            dims=("reads", "subkmers"),
        )

    pm.model_to_graphviz(multidirimarg)
    with multidirimarg:
        trace_dm_marginalized = pm.sample(chains=4)
    return (
        conc,
        coords,
        counts,
        fraction,
        multidirimarg,
        trace_dm_marginalized,
    )


@app.cell
def __(trace_dm_marginalized):
    import arviz as az

    trace_plot = az.plot_trace(
        data=trace_dm_marginalized, var_names=["frac", "conc"]
    )
    trace_plot
    return az, trace_plot


@app.cell
def __():
    import matplotlib.pyplot as plt

    plt.plot([1, 2])
    # plt.gca() gets the current `Axes`
    plt.gca()
    return plt,


if __name__ == "__main__":
    app.run()

lborcard avatar Mar 30 '24 10:03 lborcard

Adding support for arviz would involve adding a formatter, here: https://github.com/marimo-team/marimo/tree/main/marimo/_output/formatters

This is a good first issue. For anyone interested, you can look at the other formatters for examples.

akshayka avatar Mar 30 '24 16:03 akshayka

Working on this. Will try to resolve and refer to formatters provided above.

Haleshot avatar Jun 19 '24 09:06 Haleshot

awesome, thanks @Haleshot!

mscolnick avatar Jun 19 '24 16:06 mscolnick

Trying out different variations; saw how seaborn uses it's plotting functions; plotly and leafmap too. @lborcard Current using your code provided above to understand the return types thrown; assigned a variable to the plot_trace method and it seems it is of the numpy.ndarray type.

Will try different approaches to see how I can double down on a solution,

Haleshot avatar Jun 20 '24 09:06 Haleshot

After extensive investigation through the library's docs and other resources, I've discovered a simple fix for the ArviZ formatter issue. It appears that adding plt.show() after any ArviZ plot function call resolves the display problem. This approach leverages the fact that ArviZ uses matplotlib as its default backend (unless otherwise specified in parameters through functions to use Bokeh), as explained here.

For example, the ArviZ documentation showcases this usage of plt.show() in their examples (which is where I came across the simple fix). (FYI @lborcard)

The solution is straightforward: append plt.show() to the end of the cell block containing the ArviZ plot function (trace_plot in this case - considering the code given above for reproduction).

Given this, I'm considering posting this as a simpler alternative to implementing a complex formatter. The current formatter implementation, which applies changes to all np.ndarray formats, is causing unintended effects on other libraries, particularly matplotlib (posted on the discord server as a thread in #contributing).

I'm planning to update my current local changes to reflect this simpler approach.

@akshayka @mscolnick (Apologize for the ping) Do you think it's still worthwhile to create a formatter explicitly for ArviZ which handles the plt.show() on itself without the user having to explicitly mention it in their code (cell blocks) to automatically apply this change for ArviZ library usage? Or should we opt for the simpler, more explicit solution of recommending users add plt.show() after ArviZ plot calls (that too only for the ones that are of the np.ndarray type like plot_trace?

Your thoughts on this approach would be greatly appreciated.

Haleshot avatar Jun 24 '24 06:06 Haleshot