flox icon indicating copy to clipboard operation
flox copied to clipboard

Allow specifying output dtype

Open dcherian opened this issue 2 years ago • 12 comments

Closes https://github.com/pydata/xarray/issues/6902

cc @Illviljan @tasansal

dcherian avatar Aug 09 '22 18:08 dcherian

I don't really ever specify dtypes so a close review would be very valuable if you have the time.

dcherian avatar Aug 09 '22 18:08 dcherian

Thanks for the super quick fix!!!

I will try this and get back to you, need to read the developer docs first for using this branch.

tasansal avatar Aug 09 '22 19:08 tasansal

The easiest way would be to clone this repo and then use the github cli

git clone ...
gh pr checkout 131
pip install -e .

dcherian avatar Aug 09 '22 19:08 dcherian

@dcherian

It seems to be working fine, except I would expect A == B below. Output dtype mutates, and values don't match ground truth. I think its still running mean in fp32, but at the very and casting it to fp64 when doing B.

The rest of them match, which is excellent!

ground_truth_fp64 = da.groupby("time.month").apply(np.mean, dtype='float64').compute()
ground_truth_fp32 = da.groupby("time.month").apply(np.mean).compute() (A)

flox_as_is = da.groupby("time.month").mean().compute() (B)

dtype_in_mean = da.groupby("time.month").mean(dtype="float64").compute()

cast_input_dtype = da.astype("float64").groupby("time.month").mean().compute()

cast_and_in_mean = da.astype("float64").groupby("time.month").mean(dtype="float64").compute()
**ground truth fp64**
input dtype: float32 output_dtype float64
[6.5968835698 7.7664068452 7.1014096086 8.2660835249 8.4281900699
 8.3720580059 6.9553840001 7.0804730689 7.0591219326 6.8767603768
 6.9561013470 7.1530014089]

**ground truth fp32** (A)
input dtype: float32 output_dtype float32
[6.5968837738 7.7664070129 7.1014094353 8.2660837173 8.4281892776
 8.3720579147 6.9553837776 7.0804729462 7.0591216087 6.8767600060
 6.9561014175 7.1530008316]

**As is Flox groupby mean()** (B)
input dtype: float32 output_dtype float64
[6.5968835755 7.7664066341 7.1014098027 8.2660843461 8.4281902792
 8.3720576534 6.9553839571 7.0804727753 7.0591213650 6.8767606127
 6.9561012551 7.1530014440]

**Flox with dtype in mean()**
input dtype: float32 output_dtype float64
[6.5968835698 7.7664068452 7.1014096086 8.2660835249 8.4281900699
 8.3720580059 6.9553840001 7.0804730689 7.0591219326 6.8767603768
 6.9561013470 7.1530014089]

**Flox with input cast to float64, mean() as is (also ground truth)**
input dtype: float64 output_dtype float64
[6.5968835698 7.7664068452 7.1014096086 8.2660835249 8.4281900699
 8.3720580059 6.9553840001 7.0804730689 7.0591219326 6.8767603768
 6.9561013470 7.1530014089]

**Flox with input cast to float64, mean() also with dtype**
input dtype: float64 output_dtype float64
[6.5968835698 7.7664068452 7.1014096086 8.2660835249 8.4281900699
 8.3720580059 6.9553840001 7.0804730689 7.0591219326 6.8767603768
 6.9561013470 7.1530014089]

tasansal avatar Aug 09 '22 19:08 tasansal

Thanks for all your suggestions.

Now tests are failing only for engine="flox" and input dtype=float64.

===================================== short test summary info =====================================
FAILED tests/test_xarray.py::test_dtype[flox-True-float64-float640] - AssertionError: Left and r...
FAILED tests/test_xarray.py::test_dtype[flox-True-float64-float641] - AssertionError: Left and r...
FAILED tests/test_xarray.py::test_dtype[flox-True-float64-dtype_out2] - AssertionError: Left and...
FAILED tests/test_xarray.py::test_dtype[flox-False-float64-float640] - AssertionError: Left and ...
FAILED tests/test_xarray.py::test_dtype[flox-False-float64-float641] - AssertionError: Left and ...
FAILED tests/test_xarray.py::test_dtype[flox-False-float64-dtype_out2] - AssertionError: Left an...
============================ 6 failed, 30 passed, 2 warnings in 12.26s ============================

