reproject icon indicating copy to clipboard operation
reproject copied to clipboard

Use multi-threading instead of multi-processing

Open astrofrog opened this issue 1 year ago • 40 comments

The main functions in reproject, such as reproject_interp, accept parallel= and block_size= arguments which, if used, will leverage dask behind the scenes to split up the data into chunks and then use a multi-processing scheduler to distribute the work.

Ideally we should be using multi-threading instead of multi-processing, but currently we don't because there appear to be some issues with some output pixels not having the right value when using multi-threading.

Fixing this will provide two main benefits:

  • Avoid the whole dumping arrays to memmaps and being able to just use input arrays as-is without duplication
  • Avoiding nasty surprises for users who select return_type='dask' then use the default scheduler to compute the array

Here's a notebook illustrating the issues: https://gist.github.com/astrofrog/e8808ee3ee8b7b86a979e0cb305d518b - note that while there is no explicit mention of threads anywhere, in the compare function array2 is a dask array and when it gets passed to Matplotlib, .compute() gets called and the default scheduler uses threads.

At the moment all algorithms seem to have issues though all appear to be different. Note that for adaptive I sometimes have to run a few times to see issues.

astrofrog avatar Sep 14 '23 10:09 astrofrog

@svank - just out of curiosity, do you have any sense of what could be making the adaptive code not be thread-safe?

astrofrog avatar Sep 14 '23 10:09 astrofrog

My first thought was about how the Cython function calls back into Python-land for the coordinates calculations, but I guess that's just a GIL matter rather than a thread-safety thing (and wouldn't clearly cause the glitch your notebook shows). I'll have to play with it a bit

svank avatar Sep 14 '23 15:09 svank

I just played with this a little bit and found that if I pass roundtrip_coords=False, the multi-threading output glitches are much more rare (about one in ten runs, versus almost every run with roundtrip_coords=True). Could (at least one) problem be in astropy.coordinates (or wcslib?)?

Here's an expanded notebook: https://gist.github.com/svank/d63ef6bdf4e146577d7a78111ad85855

svank avatar Sep 26 '23 16:09 svank

@svank yes I think I've come to the same conclusion that this could be in astropy.wcs or in astropy.coordinates (perhaps in ERFA if it is called), and the roundtrip_coords=True option (default) makes it so more WCS transformations are happening and increases the chance of issues. I switched to using a custom APE 14 WCS that does not use astropy.wcs and the issues disappear.

Interestingly https://www.atnf.csiro.au/people/mcalabre/WCS/wcslib/threads.html says that WCSLIB is basically thread-safe but https://www.gnu.org/software/gnuastro/manual/html_node/World-Coordinate-System.html#:~:text=The%20wcsprm%20structure%20of%20WCSLIB,the%20same%20wcsprm%20structure%20pointer. says:

The wcsprm structure of WCSLIB is not thread-safe: you can’t use the same pointer on multiple threads. For example, if you use gal_wcs_img_to_world simultaneously on multiple threads, you shouldn’t pass the same wcsprm structure pointer. You can use gal_wcs_copy to keep and use separate copies the main structure within each thread, and later free the copies with gal_wcs_free.

and I think we do use wcsprm so maybe it is related to that.

cc @Manodeep

astrofrog avatar Mar 25 '24 23:03 astrofrog

It is indeed a WCS issue - the third commit in https://github.com/astropy/reproject/pull/434 fixes the multi-threaded results (but presumably has a performance penalty)

astrofrog avatar Mar 26 '24 00:03 astrofrog

Yeah - just reading through the references that you posted, it definitely seemed likely that wcs would be the culprit

manodeep avatar Mar 26 '24 00:03 manodeep

This commit to casacore adds mutexes for thread-unsafe WCS - however, those are targeting WCSLIB < 5.18, whereas the astropy.wcslib seems to be 8.2.2.

manodeep avatar Mar 26 '24 00:03 manodeep

Small example to reproduce the bug with astropy.wcs:

import numpy as np
from astropy.wcs import WCS
from astropy.wcs.utils import pixel_to_pixel
from astropy.io import fits
from astropy.utils.data import get_pkg_data_filename
from multiprocessing.pool import ThreadPool

hdu1 = fits.open(get_pkg_data_filename("galactic_center/gc_2mass_k.fits"))[0]
hdu2 = fits.open(get_pkg_data_filename("galactic_center/gc_msx_e.fits"))[0]

wcs1 = WCS(hdu1.header)
wcs2 = WCS(hdu2.header)

N = 1_000_000
N_iter = 1

xp = np.random.randint(1, 100, N).astype(float).reshape((1000, 1000))
yp = np.random.randint(1, 100, N).astype(float).reshape((1000, 1000))


def repeated_transforms(xp, yp):
    for i in range(N_iter):
        xp, yp = pixel_to_pixel(wcs1, wcs2, xp, yp)
        xp, yp = pixel_to_pixel(wcs2, wcs1, xp, yp)

    return xp, yp


pool = ThreadPool(8)
results = pool.starmap(repeated_transforms, ((xp, yp),) * 8)

for xp2, yp2 in results:
    print(
        f"Mismatching elements: {np.sum(~np.isclose(xp, xp2))} {np.sum(~np.isclose(yp, yp2))}"
    )

astrofrog avatar Mar 26 '24 10:03 astrofrog

The above outputs:

Mismatching elements: 62 62
Mismatching elements: 0 0
Mismatching elements: 787 787
Mismatching elements: 967 967
Mismatching elements: 23 23
Mismatching elements: 175 175
Mismatching elements: 15 15
Mismatching elements: 26 26

astrofrog avatar Mar 26 '24 10:03 astrofrog

And now even simpler:

import numpy as np
from astropy.wcs import WCS
from astropy.wcs.utils import pixel_to_pixel
from astropy.io import fits
from astropy.utils.data import get_pkg_data_filename
from multiprocessing.pool import ThreadPool


hdu1 = fits.open(get_pkg_data_filename("galactic_center/gc_2mass_k.fits"))[0]

wcs = WCS(hdu1.header)

N = 1_000_000
N_iter = 1

xp = np.random.randint(-1000, 1000, N).astype(float)
yp = np.random.randint(-1000, 1000, N).astype(float)


def repeated_transforms(xp, yp):
    for i in range(N_iter):
        xw, yw = wcs.all_pix2world(xp, yp, 0)
        wcs.wcs.lng  # this access causes issues, without it all works
        xp, yp = wcs.all_world2pix(xw, yw, 0)

    return xp, yp


pool = ThreadPool(8)
results = pool.starmap(repeated_transforms, ((xp, yp),) * 8)

for xp2, yp2 in results:
    print(
        f"Mismatching elements: {np.sum(~np.isclose(xp, xp2))} {np.sum(~np.isclose(yp, yp2))}"
    )

it seems accessing .lng on the Wcsprm between the two conversions is what causes the issues. Accessing .lat causes the same issue, but accessing e.g. .equinox does not.

astrofrog avatar Mar 26 '24 10:03 astrofrog

And even simpler, this time using the conversion functions on Wcsprm directly:

import numpy as np
from astropy.wcs import WCS
from multiprocessing.pool import ThreadPool

wcs = WCS(naxis=2)

N = 1_000_000

pixel = np.random.randint(-1000, 1000, N * 2).reshape((N, 2)).astype(float)


def repeated_transforms(pixel):
    world = wcs.wcs.p2s(pixel, 0)["world"]
    wcs.wcs.lat
    pixel = wcs.wcs.s2p(world, 0)["pixcrd"]
    return pixel


for n_threads in [1, 2, 8]:
    print("N_threads:", n_threads)

    pool = ThreadPool(n_threads)
    results = pool.map(repeated_transforms, (pixel,) * n_threads)

    for pixel2 in results:
        print(f"Mismatching: {np.sum(~np.isclose(pixel, pixel2))}")

gives:

