xarray icon indicating copy to clipboard operation
xarray copied to clipboard

Add a filter option to stack

Open JulienBrn opened this issue 1 year ago • 1 comments

Is your feature request related to a problem?

I currently have a dataset where one of my dimensions (let's call it x) is of size 10^5. Later in my analysis, I want to consider pairs of values of that dimension but not all of them. Note that considering all of them would lead to 10^10 entries (not viable memory usage), when in practice I only want to consider around 10^6 of them.

Therefore, the final dataset should have a dimension x_pair which is the stacking of dimensions x_1 and x_2.

However, it seems I have no straightforward way of using stack for that purpose: whatever I do it will create a 10^8 array that I can then filter it using where(drop=True).

Should this problem be unclear, I could provide a minimal example, but hopefully the explanation of the issue is enough (and my current code is provided as additional context).

Describe the solution you'd like

Have a filter parameter to stack. The filter function should take a dataset and return the set of elements that should appear in the final multiindex.

Describe alternatives you've considered

Currently, I have solved my problem by dividing the dataset into many smaller datasets, stacking and filtering each of these datasets separately and then merging the filtered datasets together.

Note: the stacking time without any parallelization of all the smaller datasets still feels very along (almost 2h). I do not know whether this is sensible.

Additional context

Currently, my code looks like the following and I have three initial dimensions to my dataset, Contact, sig_preprocessing, f. Both Contact and sig_preprocessing should be transformed into pairs.

signal_pairs = xr.merge([
  signals.rename(**{x:f"{x}_1" for x in signals.coords if not x=="f"}, **{x:f"{x}_1" for x in signals.data_vars}), 
  signals.rename(**{x:f"{x}_2" for x in signals.coords  if not x=="f"}, **{x:f"{x}_2" for x in signals.data_vars})
])

def stack_dataset(dataset):
    dataset=dataset.copy()
    dataset["common_duration"] = xr.where(dataset["start_time_1"] > dataset["start_time_2"], 
             xr.where(dataset["end_time_1"] > dataset["end_time_2"],
                      dataset["end_time_2"]- dataset["start_time_1"], 
                      dataset["end_time_1"]- dataset["start_time_1"]
             ),
             xr.where(dataset["end_time_1"] > dataset["end_time_2"],
                      dataset["end_time_2"]- dataset["start_time_2"], 
                      dataset["end_time_1"]- dataset["start_time_2"]
             )
    )
    dataset["relevant_pair"] = (
        (dataset["Session_1"] == dataset["Session_2"])
        & (dataset["Contact_1"] != dataset["Contact_2"]) 
        & (dataset["Structure_1"] == dataset["Structure_2"])
        & (dataset["sig_type_1"] =="bua")
        & (dataset["sig_type_2"] =="spike_times")
        & (~dataset["resampled_continuous_path_1"].isnull())
        & (~dataset["resampled_continuous_path_2"].isnull())
        &  (dataset["common_duration"] >10)
    )
    dataset=dataset.stack(sig_preprocessing_pair=("sig_preprocessing_1","sig_preprocessing_2"), Contact_pair=("Contact_1", "Contact_2"))
    dataset = dataset.where(dataset["relevant_pair"].any("sig_preprocessing_pair"), drop=True)
    dataset = dataset.where(dataset["relevant_pair"].any("Contact_pair"), drop=True)
    return dataset


stack_size = 100
signal_pairs_split = [signal_pairs.isel(dict(Contact_1=slice(stack_size*i, stack_size*(i+1)), Contact_2=slice(stack_size*j, stack_size*(j+1)))) 
                    for i in range(int(np.ceil(signal_pairs.sizes["Contact_1"]/stack_size)))
                    for j in range(int(np.ceil(signal_pairs.sizes["Contact_2"]/stack_size)))
]
import concurrent.futures
with concurrent.futures.ProcessPoolExecutor(max_workers=30) as executor:
    futures = [executor.submit(stack_dataset, dataset) for dataset in signal_pairs_split]
    signal_pairs_split_stacked = [future.result() for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Stacking")]
signal_pairs = xr.merge(signal_pairs_split_stacked)

JulienBrn avatar Jan 31 '24 12:01 JulienBrn

What happens if we filter out values with .where prior to stacking — does it still create huge arrays? (An MCVE would help here...)

max-sixty avatar Jan 31 '24 18:01 max-sixty