zarrdataset
zarrdataset copied to clipboard
Combined usage of `MasksDatasetSpecs` and `LabelsDatasetSpecs` slow?
Hi @fercer , me again ^^"
I've noticed that if I am adding a MasksDatasetSpecs to my Dataset, retrieving the patches becomes awefully slow.
This is how am I am building the dataset:
image_specs = zds.ImagesDatasetSpecs(
filenames=filenames,
data_group="0",
source_axes="CZYX",
padding_mode='reflect'
)
labels_specs = zds.LabelsDatasetSpecs(
filenames=filenames,
data_group="labels/annotations/0",
source_axes="ZYX",
padding_mode='reflect',
)
masks_specs = zds.MasksDatasetSpecs(
filenames=filenames,
data_group="masks/annotations/4",
source_axes="ZYX",
padding_mode='reflect',
)
ds = zds.ZarrDataset([image_specs, labels_specs, masks_specs],
patch_sampler=patch_sampler,
shuffle=True, return_positions=False,
return_worker_id=False)
where masks_specs´ is a pre-computed segmentation of the data to avoid drawing patches from dark regions too often, which I suppose is what the MasksDatasetSpecs` is designed for. I noticed that by adding the mask spec to the dataset, drawing the patches becomes notably slower. Without the mask spec I am looking at <10s for 460 tiles. With the Mask spec I get almost 2minutes - I wonder where this comes from?
I don't know the inner workings so much and I may just be using the Mask spec in the wrong way - in that case a pointer in the right direction would be much appreciated :)
Hi again @jo-mueller and thanks for raising this issue!
The MaskDatasetSpecs are supposed to be at a lower resolution than your inputs and labels.
Of course I don't think that is already anywhere in the documentation, so is very easy to overlook.
Internally, the PatchSampler uses all non-zero coordinates in the mask array to determine the samplable positions.
But if the mask is close to the resolution of the input image, that operation becomes time consuming computationally.
Could it be the case that your mask is of the same resolution of your input image and labels (1 pixel image = 1 pixel mask)?
Hi @fercer ,
thanks for the quick reply. It is stated for the WSITissueMaskGenerator that because the input image (zarr group “1”) is large, computing the mask directly on that could require high computational resources.. I wasn't sure how that translates to the PatchSampler though.
In my case above, I thought I was already taking the 4-th pyramid level rather than the raw resolution:
masks_specs = zds.MasksDatasetSpecs(
filenames=filenames,
data_group="masks/annotations/4", <---
source_axes="ZYX",
padding_mode='reflect',
)
or am I maybe specifying it incorrectly?
Yes @jo-mueller , you are actually specifying the data group for the mask correctly.
I have an idea of what could cause the time to increase when specifying a mask vs no mask specified.
Basically, the PatchSampler loads the full mask (mask[:]) on every iteration to determine the samplable patches.
That is the reason why smaller masks were preferred.
But because I only tested with 2D images, I didn't considered any extended tests with 3D data, for which that operation is not that efficient even with relatively smaller mask.
I'll fix the way the mask is used when computing the samplable patches to reduce the amount of data loaded from the mask zarr.
In the meantime, you could try using the argument draw_same_chunk=True to reduce the amount of calls to the patch computation method of the PatchSampler.
ds = zds.ZarrDataset([image_specs, labels_specs, masks_specs],
patch_sampler=patch_sampler,
shuffle=True, return_positions=False,
draw_same_chunk=True, # <----
return_worker_id=False)
Please let me know if you try this approach and notice any reduction of the sampling time.