pytensor
pytensor copied to clipboard
Fix `OpFromGraph` with disconnected output gradients
Description
PyTensor uses DisconnectedType
and NullType
variables to raise informative errors when users request gradients wrt to inputs that can't be computed. This is a problem for OpFromGraph which may include parallel graphs, some of which are disconnected/null and others not. We don't want to fail when the user only needs the gradient that's supported.
There was already some special logic before, to handle cases where NullType
and DisconnectedType
arise from the OFG inner graph. Instead of outputing those types (which OFG cannot produce out of thin air, as they are root variables), we were outputing dummy zeros, and then masking those with the original NullType
or DisconnectedType
variables created in the internal call to grad
/Rop
. This seems reasonable if only a bit tedious. This PR first refactors this code to avoid the dummy outputs altogether (there's no reason for them!).
Then it extends this logic to also handle cases where NullType
/DisconnectedType
arise before the inner graph of OpFromGraph. This was the case behind one of the issues described in #1. When an OFG has multiple outputs, and the requested gradient only uses a subset, PyTensor will feed DisconnectedType
variables in place of the output_gradients
used by the L_op
. The solution to this problem is to filter out these unused input variables. This should be safe(?), in that if the inner graph of the OFG needs to use these variables and we don't provide them, an error will naturally be raised that they are missing.
This however means we may need distinct OFG from different patterns of disconnected gradients. Accordingly, the cache is now done per pattern.
I suspect this is the issue behind #652
This PR also deprecates grad_overrides
and custom logic for invalid connection_patterns
. Hopefully this helps us making OpFromGrah more maintainable.
Related Issue
- [x] Closes #1
- [x] Related to #652
Checklist
- [ ] Checked that the pre-commit linting/style checks pass
- [ ] Included tests that prove the fix is effective or that the new feature works
- [ ] Added necessary documentation (docstrings and/or example notebooks)
- [ ] If you are a pro: each commit corresponds to a relevant logical change
Type of change
- [ ] New feature / enhancement
- [x] Bug fix
- [ ] Documentation
- [ ] Maintenance
- [ ] Other (please specify):
Codecov Report
Attention: Patch coverage is 79.50820%
with 25 lines
in your changes missing coverage. Please review.
Project coverage is 80.94%. Comparing base (
fc21336
) to head (a96e5a1
). Report is 208 commits behind head on main.
Files with missing lines | Patch % | Lines |
---|---|---|
pytensor/compile/builders.py | 79.50% | 14 Missing and 11 partials :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## main #723 +/- ##
==========================================
+ Coverage 80.85% 80.94% +0.08%
==========================================
Files 162 162
Lines 47043 46945 -98
Branches 11514 11481 -33
==========================================
- Hits 38038 37998 -40
+ Misses 6750 6706 -44
+ Partials 2255 2241 -14
Files with missing lines | Coverage Δ | |
---|---|---|
pytensor/gradient.py | 77.37% <ø> (+0.54%) |
:arrow_up: |
pytensor/compile/builders.py | 88.38% <79.50%> (+10.93%) |
:arrow_up: |
Superficial first pass across the PR. I cannot make informed comment about the actual meat of the changes until I fire up a debugger and try to grok what OpFromGraph is actually doing. I will make an effort to do this in the next 48 hours and give a more meaningful review.
Thanks! It may help to convince yourself that no behavior was changed until commit -2 where the bug fix is done (other than deprecations and removal of special behavior in connection pattern)
I found another issue, if the outputs of an OpFromGraph are not independent, the existing logic fails in that instead of adding the contributions coming from each output, it overrides due to how known_grads
we are using internally behaves.
The new test cases in the last commit illustrate this. Any case that depends on out3
fails numerically because we ignore/mask the contributions coming from it.
x, y = dscalars("x", "y")
rng = np.random.default_rng(594)
point = list(rng.normal(size=(2,)))
out1 = x + y
out2 = x * y
out3 = out1 + out2 # Create dependency between outputs
op = OpFromGraph([x, y], [out1, out2, out3])
verify_grad(lambda x, y: pt.add(*op(x, y)), point, rng=rng)
verify_grad(lambda x, y: pt.add(*op(x, y)[:-1]), point, rng=rng)
verify_grad(lambda x, y: pt.add(*op(x, y)[1:]), point, rng=rng)
verify_grad(lambda x, y: pt.add(*op(x, y)[::2]), point, rng=rng)
verify_grad(lambda x, y: op(x, y)[0], point, rng=rng)
verify_grad(lambda x, y: op(x, y)[1], point, rng=rng)
verify_grad(lambda x, y: op(x, y)[2], point, rng=rng)
If instead we defined out3 explicitly as out3 = (x + y) * (x * y)
it works fine again
@aseyboldt any idea how we could handle this? In an outer function I think this would be handled by adding the direct contributions to out1/out2 with the inderect ones coming from out3
It seems like I want to initialize those variable grads to the output_grad values, but still allow them to be updated, and not setting them as known which doesn't allow any further updates?
Found a nice(?) hack. Instead of calling Lop internally with known_grads=dict(zip(inner_outputs, output_gradients))
I do it with known_grads(dict(zip(identity_inner_outputs, output_gradients))
where identity_inner_outputs
is each inner_output
wrapped in a dummy Identity
operation. This way we correctly accumulate direct and indirect contributions coming from other inner outputs