torchinterp1d
torchinterp1d copied to clipboard
Problems reproducing values given by np.interp()
Hello, I'm trying to rewrite some histogram matching code in pytorch which relies on some 1D interpolations.
I've noticed that while most of the values in my result with torchinterp1d are the same, there are a couple values which are an order of magnitude off of what I expect.
Here's some code to reproduce the issue:
import numpy as np
import torch
from torchinterp1d import Interp1d
interp1d = Interp1d()
# histogram matching with numpy
random_state = np.random.RandomState(12345) # not all seeds have this issue, but this is one that does
bins = 64
target = random_state.normal(size=(128 * 128)) * 2 # some random data between about -8 and 8
source = random_state.normal(size=(128 * 128)) * 2
matched = np.empty_like(target)
lo = min(target.min(), source.min())
hi = max(target.max(), source.max())
target_hist_np, bin_edges_np = np.histogram(target, bins=bins, range=[lo, hi])
source_hist_np, _ = np.histogram(source, bins=bins, range=[lo, hi])
target_cdf_np = target_hist_np.cumsum()
target_cdf_np = target_cdf_np / target_cdf_np[-1]
source_cdf_np = source_hist_np.cumsum()
source_cdf_np = source_cdf_np / source_cdf_np[-1]
remapped_cdf_np = np.interp(target_cdf_np, source_cdf_np, bin_edges_np[1:])
matched_np = np.interp(target, bin_edges_np[1:], remapped_cdf_np, left=0, right=bins)
# now with pytorch
target = torch.from_numpy(target)
source = torch.from_numpy(source)
target_hist = torch.histc(target, bins, lo, hi)
source_hist = torch.histc(source, bins, lo, hi)
assert np.allclose(target_hist_np, target_hist.numpy())
assert np.allclose(source_hist_np, source_hist.numpy())
target_cdf = target_hist.cumsum(0)
target_cdf = target_cdf / target_cdf[-1]
assert np.allclose(target_cdf_np, target_cdf.numpy())
source_cdf = source_hist.cumsum(0)
source_cdf = source_cdf / source_cdf[-1]
assert np.allclose(source_cdf_np, source_cdf.numpy())
bin_edges = torch.linspace(lo, hi, bins + 1)
assert np.allclose(bin_edges_np, bin_edges.numpy())
remapped_cdf = interp1d(source_cdf, bin_edges[1:], target_cdf).squeeze()
# ^^^ first positions of this have -100 values all of a sudden?!
print(remapped_cdf_np)
print(remapped_cdf.numpy())
assert np.allclose(remapped_cdf_np, remapped_cdf.numpy()) # fails
matched = interp1d(bin_edges[1:], remapped_cdf, target)
assert np.allclose(matched_np, matched.numpy())
The above code gives me the output:
[-8.04819874 -8.04819874 -8.04819874 -7.03412467 -6.52708763 -6.34600297
-6.27356911 -6.10455677 -5.89329133 -5.55526664 -5.28932075 -5.00597652
-4.81837282 -4.66795183 -4.43309052 -4.17367044 -3.93438144 -3.670879
-3.44365304 -3.19894227 -2.97056192 -2.7420723 -2.47732906 -2.21208839
-1.96009338 -1.69844422 -1.44815496 -1.20431557 -0.94311239 -0.68723275
-0.44108403 -0.18912467 0.055417 0.30790917 0.5585027 0.81660576
1.0688232 1.33458219 1.60847022 1.85890728 2.12938742 2.38900627
2.66416974 2.93036861 3.17321839 3.41920686 3.64490881 3.92116168
4.1785585 4.43336298 4.75240842 4.99895072 5.34133486 5.58747586
5.77523205 5.9234885 6.12066957 6.29370601 6.36613988 6.52911607
7.6699494 7.6699494 7.6699494 7.92346792]
[-1.37849957e+02 -1.37849957e+02 -1.37849957e+02 -7.28813667e+00
-6.78085313e+00 -6.34605302e+00 -6.27363935e+00 -6.10459291e+00
-5.89331551e+00 -5.55530036e+00 -5.28934614e+00 -5.00599635e+00
-4.81837965e+00 -4.66795432e+00 -4.43309173e+00 -4.17367124e+00
-3.93438180e+00 -3.67087919e+00 -3.44365297e+00 -3.19894198e+00
-2.97056138e+00 -2.74207334e+00 -2.47732997e+00 -2.21208787e+00
-1.96009280e+00 -1.69844372e+00 -1.44815450e+00 -1.20431574e+00
-9.43111794e-01 -6.87232221e-01 -4.41083541e-01 -1.89125193e-01
5.54165081e-02 3.07908676e-01 5.58502213e-01 8.16605249e-01
1.06882271e+00 1.33458229e+00 1.60847020e+00 1.85890732e+00
2.12938745e+00 2.38900625e+00 2.66416942e+00 2.93036823e+00
3.17321798e+00 3.41920633e+00 3.64490848e+00 3.92116086e+00
4.17855749e+00 4.43336134e+00 4.75240455e+00 4.99894268e+00
5.34132004e+00 5.58744816e+00 5.77521847e+00 5.92348256e+00
6.12062090e+00 6.29366587e+00 6.36607955e+00 6.52899272e+00
6.90914693e+00 6.90914693e+00 6.90914693e+00 7.92322078e+00]
Traceback (most recent call last):
File "histmatch.py", line 256, in <module>
assert np.allclose(remapped_cdf_np, remapped_cdf.numpy()) # fails
AssertionError
The values printed in the second array are from torchinterp1d while the top values are from np.interp for the same inputs (as evidenced by earlier asserts not triggering). Note that the order of arguments for torchinterp1d are slightly different than np.interp, but I believe they should produce the same result.
In fact, most of the values that are printed are the same. Take the last value of the array for example: 7.92322078e+00 is pretty close to 7.92346792. The same holds for almost all values in the array, except for the first 3. These are an order of magnitude lower than the rest of the values (around -140).
To be concrete, these two lines give different results for the same inputs:
remapped_cdf_np = np.interp(x=target_cdf_np, xp=source_cdf_np, fp=bin_edges_np[1:])
remapped_cdf = interp1d(x=source_cdf, y=bin_edges[1:], xnew=target_cdf).squeeze()
What's going on here? Is there a way to exactly reproduce numpy's results with pytorch?