pymc icon indicating copy to clipboard operation
pymc copied to clipboard

Make coords and data always mutable

Open ricardoV94 opened this issue 2 years ago • 1 comments

Closes #6972

This PR provides a new model transform that freezes RV dims that depend on coords as well as mutable data, for those worried about performance issues or incompatibilities with JAX dynamic shape limitations. I expect most users won't need this. The default C-backend doesn't really exploit static shapes. I believe the simpler API for users is a net win.

Note that JAX dynamic shape stuff is not relevant when using JAX samplers because we already replace any shared variables by constants anyways. It's only relevant when compiling PyTensor functions with mode="JAX"

The big picture here is that we define the most general model first, and later specialize if needed. Going from a model with constant shapes to another with different constant shapes is generally not possible because PyTensor eagerly computes static shape outputs for intermediate nodes, and rebuilding with different constant types is not always supported.

Starting with more general models could be quite helpful for producing predictive models automatically.

Note: If there's resistance, this PR can be narrowed down in scope to just remove the distinction between coords_mutable and coords, but still leave MutableData vs ConstantData

ricardoV94 avatar Dec 04 '23 12:12 ricardoV94

Codecov Report

Attention: Patch coverage is 31.74603% with 43 lines in your changes are missing coverage. Please review.

Project coverage is 39.54%. Comparing base (81d31c8) to head (28ded31).

:exclamation: Current head 28ded31 differs from pull request most recent head 55693f4. Consider uploading reports for the commit 55693f4 to get more accurate results

Additional details and impacted files

Impacted file tree graph

@@             Coverage Diff             @@
##             main    #7047       +/-   ##
===========================================
- Coverage   92.30%   39.54%   -52.77%     
===========================================
  Files         100      101        +1     
  Lines       16895    16835       -60     
===========================================
- Hits        15595     6657     -8938     
- Misses       1300    10178     +8878     
Files Coverage Δ
pymc/gp/hsgp_approx.py 23.80% <ø> (-71.82%) :arrow_down:
pymc/sampling/forward.py 14.47% <ø> (-81.43%) :arrow_down:
pymc/data.py 40.62% <33.33%> (-48.85%) :arrow_down:
pymc/model/fgraph.py 25.12% <0.00%> (-72.27%) :arrow_down:
pymc/model_graph.py 14.28% <20.00%> (-63.11%) :arrow_down:
pymc/model/core.py 57.62% <63.15%> (-34.61%) :arrow_down:
pymc/model/transform/conditioning.py 15.83% <17.24%> (-79.92%) :arrow_down:

... and 85 files with indirect coverage changes

codecov[bot] avatar Dec 04 '23 12:12 codecov[bot]

Love it.

twiecki avatar Mar 27 '24 17:03 twiecki

Woohoo!

twiecki avatar Apr 03 '24 12:04 twiecki

We need to adapt the pymc-examples (and maybe the NBs here too?).

twiecki avatar Apr 03 '24 12:04 twiecki

We need to adapt the pymc-examples (and maybe the NBs here too?).

Possibly

ricardoV94 avatar Apr 03 '24 12:04 ricardoV94

It shouldn't fail immediately, just issue a warning.

Only when we remove the kwarg will it fail hard

ricardoV94 avatar Apr 03 '24 12:04 ricardoV94

Hi @ricardoV94! @michaelosthege and I found out that the latest changes create issues when using pm.ConstantData (or pm.Data) and setting a dtype explicitly. We don't understand why because pytensor.shared has no problem with the dtype argument.

Here is an example:

with pm.Model():
    pm.Data("b", [True, False], dtype=bool)
Traceback (most recent call last):
  File "<stdin>", line 2, in <module>
  File "C:\Users\osthege\AppData\Local\mambaforge\envs\dibecs_6.13.0\lib\site-packages\pymc\data.py", line 420, in Data
    x = pytensor.shared(arr, name, **kwargs)
  File "C:\Users\osthege\AppData\Local\mambaforge\envs\dibecs_6.13.0\lib\site-packages\pytensor\compile\sharedvalue.py", line 202, in shared
    var = shared_constructor(
  File "C:\Users\osthege\AppData\Local\mambaforge\envs\dibecs_6.13.0\lib\functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
TypeError: tensor_constructor() got an unexpected keyword argument 'dtype'

Do you have any idea why this is happening? Thanks in advance for your help!

lhelleckes avatar Apr 22 '24 17:04 lhelleckes