datatree
datatree copied to clipboard
API for filtering / subsetting
So far we've only really implemented dictionary-like get/setitem
syntax, but we should add a variety of other ways to select nodes from a tree too. Here are some suggestions:
class DataTree:
...
def __getitem__(self, key: str) -> DataTree | DataArray:
"""
Accepts node/variable names, or file-like paths to nodes/variables (inc. '../var').
(Also needs to accommodate indexing somehow.)
"""
...
def subset(self, keys: Sequence[str]) -> DataTree:
"""
Return new tree containing only nodes with names matching keys.
(Could probably be combined with `__getitem__`.
Also unsure what the return type should be.)
"""
...
@property
def subtree(self) -> Iterator[DataTree]:
"""An iterator over all nodes in this tree, including both self and all descendants."""
...
def filter(self, filterfunc: Callable) -> Iterator[DataTree]:
"""Filters subtree by returning only nodes for which `filterfunc(node)` is True."""
...
Are there other types of access that we're missing here? Filtering by regex match? Getting nodes where at least one part of the path matches ("tag-like" access)? Glob?
@oriolabril would these types of functions be sufficient for ArViz's usecases you think? From https://github.com/arviz-devs/arviz/issues/2015#issuecomment-1106957255:
dt[["posterior", "posterior_predictive"]] is not possible
getting a subset of the datatree that consists of multiple groups
This is what I'm suggesting subset
do, or __getitem__
.
applying a function to the variable x that is present in 3 out of 5 groups of the datatree.
I'm imagining enabling that via
dt.filter(lambda node: 'x' in node.variables).map_over_subtree(func)
Or we could potentially add an optional filterfunc
argument to map_over_subtree
.
One possible point of (approximate) alignment with Xarray API is this issue: https://github.com/pydata/xarray/issues/3894 for selecting using an iterable of variable names. This seems analogous to selecting nodes using subset
I had not seen that issue, thanks @dcherian
I think that would cover everything, but I'll try to think of examples so that we can also have things to test on.
We could also provide functions in datatree/xarray/arviz to act as filterfunc
for common cases. My main question when thinking about using filter
is storing the results back. I guess a merge would do it? With some renaming happening in the process maybe. It will probably be best to discuss with some examples.
My main question when thinking about using filter is storing the results back.
Yes that's the tricky bit, because if you want to return a tree then you might need to retain nodes for which filterfunc(node)=False
in order to still have a valid tree structure afterwards...
For example:
def name_is_lowercase(node)
return node.name == node.name.lower()
root = DataTree("a")
child = DataTree(parent=root, name="B")
grandchild = DataTree(parent=child, name="c")
root.filter(name_is_lowercase)
This would return nodes "a" and "c", but it couldn't automatically reconstruct them into a tree without also preserving node "B".
If .filter
just returned an iterator of nodes then you wouldn't need to be able to rebuild a tree, but this might not be most convenient for the user. This is why I would like to build these functions with some desired usage patterns in mind.