transcript assignment fails to match nuclei boundaries
Hi Segger team,
Thanks for the amazing tool! I'm using Segger on my own dataset and followed the standard training pipeline.
from pathlib import Path
from torch_geometric.loader import DataLoader
from torch_geometric.nn import to_hetero
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CSVLogger
from segger.data.utils import SpatialTranscriptomicsDataset
# ==== CONFIGURATION ====
train_dir = Path("resolve_reanalysis_2025/segger_dataset/train_tiles")
val_dir = Path("resolve_reanalysis_2025/segger_dataset/val_tiles")
# Model params
num_tx_tokens = 128
init_emb = 8
hidden_channels = 64
out_channels = 1
heads = 4
mid_layers = 2
# Training params
accelerator = "gpu"
strategy = "auto"
precision = "16-mixed"
devices = 1
epochs = 100
batch_size_train = 2
batch_size_val = 2
log_dir = "./models/segger_xenium"
# ==== Load Datasets ====
train_ds = SpatialTranscriptomicsDataset(root=train_dir)
val_ds = SpatialTranscriptomicsDataset(root=val_dir)
# ==== Define Model ====
model = Segger(
num_tx_tokens=num_tx_tokens,
init_emb=init_emb,
hidden_channels=hidden_channels,
out_channels=out_channels,
heads=heads,
num_mid_layers=mid_layers,
)
model = to_hetero(
model,
(["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")]),
)
# Dry run to initialize shapes
_ = model(train_ds[0].x_dict, train_ds[0].edge_index_dict)
# ==== Wrap in Lightning ====
litsegger = LitSegger(model=model, learning_rate=1e-3)
# ==== DataLoaders ====
train_loader = DataLoader(train_ds, batch_size=batch_size_train, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=batch_size_val)
# ==== Logger ====
logger = CSVLogger(log_dir)
# ==== Trainer ====
trainer = Trainer(
accelerator=accelerator,
strategy=strategy,
precision=precision,
devices=devices,
max_epochs=epochs,
default_root_dir=log_dir,
logger=logger,
)
# ==== Train ====
trainer.fit(litsegger, train_loader, val_loader)
However, I’ve noticed something a bit odd with the training dynamics.
Due to the unstable training loss, I suspect the model suffers from poor convergence. As a result, during prediction, it fails to assign a sufficient number of transcripts to nuclei. In the visualization, the red outline represents my Cellpose nuclei segmentation, and the dots are individual transcripts (colorized by segger_cell id). Ideally, Segger should assign transcripts that extend slightly beyond the nuclei boundary, aligning more closely with the Cellpose segmentation. However, the current predictions are too conservative and miss many nearby transcripts.
❓ Questions:
How can I improve the model (e.g., optimizer, hyperparameter)?
Should I consider tuning the receptive field (k_tx, dist_tx, k_bd, dist_bd) to capture wider spatial associations?
Hey @xu-ziwei, sorry for the late response, and thanks for your detailed issue. could you let me know what parameters you used for data generation? and how many genes exist in your panel?
As for the first suspicion, I believe using 128 tokens might be a bit low, because Xenium panels usually contain more genes.
hi @EliHei2, Thanks for reply.
Nucleus Boundaries" I use cellpose segmentaion for my nuleus and save to segger_data/nucleus_boundaries.parquet.
Transcripts: My raw transcript data contains x, y, z, and gene (feature name).
I added the following required columns to match Segger-compatible format saved at segger_data/transcripts.parquet :
-
feature_name: from the raw gene column.
-
transcript_id: unique uint64 ID for each row.
-
qv: constant quality value (set to 30.0).
-
x_location, y_location, z_location: float32 positions from the original coordinates.
-
cell_id: derived using pixel overlap with the Cellpose mask (formatted as cell_
-
overlaps_nucleus: binary column (1 if the transcript overlaps a nucleus).
then I generate the data by
from segger.data.parquet.sample import STSampleParquet
from pathlib import Path
# Set up your dataset path
my_data_dir = Path("resolve_reanalysis_2025/segger_data")
# Use xenium_v2 if you don’t have 'overlaps_nucleus' column
sample = STSampleParquet(
base_dir=my_data_dir,
n_workers=4,
sample_type="xenium", # ← adjust to match your schema
)
sample.save(
data_dir=Path("resolve_reanalysis_2025/segger_dataset"),
k_bd=3, # Number of boundary points to connect
dist_bd=15, # Maximum distance for boundary connections
k_tx=5, # Use calculated optimal transcript neighbors
dist_tx=5, # Use calculated optimal search radius
tile_width=1000, # Tile size for processing
tile_height=1000,
neg_sampling_ratio=0.6,
frac=1.0, # Use all data
val_prob=0.3, # 30% validation set
test_prob=0.1, # No test set
)
Then I have check the train_tiles/processed folder. the graph construction looks fine.
btw I have 99 genes exist in my pannel