pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Rewrite scalar dot as multiplication

Open ricardoV94 opened this issue 10 months ago • 1 comments

Description

In https://github.com/pymc-devs/pytensor/pull/1178 we rewrite batched dots that are just multiplication away, but left core dots the same due to use of BLAS operations for those (whether they are worth it or not is a question on its own). But there is one case that is definitely not worth it: scalar multiplication.

The following graph should definitely be simplified:

import pytensor
import pytensor.tensor as pt

x = pt.tensor("x", shape=(1, 1))
y = pt.tensor("y", shape=(1, 1))
out = x @ y
pytensor.function([x, y], out).dprint()
CGer{non-destructive} [id A] 2
 ├─ [[0.]] [id B]
 ├─ 1.0 [id C]
 ├─ DropDims{axis=1} [id D] 1
 │  └─ x [id E]
 └─ DropDims{axis=0} [id F] 0
    └─ y [id G]

Or without BLAS stuff

pytensor.function([x, y], out, mode="FAST_COMPILE").dprint()
Dot22 [id A] 0
 ├─ x [id B]
 └─ y [id C]

Those should just be mul because that can be fused with other Elemwise operations (and calling BLAS for it is the silliest thing ever)

ricardoV94 avatar Feb 12 '25 17:02 ricardoV94

We should also consider the remaining cases that are just multiplication, specially in non-default backends where the BLAS question is completely irrelevant. Even in the C-backend I saw many cases where it was faster without BLAS (but some where it was slower :( )

ricardoV94 avatar Feb 12 '25 17:02 ricardoV94