Basic support for opfromgraph
Torch OpFromGraph
Allows us to precompile a subset of the graph via torch.
Related Issue
- [ ] Closes #
- [ ] Related to #939
Checklist
- [ ] Checked that the pre-commit linting/style checks pass
- [X] Included tests that prove the fix is effective or that the new feature works
- [ ] Added necessary documentation (docstrings and/or example notebooks)
- [X] If you are a pro: each commit corresponds to a relevant logical change
Type of change
- [X] New feature / enhancement
- [ ] Bug fix
- [ ] Documentation
- [ ] Maintenance
- [ ] Other (please specify):
@ricardoV94 I noted what the changed ended up being, I'm not sure if that is the right call but at least for now the tests pass so I'm going to put it into ready. It will fail on the build machine with this because we don't okay eager mode
FAILED tests/link/pytorch/test_basic.py::test_pytorch_OpFromGraph - torch._dynamo.exc.InternalTorchDynamoError: module 'pytensor.link.utils' has no attribute 'elemwise_fn'
Codecov Report
All modified and coverable lines are covered by tests :white_check_mark:
Project coverage is 81.74%. Comparing base (
3e55a20) to head (0f18d8d). Report is 119 commits behind head on main.
Additional details and impacted files
@@ Coverage Diff @@
## main #956 +/- ##
=======================================
Coverage 81.74% 81.74%
=======================================
Files 183 183
Lines 47724 47733 +9
Branches 11616 11616
=======================================
+ Hits 39011 39020 +9
Misses 6518 6518
Partials 2195 2195
| Files with missing lines | Coverage Δ | |
|---|---|---|
| pytensor/link/pytorch/dispatch/basic.py | 93.93% <100.00%> (+0.60%) |
:arrow_up: |
@Ch0ronomato I tried something different in a commit I just pushed. I'm disabling only one level of the stack, so the inner code is still compiled? Does that make sense?
The problem we're facing seems to be with PyTorch trying to import the dynamically generated elemwise_fn from the module where fgraph_to_python is defined, but this is non-sensical as that function is defined in an inner scope, not at the module level. I also checked that allow_in_graph works, but according to the torch docs this disables safety checks, and shouldn't be used with functions that can mutate the inputs, so I don't think it's safe to do in our case?
I see - yea using the fine grain api does seem to be the only way. I like that yours doesn't do recursive disabling, where the api I used i think does. I like it!
@ricardoV94 - I do see the same issue (I think) with this impl https://github.com/pymc-devs/pytensor/actions/runs/10763803733/job/29845771297?pr=988.
One other thing I'm not sure I get - why does the error go away when we remove doing if / else things declared outside the inner function.
why does the error go away when we remove doing if / else things declared outside the inner function.
What do you mean. Can you show the code that doesn't error out vs the one that does?
Seems like we need to resolve some trivial conflicts.