I'm not really sure why. probably a small roundoff difference.

@tasansal Does your sample dataset contain NaNs? The promotion of the output to float64 in (B) is funny.

dcherian avatar Aug 09 '22 20:08 dcherian

Thanks for all your suggestions.

Now tests are failing only for engine="flox" and input dtype=float64.

===================================== short test summary info =====================================
FAILED tests/test_xarray.py::test_dtype[flox-True-float64-float640] - AssertionError: Left and r...
FAILED tests/test_xarray.py::test_dtype[flox-True-float64-float641] - AssertionError: Left and r...
FAILED tests/test_xarray.py::test_dtype[flox-True-float64-dtype_out2] - AssertionError: Left and...
FAILED tests/test_xarray.py::test_dtype[flox-False-float64-float640] - AssertionError: Left and ...
FAILED tests/test_xarray.py::test_dtype[flox-False-float64-float641] - AssertionError: Left and ...
FAILED tests/test_xarray.py::test_dtype[flox-False-float64-dtype_out2] - AssertionError: Left an...
============================ 6 failed, 30 passed, 2 warnings in 12.26s ============================

I'm not really sure why. probably a small roundoff difference.

@tasansal Does your sample dataset contain NaNs? The promotion of the output to float64 in (B) is funny.

Yes it has NaNs. The normal Xarray mean drops them and averages to float32.

I also have Bottleneck and numpy_groupies installed, could that have anything to do with it?

The precision is a bit strange, I would expect it to match the numpy and flox=False version of Xarray.

Let me know if I can help any further!

tasansal avatar Aug 09 '22 20:08 tasansal

assert_equal is maybe too accurate? Maybe try using allclose instead? Something like this:

np.testing.assert_allclose(actual, expected, rtol=0, atol=np.finfo(actual.dtype).eps)

Illviljan avatar Aug 09 '22 21:08 Illviljan

Yeah allclose is what's needed.

@tasansal I can't reproduce your results even after adding NaNs. Can you try to construct a synthetic example that shows the problem? Does passing engine="numpy" help?

dcherian avatar Aug 09 '22 21:08 dcherian

@dcherian here you go.

I am calling Xarray without Flox "Vanilla Xarray," FYI.

Conclusions:

  1. The casting issue I've seen above is reproduced on Xarray + Flox + Dask (chunked) combination (STEP 7). All others look fine; the one with Dask returns a "float64" even though we didn't ask for it.
  2. None of the fp32 implementations except vanilla Xarray without Dask matches numpy ground truth. Which makes sense because they use numpy, correct? The others not matching also makes sense ONLY IF numpy isn't used in the backend, which I think is the case. Then this is ok. The "float64" examples always match, which is good.
  3. Enabling Dask on vanilla Xarray also breaks bullet 2, which is not good because it means there may be a problem with either partial means Dask + Xarray does, assuming numpy is still being used. Or maybe it's fine because we may be losing some precision doing partial mean purely in float32. It would be nice to get someone else from Dask or Xarray developers to comment on this. Can you tag some people?
  4. I would recommend making sure on the tests allclose is adjusted based on precision. The allclose defaults are not valid for float64 values, and this is important for scientific computing.

Step 1:

  • import stuff
  • set numpy display tolerance to 18 decimal points
  • set atol and rtol variables for np.allcose on "float64" variables to 15 and 18 decimal points (based on here)
  • defaults for atol and rtol are valid for "float32"
import xarray as xr
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
np.set_printoptions(precision=18)

# Tolerances for fp64
# fp64 can have between 15-18 decimal points accouracy
RTOL = 1e-15
ATOL = 1e-18

