segger_dev icon indicating copy to clipboard operation
segger_dev copied to clipboard

transcript assignment fails to match nuclei boundaries

Open xu-ziwei opened this issue 8 months ago • 2 comments

Hi Segger team,

Thanks for the amazing tool! I'm using Segger on my own dataset and followed the standard training pipeline.

from pathlib import Path
from torch_geometric.loader import DataLoader
from torch_geometric.nn import to_hetero
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CSVLogger

from segger.data.utils import SpatialTranscriptomicsDataset


# ==== CONFIGURATION ====
train_dir = Path("resolve_reanalysis_2025/segger_dataset/train_tiles")
val_dir = Path("resolve_reanalysis_2025/segger_dataset/val_tiles")

# Model params
num_tx_tokens = 128
init_emb = 8
hidden_channels = 64
out_channels = 1
heads = 4
mid_layers = 2


# Training params
accelerator = "gpu"    
strategy = "auto"
precision = "16-mixed"
devices = 1
epochs = 100
batch_size_train = 2
batch_size_val = 2
log_dir = "./models/segger_xenium"

# ==== Load Datasets ====
train_ds = SpatialTranscriptomicsDataset(root=train_dir)
val_ds = SpatialTranscriptomicsDataset(root=val_dir)

# ==== Define Model ====
model = Segger(
    num_tx_tokens=num_tx_tokens,
    init_emb=init_emb,
    hidden_channels=hidden_channels,
    out_channels=out_channels,
    heads=heads,
    num_mid_layers=mid_layers,
)

model = to_hetero(
    model,
    (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")]),

)

# Dry run to initialize shapes
_ = model(train_ds[0].x_dict, train_ds[0].edge_index_dict)

# ==== Wrap in Lightning ====
litsegger = LitSegger(model=model, learning_rate=1e-3)

# ==== DataLoaders ====
train_loader = DataLoader(train_ds, batch_size=batch_size_train, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=batch_size_val)

# ==== Logger ====
logger = CSVLogger(log_dir)

# ==== Trainer ====
trainer = Trainer(
    accelerator=accelerator,
    strategy=strategy,
    precision=precision,
    devices=devices,
    max_epochs=epochs,
    default_root_dir=log_dir,
    logger=logger,
)

# ==== Train ====
trainer.fit(litsegger, train_loader, val_loader)

However, I’ve noticed something a bit odd with the training dynamics.

Image

Due to the unstable training loss, I suspect the model suffers from poor convergence. As a result, during prediction, it fails to assign a sufficient number of transcripts to nuclei. In the visualization, the red outline represents my Cellpose nuclei segmentation, and the dots are individual transcripts (colorized by segger_cell id). Ideally, Segger should assign transcripts that extend slightly beyond the nuclei boundary, aligning more closely with the Cellpose segmentation. However, the current predictions are too conservative and miss many nearby transcripts.

Image

❓ Questions:

How can I improve the model (e.g., optimizer, hyperparameter)?

Should I consider tuning the receptive field (k_tx, dist_tx, k_bd, dist_bd) to capture wider spatial associations?

xu-ziwei avatar May 12 '25 09:05 xu-ziwei

Hey @xu-ziwei, sorry for the late response, and thanks for your detailed issue. could you let me know what parameters you used for data generation? and how many genes exist in your panel?

As for the first suspicion, I believe using 128 tokens might be a bit low, because Xenium panels usually contain more genes.

EliHei2 avatar May 19 '25 08:05 EliHei2

hi @EliHei2, Thanks for reply.

Nucleus Boundaries" I use cellpose segmentaion for my nuleus and save to segger_data/nucleus_boundaries.parquet.

Transcripts: My raw transcript data contains x, y, z, and gene (feature name). I added the following required columns to match Segger-compatible format saved at segger_data/transcripts.parquet :

  • feature_name: from the raw gene column.

  • transcript_id: unique uint64 ID for each row.

  • qv: constant quality value (set to 30.0).

  • x_location, y_location, z_location: float32 positions from the original coordinates.

  • cell_id: derived using pixel overlap with the Cellpose mask (formatted as cell_

  • overlaps_nucleus: binary column (1 if the transcript overlaps a nucleus).

then I generate the data by

from segger.data.parquet.sample import STSampleParquet
from pathlib import Path

# Set up your dataset path
my_data_dir = Path("resolve_reanalysis_2025/segger_data")

# Use xenium_v2 if you don’t have 'overlaps_nucleus' column
sample = STSampleParquet(
    base_dir=my_data_dir,
    n_workers=4,
    sample_type="xenium",  # ← adjust to match your schema
)


sample.save(
    data_dir=Path("resolve_reanalysis_2025/segger_dataset"),
    k_bd=3,  # Number of boundary points to connect
    dist_bd=15,  # Maximum distance for boundary connections
    k_tx=5,  # Use calculated optimal transcript neighbors
    dist_tx=5,  # Use calculated optimal search radius
    tile_width=1000,  # Tile size for processing
    tile_height=1000,
    neg_sampling_ratio=0.6,  
    frac=1.0,  # Use all data
    val_prob=0.3,  # 30% validation set
    test_prob=0.1,  # No test set
)

Then I have check the train_tiles/processed folder. the graph construction looks fine. Image

Image

Image

btw I have 99 genes exist in my pannel

xu-ziwei avatar May 26 '25 13:05 xu-ziwei