MHKiT-Python
MHKiT-Python copied to clipboard
Xarray Strategy causing significant MHKiT computation times (MHKiT slow)
Not a bug, but I was trying to analyze 30 years of hourly data with xarray
, dask
, and mhkit
and having a lot of trouble getting things to run in a timely manner. I found that by simply re-writing some of the functions from mhkit
in pure numpy
, I got speed improvements on the order of 3000x. Looking at the mhkit
code, I think this can only be due to converting between types?
function | time |
---|---|
mhkit.wave.resource.significant_wave_height |
2.4s |
numpy |
0.8ms |
# %%
import xarray as xr
from mhkit.wave.resource import significant_wave_height
import numpy as np
import timeit
import matplotlib.pyplot as plt
# %% [markdown]
# # Load data
# %%
ds = xr.open_dataset('cape_hatteras_download_spectra_2000.nc',)
ds = ds.rename({'time_index': 'time'})
ds['frequency'].attrs['units'] = 'Hz'
ds['frequency'].attrs['long_name'] = 'Frequency'
ds['direction'] = ds['direction']
ds['direction'].attrs['units'] = 'rad'
ds['direction'].attrs['long_name'] = 'Direction'
ds.attrs['gid'] = ds.gid.item()
ds = ds.drop_vars('gid').squeeze()
ds = ds.to_array()
ds = ds.drop_vars('variable').squeeze()
ds = ds / (1025*9.81)
dso = ds.integrate('direction')
dso.attrs['units'] = 'm$^2$/Hz'
dso.attrs['long_name'] = 'Spectral density'
dso.name = 'S'
dso
# %% [markdown]
# # Timing
# %%
time = {}
n = 20
# %% [markdown]
# ## Using MHKiT
# %%
time['mhkit'] = timeit.timeit(
lambda: significant_wave_height(dso.to_pandas().transpose()), number=n)/n
# %% [markdown]
# ## Using numpy
# %%
def moment(da, order=0):
df = np.insert(np.diff(da['frequency']), 0, da['frequency'][0])
m = np.sum(df*da.data*da.frequency.data**order, axis=1)
return m
def sig_wave_height(da):
return 4*np.sqrt(moment(da, 0))
# %%
time['numpy'] = timeit.timeit(
lambda: sig_wave_height(dso), number=n)/n
# %%
time
# %%
time['mhkit']/time['numpy']
# %% [markdown]
# # Check that they agree
# %%
significant_wave_height(dso.to_pandas().transpose()).to_numpy().squeeze() - sig_wave_height(dso)