pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Fix `OpFromGraph` with disconnected output gradients

Open ricardoV94 opened this issue 10 months ago • 2 comments

Description

PyTensor uses DisconnectedType and NullType variables to raise informative errors when users request gradients wrt to inputs that can't be computed. This is a problem for OpFromGraph which may include parallel graphs, some of which are disconnected/null and others not. We don't want to fail when the user only needs the gradient that's supported.

There was already some special logic before, to handle cases where NullType and DisconnectedType arise from the OFG inner graph. Instead of outputing those types (which OFG cannot produce out of thin air, as they are root variables), we were outputing dummy zeros, and then masking those with the original NullType or DisconnectedType variables created in the internal call to grad/Rop. This seems reasonable if only a bit tedious. This PR first refactors this code to avoid the dummy outputs altogether (there's no reason for them!).

Then it extends this logic to also handle cases where NullType/DisconnectedType arise before the inner graph of OpFromGraph. This was the case behind one of the issues described in #1. When an OFG has multiple outputs, and the requested gradient only uses a subset, PyTensor will feed DisconnectedType variables in place of the output_gradients used by the L_op. The solution to this problem is to filter out these unused input variables. This should be safe(?), in that if the inner graph of the OFG needs to use these variables and we don't provide them, an error will naturally be raised that they are missing.

This however means we may need distinct OFG from different patterns of disconnected gradients. Accordingly, the cache is now done per pattern.

I suspect this is the issue behind #652

This PR also deprecates grad_overrides and custom logic for invalid connection_patterns. Hopefully this helps us making OpFromGrah more maintainable.

Related Issue

  • [x] Closes #1
  • [x] Related to #652

Checklist

Type of change

  • [ ] New feature / enhancement
  • [x] Bug fix
  • [ ] Documentation
  • [ ] Maintenance
  • [ ] Other (please specify):

ricardoV94 avatar Apr 20 '24 20:04 ricardoV94

Codecov Report

Attention: Patch coverage is 79.50820% with 25 lines in your changes missing coverage. Please review.

Project coverage is 80.94%. Comparing base (fc21336) to head (a96e5a1). Report is 208 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/compile/builders.py 79.50% 14 Missing and 11 partials :warning:
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #723      +/-   ##
==========================================
+ Coverage   80.85%   80.94%   +0.08%     
==========================================
  Files         162      162              
  Lines       47043    46945      -98     
  Branches    11514    11481      -33     
==========================================
- Hits        38038    37998      -40     
+ Misses       6750     6706      -44     
+ Partials     2255     2241      -14     
Files with missing lines Coverage Δ
pytensor/gradient.py 77.37% <ø> (+0.54%) :arrow_up:
pytensor/compile/builders.py 88.38% <79.50%> (+10.93%) :arrow_up:

... and 4 files with indirect coverage changes

codecov[bot] avatar Apr 20 '24 21:04 codecov[bot]

Superficial first pass across the PR. I cannot make informed comment about the actual meat of the changes until I fire up a debugger and try to grok what OpFromGraph is actually doing. I will make an effort to do this in the next 48 hours and give a more meaningful review.

Thanks! It may help to convince yourself that no behavior was changed until commit -2 where the bug fix is done (other than deprecations and removal of special behavior in connection pattern)

ricardoV94 avatar Apr 21 '24 11:04 ricardoV94

I found another issue, if the outputs of an OpFromGraph are not independent, the existing logic fails in that instead of adding the contributions coming from each output, it overrides due to how known_grads we are using internally behaves.

The new test cases in the last commit illustrate this. Any case that depends on out3 fails numerically because we ignore/mask the contributions coming from it.

x, y = dscalars("x", "y")
rng = np.random.default_rng(594)
point = list(rng.normal(size=(2,)))

out1 = x + y
out2 = x * y
out3 = out1 + out2  # Create dependency between outputs
op = OpFromGraph([x, y], [out1, out2, out3])
verify_grad(lambda x, y: pt.add(*op(x, y)), point, rng=rng)
verify_grad(lambda x, y: pt.add(*op(x, y)[:-1]), point, rng=rng)
verify_grad(lambda x, y: pt.add(*op(x, y)[1:]), point, rng=rng)
verify_grad(lambda x, y: pt.add(*op(x, y)[::2]), point, rng=rng)
verify_grad(lambda x, y: op(x, y)[0], point, rng=rng)
verify_grad(lambda x, y: op(x, y)[1], point, rng=rng)
verify_grad(lambda x, y: op(x, y)[2], point, rng=rng)

If instead we defined out3 explicitly as out3 = (x + y) * (x * y) it works fine again

@aseyboldt any idea how we could handle this? In an outer function I think this would be handled by adding the direct contributions to out1/out2 with the inderect ones coming from out3

It seems like I want to initialize those variable grads to the output_grad values, but still allow them to be updated, and not setting them as known which doesn't allow any further updates?

ricardoV94 avatar May 29 '24 10:05 ricardoV94

Found a nice(?) hack. Instead of calling Lop internally with known_grads=dict(zip(inner_outputs, output_gradients)) I do it with known_grads(dict(zip(identity_inner_outputs, output_gradients)) where identity_inner_outputs is each inner_output wrapped in a dummy Identity operation. This way we correctly accumulate direct and indirect contributions coming from other inner outputs

ricardoV94 avatar May 29 '24 10:05 ricardoV94