Step 2: Make synthetic data with NaNs. Synthetic data is cast to float32 as the real world example. NOTE: since random seed isn't fixed your results may look different

datetimes = pd.date_range("2010-01", "2015-01", freq="6H", inclusive="left")
samples = 10 + np.cos(2 * np.pi * 0.001 * np.arange(len(datetimes))) * 1
samples += np.random.randn(len(datetimes))
samples = samples.astype('float32')

nan_indices = np.random.default_rng().integers(0, len(samples), size=5_000)
samples[nan_indices] = np.nan

Step 3: Plot

plt.figure(figsize=(15, 3))
line = plt.plot(datetimes, samples)

image

Step 4: Calculate ground truth with NumPy

  • iterate months
  • find month indices from pandas datetime
  • reduce (mean) as is, which yields fp32 output
  • reduce (mean) with dtype="float64" accumulator
  • reduce (mean) after casting to "float64", not specifying dtype in mean() to compare
ground_truth_fp32 = []
ground_truth_fp64 = []
cast_before_fp64 = []
for month in range(1, 13):
    month_mask = datetimes.month == month
    month_data = samples[month_mask]
    ground_truth_fp32.append(np.nanmean(month_data))
    ground_truth_fp64.append(np.nanmean(month_data, dtype="float64"))
    cast_before_fp64.append(np.nanmean(month_data.astype("float64")))
    
ground_truth_fp32 = np.asarray(ground_truth_fp32)
ground_truth_fp64 = np.asarray(ground_truth_fp64)
cast_before_fp64 = np.asarray(cast_before_fp64)
print("fp32 ground truth:\n", ground_truth_fp32)
print("fp64 ground truth:\n", ground_truth_fp64)
print("fp64 casted ground truth:\n", ground_truth_fp64)
print("cast equals dtype arg:", np.equal(ground_truth_fp64, cast_before_fp64).all())

--- 

fp32 ground truth:
 [10.1700325 10.009565  10.092217   9.943765   9.769657   9.922329
  9.994079   9.99988   10.225714  10.153132  10.065705   9.93171  ]
fp64 ground truth:
 [10.170032398493728 10.009565626110946 10.092218036775465
  9.94376481957987   9.769656868304237  9.922329294098008
  9.994078634636194  9.999879809414468 10.225713932633004
 10.153132471137742 10.065704485859733  9.93170997271171 ]
fp64 casted ground truth:
 [10.170032398493728 10.009565626110946 10.092218036775465
  9.94376481957987   9.769656868304237  9.922329294098008
  9.994078634636194  9.999879809414468 10.225713932633004
 10.153132471137742 10.065704485859733  9.93170997271171 ]
cast equals dtype arg: True

Step 5: Pandas groupby mean for reference Note that it doesn't support dtype argument and they use their own Cython or Numba aggregations.

series = pd.Series(samples, index=datetimes)
pd_mean_fp32 = series.groupby(series.index.month).mean().to_numpy()
pd_mean_fp64 = series.astype("float64").groupby(series.index.month).mean().to_numpy()
print("pandas fp32 all equal:", np.equal(pd_mean_fp32, ground_truth_fp32).all())
print("pandas fp32 all close:", np.allclose(pd_mean_fp32, ground_truth_fp32))
print("pandas fp64 all equal:", np.equal(pd_mean_fp64, ground_truth_fp64).all())
print("pandas fp64 all close:", np.allclose(pd_mean_fp64, ground_truth_fp64, rtol=RTOL, atol=ATOL))
print("fp32 out dtype:", pd_mean_fp32.dtype)
print("fp64 out dtype:", pd_mean_fp64.dtype)

---

pandas fp32 all equal: False
pandas fp32 all close: True
pandas fp64 all equal: True
pandas fp64 all close: True
fp32 out dtype: float32
fp64 out dtype: float64

Step 6: Xarray + Flox + No Dask

