Implement ScalarLoop in torch backend
Description
Adds ScalarLoop for pytorch. I do it as a loop as opposed to trying to vectorize it...lmk if I should go that approach or not.
Related Issue
- [ ] Closes #
- [ ] Related to #939
Checklist
- [X] Checked that the pre-commit linting/style checks pass
- [X] Included tests that prove the fix is effective or that the new feature works
- [ ] Added necessary documentation (docstrings and/or example notebooks)
- [X] If you are a pro: each commit corresponds to a relevant logical change
Type of change
- [X] New feature / enhancement
- [ ] Bug fix
- [ ] Documentation
- [ ] Maintenance
- [ ] Other (please specify):
@Ch0ronomato thanks for taking a stab, I left some comments above
Codecov Report
Attention: Patch coverage is 81.25000% with 9 lines in your changes missing coverage. Please review.
Project coverage is 82.10%. Comparing base (
ef97287) to head (521ad67). Report is 131 commits behind head on main.
Additional details and impacted files
@@ Coverage Diff @@
## main #958 +/- ##
==========================================
+ Coverage 82.09% 82.10% +0.01%
==========================================
Files 183 185 +2
Lines 48010 48130 +120
Branches 8653 8669 +16
==========================================
+ Hits 39412 39519 +107
- Misses 6435 6444 +9
- Partials 2163 2167 +4
| Files with missing lines | Coverage Ξ | |
|---|---|---|
| pytensor/link/pytorch/linker.py | 100.00% <100.00%> (ΓΈ) |
|
| pytensor/link/pytorch/dispatch/elemwise.py | 69.11% <81.81%> (+2.45%) |
:arrow_up: |
| pytensor/link/pytorch/dispatch/scalar.py | 74.07% <77.27%> (+2.19%) |
:arrow_up: |
π New features to boost your workflow:
- β Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
@ricardoV94 - these failures in the CI look a bit strange; i'll look into them before merging...hopefully they go away with merging main π
@ricardoV94 #1031 is blocking the elemwise test - how do you want to proceed with this pr?
@ricardoV94 #1031 is blocking the elemwise test - how do you want to proceed with this pr?
If we can't elemwise it there's not much point to the ScalarLoop. Maybe we need to loop manually instead of vmap for this Op
How is unbind(0) different than [x[i] for i in x.size()[0]]?
How is unbind(0) different than
[x[i] for i in x.size()[0]]?
https://discuss.pytorch.org/t/the-purpose-of-unbind/98648
It's essentially the same, maybe faster
How is unbind(0) different than
[x[i] for i in x.size()[0]]?https://discuss.pytorch.org/t/the-purpose-of-unbind/98648
It's essentially the same, maybe faster
But if we index in the loop after raveling we don't need all the slices in memory. This is looking like a custom Elemwise with explicit broadcasting:
bcasted_inputs = boradcast_arrays(*inputs)
raveled_inputs = [inp.ravel() for inp in bcasted_inputs]
out_shape = bcasted_inputs[0].size()
out_size = out_shape.nelem()
raveled_outputs = [torch.empty(out_size, dtype=out.dtype) for out in node.outputs]
for i in range(out_size):
core_outs = core_func(*(inp[i] for i in raveled_inputs))
if len(n_outputs) == 1:
raveled_outputs[0][i] = core_outs
else:
for o in range(n_outputs):
raveled_outputs[o][i] = core_outs[o]
outputs = tuple(out.view(out_shape) for out in raveled_outputs)
if n_outputs == 1:
return outputs[0]
else:
return outputs
Also note that nothing is specific to scalar loop, so it can be a (non-performant) fallback for all sorts of Elemwise
That looks great. I think we'll still need to have some dispatch logic to know what can't be vmap'd; do we want to keep the current method? How does your approach merge with #1032?
That looks great. I think we'll still need to have some dispatch logic to know what can't be vmap'd; do we want to keep the current method?
Yes this can be a fallback only for registered Ops (and specifically only ScalarLoop at the time being).
If my suggestion works it should be better than the nested unbind unless torch is really weird
I can squash commits...
Also got conflicts here
@ricardoV94 this is ready (it'll probably have conflicts with #1066)