Make coords and data always mutable
Closes #6972
This PR provides a new model transform that freezes RV dims that depend on coords as well as mutable data, for those worried about performance issues or incompatibilities with JAX dynamic shape limitations. I expect most users won't need this. The default C-backend doesn't really exploit static shapes. I believe the simpler API for users is a net win.
Note that JAX dynamic shape stuff is not relevant when using JAX samplers because we already replace any shared variables by constants anyways. It's only relevant when compiling PyTensor functions with mode="JAX"
The big picture here is that we define the most general model first, and later specialize if needed. Going from a model with constant shapes to another with different constant shapes is generally not possible because PyTensor eagerly computes static shape outputs for intermediate nodes, and rebuilding with different constant types is not always supported.
Starting with more general models could be quite helpful for producing predictive models automatically.
Note: If there's resistance, this PR can be narrowed down in scope to just remove the distinction between coords_mutable and coords, but still leave MutableData vs ConstantData
Codecov Report
Attention: Patch coverage is 31.74603% with 43 lines in your changes are missing coverage. Please review.
Project coverage is 39.54%. Comparing base (
81d31c8) to head (28ded31).
:exclamation: Current head 28ded31 differs from pull request most recent head 55693f4. Consider uploading reports for the commit 55693f4 to get more accurate results
Additional details and impacted files
@@ Coverage Diff @@
## main #7047 +/- ##
===========================================
- Coverage 92.30% 39.54% -52.77%
===========================================
Files 100 101 +1
Lines 16895 16835 -60
===========================================
- Hits 15595 6657 -8938
- Misses 1300 10178 +8878
| Files | Coverage Δ | |
|---|---|---|
| pymc/gp/hsgp_approx.py | 23.80% <ø> (-71.82%) |
:arrow_down: |
| pymc/sampling/forward.py | 14.47% <ø> (-81.43%) |
:arrow_down: |
| pymc/data.py | 40.62% <33.33%> (-48.85%) |
:arrow_down: |
| pymc/model/fgraph.py | 25.12% <0.00%> (-72.27%) |
:arrow_down: |
| pymc/model_graph.py | 14.28% <20.00%> (-63.11%) |
:arrow_down: |
| pymc/model/core.py | 57.62% <63.15%> (-34.61%) |
:arrow_down: |
| pymc/model/transform/conditioning.py | 15.83% <17.24%> (-79.92%) |
:arrow_down: |
Love it.
Woohoo!
We need to adapt the pymc-examples (and maybe the NBs here too?).
We need to adapt the pymc-examples (and maybe the NBs here too?).
Possibly
It shouldn't fail immediately, just issue a warning.
Only when we remove the kwarg will it fail hard
Hi @ricardoV94!
@michaelosthege and I found out that the latest changes create issues when using pm.ConstantData (or pm.Data) and setting a dtype explicitly. We don't understand why because pytensor.shared has no problem with the dtype argument.
Here is an example:
with pm.Model():
pm.Data("b", [True, False], dtype=bool)
Traceback (most recent call last):
File "<stdin>", line 2, in <module>
File "C:\Users\osthege\AppData\Local\mambaforge\envs\dibecs_6.13.0\lib\site-packages\pymc\data.py", line 420, in Data
x = pytensor.shared(arr, name, **kwargs)
File "C:\Users\osthege\AppData\Local\mambaforge\envs\dibecs_6.13.0\lib\site-packages\pytensor\compile\sharedvalue.py", line 202, in shared
var = shared_constructor(
File "C:\Users\osthege\AppData\Local\mambaforge\envs\dibecs_6.13.0\lib\functools.py", line 889, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
TypeError: tensor_constructor() got an unexpected keyword argument 'dtype'
Do you have any idea why this is happening? Thanks in advance for your help!