N_threads: 1
Mismatching: 0
N_threads: 2
Mismatching: 2000000
Mismatching: 2000000
N_threads: 8
Mismatching: 2000000
Mismatching: 2000000
Mismatching: 2000000
Mismatching: 2000000
Mismatching: 2000000
Mismatching: 2000000
Mismatching: 2000000
Mismatching: 2000000

Interestingly in this case, because the inner function is almost exclusively the C functions, it seems all the data gets corrupt.

astrofrog avatar Mar 26 '24 11:03 astrofrog

And now even simpler:

import numpy as np
from astropy.wcs import WCS
from astropy.wcs.utils import pixel_to_pixel
from astropy.io import fits
from astropy.utils.data import get_pkg_data_filename
from multiprocessing.pool import ThreadPool


hdu1 = fits.open(get_pkg_data_filename("galactic_center/gc_2mass_k.fits"))[0]

wcs = WCS(hdu1.header)

N = 1_000_000
N_iter = 1

xp = np.random.randint(-1000, 1000, N).astype(float)
yp = np.random.randint(-1000, 1000, N).astype(float)


def repeated_transforms(xp, yp):
    for i in range(N_iter):
        xw, yw = wcs.all_pix2world(xp, yp, 0)
        wcs.wcs.lng  # this access causes issues, without it all works
        xp, yp = wcs.all_world2pix(xw, yw, 0)

    return xp, yp


pool = ThreadPool(8)
results = pool.starmap(repeated_transforms, ((xp, yp),) * 8)

for xp2, yp2 in results:
    print(
        f"Mismatching elements: {np.sum(~np.isclose(xp, xp2))} {np.sum(~np.isclose(yp, yp2))}"
    )

it seems accessing .lng on the Wcsprm between the two conversions is what causes the issues. Accessing .lat causes the same issue, but accessing e.g. .equinox does not.

Just so I understand - adding that access to wcs.lng/wcs.lat, and access only (i.e., no writes) causes the code snippet to fail?

manodeep avatar Mar 26 '24 11:03 manodeep

And even simpler, this time using the conversion functions on Wcsprm directly:

import numpy as np
from astropy.wcs import WCS
from multiprocessing.pool import ThreadPool

wcs = WCS(naxis=2)

N = 1_000_000

pixel = np.random.randint(-1000, 1000, N * 2).reshape((N, 2)).astype(float)


def repeated_transforms(pixel):
    world = wcs.wcs.p2s(pixel, 0)["world"]
    wcs.wcs.lat
    pixel = wcs.wcs.s2p(world, 0)["pixcrd"]
    return pixel


for n_threads in [1, 2, 8]:
    print("N_threads:", n_threads)

    pool = ThreadPool(n_threads)
    results = pool.map(repeated_transforms, (pixel,) * n_threads)

    for pixel2 in results:
        print(f"Mismatching: {np.sum(~np.isclose(pixel, pixel2))}")

gives:

N_threads: 1
Mismatching: 0
N_threads: 2
Mismatching: 2000000
Mismatching: 2000000
N_threads: 8
Mismatching: 2000000
Mismatching: 2000000
Mismatching: 2000000
Mismatching: 2000000
Mismatching: 2000000
Mismatching: 2000000
Mismatching: 2000000
Mismatching: 2000000

Interestingly in this case, because the inner function is almost exclusively the C functions, it seems all the data gets corrupt.

At the very least, this convincingly demonstrates (to me) that wcs is not threadsafe.

Does the error go away if you make a copy of wcs within the function and use that copy?

manodeep avatar Mar 26 '24 11:03 manodeep

@manodeep - hmm in this example:

https://github.com/astropy/reproject/issues/394#issuecomment-2020132623

it doesn't seem to actually matter, I can remove the access to wcs.lng or wcs.lat. In this example:

https://github.com/astropy/reproject/issues/394#issuecomment-2020091302

It does seem to matter, and removing it fixes my issues.

astrofrog avatar Mar 26 '24 11:03 astrofrog

