geocube
geocube copied to clipboard
Support for rasterisation to dask arrays
Hi folks,
I've had a quick look through the source code and haven't able to find this functionality, so apologies if it exists and I've missed something.
What do you think about the feasibility/attractiveness of being able to run make_geocube as a delayed operation so that the returned xarray wraps a dask array rather than an in-memory numpy ndarray (by, say, passing a chunks argument somewhere as in rioxarray.open_rasterio here.)?
An example use case is in a heavy machine learning workload, where a neural network would be trained on a O(10-100)GB dataset of high resolution aerial photography with rasterised vector layers representing ground truth data.
I'm happy to take a look at this but don't have the familiarity with the codebase to know where a good seam would be for it and whether it's possible to do without breaking things downstream, so would be nice to hear your thoughts.
Cheers!
L
What do you think about the feasibility/attractiveness of being able to run make_geocube as a delayed operation so that the returned xarray wraps a dask array rather than an in-memory numpy ndarray (by, say, passing a chunks argument somewhere as in rioxarray.open_rasterio here.)?
This definitely sounds like something useful and worth looking into :+1:. If you do proceed, I would definitely want to preserve the current behavior and have the dask part something that can be toggled on or off, maybe with the chunks
argument.
I am not entirely sure about the feasibility if you are thinking about that in terms of the difficulty level. That is something that will take some research. My first thought was that it may be useful to see if you could incorporate dask-geopandas.
You might be able to update geocube.vector_to_cube.VectorToCube
to use dask. However, you may have to create a VectorToCubeDask
class as well depending on how things go.
This would be super helpful! I am just this moment fighting with a huge aerial photo (~340GB) for which I need to create a rasterized ROI mask based on a geopandas data frame... My kernel/ notebook is constantly crashing due to insufficient memory...
Hey folks,
I've not had the time to implement and test this properly and open a PR, but i did make a slightly hacky attempt that uses the "like" protocol to construct a rasterised dask array with geocube from a reference raster, copying the geospatial metadata (my use case) that could be a useful starting point for a general implementation:
import itertools
from functools import partial
from typing import Union, Optional, List, Any
import xarray as xr
import shapely
import geopandas as gpd
import numpy as np
import dask.array as da
import rioxarray as rx
import geocube
import pandas as pd
import pyproj
from geocube.api.core import make_geocube
from dask import delayed
from rasterio.windows import Window
def make_geocube_like_dask(
df:gpd.GeoDataFrame,
measurements:Optional[List[str]],
like:xr.core.dataarray.DataArray,
fill:int=0,
rasterize_function:callable=partial(geocube.rasterize.rasterize_image, all_touched=True),
**kwargs
):
"""
Implements dask support for a subset of the make_geocube API.
Requires using the "like" protocol for array construction. Chunk spatial dimensions are
copied from the reference "like" raster.
Parameters
----------
df :
A geopandas dataframe with relevant polygons to be rasterized
measurements :
Columns of the geopandas dataframe to rasterize
like :
A reference (rio)xarray dataset. Georeferencing and spatial dimensions from this are
used to construct a matching raster from the polygons.
fill :
Fill value for rasterized shapes
rasterize_function :
Function used for rasterization. See the geocube make_geocube documentation.
kwargs :
passed directly to make_geocube
"""
# take spatial dims from "like" xarray to construct chunk window boundaries (rioxarray has channels first)
row_chunks, col_chunks = [list(np.cumsum(c)) for c in like.chunks][1:]
row_chunks[:0], col_chunks[:0] = [0], [0]
# construct 2-tuples of row and col slices corresponding to first/last indices in each chunk
row_slices, col_slices = list(zip(row_chunks, row_chunks[1:])), list(zip(col_chunks, col_chunks[1:]))
# store blocks for recombination of mask chunks into output dask array
delayed_blocks = [[] for _ in range(len(row_slices))]
# go through each chunk of the dask array pointing to the reference raster
for ix, (rs, cs) in enumerate(itertools.product(row_slices, col_slices)):
# calculate the block index in the final dask array
row_ix, col_ix = np.unravel_index(ix, like.data.numblocks[1:])
# slice out the chunk of the raster array with isel_window to preserve correct geospatial metadata
window = Window.from_slices(rs, cs)
blk_window = like.rio.isel_window(window)
# attempt fix at crs multithreaded access bug when writing raster to disk (see: https://github.com/pyproj4/pyproj/issues/589)
blk_window['spatial_ref'] = blk_window.spatial_ref.copy(deep=True)
# create a delayed xarray with make_geocube for this chunk
delayed_task = delayed(make_geocube)(
df,
measurements=measurements,
like=blk_window,
fill=fill,
rasterize_function=rasterize_function
)
# access the (delayed) numpy array underlying it
delayed_block = delayed_task.to_array().data
# convert to a dask array chunk of the correct shape
chunk_shape = (len(measurements) if measurements is not None else 1, *like.data.blocks[0, row_ix, col_ix].shape[1:])
dask_block = da.from_delayed(
delayed_block,
shape=chunk_shape, dtype=np.uint8
)
# insert into the block list at the correct position
delayed_blocks[row_ix].append(dask_block)
# assemble blocks into output array
mask_darr = da.block(delayed_blocks)
# modify spatial metadata to reflect common spatial reference and different band meanings
coords = {'band': list(measurements) if measurements is not None else ['mask'], 'y':like.coords['y'], 'x':like.coords['x']}
xarr = xr.DataArray(
data=mask_darr,
coords=coords,
dims=['band', 'y', 'x']
)
xarr['spatial_ref'] = like.spatial_ref
xarr.attrs = like.attrs
return xarr
This produces an dask-backed xarray of the correct shape and gives sensible results when plotted for some small test rasters and geodataframes.
However I am running into an issue with this implementation which I think is related to concurrent access to the CRS object described in UnicodeDecodeError when run multithreaded #589. This only seems to occur for larger rasters where dask is crunching multiple chunks at a time.
When trying to compute pieces of the array:
Traceback is as follows:
---------------------------------------------------------------------------
UnicodeDecodeError Traceback (most recent call last)
<ipython-input-11-4037d4c44d6f> in <module>
----> 1 y0_multi = tds_multi.y.blocks[0].compute()
/opt/conda/envs/gim_cv_gpu/lib/python3.8/site-packages/dask/base.py in compute(self, **kwargs)
277 dask.base.compute
278 """
--> 279 (result,) = compute(self, traverse=False, **kwargs)
280 return result
281
/opt/conda/envs/gim_cv_gpu/lib/python3.8/site-packages/dask/base.py in compute(*args, **kwargs)
565 postcomputes.append(x.__dask_postcompute__())
566
--> 567 results = schedule(dsk, keys, **kwargs)
568 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
569
/opt/conda/envs/gim_cv_gpu/lib/python3.8/site-packages/dask/threaded.py in get(dsk, result, cache, num_workers, pool, **kwargs)
74 pools[thread][num_workers] = pool
75
---> 76 results = get_async(
77 pool.apply_async,
78 len(pool._pool),
/opt/conda/envs/gim_cv_gpu/lib/python3.8/site-packages/dask/local.py in get_async(apply_async, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, **kwargs)
484 _execute_task(task, data) # Re-execute locally
485 else:
--> 486 raise_exception(exc, tb)
487 res, worker_id = loads(res_info)
488 state["cache"][key] = res
/opt/conda/envs/gim_cv_gpu/lib/python3.8/site-packages/dask/local.py in reraise(exc, tb)
314 if exc.__traceback__ is not tb:
315 raise exc.with_traceback(tb)
--> 316 raise exc
317
318
/opt/conda/envs/gim_cv_gpu/lib/python3.8/site-packages/dask/local.py in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
220 try:
221 task, data = loads(task_info)
--> 222 result = _execute_task(task, data)
223 id = get_id()
224 result = dumps((result, id))
/opt/conda/envs/gim_cv_gpu/lib/python3.8/site-packages/dask/core.py in _execute_task(arg, cache, dsk)
119 # temporaries by their reference count and can execute certain
120 # operations in-place.
--> 121 return func(*(_execute_task(a, cache) for a in args))
122 elif not ishashable(arg):
123 return arg
/opt/conda/envs/gim_cv_gpu/lib/python3.8/site-packages/dask/core.py in <genexpr>(.0)
119 # temporaries by their reference count and can execute certain
120 # operations in-place.
--> 121 return func(*(_execute_task(a, cache) for a in args))
122 elif not ishashable(arg):
123 return arg
/opt/conda/envs/gim_cv_gpu/lib/python3.8/site-packages/dask/core.py in _execute_task(arg, cache, dsk)
119 # temporaries by their reference count and can execute certain
120 # operations in-place.
--> 121 return func(*(_execute_task(a, cache) for a in args))
122 elif not ishashable(arg):
123 return arg
/opt/conda/envs/gim_cv_gpu/lib/python3.8/site-packages/dask/core.py in <genexpr>(.0)
119 # temporaries by their reference count and can execute certain
120 # operations in-place.
--> 121 return func(*(_execute_task(a, cache) for a in args))
122 elif not ishashable(arg):
123 return arg
/opt/conda/envs/gim_cv_gpu/lib/python3.8/site-packages/dask/core.py in _execute_task(arg, cache, dsk)
119 # temporaries by their reference count and can execute certain
120 # operations in-place.
--> 121 return func(*(_execute_task(a, cache) for a in args))
122 elif not ishashable(arg):
123 return arg
/opt/conda/envs/gim_cv_gpu/lib/python3.8/site-packages/dask/core.py in <genexpr>(.0)
119 # temporaries by their reference count and can execute certain
120 # operations in-place.
--> 121 return func(*(_execute_task(a, cache) for a in args))
122 elif not ishashable(arg):
123 return arg
/opt/conda/envs/gim_cv_gpu/lib/python3.8/site-packages/dask/core.py in _execute_task(arg, cache, dsk)
119 # temporaries by their reference count and can execute certain
120 # operations in-place.
--> 121 return func(*(_execute_task(a, cache) for a in args))
122 elif not ishashable(arg):
123 return arg
/opt/conda/envs/gim_cv_gpu/lib/python3.8/site-packages/dask/utils.py in apply(func, args, kwargs)
28 def apply(func, args, kwargs=None):
29 if kwargs:
---> 30 return func(*args, **kwargs)
31 else:
32 return func(*args)
/opt/conda/envs/gim_cv_gpu/lib/python3.8/site-packages/geocube/api/core.py in make_geocube(vector_data, measurements, datetime_measurements, output_crs, resolution, align, geom, like, fill, group_by, interpolate_na_method, categorical_enums, rasterize_function)
84 geobox_maker = GeoBoxMaker(output_crs, resolution, align, geom, like)
85
---> 86 return VectorToCube(
87 vector_data=vector_data,
88 geobox_maker=geobox_maker,
/opt/conda/envs/gim_cv_gpu/lib/python3.8/site-packages/geocube/vector_to_cube.py in make_geocube(self, measurements, datetime_measurements, group_by, interpolate_na_method, rasterize_function)
133 )
134 # reproject vector data to the projection of the output raster
--> 135 vector_data = self.vector_data.to_crs(self.geobox.crs.wkt)
136
137 # convert to datetime
/opt/conda/envs/gim_cv_gpu/lib/python3.8/site-packages/datacube/utils/geometry/_base.py in wkt(self)
188 @property
189 def wkt(self) -> str:
--> 190 return self.to_wkt(version="WKT1_GDAL")
191
192 def to_epsg(self) -> Optional[int]:
/opt/conda/envs/gim_cv_gpu/lib/python3.8/site-packages/datacube/utils/geometry/_base.py in to_wkt(self, pretty, version)
184 version = self.DEFAULT_WKT_VERSION
185
--> 186 return self._crs.to_wkt(pretty=pretty, version=version)
187
188 @property
pyproj/_crs.pyx in pyproj._crs.Base.to_wkt()
pyproj/_crs.pyx in pyproj._crs._to_wkt()
pyproj/_crs.pyx in pyproj._crs.cstrdecode()
/opt/conda/envs/gim_cv_gpu/lib/python3.8/site-packages/pyproj/compat.py in pystrdecode(cstr)
21 """
22 try:
---> 23 return cstr.decode("utf-8")
24 except AttributeError:
25 return cstr
UnicodeDecodeError: 'utf-8' codec can't decode byte 0x80 in position 0: invalid start byte
Any suggestions would be welcome!
This is great, thanks for sharing :+1:. I am fairly confident that you are running into UnicodeDecodeError when run multithreaded #589. The only way around that is to be able to construct the pyproj.CRS
object in each thread. You cannot share the same pyproj.CRS
object among threads. This means that you will need to find a way to store the CRS WKT on the geodataframe and re-generate the CRS in each thread.
@LiamRMoore, a fix for this is in the master branch of pyproj
(https://github.com/pyproj4/pyproj/pull/782; https://github.com/pyproj4/pyproj/pull/793).
I reimplemented this using map_blocks
. It works, but I'm not sure if I handled the coords/names the best way (it only works for a single band).
from functools import partial
from typing import Optional, List
import xarray as xr
import geopandas as gpd
import numpy as np
import rioxarray as rx
import geocube
from geocube.api.core import make_geocube
from dask.distributed import Client, Lock
def make_geocube_like_dask2(
df: gpd.GeoDataFrame,
measurements: Optional[List[str]],
like: xr.core.dataarray.DataArray,
fill: int=0,
rasterize_function:callable=partial(geocube.rasterize.rasterize_image, all_touched=True),
**kwargs
):
def rasterize_block(block):
return(
make_geocube(
df,
measurements=measurements,
like=block,
fill=fill,
rasterize_function=rasterize_function,
)
.to_array(measurements[0])
.assign_coords(block.coords)
)
like = like.rename(dict(zip(['band'], measurements)))
return like.map_blocks(
rasterize_block,
template=like
)