da = xr.DataArray(samples, dims=("time",), coords=[datetimes])
da_mean_nodask_flox_fp32 = da.groupby("time.month").mean().values
da_mean_nodask_flox_fp64 = da.groupby("time.month").mean(dtype="float64").values
print("flox_nodask fp32 all equal:", np.equal(da_mean_nodask_flox_fp32, ground_truth_fp32).all())
print("flox_nodask fp32 all close:", np.allclose(da_mean_nodask_flox_fp32, ground_truth_fp32))
print("flox_nodask fp64 all equal:", np.equal(da_mean_nodask_flox_fp64, ground_truth_fp64).all())
print("flox_nodask fp64 all close:", np.allclose(da_mean_nodask_flox_fp64, ground_truth_fp64, rtol=RTOL, atol=ATOL))
print("fp32 out dtype:", da_mean_nodask_flox_fp32.dtype)
print("fp64 out dtype:", da_mean_nodask_flox_fp64.dtype)

---

flox_nodask fp32 all equal: False
flox_nodask fp32 all close: True
flox_nodask fp64 all equal: True
flox_nodask fp64 all close: True
fp32 out dtype: float32
fp64 out dtype: float64

Step 7: Xarray + Flox + With Dask (Chunked)

da_mean_dask_flox_fp32 = da.chunk(time=1024).groupby("time.month").mean().values
da_mean_dask_flox_fp64 = da.chunk(time=1024).groupby("time.month").mean(dtype="float64").values
print("flox dask fp32 all equal:", np.equal(da_mean_dask_flox_fp32, ground_truth_fp32).all())
print("flox dask fp32 all close:", np.allclose(da_mean_dask_flox_fp32, ground_truth_fp32))
print("flox dask fp64 all equal:", np.equal(da_mean_dask_flox_fp64, ground_truth_fp64).all())
print("flox dask fp64 all close:", np.allclose(da_mean_dask_flox_fp64, ground_truth_fp64, rtol=RTOL, atol=ATOL))
print("fp32 out dtype:", da_mean_dask_flox_fp32.dtype)
print("fp64 out dtype:", da_mean_dask_flox_fp64.dtype)

---

flox dask fp32 all equal: False
flox dask fp32 all close: True
flox dask fp64 all equal: True
flox dask fp64 all close: True
fp32 out dtype: float64  # HERE
fp64 out dtype: float64

Now turn off Flox with xr.set_options(use_flox=False).

Step 8: Xarray + No Flox + No Dask

da_mean_nodask_noflox_fp32 = da.groupby("time.month").mean().values
da_mean_nodask_noflox_fp64 = da.groupby("time.month").mean(dtype="float64").values
print("no flox nodask fp32 all equal:", np.equal(da_mean_nodask_noflox_fp32, ground_truth_fp32).all())
print("no flox nodask fp32 all close:", np.allclose(da_mean_nodask_noflox_fp32, ground_truth_fp32))
print("no flox nodask fp64 all equal:", np.equal(da_mean_nodask_noflox_fp64, ground_truth_fp64).all())
print("no flox nodask fp64 all close:", np.allclose(da_mean_nodask_noflox_fp64, ground_truth_fp64, rtol=RTOL, atol=ATOL))
print("fp32 out dtype:", da_mean_nodask_noflox_fp32.dtype)
print("fp64 out dtype:", da_mean_nodask_noflox_fp64.dtype)

---

no flox nodask fp32 all equal: True
no flox nodask fp32 all close: True
no flox nodask fp64 all equal: True
no flox nodask fp64 all close: True
fp32 out dtype: float32
fp64 out dtype: float64

Step 9: Xarray No Flox + With Dask (Chunked)