Hmm well now I'm puzzled, the following example also shows the issue, this is even if I make a whole new WCS object inside each thread:

import numpy as np
from astropy.wcs import WCS
from multiprocessing.pool import ThreadPool

N = 1_000_000

pixel = np.random.randint(-1000, 1000, N * 2).reshape((N, 2)).astype(float)


def repeated_transforms(pixel):
    wcs = WCS(naxis=2)
    world = wcs.wcs.p2s(pixel, 0)["world"]
    pixel = wcs.wcs.s2p(world, 0)["pixcrd"]
    return pixel


for n_threads in [0, 1, 2, 8]:
    print("N_threads:", n_threads)

    if n_threads == 0:
        results = [repeated_transforms(pixel)]
    else:
        pool = ThreadPool(n_threads)
        results = pool.map(repeated_transforms, (pixel,) * n_threads)

    for pixel2 in results:
        print(f"Mismatching: {np.sum(~np.isclose(pixel, pixel2))}")

gives:

N_threads: 0
Mismatching: 0
N_threads: 1
Mismatching: 0
N_threads: 2
Mismatching: 2000000
Mismatching: 2000000
N_threads: 8
Mismatching: 2000000
Mismatching: 2000000
Mismatching: 2000000
Mismatching: 2000000
Mismatching: 2000000
Mismatching: 2000000
Mismatching: 2000000
Mismatching: 2000000

I wonder if I'm doing something wrong with the calls to p2s and s2p here as it's a bit suspicious that suddenly all values are different compared to earlier examples, but maybe it's also because a higher fraction of time is spent in C code. It's also weird that the issue persists here even when creating a new WCS object inside each thread.

astrofrog avatar Mar 26 '24 11:03 astrofrog

Having said that, maybe it is indeed highlighting the issue, if I switch to a process-based Pool the issue goes away:

import numpy as np
from astropy.wcs import WCS
from multiprocessing.pool import Pool


def repeated_transforms(pixel):
    wcs = WCS(naxis=2)
    world = wcs.wcs.p2s(pixel, 0)["world"]
    pixel = wcs.wcs.s2p(world, 0)["pixcrd"]
    return pixel



def main():

    N = 1_000_000

    pixel = np.random.randint(-1000, 1000, N * 2).reshape((N, 2)).astype(float)

    for n_proc in [0, 1, 2, 8]:
        print("N_processes:", n_proc)

        if n_proc == 0:
            results = [repeated_transforms(pixel)]
        else:
            pool = Pool(n_proc)
            results = pool.map(repeated_transforms, (pixel,) * n_proc)

        for pixel2 in results:
            print(f"Mismatching: {np.sum(~np.isclose(pixel, pixel2))}")

if __name__ == "__main__":
    main()
N_processes: 0
Mismatching: 0
N_processes: 1
Mismatching: 0
N_processes: 2
Mismatching: 0
Mismatching: 0
N_processes: 8
Mismatching: 0
Mismatching: 0
Mismatching: 0
Mismatching: 0
Mismatching: 0
Mismatching: 0
Mismatching: 0
Mismatching: 0

astrofrog avatar Mar 26 '24 11:03 astrofrog

Ok well on that basis maybe https://github.com/astropy/reproject/issues/394#issuecomment-2020162013 is the best example to go with to reproduce the issue? (so the access to .lng/.lat might have been a red herring). If that example is actually really showing a WCSLIB issue, I wonder if somehow there is a global variable being accessed/changed?

astrofrog avatar Mar 26 '24 11:03 astrofrog

I am so confused by this sample - how can a race condition occur if you are creating a new wcs within the function?! It seems unlikely that pixel is being modified within that function and those two (wcs, pixel) are the only two variables available.

Is there a typo in this sample - the for loop variable is called n_proc but n_threads is used within the for loop.

manodeep avatar Mar 26 '24 11:03 manodeep

