sgkit
sgkit copied to clipboard
Add documentation on how to persist computations in a workflow
Because our library may define many precursor variables for any one calculation, it will become crucial for users to be able to persist/cache some of those variables. This means that instead of letting the variable be defined automatically as a dask array, they would have to define it first themselves, call .compute() or .persist(), and then call the original function with that precursor variable already defined. There should be a section in the docs on this.
One pattern that may be useful is to compute new variables in a new dataset (via merge=False), save only the new variables, then merge the datasets. This avoids having to re-persist the whole dataset to disk - only the new arrays are written.
ds = ... # original dataset
cac = sg.count_cohort_alleles(ds, merge=False) # new variable
cac.to_zarr(...) # compute and persist to disk
# later
ds = xr.open_zarr(...) # reload from disk
cac = xr.open_zarr(...) # reload from disk
ds = xr.merge([ds, cac]) # merge into one dataset
I think this could be used in conjunction with #298.
I've been trying out a checkpoint function that I think is a better way of doing what I was trying to achieve in my previous comment.
ds = ... # original dataset
ds = sg.count_cohort_alleles(ds) # new variable
with ProgressBar():
# compute and persist to disk
ds = sg.checkpoint_dataset(ds, path, [sg.variables.cohort_allele_count])
The way this works uses Xarray's ability to append a subset of variables to an existing Zarr store. This avoids having to re-write existing variables, and keeps the whole dataset together with no merging required. It's also possible to use the same function to overwrite existing variables. (The load_dataset and save_dataset functions are from #298/#392.)
def checkpoint_dataset(
ds: Dataset,
path: PathType,
data_vars: Optional[Sequence[Hashable]] = None
) -> Dataset:
if data_vars is not None:
ds = ds.drop_vars(set(ds.data_vars) - set(data_vars))
save_dataset(ds, path, mode="a")
return load_dataset(path)
I used a conditional form of this pattern in a notebook, so that it can be rerun without having to recompute variables:
ds = ... # original dataset
if not sg.variables.cohort_allele_count in ds:
ds = sg.count_cohort_alleles(ds) # new variable
with ProgressBar():
# compute and persist to disk
ds = sg.checkpoint_dataset(ds, path, [sg.variables.cohort_allele_count])
Some more ideas from @eric-czech and @alimanfoo:
- Have option for checkpoint to force an overwrite (Hail does this)
- Rather than checking for the existence of the variable in the dataset, check for a flag that shows the computation completed successfully.
What about update_dataset that automatically detects the new variables? Like (untested):
def update_dataset(
ds: Dataset,
path: PathType,
) -> Dataset:
saved_ds = load_dataset(path)
ds = ds.drop_vars(set(saved_ds.data_vars) - set(data_vars))
save_dataset(ds, path, mode="a")
return load_dataset(path)
Or, maybe better, add a keyword arg update=False to save_dataset that does this?
I find "checkpoint" slightly odd because it implies that you're part-way through a calculation, but really you're saving complete variables. "append" feels like the wrong idea, because that implies to me that you're adding more data to some dimension.