da_mean_dask_noflox_fp32 = da.chunk(time=1024).groupby("time.month").mean().values
da_mean_dask_noflox_fp64 = da.chunk(time=1024).groupby("time.month").mean(dtype="float64").values
print("no flox nodask fp32 all equal:", np.equal(da_mean_dask_noflox_fp32, ground_truth_fp32).all())
print("no flox nodask fp32 all close:", np.allclose(da_mean_dask_noflox_fp32, ground_truth_fp32))
print("no flox nodask fp64 all equal:", np.equal(da_mean_dask_noflox_fp64, ground_truth_fp64).all())
print("no flox nodask fp64 all close:", np.allclose(da_mean_dask_noflox_fp64, ground_truth_fp64, rtol=RTOL, atol=ATOL))
print("fp32 out dtype:", da_mean_dask_noflox_fp32.dtype)
print("fp64 out dtype:", da_mean_dask_noflox_fp64.dtype)

---

no flox nodask fp32 all equal: False
no flox nodask fp32 all close: True
no flox nodask fp64 all equal: True
no flox nodask fp64 all close: True
fp32 out dtype: float32
fp64 out dtype: float64

tasansal avatar Aug 10 '22 15:08 tasansal

Thanks @tasansal this is very valuable!

Would you mind writing up your post as a test and sending in a PR? It'd be nice to give you credit and would immensely help tracking down the fix. We can xfail it for now to merge and then I can modify this PR to fix it.

dcherian avatar Aug 10 '22 16:08 dcherian

Thanks @tasansal this is very valuable!

Would you mind writing up your post as a test and sending in a PR? It'd be nice to give you credit and would immensely help tracking down the fix. We can xfail it for now to merge and then I can modify this PR to fix it.

Sure thing, do you think the test should be upstreamed to Xarray, or is Flox where the test should belong? Xarray may already have a similar test I will also check that.

tasansal avatar Aug 10 '22 19:08 tasansal

Good call, Xarray's test_groupby is a good place to put it.

One thing to note is that we don't expect dask + Xarray to "match" numpy + Xarray in general (depends on what you mean by match) Because dask accumulates in chunks first you'll have different roundoff error.

dcherian avatar Aug 10 '22 19:08 dcherian

Ok, this didn't fix the particular dtype issues I'm having.

Maybe it's time to merge this then?

These tests failed when I added a dtype check, 4dab89a:

=========================== short test summary info ============================
FAILED tests/test_core.py::test_groupby_reduce[flox-sum-array1-by1-expected1-None-False-1-int]
FAILED tests/test_core.py::test_groupby_reduce[flox-sum-array1-by1-expected1-None-False-1-float]
FAILED tests/test_core.py::test_groupby_reduce[flox-sum-array3-by3-expected3-None-False-1-float]
FAILED tests/test_core.py::test_groupby_reduce[flox-sum-array3-by3-expected3-None-False-1-int]
FAILED tests/test_core.py::test_groupby_reduce[flox-sum-array5-by5-expected5-None-False-1-float]
FAILED tests/test_core.py::test_groupby_reduce[flox-sum-array5-by5-expected5-None-False-1-int]
FAILED tests/test_core.py::test_groupby_reduce[flox-count-array7-by7-expected7-None-False-1-float]
FAILED tests/test_core.py::test_groupby_reduce[flox-count-array7-by7-expected7-None-False-1-int]
FAILED tests/test_core.py::test_groupby_reduce[flox-count-array9-by9-expected9-None-False-1-float]
FAILED tests/test_core.py::test_groupby_reduce[flox-count-array9-by9-expected9-None-False-1-int]
FAILED tests/test_core.py::test_groupby_reduce[flox-count-array11-by11-expected11-None-False-1-float]
FAILED tests/test_core.py::test_groupby_reduce[flox-count-array11-by11-expected11-None-False-1-int]
FAILED tests/test_core.py::test_groupby_reduce[flox-nanmean-array13-by13-expected13-None-False-1-float]
FAILED tests/test_core.py::test_groupby_reduce[flox-nanmean-array13-by13-expected13-None-False-1-int]
FAILED tests/test_core.py::test_groupby_reduce[flox-nanmean-array15-by15-expected15-None-False-1-float]
FAILED tests/test_core.py::test_groupby_reduce[flox-nanmean-array15-by15-expected15-None-False-1-int]
FAILED tests/test_core.py::test_groupby_reduce[flox-nanmean-array17-by17-expected17-None-False-1-float]
FAILED tests/test_core.py::test_groupby_reduce[flox-nanmean-array17-by17-expected17-None-False-1-int]
FAILED tests/test_core.py::test_groupby_reduce[numpy-count-array9-by9-expected9-None-False-1-float]
FAILED tests/test_core.py::test_groupby_reduce[numpy-count-array9-by9-expected9-None-False-1-int]
FAILED tests/test_core.py::test_groupby_reduce[numpy-count-array11-by11-expected11-None-False-1-float]
FAILED tests/test_core.py::test_groupby_reduce[numpy-count-array11-by11-expected11-None-False-1-int]
FAILED tests/test_core.py::test_groupby_reduce[numpy-nanmean-array13-by13-expected13-None-False-1-float]
FAILED tests/test_core.py::test_groupby_reduce[numpy-nanmean-array13-by13-expected13-None-False-1-int]
FAILED tests/test_core.py::test_groupby_reduce[numpy-nanmean-array15-by15-expected15-None-False-1-float]
FAILED tests/test_core.py::test_groupby_reduce[numpy-nanmean-array15-by15-expected15-None-False-1-int]
FAILED tests/test_core.py::test_groupby_reduce[numpy-nanmean-array17-by17-expected17-None-False-1-float]
FAILED tests/test_core.py::test_groupby_reduce[numpy-nanmean-array17-by17-expected17-None-False-1-int]
FAILED tests/test_core.py::test_bool_reductions[flox-max] - AssertionError: a...
FAILED tests/test_core.py::test_bool_reductions[flox-nanmax] - AssertionError...
FAILED tests/test_core.py::test_bool_reductions[flox-min] - AssertionError: a...
FAILED tests/test_core.py::test_bool_reductions[flox-nanmin] - AssertionError...
FAILED tests/test_core.py::test_groupby_reduce[numpy-sum-array1-by1-expected1-None-False-1-float]
FAILED tests/test_core.py::test_groupby_reduce[numpy-sum-array1-by1-expected1-None-False-1-int]
FAILED tests/test_core.py::test_groupby_reduce[numpy-sum-array3-by3-expected3-None-False-1-float]
FAILED tests/test_core.py::test_groupby_reduce[numpy-sum-array3-by3-expected3-None-False-1-int]
FAILED tests/test_core.py::test_groupby_reduce[numpy-sum-array5-by5-expected5-None-False-1-float]
FAILED tests/test_core.py::test_groupby_reduce[numpy-sum-array5-by5-expected5-None-False-1-int]
FAILED tests/test_core.py::test_groupby_reduce[numpy-count-array7-by7-expected7-None-False-1-float]
FAILED tests/test_core.py::test_groupby_reduce[numpy-count-array7-by7-expected7-None-False-1-int]
FAILED tests/test_core.py::test_bool_reductions[numpy-max] - AssertionError: ...
FAILED tests/test_core.py::test_bool_reductions[numpy-nanmax] - AssertionErro...
FAILED tests/test_core.py::test_bool_reductions[numpy-min] - AssertionError: ...
FAILED tests/test_core.py::test_bool_reductions[numpy-nanmin] - AssertionErro...
FAILED tests/test_core.py::test_groupby_reduce[numba-sum-array1-by1-expected1-None-False-1-float]
FAILED tests/test_core.py::test_groupby_reduce[numba-sum-array1-by1-expected1-None-False-1-int]
FAILED tests/test_core.py::test_groupby_reduce[numba-sum-array3-by3-expected3-None-False-1-float]
FAILED tests/test_core.py::test_groupby_reduce[numba-sum-array3-by3-expected3-None-False-1-int]
FAILED tests/test_core.py::test_groupby_reduce[numba-sum-array5-by5-expected5-None-False-1-float]
FAILED tests/test_core.py::test_groupby_reduce[numba-sum-array5-by5-expected5-None-False-1-int]
FAILED tests/test_core.py::test_groupby_reduce[numba-count-array7-by7-expected7-None-False-1-float]
FAILED tests/test_core.py::test_groupby_reduce[numba-count-array7-by7-expected7-None-False-1-int]
FAILED tests/test_core.py::test_groupby_reduce[numba-count-array9-by9-expected9-None-False-1-float]
FAILED tests/test_core.py::test_groupby_reduce[numba-count-array9-by9-expected9-None-False-1-int]
FAILED tests/test_core.py::test_groupby_reduce[numba-count-array11-by11-expected11-None-False-1-float]
FAILED tests/test_core.py::test_groupby_reduce[numba-count-array11-by11-expected11-None-False-1-int]
FAILED tests/test_core.py::test_groupby_reduce[numba-nanmean-array13-by13-expected13-None-False-1-float]
FAILED tests/test_core.py::test_groupby_reduce[numba-nanmean-array13-by13-expected13-None-False-1-int]
FAILED tests/test_core.py::test_groupby_reduce[numba-nanmean-array15-by15-expected15-None-False-1-float]
FAILED tests/test_core.py::test_groupby_reduce[numba-nanmean-array15-by15-expected15-None-False-1-int]
FAILED tests/test_core.py::test_groupby_reduce[numba-nanmean-array17-by17-expected17-None-False-1-float]
FAILED tests/test_core.py::test_groupby_reduce[numba-nanmean-array17-by17-expected17-None-False-1-int]
FAILED tests/test_core.py::test_bool_reductions[numba-max] - AssertionError: ...
FAILED tests/test_core.py::test_multiple_groupers - AssertionError: a and b h...
FAILED tests/test_core.py::test_bool_reductions[numba-nanmax] - AssertionErro...
FAILED tests/test_core.py::test_bool_reductions[numba-min] - AssertionError: ...
FAILED tests/test_core.py::test_bool_reductions[numba-nanmin] - AssertionErro...
FAILED tests/test_core.py::test_map_reduce_blockwise_mixed - AttributeError: ...
= 68 failed, 6438 passed, 1958 skipped, 174 xfailed, 3 xpassed, 1744 warnings in 356.04s (0:05:56) =
Error: Process completed with exit code 1.