FWIW the issue is definitely in WCSLIB - if I edit the wcslib_wrap.c code in astropy and move the Py_BEGIN_ALLOW_THREADS and Py_END_ALLOW_THREADS to be just around wcss2p and wcsp2s, the failure appears.

astrofrog avatar Mar 26 '24 11:03 astrofrog

Is there a typo in https://github.com/astropy/reproject/issues/394#issuecomment-2020166130 - the for loop variable is called n_proc but n_threads is used within the for loop.

yes sorry I tried editing the code in the comment directly to change thread -> proc but clearly failed (fixed now)

astrofrog avatar Mar 26 '24 11:03 astrofrog

I am so confused by https://github.com/astropy/reproject/issues/394#issuecomment-2020162013 - how can a race condition occur if you are creating a new wcs within the function?! It seems unlikely that pixel is being modified within that function and those two (wcs, pixel) are the only two variables available.

I am very confused by this too, and this would suggest perhaps that there is some kind of global variable in WCSLIB that is being accessed and modified by different threads?

astrofrog avatar Mar 26 '24 11:03 astrofrog

For fun, I tried checking what the actual offset between expected and actual pixel positions is, by doing:

        print(np.unique(pixel - pixel2))

which gives:

N_threads: 0
[0.]
N_threads: 1
[0.]
N_threads: 2
[-1.]
[-1.]
N_threads: 8
[-4. -3.]
[-5. -4. -3.]
[-7. -6. -5.]
[-7. -6. -5.]
[-7. -6. -5. -4.]
[-7. -6. -5. -4.]
[-6. -5.]
[-7. -6. -5. -4.]

So seems to be off by small integer values, which seem to be related to the number of threads running concurrently.

astrofrog avatar Mar 26 '24 12:03 astrofrog

Ah, changing the second argument of s2p and p2s to 1, meaning default FITS offset, seems to resolve the issue, so somehow it must be related to the subtraction/addition of 1 because of the non-default origin.

astrofrog avatar Mar 26 '24 12:03 astrofrog

Hmm I wonder if it's something dumb like that preoffset_array in astropy's wrapper is modifying the input array inplace and then modifying it back afterwards?? Calling that with origin=1 is a no-op, so that could explain it. It also seems like a very dangerous thing to do 😅

astrofrog avatar Mar 26 '24 12:03 astrofrog

Ok so I think there must be two separate issues, because now if I go back to one of the earlier examples and change the origin argument in that to be 1, there is still a problem:

import numpy as np
from astropy.wcs import WCS
from astropy.wcs.utils import pixel_to_pixel
from astropy.io import fits
from astropy.utils.data import get_pkg_data_filename
from multiprocessing.pool import ThreadPool


hdu1 = fits.open(get_pkg_data_filename("galactic_center/gc_2mass_k.fits"))[0]

wcs = WCS(hdu1.header)

N = 1_000_000
N_iter = 1

xp = np.random.randint(-1000, 1000, N).astype(float)
yp = np.random.randint(-1000, 1000, N).astype(float)


def repeated_transforms(xp, yp):
    for i in range(N_iter):
        xw, yw = wcs.all_pix2world(xp, yp, 1)
        wcs.wcs.lng  # this access causes issues, without it all works
        xp, yp = wcs.all_world2pix(xw, yw, 1)

    return xp, yp


pool = ThreadPool(8)
results = pool.starmap(repeated_transforms, ((xp, yp),) * 8)

for xp2, yp2 in results:
    print(
        f"Mismatching elements: {np.sum(~np.isclose(xp, xp2))} {np.sum(~np.isclose(yp, yp2))}"
    )

and in this case the offsets between expected and actual positions are not integers. I think the off-by-one issue above might not be the one we are running into in reproject because it doesn't seem to be triggered when calling wcs.all_pix2world and wcs.all_world2pix, just s2p and p2s.

astrofrog avatar Mar 26 '24 12:03 astrofrog

