pymc icon indicating copy to clipboard operation
pymc copied to clipboard

ENH: Improve debug docs and function helpers

Open twiecki opened this issue 1 year ago • 12 comments

Before

x = pm.Normal("x")
x_print = Print("x")(x)

After

x = pm.Normal("x", debug=True)

Context for the issue:

On https://www.pymc.io/projects/examples/en/latest/howto/howto_debugging.html we show how pytensor.printing.Print can be used to debug-print RVs which is helpful for debugging model issues. This could be done with nicer API if we add a debug kwarg that would wrap the RV internally with a Print Op.

twiecki avatar Aug 01 '23 16:08 twiecki

It seems like the notebook is doing its job?

I don't think we should add a debug kwarg to RVs: It's a bit vague what it means and it's not really discoverable anyway (none of the distribution arguments are). I rather have Distribution do less things than more.

Instead I would propose to:

  1. Update the notebook
  2. Add an initial point failed example and show model.debug
  3. Link to the notebook in model.debug output, for the cases where that is not sufficient and the print thing may be better.
  4. Add a nicer print_value helper so users don't have to initialize the Print Op manually which is certainly weird:
def print_value(var, name=None):
  """Print value of variable when it is computed during sampling.

  This is likely to affect sampling performance.
  """
  if name is None:
    name = var.name
  return Print(name)(var)
  1. Implement PrintOp in JAX (I think it can be done with https://jax.readthedocs.io/en/latest/debugging/print_breakpoint.html) for numpyro/blackjax based samplers

ricardoV94 avatar Aug 02 '23 07:08 ricardoV94

@ricardoV94 I like all these suggestions.

twiecki avatar Aug 02 '23 08:08 twiecki

I would like to work on this issue.

itsdivya1309 avatar Feb 01 '24 12:02 itsdivya1309

Great @itsdivya1309, do you have any questions on how to get going? Otherwise, feel free to open draft PR and we can take it from there.

twiecki avatar Feb 01 '24 12:02 twiecki

Correct me if I am wrong, but I need to update this notebook as suggested above, right?

itsdivya1309 avatar Feb 01 '24 12:02 itsdivya1309

@itsdivya1309 Correct.

twiecki avatar Feb 01 '24 16:02 twiecki

I don't understand what you mean by 'Implement PrintOp in JAX'. Can you please explain. Also, the print_value() function anyways uses the Print Op, which means its anyways being initialized.

itsdivya1309 avatar Feb 03 '24 06:02 itsdivya1309

@itsdivya1309 You can treat that as a separate issue. As you can see in the NB, we're using Print() to output debug values. Print is implemented for the c-backend, but not currently for the JAX backend. In JAX, this functionality seems to be here https://jax.readthedocs.io/en/latest/debugging/print_breakpoint.html.

But you can just do the changes to the NB for now.

twiecki avatar Feb 05 '24 03:02 twiecki

Alright

itsdivya1309 avatar Feb 05 '24 04:02 itsdivya1309

@twiecki, I am also interested in this issue and it is not assigned to anyone. Can I open a Pr since no pr is opened?

AryanNanda17 avatar Feb 05 '24 18:02 AryanNanda17

My understanding of the issue:-

  • You suggested to introduce a new debug keyword argument to the pm.Normal distribution, which, when set to True, internally wraps the random variable with a print operation.
  • @ricardoV94 suggested to use a function print_value instead to hide the weird nature of PrintOP
  • He also suggested to add some examples to the notebook. And mainly to implement PrintOP in jax.

AryanNanda17 avatar Feb 05 '24 18:02 AryanNanda17

@AryanNanda17 you can work on the JAX part if you want.

itsdivya1309 avatar Feb 09 '24 17:02 itsdivya1309