xarray
xarray copied to clipboard
Parallelize map_over_subtree
Copied from https://github.com/xarray-contrib/datatree/issues/252
What is your issue?
I think there's some good opportunities to run map_over_subtree
in parallel using dask.delayed
.
Consider this example data:
import numpy as np
import xarray as xr
from datatree import DataTree
number_of_files = 25
number_of_groups = 20
number_of_variables = 2000
datasets = {}
for f in range(number_of_files):
for g in range(number_of_groups):
# Create random data:
time = np.linspace(0, 50 + f, 100 + g)
y = f * time + g
# Create dataset:
ds = xr.Dataset(
data_vars={
f"temperature_{g}{i}": ("time", y)
for i in range(number_of_variables // number_of_groups)
},
coords={"time": ("time", time)},
) # .chunk()
# Prepare for Datatree:
name = f"file_{f}/group_{g}"
datasets[name] = ds
dt = DataTree.from_dict(datasets)
# %% Interpolate to same time coordinate
new_time = np.linspace(0, 150, 50)
dt_interp = dt.interp(time=new_time)
# Original 10s, with dask.delayed 6s
# If datasets were chunked: Original 34s, with dask.delayed 10s
Here's my modded map_over_subtree
:
def map_over_subtree(func: Callable) -> Callable:
"""
Decorator which turns a function which acts on (and returns) Datasets into one which acts on and returns DataTrees.
Applies a function to every dataset in one or more subtrees, returning new trees which store the results.
The function will be applied to any non-empty dataset stored in any of the nodes in the trees. The returned trees
will have the same structure as the supplied trees.
`func` needs to return one Datasets, DataArrays, or None in order to be able to rebuild the subtrees after
mapping, as each result will be assigned to its respective node of a new tree via `DataTree.__setitem__`. Any
returned value that is one of these types will be stacked into a separate tree before returning all of them.
The trees passed to the resulting function must all be isomorphic to one another. Their nodes need not be named
similarly, but all the output trees will have nodes named in the same way as the first tree passed.
Parameters
----------
func : callable
Function to apply to datasets with signature:
`func(*args, **kwargs) -> Union[Dataset, Iterable[Dataset]]`.
(i.e. func must accept at least one Dataset and return at least one Dataset.)
Function will not be applied to any nodes without datasets.
*args : tuple, optional
Positional arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets
via .ds .
**kwargs : Any
Keyword arguments passed on to `func`. If DataTrees any data-containing nodes will be converted to Datasets
via .ds .
Returns
-------
mapped : callable
Wrapped function which returns one or more tree(s) created from results of applying ``func`` to the dataset at
each node.
See also
--------
DataTree.map_over_subtree
DataTree.map_over_subtree_inplace
DataTree.subtree
"""
# TODO examples in the docstring
# TODO inspect function to work out immediately if the wrong number of arguments were passed for it?
@functools.wraps(func)
def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]:
"""Internal function which maps func over every node in tree, returning a tree of the results."""
from .datatree import DataTree
parallel = True
if parallel:
import dask
func_ = dask.delayed(func)
else:
func_ = func
all_tree_inputs = [a for a in args if isinstance(a, DataTree)] + [
a for a in kwargs.values() if isinstance(a, DataTree)
]
if len(all_tree_inputs) > 0:
first_tree, *other_trees = all_tree_inputs
else:
raise TypeError("Must pass at least one tree object")
for other_tree in other_trees:
# isomorphism is transitive so this is enough to guarantee all trees are mutually isomorphic
check_isomorphic(
first_tree,
other_tree,
require_names_equal=False,
check_from_root=False,
)
# Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees
# We don't know which arguments are DataTrees so we zip all arguments together as iterables
# Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return
out_data_objects = {}
args_as_tree_length_iterables = [
a.subtree if isinstance(a, DataTree) else repeat(a) for a in args
]
n_args = len(args_as_tree_length_iterables)
kwargs_as_tree_length_iterables = {
k: v.subtree if isinstance(v, DataTree) else repeat(v)
for k, v in kwargs.items()
}
for node_of_first_tree, *all_node_args in zip(
first_tree.subtree,
*args_as_tree_length_iterables,
*list(kwargs_as_tree_length_iterables.values()),
):
node_args_as_datasets = [
a.to_dataset() if isinstance(a, DataTree) else a
for a in all_node_args[:n_args]
]
node_kwargs_as_datasets = dict(
zip(
[k for k in kwargs_as_tree_length_iterables.keys()],
[
v.to_dataset() if isinstance(v, DataTree) else v
for v in all_node_args[n_args:]
],
)
)
# Now we can call func on the data in this particular set of corresponding nodes
results = (
func_(*node_args_as_datasets, **node_kwargs_as_datasets)
if not node_of_first_tree.is_empty
else None
)
# TODO implement mapping over multiple trees in-place using if conditions from here on?
out_data_objects[node_of_first_tree.path] = results
if parallel:
keys, values = dask.compute(
[k for k in out_data_objects.keys()],
[v for v in out_data_objects.values()],
)
out_data_objects = {k: v for k, v in zip(keys, values)}
# Find out how many return values we received
num_return_values = _check_all_return_values(out_data_objects)
# Reconstruct 1+ subtrees from the dict of results, by filling in all nodes of all result trees
original_root_path = first_tree.path
result_trees = []
for i in range(num_return_values):
out_tree_contents = {}
for n in first_tree.subtree:
p = n.path
if p in out_data_objects.keys():
if isinstance(out_data_objects[p], tuple):
output_node_data = out_data_objects[p][i]
else:
output_node_data = out_data_objects[p]
else:
output_node_data = None
# Discard parentage so that new trees don't include parents of input nodes
relative_path = str(
NodePath(p).relative_to(original_root_path)
)
relative_path = "/" if relative_path == "." else relative_path
out_tree_contents[relative_path] = output_node_data
new_tree = DataTree.from_dict(
out_tree_contents,
name=first_tree.name,
)
result_trees.append(new_tree)
# If only one result then don't wrap it in a tuple
if len(result_trees) == 1:
return result_trees[0]
else:
return tuple(result_trees)
return _map_over_subtree
I'm a little unsure how to get the parallel-argument down to map_over_subtree
though?