pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Implement ScalarLoop in torch backend

Open Ch0ronomato opened this issue 1 year ago β€’ 3 comments

Description

Adds ScalarLoop for pytorch. I do it as a loop as opposed to trying to vectorize it...lmk if I should go that approach or not.

Related Issue

  • [ ] Closes #
  • [ ] Related to #939

Checklist

Type of change

  • [X] New feature / enhancement
  • [ ] Bug fix
  • [ ] Documentation
  • [ ] Maintenance
  • [ ] Other (please specify):

Ch0ronomato avatar Aug 01 '24 03:08 Ch0ronomato

@Ch0ronomato thanks for taking a stab, I left some comments above

ricardoV94 avatar Aug 03 '24 14:08 ricardoV94

Codecov Report

Attention: Patch coverage is 81.25000% with 9 lines in your changes missing coverage. Please review.

Project coverage is 82.10%. Comparing base (ef97287) to head (521ad67). Report is 131 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/pytorch/dispatch/scalar.py 77.27% 4 Missing and 1 partial :warning:
pytensor/link/pytorch/dispatch/elemwise.py 81.81% 2 Missing and 2 partials :warning:
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #958      +/-   ##
==========================================
+ Coverage   82.09%   82.10%   +0.01%     
==========================================
  Files         183      185       +2     
  Lines       48010    48130     +120     
  Branches     8653     8669      +16     
==========================================
+ Hits        39412    39519     +107     
- Misses       6435     6444       +9     
- Partials     2163     2167       +4     
Files with missing lines Coverage Ξ”
pytensor/link/pytorch/linker.py 100.00% <100.00%> (ΓΈ)
pytensor/link/pytorch/dispatch/elemwise.py 69.11% <81.81%> (+2.45%) :arrow_up:
pytensor/link/pytorch/dispatch/scalar.py 74.07% <77.27%> (+2.19%) :arrow_up:

... and 6 files with indirect coverage changes

πŸš€ New features to boost your workflow:
  • ❄ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov[bot] avatar Aug 11 '24 19:08 codecov[bot]

@ricardoV94 - these failures in the CI look a bit strange; i'll look into them before merging...hopefully they go away with merging main πŸ˜“

Ch0ronomato avatar Sep 19 '24 22:09 Ch0ronomato

@ricardoV94 #1031 is blocking the elemwise test - how do you want to proceed with this pr?

Ch0ronomato avatar Oct 20 '24 23:10 Ch0ronomato

@ricardoV94 #1031 is blocking the elemwise test - how do you want to proceed with this pr?

If we can't elemwise it there's not much point to the ScalarLoop. Maybe we need to loop manually instead of vmap for this Op

ricardoV94 avatar Oct 21 '24 06:10 ricardoV94

How is unbind(0) different than [x[i] for i in x.size()[0]]?

ricardoV94 avatar Nov 12 '24 15:11 ricardoV94

How is unbind(0) different than [x[i] for i in x.size()[0]]?

https://discuss.pytorch.org/t/the-purpose-of-unbind/98648

It's essentially the same, maybe faster

Ch0ronomato avatar Nov 12 '24 16:11 Ch0ronomato

How is unbind(0) different than [x[i] for i in x.size()[0]]?

https://discuss.pytorch.org/t/the-purpose-of-unbind/98648

It's essentially the same, maybe faster

But if we index in the loop after raveling we don't need all the slices in memory. This is looking like a custom Elemwise with explicit broadcasting:

bcasted_inputs = boradcast_arrays(*inputs)
raveled_inputs = [inp.ravel() for inp in bcasted_inputs]

out_shape = bcasted_inputs[0].size()
out_size = out_shape.nelem()
raveled_outputs = [torch.empty(out_size, dtype=out.dtype) for out in node.outputs]

for i in range(out_size):
  core_outs = core_func(*(inp[i] for i in raveled_inputs))
  if len(n_outputs) == 1:
    raveled_outputs[0][i] = core_outs
  else:
    for o in range(n_outputs):
      raveled_outputs[o][i] = core_outs[o]

outputs = tuple(out.view(out_shape) for out in raveled_outputs)
if n_outputs == 1:
  return outputs[0]
else:
  return outputs

Also note that nothing is specific to scalar loop, so it can be a (non-performant) fallback for all sorts of Elemwise

ricardoV94 avatar Nov 12 '24 16:11 ricardoV94

That looks great. I think we'll still need to have some dispatch logic to know what can't be vmap'd; do we want to keep the current method? How does your approach merge with #1032?

Ch0ronomato avatar Nov 12 '24 16:11 Ch0ronomato

That looks great. I think we'll still need to have some dispatch logic to know what can't be vmap'd; do we want to keep the current method?

Yes this can be a fallback only for registered Ops (and specifically only ScalarLoop at the time being).

ricardoV94 avatar Nov 12 '24 16:11 ricardoV94

If my suggestion works it should be better than the nested unbind unless torch is really weird

ricardoV94 avatar Nov 12 '24 16:11 ricardoV94

I can squash commits...

Ch0ronomato avatar Nov 25 '24 04:11 Ch0ronomato

Also got conflicts here

ricardoV94 avatar Nov 25 '24 16:11 ricardoV94

@ricardoV94 this is ready (it'll probably have conflicts with #1066)

Ch0ronomato avatar Nov 25 '24 17:11 Ch0ronomato