High Overhead of `pprint.pformat` Calls During Compilation
Reporting a bug
- [x] I have tried using the latest released version of Numba (most recent is visible in the release notes (https://numba.readthedocs.io/en/stable/release-notes-overview.html).
- [x] I have included a self contained code sample to reproduce the problem. i.e. it's possible to run as 'python bug.py'.
Based on a performance test run in issue https://github.com/numba/numba/issues/9700, we tested a larger-scale numerical algorithm using Numba and noticed that about 5% of compilation time was spent in the pprint Python module. Looking further, there are all from calls to pprint.pformat and not to pprint.pprint.
cProfile Output
cProfile output filtered for only pprint.py:
146771/73897 0.076 0.000 0.256 0.000 pprint.py:466(format)
146771/73897 0.074 0.000 0.224 0.000 pprint.py:554(_safe_repr)
73897 0.040 0.000 0.304 0.000 pprint.py:457(_repr)
39796/1600 0.028 0.000 0.441 0.000 pprint.py:171(_format)
1546 0.028 0.000 0.226 0.000 pprint.py:380(_format_dict_items)
66043 0.028 0.000 0.035 0.000 pprint.py:102(_safe_tuple)
207753 0.013 0.000 0.013 0.000 pprint.py:95(__lt__)
132086 0.007 0.000 0.007 0.000 pprint.py:92(__init__)
2770/1402 0.003 0.000 0.025 0.000 pprint.py:416(_format_items)
1600 0.002 0.000 0.447 0.000 pprint.py:57(pformat)
1546 0.002 0.000 0.277 0.000 pprint.py:209(_pprint_dict)
1600 0.002 0.000 0.443 0.000 pprint.py:159(pformat)
1368 0.001 0.000 0.012 0.000 pprint.py:247(_pprint_tuple)
1600 0.001 0.000 0.001 0.000 pprint.py:107(__init__)
1402 0.001 0.000 0.026 0.000 pprint.py:239(_pprint_list)
90 0.000 0.000 0.056 0.001 pprint.py:473(_pprint_default_dict)
Looking through the codebase, I see that:
-
pprint.pformatis mainly used in logging or debugging code - All of the uses in
numba/core/byteflow.pyuse a intermediate class_lazy_pformatin order to lazily callpformatonly when logging is enabled.
In more complex cases, like inside of Bodo or when the IR is larger, pformat overhead is even higher. Thus, we should use _lazy_pformat in most of the other cases.
- Whenever its an argument to logging
- In one situation where we pass data to
ev.trigger_event, which I believe is for compiler debugging purposes only
bug.py
import numba
import cProfile
import numpy as np
import pyinstrument
import time
import logging
logger = logging.getLogger()
logger.disabled = True
def interpolate(npulses, nu, nv, k_ui, k_vi, ku, kv, phs, win1, win2):
# Radially interpolate kx and ky data from polar raster
# onto evenly spaced kx_i and ky_i grid for each pulse
real_rad_interp = np.zeros((npulses, nu))
imag_rad_interp = np.zeros((npulses, nu))
ky_new = np.zeros((npulses, nu))
for i in numba.prange(npulses):
# print('range interpolating for pulse %i'%(i+1))
real_rad_interp[i, :] = np.interp(
# Numba code change: left/right arguments not supported
k_ui,
ku[i, :],
phs.real[i, :] * win1, # left=0, right=0
)
imag_rad_interp[i, :] = np.interp(
# Numba code change: left/right arguments not supported
k_ui,
ku[i, :],
phs.imag[i, :] * win1, # left=0, right=0
)
ky_new[i, :] = np.interp(k_ui, ku[i, :], kv[i, :])
# Interpolate in along track direction to obtain polar formatted data
real_polar = np.zeros((nv, nu))
imag_polar = np.zeros((nv, nu))
isSort = ky_new[npulses // 2, nu // 2] < ky_new[npulses // 2 + 1, nu // 2]
if isSort:
for i in numba.prange(nu):
# print('cross-range interpolating for sample %i'%(i+1))
real_polar[:, i] = np.interp(
# Numba code change: left/right arguments not supported
k_vi,
ky_new[:, i],
real_rad_interp[:, i] * win2, # left=0, right=0
)
imag_polar[:, i] = np.interp(
# Numba code change: left/right arguments not supported
k_vi,
ky_new[:, i],
imag_rad_interp[:, i] * win2, # left=0, right=0
)
else:
for i in numba.prange(nu):
# print('cross-range interpolating for sample %i'%(i+1))
real_polar[:, i] = np.interp(
# Numba code change: left/right arguments not supported
k_vi,
ky_new[::-1, i],
real_rad_interp[::-1, i] * win2, # left=0, right=0
)
imag_polar[:, i] = np.interp(
# Numba code change: left/right arguments not supported
k_vi,
ky_new[::-1, i],
imag_rad_interp[::-1, i] * win2, # left=0, right=0
)
real_polar = np.nan_to_num(real_polar)
imag_polar = np.nan_to_num(imag_polar)
phs_polar = np.nan_to_num(real_polar + 1j * imag_polar)
return phs_polar
def main():
dispatcher = numba.njit(
(
numba.int64, numba.int64, numba.int64,
numba.float64[:], numba.float64[:],
numba.float64[:, :], numba.float64[:, :],
numba.complex64[:, :],
numba.float64[:], numba.float64[:]
),
parallel=True
)(interpolate)
cProfile.run('main()', sort="tottime")
I have a draft implemented that gets rid of all pprint.py uses in our test script and reduces compilation time by ~5%. Will open a PR
I have a draft implemented that gets rid of all
pprint.pyuses in our test script and reduces compilation time by ~5%. Will open a PR
Thanks for writing this patch. Fixing this has been on the to-do list for a while. I've also observed pprint debugging/logging calls showing up relatively high in performance profiles.