litdata icon indicating copy to clipboard operation
litdata copied to clipboard

Stream selected channels

Open robmarkcole opened this issue 1 year ago • 1 comments

🚀 Feature

Streaming subsets of channels

Motivation

My geotiff data is typically multispectral and I do experiments using subsets of the channels. I would like to stream only the required channels in order to save bandwidth

Pitch

Select e.g. channels 1,3,5 to stream

Alternatives

I can list the channels as separate files, and then access only those I require

Additional context

The equivalent using Rasterio:

import rasterio

# URL to an S3 bucket raster file
s3_url = 's3://your-bucket-name/path-to-your-raster-file.tif'

# Open the raster file
with rasterio.open(s3_url) as src:
    # Read a specific band, for example, band 1
    band1 = src.read(1)  # Reading only the first band

    # You can also read multiple specific bands by passing a tuple
    band1, band3 = src.read((1, 3))  # Reading bands 1 and 3

    # Process or analyze the bands as needed
    print(band1, band3)

My current solution:

class SegmentationStreamingDataset(StreamingDataset):
    """
    Segmentation dataset with streaming.

    Args:
        input_dir (str): Local directory or S3 location of the dataset
        transforms (Optional[Callable]): A transform that takes in an image and returns a transformed version.
        band_indices (Optional[List[int]]): List of band indices to read from the dataset.
    """

    def __init__(self, *args, transforms: Optional[Callable] = None, band_indices: Optional[List[int]] = None, **kwargs):
        super().__init__(*args, **kwargs)
        self.transforms = transforms
        self.band_indices = band_indices

    def __getitem__(self, index) -> dict:
        data = super().__getitem__(index)
        image_name = data["name"]
        image = data["image"]
        mask = data["mask"]

        with MemoryFile(image) as memfile:
            with memfile.open() as dataset:
                image = torch.from_numpy(dataset.read()).float()
                if self.band_indices:
                    image = image[self.band_indices]

        with MemoryFile(mask) as memfile:
            with memfile.open() as dataset:
                mask = torch.from_numpy(dataset.read()).long()                    

        sample = {"image": image, "mask": mask, "image_name": image_name}
        if self.transforms is not None:
            sample = self.transforms(sample)
        return sample

robmarkcole avatar May 13 '24 12:05 robmarkcole

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

stale[bot] avatar Apr 16 '25 06:04 stale[bot]