reproject
reproject copied to clipboard
Use multi-threading instead of multi-processing
The main functions in reproject, such as reproject_interp
, accept parallel=
and block_size=
arguments which, if used, will leverage dask behind the scenes to split up the data into chunks and then use a multi-processing scheduler to distribute the work.
Ideally we should be using multi-threading instead of multi-processing, but currently we don't because there appear to be some issues with some output pixels not having the right value when using multi-threading.
Fixing this will provide two main benefits:
- Avoid the whole dumping arrays to memmaps and being able to just use input arrays as-is without duplication
- Avoiding nasty surprises for users who select
return_type='dask'
then use the default scheduler to compute the array
Here's a notebook illustrating the issues: https://gist.github.com/astrofrog/e8808ee3ee8b7b86a979e0cb305d518b - note that while there is no explicit mention of threads anywhere, in the compare
function array2
is a dask array and when it gets passed to Matplotlib, .compute()
gets called and the default scheduler uses threads.
At the moment all algorithms seem to have issues though all appear to be different. Note that for adaptive I sometimes have to run a few times to see issues.
@svank - just out of curiosity, do you have any sense of what could be making the adaptive code not be thread-safe?
My first thought was about how the Cython function calls back into Python-land for the coordinates calculations, but I guess that's just a GIL matter rather than a thread-safety thing (and wouldn't clearly cause the glitch your notebook shows). I'll have to play with it a bit
I just played with this a little bit and found that if I pass roundtrip_coords=False
, the multi-threading output glitches are much more rare (about one in ten runs, versus almost every run with roundtrip_coords=True
). Could (at least one) problem be in astropy.coordinates
(or wcslib
?)?
Here's an expanded notebook: https://gist.github.com/svank/d63ef6bdf4e146577d7a78111ad85855
@svank yes I think I've come to the same conclusion that this could be in astropy.wcs
or in astropy.coordinates
(perhaps in ERFA if it is called), and the roundtrip_coords=True
option (default) makes it so more WCS transformations are happening and increases the chance of issues. I switched to using a custom APE 14 WCS that does not use astropy.wcs
and the issues disappear.
Interestingly https://www.atnf.csiro.au/people/mcalabre/WCS/wcslib/threads.html says that WCSLIB is basically thread-safe but https://www.gnu.org/software/gnuastro/manual/html_node/World-Coordinate-System.html#:~:text=The%20wcsprm%20structure%20of%20WCSLIB,the%20same%20wcsprm%20structure%20pointer. says:
The wcsprm structure of WCSLIB is not thread-safe: you can’t use the same pointer on multiple threads. For example, if you use gal_wcs_img_to_world simultaneously on multiple threads, you shouldn’t pass the same wcsprm structure pointer. You can use gal_wcs_copy to keep and use separate copies the main structure within each thread, and later free the copies with gal_wcs_free.
and I think we do use wcsprm so maybe it is related to that.
cc @Manodeep
It is indeed a WCS issue - the third commit in https://github.com/astropy/reproject/pull/434 fixes the multi-threaded results (but presumably has a performance penalty)
Yeah - just reading through the references that you posted, it definitely seemed likely that wcs
would be the culprit
This commit to casacore adds mutexes for thread-unsafe WCS - however, those are targeting WCSLIB < 5.18
, whereas the astropy.wcslib
seems to be 8.2.2
.
Small example to reproduce the bug with astropy.wcs
:
import numpy as np
from astropy.wcs import WCS
from astropy.wcs.utils import pixel_to_pixel
from astropy.io import fits
from astropy.utils.data import get_pkg_data_filename
from multiprocessing.pool import ThreadPool
hdu1 = fits.open(get_pkg_data_filename("galactic_center/gc_2mass_k.fits"))[0]
hdu2 = fits.open(get_pkg_data_filename("galactic_center/gc_msx_e.fits"))[0]
wcs1 = WCS(hdu1.header)
wcs2 = WCS(hdu2.header)
N = 1_000_000
N_iter = 1
xp = np.random.randint(1, 100, N).astype(float).reshape((1000, 1000))
yp = np.random.randint(1, 100, N).astype(float).reshape((1000, 1000))
def repeated_transforms(xp, yp):
for i in range(N_iter):
xp, yp = pixel_to_pixel(wcs1, wcs2, xp, yp)
xp, yp = pixel_to_pixel(wcs2, wcs1, xp, yp)
return xp, yp
pool = ThreadPool(8)
results = pool.starmap(repeated_transforms, ((xp, yp),) * 8)
for xp2, yp2 in results:
print(
f"Mismatching elements: {np.sum(~np.isclose(xp, xp2))} {np.sum(~np.isclose(yp, yp2))}"
)
The above outputs:
Mismatching elements: 62 62
Mismatching elements: 0 0
Mismatching elements: 787 787
Mismatching elements: 967 967
Mismatching elements: 23 23
Mismatching elements: 175 175
Mismatching elements: 15 15
Mismatching elements: 26 26
And now even simpler:
import numpy as np
from astropy.wcs import WCS
from astropy.wcs.utils import pixel_to_pixel
from astropy.io import fits
from astropy.utils.data import get_pkg_data_filename
from multiprocessing.pool import ThreadPool
hdu1 = fits.open(get_pkg_data_filename("galactic_center/gc_2mass_k.fits"))[0]
wcs = WCS(hdu1.header)
N = 1_000_000
N_iter = 1
xp = np.random.randint(-1000, 1000, N).astype(float)
yp = np.random.randint(-1000, 1000, N).astype(float)
def repeated_transforms(xp, yp):
for i in range(N_iter):
xw, yw = wcs.all_pix2world(xp, yp, 0)
wcs.wcs.lng # this access causes issues, without it all works
xp, yp = wcs.all_world2pix(xw, yw, 0)
return xp, yp
pool = ThreadPool(8)
results = pool.starmap(repeated_transforms, ((xp, yp),) * 8)
for xp2, yp2 in results:
print(
f"Mismatching elements: {np.sum(~np.isclose(xp, xp2))} {np.sum(~np.isclose(yp, yp2))}"
)
it seems accessing .lng
on the Wcsprm
between the two conversions is what causes the issues. Accessing .lat
causes the same issue, but accessing e.g. .equinox
does not.
And even simpler, this time using the conversion functions on Wcsprm directly:
import numpy as np
from astropy.wcs import WCS
from multiprocessing.pool import ThreadPool
wcs = WCS(naxis=2)
N = 1_000_000
pixel = np.random.randint(-1000, 1000, N * 2).reshape((N, 2)).astype(float)
def repeated_transforms(pixel):
world = wcs.wcs.p2s(pixel, 0)["world"]
wcs.wcs.lat
pixel = wcs.wcs.s2p(world, 0)["pixcrd"]
return pixel
for n_threads in [1, 2, 8]:
print("N_threads:", n_threads)
pool = ThreadPool(n_threads)
results = pool.map(repeated_transforms, (pixel,) * n_threads)
for pixel2 in results:
print(f"Mismatching: {np.sum(~np.isclose(pixel, pixel2))}")
gives:
N_threads: 1
Mismatching: 0
N_threads: 2
Mismatching: 2000000
Mismatching: 2000000
N_threads: 8
Mismatching: 2000000
Mismatching: 2000000
Mismatching: 2000000
Mismatching: 2000000
Mismatching: 2000000
Mismatching: 2000000
Mismatching: 2000000
Mismatching: 2000000
Interestingly in this case, because the inner function is almost exclusively the C functions, it seems all the data gets corrupt.
And now even simpler:
import numpy as np from astropy.wcs import WCS from astropy.wcs.utils import pixel_to_pixel from astropy.io import fits from astropy.utils.data import get_pkg_data_filename from multiprocessing.pool import ThreadPool hdu1 = fits.open(get_pkg_data_filename("galactic_center/gc_2mass_k.fits"))[0] wcs = WCS(hdu1.header) N = 1_000_000 N_iter = 1 xp = np.random.randint(-1000, 1000, N).astype(float) yp = np.random.randint(-1000, 1000, N).astype(float) def repeated_transforms(xp, yp): for i in range(N_iter): xw, yw = wcs.all_pix2world(xp, yp, 0) wcs.wcs.lng # this access causes issues, without it all works xp, yp = wcs.all_world2pix(xw, yw, 0) return xp, yp pool = ThreadPool(8) results = pool.starmap(repeated_transforms, ((xp, yp),) * 8) for xp2, yp2 in results: print( f"Mismatching elements: {np.sum(~np.isclose(xp, xp2))} {np.sum(~np.isclose(yp, yp2))}" )
it seems accessing
.lng
on theWcsprm
between the two conversions is what causes the issues. Accessing.lat
causes the same issue, but accessing e.g..equinox
does not.
Just so I understand - adding that access to wcs.lng/wcs.lat
, and access only (i.e., no writes) causes the code snippet to fail?
And even simpler, this time using the conversion functions on Wcsprm directly:
import numpy as np from astropy.wcs import WCS from multiprocessing.pool import ThreadPool wcs = WCS(naxis=2) N = 1_000_000 pixel = np.random.randint(-1000, 1000, N * 2).reshape((N, 2)).astype(float) def repeated_transforms(pixel): world = wcs.wcs.p2s(pixel, 0)["world"] wcs.wcs.lat pixel = wcs.wcs.s2p(world, 0)["pixcrd"] return pixel for n_threads in [1, 2, 8]: print("N_threads:", n_threads) pool = ThreadPool(n_threads) results = pool.map(repeated_transforms, (pixel,) * n_threads) for pixel2 in results: print(f"Mismatching: {np.sum(~np.isclose(pixel, pixel2))}")
gives:
N_threads: 1 Mismatching: 0 N_threads: 2 Mismatching: 2000000 Mismatching: 2000000 N_threads: 8 Mismatching: 2000000 Mismatching: 2000000 Mismatching: 2000000 Mismatching: 2000000 Mismatching: 2000000 Mismatching: 2000000 Mismatching: 2000000 Mismatching: 2000000
Interestingly in this case, because the inner function is almost exclusively the C functions, it seems all the data gets corrupt.
At the very least, this convincingly demonstrates (to me) that wcs is not threadsafe.
Does the error go away if you make a copy of wcs within the function and use that copy?
@manodeep - hmm in this example:
https://github.com/astropy/reproject/issues/394#issuecomment-2020132623
it doesn't seem to actually matter, I can remove the access to wcs.lng
or wcs.lat
. In this example:
https://github.com/astropy/reproject/issues/394#issuecomment-2020091302
It does seem to matter, and removing it fixes my issues.
Hmm well now I'm puzzled, the following example also shows the issue, this is even if I make a whole new WCS object inside each thread:
import numpy as np
from astropy.wcs import WCS
from multiprocessing.pool import ThreadPool
N = 1_000_000
pixel = np.random.randint(-1000, 1000, N * 2).reshape((N, 2)).astype(float)
def repeated_transforms(pixel):
wcs = WCS(naxis=2)
world = wcs.wcs.p2s(pixel, 0)["world"]
pixel = wcs.wcs.s2p(world, 0)["pixcrd"]
return pixel
for n_threads in [0, 1, 2, 8]:
print("N_threads:", n_threads)
if n_threads == 0:
results = [repeated_transforms(pixel)]
else:
pool = ThreadPool(n_threads)
results = pool.map(repeated_transforms, (pixel,) * n_threads)
for pixel2 in results:
print(f"Mismatching: {np.sum(~np.isclose(pixel, pixel2))}")
gives:
N_threads: 0
Mismatching: 0
N_threads: 1
Mismatching: 0
N_threads: 2
Mismatching: 2000000
Mismatching: 2000000
N_threads: 8
Mismatching: 2000000
Mismatching: 2000000
Mismatching: 2000000
Mismatching: 2000000
Mismatching: 2000000
Mismatching: 2000000
Mismatching: 2000000
Mismatching: 2000000
I wonder if I'm doing something wrong with the calls to p2s and s2p here as it's a bit suspicious that suddenly all values are different compared to earlier examples, but maybe it's also because a higher fraction of time is spent in C code. It's also weird that the issue persists here even when creating a new WCS object inside each thread.
Having said that, maybe it is indeed highlighting the issue, if I switch to a process-based Pool the issue goes away:
import numpy as np
from astropy.wcs import WCS
from multiprocessing.pool import Pool
def repeated_transforms(pixel):
wcs = WCS(naxis=2)
world = wcs.wcs.p2s(pixel, 0)["world"]
pixel = wcs.wcs.s2p(world, 0)["pixcrd"]
return pixel
def main():
N = 1_000_000
pixel = np.random.randint(-1000, 1000, N * 2).reshape((N, 2)).astype(float)
for n_proc in [0, 1, 2, 8]:
print("N_processes:", n_proc)
if n_proc == 0:
results = [repeated_transforms(pixel)]
else:
pool = Pool(n_proc)
results = pool.map(repeated_transforms, (pixel,) * n_proc)
for pixel2 in results:
print(f"Mismatching: {np.sum(~np.isclose(pixel, pixel2))}")
if __name__ == "__main__":
main()
N_processes: 0
Mismatching: 0
N_processes: 1
Mismatching: 0
N_processes: 2
Mismatching: 0
Mismatching: 0
N_processes: 8
Mismatching: 0
Mismatching: 0
Mismatching: 0
Mismatching: 0
Mismatching: 0
Mismatching: 0
Mismatching: 0
Mismatching: 0
Ok well on that basis maybe https://github.com/astropy/reproject/issues/394#issuecomment-2020162013 is the best example to go with to reproduce the issue? (so the access to .lng
/.lat
might have been a red herring). If that example is actually really showing a WCSLIB issue, I wonder if somehow there is a global variable being accessed/changed?
I am so confused by this sample - how can a race condition occur if you are creating a new wcs
within the function?! It seems unlikely that pixel
is being modified within that function and those two (wcs
, pixel
) are the only two variables available.
Is there a typo in this sample - the for loop variable is called n_proc
but n_threads
is used within the for loop.
FWIW the issue is definitely in WCSLIB - if I edit the wcslib_wrap.c
code in astropy and move the Py_BEGIN_ALLOW_THREADS
and Py_END_ALLOW_THREADS
to be just around wcss2p
and wcsp2s
, the failure appears.
Is there a typo in https://github.com/astropy/reproject/issues/394#issuecomment-2020166130 - the for loop variable is called n_proc but n_threads is used within the for loop.
yes sorry I tried editing the code in the comment directly to change thread -> proc but clearly failed (fixed now)
I am so confused by https://github.com/astropy/reproject/issues/394#issuecomment-2020162013 - how can a race condition occur if you are creating a new wcs within the function?! It seems unlikely that pixel is being modified within that function and those two (wcs, pixel) are the only two variables available.
I am very confused by this too, and this would suggest perhaps that there is some kind of global variable in WCSLIB that is being accessed and modified by different threads?
For fun, I tried checking what the actual offset between expected and actual pixel positions is, by doing:
print(np.unique(pixel - pixel2))
which gives:
N_threads: 0
[0.]
N_threads: 1
[0.]
N_threads: 2
[-1.]
[-1.]
N_threads: 8
[-4. -3.]
[-5. -4. -3.]
[-7. -6. -5.]
[-7. -6. -5.]
[-7. -6. -5. -4.]
[-7. -6. -5. -4.]
[-6. -5.]
[-7. -6. -5. -4.]
So seems to be off by small integer values, which seem to be related to the number of threads running concurrently.
Ah, changing the second argument of s2p
and p2s
to 1, meaning default FITS offset, seems to resolve the issue, so somehow it must be related to the subtraction/addition of 1 because of the non-default origin.
Hmm I wonder if it's something dumb like that preoffset_array
in astropy's wrapper is modifying the input array inplace and then modifying it back afterwards?? Calling that with origin=1 is a no-op, so that could explain it. It also seems like a very dangerous thing to do 😅
Ok so I think there must be two separate issues, because now if I go back to one of the earlier examples and change the origin argument in that to be 1, there is still a problem:
import numpy as np
from astropy.wcs import WCS
from astropy.wcs.utils import pixel_to_pixel
from astropy.io import fits
from astropy.utils.data import get_pkg_data_filename
from multiprocessing.pool import ThreadPool
hdu1 = fits.open(get_pkg_data_filename("galactic_center/gc_2mass_k.fits"))[0]
wcs = WCS(hdu1.header)
N = 1_000_000
N_iter = 1
xp = np.random.randint(-1000, 1000, N).astype(float)
yp = np.random.randint(-1000, 1000, N).astype(float)
def repeated_transforms(xp, yp):
for i in range(N_iter):
xw, yw = wcs.all_pix2world(xp, yp, 1)
wcs.wcs.lng # this access causes issues, without it all works
xp, yp = wcs.all_world2pix(xw, yw, 1)
return xp, yp
pool = ThreadPool(8)
results = pool.starmap(repeated_transforms, ((xp, yp),) * 8)
for xp2, yp2 in results:
print(
f"Mismatching elements: {np.sum(~np.isclose(xp, xp2))} {np.sum(~np.isclose(yp, yp2))}"
)
and in this case the offsets between expected and actual positions are not integers. I think the off-by-one issue above might not be the one we are running into in reproject because it doesn't seem to be triggered when calling wcs.all_pix2world
and wcs.all_world2pix
, just s2p
and p2s
.
Ok, so to summarize:
- I think the issue highlighted by this example is that when calling
s2p
andp2s
directly, the input pixel array is modified in-place in the astropy.wcs wrapper around WCSLIB to apply the offset for the different origin. This causes all the results to be wrong when using multiple threads and explains why even making new WCSes inside each thread didn't help. This is not actually the issue we are seeing in reproject though, and this issue does not appear to be triggered when callingall_pix2world
andall_world2pix
so I'm not sure if it's worth fixing/investigating further ass2p
andp2s
are not really supposed to be public API. -
This example is the one showing the issue we are actually encountering, and in this case changing the origin argument to 1 does not change anything. However, it is still true that accessing
wcs.lng
orwcs.lat
does trigger the issue (commenting out this line fixes everything). The offset between input and output pixels is random here - not integer values, and only a small fraction of pixels are affected. So this is the issue that I think we should spend time on.
For the second case, I've now managed to make an example that replicates the issue but just calling s2p
and p2s
, and still fails only when the call to wcs.wcs.lng
is present:
import numpy as np
from astropy.wcs import WCS
from multiprocessing.pool import ThreadPool
wcs = WCS(naxis=2)
wcs.wcs.crpix = [-234.75, 8.3393]
wcs.wcs.cdelt = np.array([-0.066667, 0.066667])
wcs.wcs.crval = [0, -90]
wcs.wcs.ctype = ["RA---AIR", "DEC--AIR"]
wcs.wcs.set()
N = 1_000_000
pixel = np.random.randint(-1000, 1000, N * 2).reshape((N, 2)).astype(float)
def repeated_transforms(pixel):
world = wcs.wcs.p2s(pixel.copy(), 1)['world']
wcs.wcs.lng # this access causes issues, without it all works
pixel = wcs.wcs.s2p(world, 1)['pixcrd']
return pixel
pool = ThreadPool(8)
results = pool.map(repeated_transforms, (pixel,) * 8)
for pixel2 in results:
print(
f"Mismatching elements: {np.sum(~np.isclose(pixel, pixel2))}"
)
gives:
Mismatching elements: 64
Mismatching elements: 64
Mismatching elements: 76
Mismatching elements: 104
Mismatching elements: 52
Mismatching elements: 62
Mismatching elements: 66
Mismatching elements: 66
so I think this is the example we should use going forward. The key is that it does a copy of the array before calling p2s
which is equivalent to what is happening in our use case and prevents issues with multiple threads modifying the input array.
There's might be multiple issues at play here: i) thread race condition within wcslib (which somehow manifests when there is an access to a memory location (!) and ii) including the fits-origin fix for the input array and then fixing afterwards
I still don't understand why wcs_[in/out].deepcopy()
resolved the failure within reproject but does not do so within these code samples.
Plus, why is there such a big discrepancy in the number of wrong pixel values between the (mostly)-C-functions (all values are wrong) and the more-python-functions test cases (~100 values are wrong).
I will also check that both the multi-process and multi-thread versions are actually splitting up the work. For example, if the multi-process is spawning (don't know why) 8 copies of the same task in serial - then it stands to reason that the 8 (identical) serial task would produce correct results.
@manodeep - yes, as mentioned in this summary I think there are two separate issues and I think we should focus on the second one I mention (2.) which I think is what you call i)
Just to be clear, .deepcopy()
does fix the issue in the example code given in my summary above (issue 2). The red herring was that in this example what matters is not whether the WCS is a copy or not, even initializing a new WCS inside each thread results in an issue – this is because at that point the issue is that the input array is being modified in-place, so it's the sharing of the input array that is the problem (issue 1)
@manodeep - for issue 1, which is where all values are wrong, this is because the whole array is being offset multiple times in-place by the different thread. For issue 2, only some of the values are wrong, even if one uses mostly the C functions as shown in the example in here