pytest icon indicating copy to clipboard operation
pytest copied to clipboard

Allow disabling assertion rewriting at a function level

Open chaudhary1337 opened this issue 1 year ago • 0 comments

Background

I was working on a function decorated with numba.jit(nopython=True) containing an assert statement. The test fails with numba errors, pointing to the assert statement, even though the logic is correct.

Example:

# test_numba.py

from numba import jit
import numpy as np


@jit(nopython=True)
def go_fast(a):
    trace = 0.0
    for i in range(a.shape[0]):
        trace += np.tanh(a[i, i])

    assert (a + trace).all()
    return a + trace


def test_go_fast():
    x = np.arange(4).reshape(2, 2)
    ans = go_fast(x).round(3)
    assert (
        ans.tolist() == [[0.995, 1.995], [2.995, 3.995]]
    )

What's the problem this feature will solve?

pytest assert rewriting deals with dumping and reading the bytecode - something numba interferes with too! This causes problems while running numba-decorated functions.

Describe the solution you'd like

The current solutions are:

  1. Turn off rich pytest asserts for the entirety of the execution with --assert=plain.
  2. Turn off rich pytest asserts per module with PYTEST_DONT_REWRITE in the docstring.

While (2) is better than (1), it requires manual intervention and edits to each of the modules containing any jitted functions and such,

The ideal solution would be to auto-disable pytest rewriting at a function level. That is, on finding a jitted function, pytest can skip rewriting only that. The remaining all tests in a module can still have rich asserts.

Alternative Solutions

My current solution has been to patch rewrite.AssertionRewriter.run. I modified the part where we access ast.FunctionDef etc., to check the node.decorator_list and see if is using any numba decorators. If a nopython=True function has been found, I skip rewriting all asserts at the function level. Remaining logic stays the same.

chaudhary1337 avatar Jul 15 '24 06:07 chaudhary1337