Contrastive Learning Implementation
My understanding of the current classifier and task with 60X res data can be found here: https://docs.google.com/document/d/1j3UePmDJL_1V_9j7v3I4nLgAgKuFXXuqmlW8ZFOyTk0/edit?usp=sharing.
Given this approach, we'd need to modify HCSDataModule to support triplet sampling. Specifically:
- We can apply transformations like rotation, cropping, and color jitter to create different views of the same cell.
- Generate an anchor and a positive sample using different augmentations of the same cell image.
- Select a different cell with a different label (infected vs. uninfected) as the negative sample.
The goal of triplet sampling is to minimize the distance between the anchor and the positive while maximizing the distance between the anchor and the negative in the learned embedding space.
# takes a base_transform and applies it to a sample to generate anchor and positive samples.
# When the __call__ method is invoked with a sample, it applies the base_transform to the sample twice: first to create the anchor and second to create the positive.
class TripletTransform:
def __init__(self, transform):
self.transform = transform
def __call__(self, sample):
anchor = self.transform(sample)
positive = self.transform(sample)
return anchor, positive
# The TripletDataset class is initialized with the dataset and a transform function. When the __getitem__ method is called with an index (idx):
# Anchor and Positive: The same data sample is retrieved for both the anchor and positive.
# Negative Sampling: A different sample is randomly selected as the negative.
# If a transform is provided:
# The TripletTransform is used to apply the base_transform to both the anchor and positive samples, creating augmented versions.
# The base_transform is applied directly to the negative sample to create its augmented version (if wanted).
class TripletDataset(Dataset):
def __init__(self, data, transform=None):
self.data = data
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
anchor = self.data[idx]
positive = self.data[idx]
# simple negative sampling
negative_idx = ...
negative = self.data[negative_idx]
if self.transform:
anchor = self.transform(anchor)
positive = self.transform(positive)
negative = self.transform(negative)
return (anchor, positive, negative)
Here the TripletTransform class takes a base transformation (defined in base_transform) and applies it to create the anchor and positive samples.
Modify HCSDataModule:
class TripletHCSDataModule(HCSDataModule):
def __init__(
self,
data_path: str,
source_channel: Union[str, Sequence[str]],
target_channel: Union[str, Sequence[str]],
z_window_size: int,
split_ratio: float = 0.8,
batch_size: int = 16,
num_workers: int = 8,
architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D", "fcmae"] = "2.5D",
yx_patch_size: tuple[int, int] = (256, 256),
normalizations: list[MapTransform] = [],
augmentations: list[MapTransform] = [],
caching: bool = False,
ground_truth_masks: Optional[Path] = None,
):
super().__init__(
data_path,
source_channel,
target_channel,
z_window_size,
split_ratio,
batch_size,
num_workers,
architecture,
yx_patch_size,
normalizations,
augmentations,
caching,
ground_truth_masks
)
self.triplet_transform = TripletTransform(transforms.Compose(normalizations + augmentations))
#update to use TripletDataset
def setup(self, stage: Optional[str] = None):
super().setup(stage)
if stage in ("fit", "validate"):
self.train_dataset = TripletDataset(self.train_dataset.data, transform=self.triplet_transform)
self.val_dataset = TripletDataset(self.val_dataset.data, transform=self.triplet_transform)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size // 3, # adjust batch size for triplets
num_workers=self.num_workers,
shuffle=True,
persistent_workers=bool(self.num_workers),
prefetch_factor=4 if self.num_workers else None,
drop_last=True,
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.batch_size // 3, # adjust batch size for triplets
num_workers=self.num_workers,
shuffle=False,
prefetch_factor=4 if self.num_workers else None,
persistent_workers=bool(self.num_workers),
)
# example of what could be included in the augmentations list
base_transform = transforms.Compose([
transforms.RandomApply([transforms.ColorJitter(0.2, 0.2, 0.2, 0.2)], p=0.5),
transforms.RandomRotation(10),
transforms.RandomResizedCrop(128, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ToTensor()
])
Using this updated dataloader:
data_module = TripletHCSDataModule(
dataset_path="...",
source_channel=["Phase", "Sensor"],
target_channel=["Inf_mask"],
yx_patch_size=[128, 128],
split_ratio=0.8,
z_window_size=1,
architecture="2D",
num_workers=4,
batch_size=64,
normalizations=[
NormalizeSampled(
keys=["Sensor", "Phase"],
level="fov_statistics",
subtrahend="median",
divisor="iqr",
)
],
augmentations=[
RandWeightedCropd(
num_samples=8,
spatial_size=[-1, 128, 128],
keys=["Sensor", "Phase", "Inf_mask"],
w_key="Inf_mask",
)
]
)
Model details:
- Use an encoder for embeddings. The contrastive learning model uses this encoder to generate embeddings and compute the triplet loss. Different losses: Triplet Margin Loss, AllTripletMiner, NTXent.
- Input: The input to the model is the same (e.g., phase and sensor data).
- Output: The model outputs embeddings.
- Loss Function: Triplet loss is used to train the model to minimize the distance between embeddings of similar samples and maximize the distance between embeddings of dissimilar samples.
- Validation: The validation process compares the embeddings using the triplet loss, ensuring that the model learns useful representations of the cells.
Other ideas: try simclr vs triplet sampling
- SimCLR: generates positive pairs by applying different augmentations to the same sample. Negative samples are implicitly created from other samples in the same batch.
- Triplet Sampling: explicitly forms triplets consisting of an anchor, a positive, and a negative.
Good resource: https://lilianweng.github.io/posts/2021-05-31-contrastive/
This code implements triplet contrastive learning training with the following set up:
- sample anchor to be a random well, a random fov and random cell. Load in that data using open_ome_zarr.
- positive pair will be that same image certain transformations applied
- negative pair: also randomly selected as long as it's not the same as the anchor image
Updated dataloader code:
import os
import random
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from iohub import open_ome_zarr
from monai.transforms import Compose, RandAdjustContrastd, RandAffined, RandGaussianNoised
import pytorch_lightning as pl
class TripletCellDataset(Dataset):
def __init__(self, base_path, z_slices=None, transforms=None):
self.base_path = base_path
self.transforms = transforms
self.positions = self._get_positions()
self.z_slices = z_slices if z_slices else slice(None)
# list of path to every single cell in the dataset
def _get_positions(self):
ds = open_ome_zarr(self.base_path, layout="hcs", mode="r")
return [path for path, _ in ds.positions()]
def _load_data(self, position):
ds = open_ome_zarr(self.base_path, layout="hcs", mode="r")
return ds[position]['0'][:, :, self.z_slices, :, :]
def __len__(self):
return len(self.positions)
def __getitem__(self, idx):
anchor_position = self.positions[idx]
anchor_data = self._load_data(anchor_position)
positive_data = anchor_data.copy()
if self.transforms:
positive_data = self.transforms({'image': positive_data})['image']
negative_position = random.choice(self.positions)
while negative_position == anchor_position:
negative_position = random.choice(self.positions)
negative_data = self._load_data(negative_position)
# three tensors are returned. Each tensor has a shape of (48, 2, selected_z_slices, 200, 200).
return torch.tensor(anchor_data, dtype=torch.float32), torch.tensor(positive_data, dtype=torch.float32), torch.tensor(negative_data, dtype=torch.float32)
class TripletCellDataModule(pl.LightningDataModule):
def __init__(self, base_path, z_slices=None, batch_size=8, num_workers=4):
super().__init__()
self.base_path = base_path
self.z_slices = z_slices
self.batch_size = batch_size
self.num_workers = num_workers
# other transformations etc can be set here
self.transforms = Compose([
RandAdjustContrastd(keys=["image"], prob=0.5, gamma=(0.5, 1.5)),
RandAffined(keys=["image"], prob=0.5, rotate_range=(0.1, 0.1), scale_range=(0.1, 0.1)),
RandGaussianNoised(keys=["image"], prob=0.5, mean=0.0, std=0.1)
])
def setup(self, stage=None):
self.train_dataset = TripletCellDataset(self.base_path, z_slices=self.z_slices, transforms=self.transforms)
self.val_dataset = TripletCellDataset(self.base_path, z_slices=self.z_slices)
self.test_dataset = TripletCellDataset(self.base_path, z_slices=self.z_slices)
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
getitem returns three tensors (anchor, positive, negative), each tensor has a shape of (48, 2, selected_z_slices, 200, 200).
Training code:
from pytorch_lightning import Trainer, LightningModule
import torch.nn.functional as F
import torch
from torchvision import models
class TripletNet(LightningModule):
def __init__(self):
super(TripletNet, self).__init__()
self.resnet = models.resnet18(pretrained=True) # can change this to resnet we use
self.resnet.fc = torch.nn.Linear(self.resnet.fc.in_features, 128) #can change output of embedding shape here2
self.triplet_loss = torch.nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-6, swap=False, reduction='mean')
def forward(self, x):
return self.resnet(x)
def training_step(self, batch, batch_idx):
anchor, positive, negative = batch
anchor_embed = self(anchor)
positive_embed = self(positive)
negative_embed = self(negative)
loss = self.triplet_loss(anchor_embed, positive_embed, negative_embed)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
base_path = '/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/6-patches/patch_final.zarr'
z_slices = slice(10, 20)
datamodule = TripletCellDataModule(base_path, z_slices=z_slices)
model = TripletNet()
trainer = Trainer(max_epochs=10)
trainer.fit(model, datamodule)
training_step:
- batch: A tuple containing three elements: anchor, positive, and negative tensors, each with shape (48, 2, z_slices, 200, 200).
- batch_idx: The index of the current batch.
- Pass through ResNet to get embeddings.
- Compute TripletMarginLoss
A few thoughts:
- Currently, we have three tensors of shape (batch_size, 48, 2, selected_z_slices, 200, 200) being passed in to the resnet. Instead resnet will expect tensors of shape (batch_size, channels, height, width) where channels are usually 3 for RGB images esp if we used a pre-trained resnet on ImageNet. I'd look into how we could accommodate for this when setting up the resnet (maybe re-sizing etc, unless there's other ideas). Essentially, the batch contains multiple triplets of tensors: each triplet consists of an anchor, a positive, and a negative example. Each of these should be a tensor shaped like (batch_size, channels, height, width).
- Some of the timesteps are empty and each cell has a different time for when it stops / ends. Right now, the dataloader is taking in the entire 48 timesteps. Is this fine for the model given that some of the timesteps for the data will be empty?