numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

[FR] Render models to a LaTex formula

Open ordabayevy opened this issue 2 years ago • 0 comments

Similar to render_model but render to a LaTex formula. Example:

def model(data):
    m = pyro.sample("m", dist.Normal(0, 1))
    sd = pyro.sample("sd", dist.LogNormal(m, 1))
    with pyro.plate("N", len(data)):
        pyro.sample("obs", dist.Normal(m, sd), obs=data)

render_to_latex_formula(model, model_args=(data,))

Possible outputs:

  1. Specify the factorization of the log-joint: $p(m) p(sd \mid m) \prod_{N} p(obs=data \mid m, sd)$
  2. Specify distribution names too: $\mathrm{Normal}(m) \mathrm{LogNormal}(sd \mid m) \prod_{N} \mathrm{Normal}(obs=data \mid m, sd)$

ordabayevy avatar Mar 26 '23 03:03 ordabayevy