Ok, so to summarize:

  1. I think the issue highlighted by this example is that when calling s2p and p2s directly, the input pixel array is modified in-place in the astropy.wcs wrapper around WCSLIB to apply the offset for the different origin. This causes all the results to be wrong when using multiple threads and explains why even making new WCSes inside each thread didn't help. This is not actually the issue we are seeing in reproject though, and this issue does not appear to be triggered when calling all_pix2world and all_world2pix so I'm not sure if it's worth fixing/investigating further as s2p and p2s are not really supposed to be public API.
  2. This example is the one showing the issue we are actually encountering, and in this case changing the origin argument to 1 does not change anything. However, it is still true that accessing wcs.lng or wcs.lat does trigger the issue (commenting out this line fixes everything). The offset between input and output pixels is random here - not integer values, and only a small fraction of pixels are affected. So this is the issue that I think we should spend time on.

For the second case, I've now managed to make an example that replicates the issue but just calling s2p and p2s, and still fails only when the call to wcs.wcs.lng is present:

import numpy as np
from astropy.wcs import WCS
from multiprocessing.pool import ThreadPool

wcs = WCS(naxis=2)
wcs.wcs.crpix = [-234.75, 8.3393]
wcs.wcs.cdelt = np.array([-0.066667, 0.066667])
wcs.wcs.crval = [0, -90]
wcs.wcs.ctype = ["RA---AIR", "DEC--AIR"]
wcs.wcs.set()

N = 1_000_000

pixel = np.random.randint(-1000, 1000, N * 2).reshape((N, 2)).astype(float)


def repeated_transforms(pixel):
    world = wcs.wcs.p2s(pixel.copy(), 1)['world']
    wcs.wcs.lng  # this access causes issues, without it all works
    pixel = wcs.wcs.s2p(world, 1)['pixcrd']
    return pixel


pool = ThreadPool(8)
results = pool.map(repeated_transforms, (pixel,) * 8)

for pixel2 in results:
    print(
        f"Mismatching elements: {np.sum(~np.isclose(pixel, pixel2))}"
    )

gives:

Mismatching elements: 64
Mismatching elements: 64
Mismatching elements: 76
Mismatching elements: 104
Mismatching elements: 52
Mismatching elements: 62
Mismatching elements: 66
Mismatching elements: 66

so I think this is the example we should use going forward. The key is that it does a copy of the array before calling p2s which is equivalent to what is happening in our use case and prevents issues with multiple threads modifying the input array.

astrofrog avatar Mar 26 '24 12:03 astrofrog

There's might be multiple issues at play here: i) thread race condition within wcslib (which somehow manifests when there is an access to a memory location (!) and ii) including the fits-origin fix for the input array and then fixing afterwards

I still don't understand why wcs_[in/out].deepcopy() resolved the failure within reproject but does not do so within these code samples.

manodeep avatar Mar 26 '24 23:03 manodeep

Plus, why is there such a big discrepancy in the number of wrong pixel values between the (mostly)-C-functions (all values are wrong) and the more-python-functions test cases (~100 values are wrong).

I will also check that both the multi-process and multi-thread versions are actually splitting up the work. For example, if the multi-process is spawning (don't know why) 8 copies of the same task in serial - then it stands to reason that the 8 (identical) serial task would produce correct results.

manodeep avatar Mar 26 '24 23:03 manodeep

@manodeep - yes, as mentioned in this summary I think there are two separate issues and I think we should focus on the second one I mention (2.) which I think is what you call i)

Just to be clear, .deepcopy() does fix the issue in the example code given in my summary above (issue 2). The red herring was that in this example what matters is not whether the WCS is a copy or not, even initializing a new WCS inside each thread results in an issue – this is because at that point the issue is that the input array is being modified in-place, so it's the sharing of the input array that is the problem (issue 1)

astrofrog avatar Mar 26 '24 23:03 astrofrog

@manodeep - for issue 1, which is where all values are wrong, this is because the whole array is being offset multiple times in-place by the different thread. For issue 2, only some of the values are wrong, even if one uses mostly the C functions as shown in the example in here

astrofrog avatar Mar 26 '24 23:03 astrofrog