Illviljan avatar Sep 25 '22 12:09 Illviljan

From https://github.com/xarray-contrib/flox/pull/131#issuecomment-1210870752

Conclusions:

  1. The casting issue I've seen above is reproduced on Xarray + Flox + Dask (chunked) combination (STEP 7). All others look fine; the one with Dask returns a "float64" even though we didn't ask for it.

@dcherian and @Illviljan

This one still is inconsistent.

Let me elaborate. When you run a mean operation on float32 data without the dtype argument, the result is also float32. However, when you pass the dtype as float64 without flox, the mean accumulator (sum of all values then divided by the number of samples) is made into a float64 for precision; and then it is cast back to float32. So to summarize, input float32, output float32 with more precision than not having the dtype argument.

With this PR, most examples in the comment have consistent behavior. However, when Xarray + Flox + Dask (must be chunked) edge case returns a float64 even if the input is float32.

Note that the inconsistency is with the return dtype, I understand values may not be exactly the same.

I suggest this should be fixed before we merge for consistency across the board.

tasansal avatar Sep 25 '22 12:09 tasansal

@tasansal should be all good here. Would you mind running your internal test suite and confirming?

dcherian avatar Oct 10 '22 03:10 dcherian

@dcherian, thanks for the awesome fixes! Sorry just catching up; I will test it and let you know if anything acts up in a few days. Will this be in 0.5.11 ?

tasansal avatar Oct 11 '22 18:10 tasansal

. Will this be in 0.5.11 ?

Yes

dcherian avatar Oct 11 '22 18:10 dcherian