pymc
pymc copied to clipboard
ENH: Improve debug docs and function helpers
Before
x = pm.Normal("x")
x_print = Print("x")(x)
After
x = pm.Normal("x", debug=True)
Context for the issue:
On https://www.pymc.io/projects/examples/en/latest/howto/howto_debugging.html we show how pytensor.printing.Print
can be used to debug-print RVs which is helpful for debugging model issues. This could be done with nicer API if we add a debug
kwarg that would wrap the RV internally with a Print
Op.
It seems like the notebook is doing its job?
I don't think we should add a debug
kwarg to RVs: It's a bit vague what it means and it's not really discoverable anyway (none of the distribution arguments are). I rather have Distribution do less things than more.
Instead I would propose to:
- Update the notebook
- Add an initial point failed example and show
model.debug
- Link to the notebook in
model.debug
output, for the cases where that is not sufficient and the print thing may be better. - Add a nicer
print_value
helper so users don't have to initialize thePrint
Op
manually which is certainly weird:
def print_value(var, name=None):
"""Print value of variable when it is computed during sampling.
This is likely to affect sampling performance.
"""
if name is None:
name = var.name
return Print(name)(var)
- Implement PrintOp in JAX (I think it can be done with https://jax.readthedocs.io/en/latest/debugging/print_breakpoint.html) for numpyro/blackjax based samplers
@ricardoV94 I like all these suggestions.
I would like to work on this issue.
Great @itsdivya1309, do you have any questions on how to get going? Otherwise, feel free to open draft PR and we can take it from there.
Correct me if I am wrong, but I need to update this notebook as suggested above, right?
@itsdivya1309 Correct.
I don't understand what you mean by 'Implement PrintOp in JAX'. Can you please explain.
Also, the print_value()
function anyways uses the Print Op
, which means its anyways being initialized.
@itsdivya1309 You can treat that as a separate issue. As you can see in the NB, we're using Print()
to output debug values. Print
is implemented for the c-backend, but not currently for the JAX backend. In JAX, this functionality seems to be here https://jax.readthedocs.io/en/latest/debugging/print_breakpoint.html.
But you can just do the changes to the NB for now.
Alright
@twiecki, I am also interested in this issue and it is not assigned to anyone. Can I open a Pr since no pr is opened?
My understanding of the issue:-
- You suggested to introduce a new debug keyword argument to the pm.Normal distribution, which, when set to True, internally wraps the random variable with a print operation.
- @ricardoV94 suggested to use a function print_value instead to hide the weird nature of PrintOP
- He also suggested to add some examples to the notebook. And mainly to implement PrintOP in jax.
@AryanNanda17 you can work on the JAX part if you want.