xarray
xarray copied to clipboard
Add a filter option to stack
Is your feature request related to a problem?
I currently have a dataset where one of my dimensions (let's call it x) is of size 10^5. Later in my analysis, I want to consider pairs of values of that dimension but not all of them. Note that considering all of them would lead to 10^10 entries (not viable memory usage), when in practice I only want to consider around 10^6 of them.
Therefore, the final dataset should have a dimension x_pair which is the stacking of dimensions x_1 and x_2.
However, it seems I have no straightforward way of using stack for that purpose: whatever I do it will create a 10^8 array that I can then filter it using where(drop=True).
Should this problem be unclear, I could provide a minimal example, but hopefully the explanation of the issue is enough (and my current code is provided as additional context).
Describe the solution you'd like
Have a filter parameter to stack. The filter function should take a dataset and return the set of elements that should appear in the final multiindex.
Describe alternatives you've considered
Currently, I have solved my problem by dividing the dataset into many smaller datasets, stacking and filtering each of these datasets separately and then merging the filtered datasets together.
Note: the stacking time without any parallelization of all the smaller datasets still feels very along (almost 2h). I do not know whether this is sensible.
Additional context
Currently, my code looks like the following and I have three initial dimensions to my dataset, Contact, sig_preprocessing, f. Both Contact and sig_preprocessing should be transformed into pairs.
signal_pairs = xr.merge([
signals.rename(**{x:f"{x}_1" for x in signals.coords if not x=="f"}, **{x:f"{x}_1" for x in signals.data_vars}),
signals.rename(**{x:f"{x}_2" for x in signals.coords if not x=="f"}, **{x:f"{x}_2" for x in signals.data_vars})
])
def stack_dataset(dataset):
dataset=dataset.copy()
dataset["common_duration"] = xr.where(dataset["start_time_1"] > dataset["start_time_2"],
xr.where(dataset["end_time_1"] > dataset["end_time_2"],
dataset["end_time_2"]- dataset["start_time_1"],
dataset["end_time_1"]- dataset["start_time_1"]
),
xr.where(dataset["end_time_1"] > dataset["end_time_2"],
dataset["end_time_2"]- dataset["start_time_2"],
dataset["end_time_1"]- dataset["start_time_2"]
)
)
dataset["relevant_pair"] = (
(dataset["Session_1"] == dataset["Session_2"])
& (dataset["Contact_1"] != dataset["Contact_2"])
& (dataset["Structure_1"] == dataset["Structure_2"])
& (dataset["sig_type_1"] =="bua")
& (dataset["sig_type_2"] =="spike_times")
& (~dataset["resampled_continuous_path_1"].isnull())
& (~dataset["resampled_continuous_path_2"].isnull())
& (dataset["common_duration"] >10)
)
dataset=dataset.stack(sig_preprocessing_pair=("sig_preprocessing_1","sig_preprocessing_2"), Contact_pair=("Contact_1", "Contact_2"))
dataset = dataset.where(dataset["relevant_pair"].any("sig_preprocessing_pair"), drop=True)
dataset = dataset.where(dataset["relevant_pair"].any("Contact_pair"), drop=True)
return dataset
stack_size = 100
signal_pairs_split = [signal_pairs.isel(dict(Contact_1=slice(stack_size*i, stack_size*(i+1)), Contact_2=slice(stack_size*j, stack_size*(j+1))))
for i in range(int(np.ceil(signal_pairs.sizes["Contact_1"]/stack_size)))
for j in range(int(np.ceil(signal_pairs.sizes["Contact_2"]/stack_size)))
]
import concurrent.futures
with concurrent.futures.ProcessPoolExecutor(max_workers=30) as executor:
futures = [executor.submit(stack_dataset, dataset) for dataset in signal_pairs_split]
signal_pairs_split_stacked = [future.result() for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Stacking")]
signal_pairs = xr.merge(signal_pairs_split_stacked)
What happens if we filter out values with .where prior to stacking — does it still create huge arrays? (An MCVE would help here...)