pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

A graph replace that changes the type from dynamic to static can lead to miscompilation

Open aseyboldt opened this issue 1 year ago • 2 comments

Description

If I replace a tensor with shape (None,) by one with static shape (1,) llvm aborts during codegen for me. Maybe some rewrite doesn't behave properly, and we produce invalid llvm code somewhere?

import pytensor.tensor as pt
import pytensor

x = pt.vector("x")
logp = -(x ** 2).sum()

pytensor.dprint(logp, print_type=True)

# Works if I set the shape to x.shape
#x_known = pt.vector("x_known", shape=x.type.shape)
x_known = pt.vector("x_known", shape=(1,))

logp = pytensor.graph_replace(logp, {x: x_known})

print("after")
pytensor.dprint(logp, print_type=True)
pytensor.function((x_known,), logp, mode=pytensor.compile.NUMBA)

Maybe we should just check that we don't change the type in a graph_replace?

aseyboldt avatar May 02 '24 11:05 aseyboldt

Here is a more direct example:

import pytensor.tensor as pt
import pytensor

x = pt.vector("x", shape=(1,))
pow_node = pytensor.graph.Apply(
    pt.pow,
    [x, pt.as_tensor(2)[None]],
    [pt.vector("x", shape=(None,))]
)
pow_out = pow_node.default_output()

func = pytensor.function((x,), pow_out, mode="NUMBA")
pytensor.dprint(func, print_destroy_map=True)
func([0.5])

Edit: Updated

ricardoV94 avatar May 02 '24 12:05 ricardoV94

I also opened a numba issue: https://github.com/numba/numba/issues/9554

aseyboldt avatar May 02 '24 14:05 aseyboldt