polars icon indicating copy to clipboard operation
polars copied to clipboard

Why nested numba function too slow?

Open wukan1986 opened this issue 2 years ago • 6 comments

Research

  • [X] I have searched the above polars tags on Stack Overflow for similar questions.

  • [ ] I have asked my usage related question on Stack Overflow.

Link to question on Stack Overflow

https://stackoverflow.com/questions/55291117/nested-numba-function-performance

Question about Polars

Why nested numba function too slow?

with nested function, it seems that multithreading has become a single thread


import time

import numba
import numpy as np
import polars as pl
from polars import LazyFrame

print(pl.__version__)
"""
0.15.7
"""
print(numba.__version__)
"""
0.56.4
"""

arr = np.ones((10000, 5), dtype=np.float64)
arr[:, 0] = np.arange(len(arr)) // 10
df = pl.from_numpy(arr)


def func1(df: LazyFrame):
    df = df.with_columns([
        pl.col('column_1').cumsum().alias('column_1_cumsum')
    ])
    return df


def func2(df: LazyFrame):
    df = df.with_columns([
        pl.col('column_2').map(lambda x: pl.Series(np.cumsum(x))).alias('column_2_cumsum')
    ])
    return df


@numba.jit(nopython=True, cache=True, nogil=True)
def cumsum_nb3(x):
    s = 0
    y = x.copy()
    for i in range(len(x)):
        s += x[i]
        y[i] = s
    return y


def cumsum_py3(x):
    return pl.Series(cumsum_nb3(x.to_numpy()))


def cumsum_py4(x):
    @numba.jit(nopython=True, cache=True, nogil=True)
    def _cumsum_nb4(x):
        s = 0
        y = x.copy()
        for i in range(len(x)):
            s += x[i]
            y[i] = s
        return y

    return pl.Series(_cumsum_nb4(x.to_numpy()))


def func3(df: LazyFrame):
    df = df.with_columns([
        pl.col('column_3').map(lambda x: cumsum_py3(x)).alias('column_3_cumsum')
    ])
    return df


def func4(df: LazyFrame):
    df = df.with_columns([
        pl.col('column_4').map(lambda x: cumsum_py4(x)).alias('column_4_cumsum')
    ])
    return df


t0 = time.perf_counter()
df = df.groupby('column_0').apply(lambda x: func1(x))
t1 = time.perf_counter()
df = df.groupby('column_0').apply(lambda x: func2(x))
t2 = time.perf_counter()
df = df.groupby('column_0').apply(lambda x: func3(x))
t3 = time.perf_counter()
df = df.groupby('column_0').apply(lambda x: func4(x))
t4 = time.perf_counter()

print('polars cumsum', t1 - t0)
print('numpy  cumsum', t2 - t1)
print('numba  cumsum', t3 - t2)
print('numba  cumsum(nested)', t4 - t3)
"""
polars cumsum 0.11049989999999998
numpy  cumsum 0.23705550000000009
numba  cumsum 0.46816349999999995
numba  cumsum(nested) 7.726379700000001
"""

wukan1986 avatar Dec 21 '22 06:12 wukan1986

Maybe because _cumsum_nb4 needs to be compiled several times ?

gab23r avatar Dec 21 '22 08:12 gab23r

@gab23r I have use cache=True

I think single thread is key point

wukan1986 avatar Dec 21 '22 08:12 wukan1986

We hold the GIL during a groupby. Any help on adapting the python apply so that we release the GIL would be welcome.

ritchie46 avatar Dec 21 '22 09:12 ritchie46

cumsum_py3 and cumsum_py4 are hold the GIL, why cumsum_py4 so slow?

they are same code, except the numba function position.

wukan1986 avatar Dec 21 '22 12:12 wukan1986

I don't think the issue is related to Polars, but just to numba itself:

# Create a version that uses a numpy array as input instead of a polars Series.
In [150]: def cumsum_py7(x):
     ...:     @numba.jit(nopython=True, cache=True, nogil=True)
     ...:     def _cumsum_nb7(x):
     ...:         s = 0
     ...:         y = x.copy()
     ...:         for i in range(x.shape[0]):
     ...:             s += x[i]
     ...:             y[i] = s
     ...:         return y
     ...: 
     ...:     return _cumsum_nb7(x)
     ...: 

# Get a column as a numpy array
In [151]: x = df["column_4"].to_numpy()

# External numba jit function with polars Series as input.
In [153]: %timeit cumsum_py3(pl.Series("x", x))
38.6 µs ± 209 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

# Nested numba jit function with polars Series as input.
In [152]: %timeit cumsum_py4(pl.Series("x", x))
7.71 ms ± 476 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

# Nested numba jit function with pure numpy array as input.
In [155]: %timeit cumsum_py7(x)
8.21 ms ± 628 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

ghuls avatar Dec 23 '22 10:12 ghuls

An old issue, but very likely still relevant: https://github.com/numba/numba/issues/1810

ghuls avatar Dec 23 '22 10:12 ghuls

(Closing as this is how numba behaves with nested jit compilation; not really a polars issue).

alexander-beedie avatar Feb 01 '23 11:02 alexander-beedie