numpyro
numpyro copied to clipboard
[FR] Render models to a LaTex formula
Similar to render_model but render to a LaTex formula. Example:
def model(data):
m = pyro.sample("m", dist.Normal(0, 1))
sd = pyro.sample("sd", dist.LogNormal(m, 1))
with pyro.plate("N", len(data)):
pyro.sample("obs", dist.Normal(m, sd), obs=data)
render_to_latex_formula(model, model_args=(data,))
Possible outputs:
- Specify the factorization of the log-joint: $p(m) p(sd \mid m) \prod_{N} p(obs=data \mid m, sd)$
- Specify distribution names too: $\mathrm{Normal}(m) \mathrm{LogNormal}(sd \mid m) \prod_{N} \mathrm{Normal}(obs=data \mid m, sd)$