marimo
marimo copied to clipboard
Arviz plots not being displayed
Describe the bug
I am using the arviz library with Pymc and the plots are not being displayed. All I see is the axis infos but not the plots.
Environment
{ "marimo": "0.3.7", "OS": "Darwin", "OS Version": "23.3.0", "Processor": "i386", "Python Version": "3.8.3", "Binaries": { "Browser": "123.0.6312.87", "Node": "--" }, "Requirements": { "click": "8.1.7", "importlib-resources": "6.1.1", "jedi": "0.19.1", "markdown": "3.6", "pymdown-extensions": "10.7.1", "pygments": "2.17.2", "tomlkit": "0.12.4", "uvicorn": "0.29.0", "starlette": "0.37.2", "websocket": "missing", "typing-extensions": "4.8.0", "black": "24.3.0" } }
Code to reproduce
import marimo
__generated_with = "0.3.7"
app = marimo.App(width="medium")
@app.cell
def __():
import marimo as mo
return mo,
@app.cell
def __():
import pymc as pm
import numpy as np
import pandas as pd
return np, pd, pm
@app.cell
def __(np):
seed = 134
rng = np.random.default_rng(seed=seed)
return rng, seed
@app.cell
def __(mo):
mo.md(
"We create a series of kmers that could be used a preferential cutting sites"
)
return
@app.cell
def __():
from random import choices
n = 0
N = 4**4
kmers = set()
while n < N:
kmer = "".join(choices(k=4, population="ATCG"))
kmers.add(kmer)
n = len(kmers)
return N, choices, kmer, kmers, n
@app.cell
async def __(choices, kmers, np, rng):
k = 7
subkmers = choices(population=list(kmers), k=k)
frac = np.array([0.2, 0.1, 0.1, 0.1, 0.4, 0.05, 0.05])
if np.sum(frac) == 1 and k == len(frac):
print("OK")
else:
print("Error")
true_c = 3
# number of reads
readsLen = 10
reads = np.arange(readsLen)
true_p = rng.dirichlet(size=readsLen, alpha=frac * true_c)
## get the counts
totalcount = 34
obs_counts = np.vstack(
[rng.multinomial(n=totalcount, pvals=pi) for pi in true_p]
)
obs_counts
return (
frac,
k,
obs_counts,
reads,
readsLen,
subkmers,
totalcount,
true_c,
true_p,
)
@app.cell
def __(obs_counts, true_p):
print(
f"shape of true_p {true_p.shape} \nshape of obs_count {obs_counts.shape}"
)
obs_counts.shape
return
@app.cell
def __(k, np, obs_counts, pm, reads, subkmers, totalcount):
coords = {"reads": reads, "subkmers": subkmers}
with pm.Model(coords=coords) as multidirimarg:
fraction = pm.Dirichlet("frac", a=np.ones(k), dims="subkmers")
conc = pm.Lognormal("conc", mu=1, sigma=1)
counts = pm.DirichletMultinomial(
"counts",
n=totalcount,
a=fraction * conc,
observed=obs_counts,
dims=("reads", "subkmers"),
)
pm.model_to_graphviz(multidirimarg)
with multidirimarg:
trace_dm_marginalized = pm.sample(chains=4)
return (
conc,
coords,
counts,
fraction,
multidirimarg,
trace_dm_marginalized,
)
@app.cell
def __(trace_dm_marginalized):
import arviz as az
trace_plot = az.plot_trace(
data=trace_dm_marginalized, var_names=["frac", "conc"]
)
trace_plot
return az, trace_plot
@app.cell
def __():
import matplotlib.pyplot as plt
plt.plot([1, 2])
# plt.gca() gets the current `Axes`
plt.gca()
return plt,
if __name__ == "__main__":
app.run()
Adding support for arviz would involve adding a formatter, here: https://github.com/marimo-team/marimo/tree/main/marimo/_output/formatters
This is a good first issue. For anyone interested, you can look at the other formatters for examples.
Working on this. Will try to resolve and refer to formatters provided above.
awesome, thanks @Haleshot!
Trying out different variations; saw how seaborn uses it's plotting functions; plotly and leafmap too.
@lborcard Current using your code provided above to understand the return types thrown; assigned a variable to the plot_trace
method and it seems it is of the numpy.ndarray
type.
Will try different approaches to see how I can double down on a solution,
After extensive investigation through the library's docs and other resources, I've discovered a simple fix for the ArviZ formatter issue. It appears that adding plt.show()
after any ArviZ plot function call resolves the display problem. This approach leverages the fact that ArviZ uses matplotlib as its default backend (unless otherwise specified in parameters through functions to use Bokeh
), as explained here.
For example, the ArviZ documentation showcases this usage of plt.show()
in their examples (which is where I came across the simple fix). (FYI @lborcard)
The solution is straightforward: append plt.show()
to the end of the cell block containing the ArviZ plot function (trace_plot in this case - considering the code given above for reproduction).
Given this, I'm considering posting this as a simpler alternative to implementing a complex formatter. The current formatter implementation, which applies changes to all np.ndarray
formats, is causing unintended effects on other libraries, particularly matplotlib (posted on the discord server as a thread in #contributing).
I'm planning to update my current local changes to reflect this simpler approach.
@akshayka @mscolnick (Apologize for the ping)
Do you think it's still worthwhile to create a formatter explicitly for ArviZ which handles the plt.show()
on itself without the user having to explicitly mention it in their code (cell blocks) to automatically apply this change for ArviZ library usage? Or should we opt for the simpler, more explicit solution of recommending users add plt.show()
after ArviZ plot calls (that too only for the ones that are of the np.ndarray type
like plot_trace?
Your thoughts on this approach would be greatly appreciated.