TensorComprehensions
TensorComprehensions copied to clipboard
Initial version of reverse mode autodiff
There are still a few things that could have been improved, but I think this can be done in subsequent PRs, and I wanted to get some feedback at this point. Some issues I see:
- It might be nicer to link it as a part of
lang. On the other hand, this needstc2halide, which is part of thecore, and that would create a dependency cycle betweenlibtc_lang.soandlibtc_core.so. - Not sure where the Python bindings should go. It's not really part of the engine, but I don't think we have a different pybind file that fits it better. I can create a new one if you want, just let me know.
yay :) @apaszke , I was wondering if you you give some high level idea of the approach for education purposes? Thanks for adding this :)
cc @abadams who might also have some idea on the autodiff related to Halide.
I can of course answer some more specific questions, but most of this patch is basically a tiny bit of reverse mode AD code, and a ton of defensive programming to perform bookkeeping and shield it from unsupported features.
oh very cool, thanks for the reference @apaszke :) I'll look at that tutorial and then look over this PR.
The bulk of the work in the Halide autodiff was handling non-trivial indexing in ways that preserve parallelism. E.g. consider:
A(i) = 2 * B(some nasty expression in i)
You don't want to compute the derivatives as:
B'(some nasty expression in i) = A'(i) / 2
because you can't necessarily parallelize that over i, so it's going to be slow. This sort of thing comes up as soon as you have a stencil or broadcast (one input influences multiple outputs, when reversed, introduces a race condition).
Dealing with bounds and shapes of the computed derivatives was also interesting in the Halide work. I think that's simpler in TC, because tensors have finite size.
I think this PR is the right sort of approach, and we should do the rewrite of B(nasty) = A'(i) / 2 into B(j) = A'(nasty) / 2 as a later pass inside lowering, using the Halide solver or polyhedral tools.
@abadams Thanks for the review! Great, it seems that we're on the same page! 😄
I agree that shifting the formulas to lhs isn't perfect, but it could easily be done as a post processing step. I know how to implement this, it's just that I'd need a linear equation solver (which is not available at this level from what I understand). For now I decided to bail and put the indexing expressions there, but it's likely to cause errors downstream anyway (I don't think it's supported in TC/Halide). It would be useful to support this in general too.
Bounds are another problem, and you can see that I had to implement a few checks to make sure that they can be still auto-inferred (usedIndexVars + requireAllIndexVarsOf). This is still not perfect because there are a few simple things like where ranges, that also cause an error for now, but are fairly easy to add later. I didn't want to complicate the initial patch so we can quickly get this in, and then work on improvements.
Equation solvers exist once you get down to Halide IR or polyhedral IR. It would be silly to implement yet another one at the TC-front-end-IR level. I'd just allow complex index expressions on the LHS for now, which I think is what you're doing? They may barf inside the tc2halide layer right now, but support for those is an outstanding near-term TODO - we need them for things like histogram computation.
Yep, that's exactly what I've been thinking. Great that we're on the same page!
Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours has expired.
Before we can review or merge your code, we need you to email [email protected] with your details so we can update your status.