celldetection
celldetection copied to clipboard
Scalable Instance Segmentation using PyTorch & PyTorch Lightning.
Cell Detection
⭐ Showcase
NeurIPS 22 Cell Segmentation Competition
https://openreview.net/forum?id=YtgRjBw-7GJ
Nuclei of U2OS cells in a chemical screen
https://bbbc.broadinstitute.org/BBBC039 (CC0)
P. vivax (malaria) infected human blood
https://bbbc.broadinstitute.org/BBBC041 (CC BY-NC-SA 3.0)
🛠 Install
Make sure you have PyTorch installed.
PyPI
pip install -U celldetection
GitHub
pip install git+https://github.com/FZJ-INM1-BDA/celldetection.git
💾 Trained models
model = cd.fetch_model(model_name, check_hash=True)
model name | training data | link |
---|---|---|
ginoro_CpnResNeXt101UNet-fbe875f1a3e5ce2c |
BBBC039, BBBC038, Omnipose, Cellpose, Sartorius - Cell Instance Segmentation, Livecell, NeurIPS 22 CellSeg Challenge | 🔗 |
Run a demo with a pretrained model
import torch, cv2, celldetection as cd
from skimage.data import coins
from matplotlib import pyplot as plt
# Load pretrained model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = cd.fetch_model('ginoro_CpnResNeXt101UNet-fbe875f1a3e5ce2c', check_hash=True).to(device)
model.eval()
# Load input
img = coins()
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
print(img.dtype, img.shape, (img.min(), img.max()))
# Run model
with torch.no_grad():
x = cd.to_tensor(img, transpose=True, device=device, dtype=torch.float32)
x = x / 255 # ensure 0..1 range
x = x[None] # add batch dimension: Tensor[3, h, w] -> Tensor[1, 3, h, w]
y = model(x)
# Show results for each batch item
contours = y['contours']
for n in range(len(x)):
cd.imshow_row(x[n], x[n], figsize=(16, 9), titles=('input', 'contours'))
cd.plot_contours(contours[n])
plt.show()
🔬 Architectures
import celldetection as cd
Contour Proposal Networks
-
cd.models.CPN
-
cd.models.CpnU22
-
cd.models.CPNCore
-
cd.models.CpnResUNet
-
cd.models.CpnSlimU22
-
cd.models.CpnWideU22
-
cd.models.CpnResNet18FPN
-
cd.models.CpnResNet34FPN
-
cd.models.CpnResNet50FPN
-
cd.models.CpnResNeXt50FPN
-
cd.models.CpnResNet101FPN
-
cd.models.CpnResNet152FPN
-
cd.models.CpnResNet18UNet
-
cd.models.CpnResNet34UNet
-
cd.models.CpnResNet50UNet
-
cd.models.CpnResNeXt101FPN
-
cd.models.CpnResNeXt152FPN
-
cd.models.CpnResNeXt50UNet
-
cd.models.CpnResNet101UNet
-
cd.models.CpnResNet152UNet
-
cd.models.CpnResNeXt101UNet
-
cd.models.CpnResNeXt152UNet
-
cd.models.CpnWideResNet50FPN
-
cd.models.CpnWideResNet101FPN
-
cd.models.CpnMobileNetV3LargeFPN
-
cd.models.CpnMobileNetV3SmallFPN
PyTorch Image Models (timm)
Also have a look at Timm Documentation.
import timm
timm.list_models(filter='*') # explore available models
Segmentation Models PyTorch (smp)
import segmentation_models_pytorch as smp
smp.encoders.get_encoder_names() # explore available models
encoder = cd.models.SmpEncoder(encoder_name='mit_b5', pretrained='imagenet')
Find a list of Smp Encoders in the smp
documentation.
U-Nets
# U-Nets are available in 2D and 3D
import celldetection as cd
model = cd.models.ResNeXt50UNet(in_channels=3, out_channels=1, nd=3)
-
cd.models.U22
-
cd.models.U17
-
cd.models.U12
-
cd.models.UNet
-
cd.models.WideU22
-
cd.models.SlimU22
-
cd.models.ResUNet
-
cd.models.UNetEncoder
-
cd.models.ResNet50UNet
-
cd.models.ResNet18UNet
-
cd.models.ResNet34UNet
-
cd.models.ResNet152UNet
-
cd.models.ResNet101UNet
-
cd.models.ResNeXt50UNet
-
cd.models.ResNeXt152UNet
-
cd.models.ResNeXt101UNet
-
cd.models.WideResNet50UNet
-
cd.models.WideResNet101UNet
-
cd.models.MobileNetV3SmallUNet
-
cd.models.MobileNetV3LargeUNet
MA-Nets
# Many MA-Nets are available in 2D and 3D
import celldetection as cd
encoder = cd.models.ConvNeXtSmall(in_channels=3, nd=3)
model = cd.models.MaNet(encoder, out_channels=1, nd=3)
Feature Pyramid Networks
-
cd.models.FPN
-
cd.models.ResNet18FPN
-
cd.models.ResNet34FPN
-
cd.models.ResNet50FPN
-
cd.models.ResNeXt50FPN
-
cd.models.ResNet101FPN
-
cd.models.ResNet152FPN
-
cd.models.ResNeXt101FPN
-
cd.models.ResNeXt152FPN
-
cd.models.WideResNet50FPN
-
cd.models.WideResNet101FPN
-
cd.models.MobileNetV3LargeFPN
-
cd.models.MobileNetV3SmallFPN
ConvNeXt Networks
# ConvNeXt Networks are available in 2D and 3D
import celldetection as cd
model = cd.models.ConvNeXtSmall(in_channels=3, nd=3)
Residual Networks
# Residual Networks are available in 2D and 3D
import celldetection as cd
model = cd.models.ResNet50(in_channels=3, nd=3)
Mobile Networks
🐳 Docker
Find us on Docker Hub: https://hub.docker.com/r/ericup/celldetection
You can pull the latest version of celldetection
via:
docker pull ericup/celldetection:latest
CPN inference via Docker with GPU
docker run --rm \
-v $PWD/docker/outputs:/outputs/ \
-v $PWD/docker/inputs/:/inputs/ \
-v $PWD/docker/models/:/models/ \
--gpus="device=0" \
celldetection:latest /bin/bash -c \
"python cpn_inference.py --tile_size=1024 --stride=768 --precision=32-true"
CPN inference via Docker with CPU
docker run --rm \
-v $PWD/docker/outputs:/outputs/ \
-v $PWD/docker/inputs/:/inputs/ \
-v $PWD/docker/models/:/models/ \
celldetection:latest /bin/bash -c \
"python cpn_inference.py --tile_size=1024 --stride=768 --precision=32-true --accelerator=cpu"
Apptainer
You can also pull our Docker images for the use with Apptainer (formerly Singularity) with this command:
apptainer pull --dir . --disable-cache docker://ericup/celldetection:latest
🤗 Hugging Face Spaces
Find us on Hugging Face and upload your own images for segmentation: https://huggingface.co/spaces/ericup/celldetection
There's also an API (Python & JavaScript), allowing you to utilize community GPUs (currently Nvidia A100) remotely!
Hugging Face API
Python
from gradio_client import Client
# Define inputs (local filename or URL)
inputs = 'https://raw.githubusercontent.com/scikit-image/scikit-image/main/skimage/data/coins.png'
# Set up client
client = Client("ericup/celldetection")
# Predict
overlay_filename, img_filename, h5_filename, csv_filename = client.predict(
inputs, # str: Local filepath or URL of your input image
# Model name
'ginoro_CpnResNeXt101UNet-fbe875f1a3e5ce2c',
# Custom Score Threshold (numeric value between 0 and 1)
False, .9, # bool: Whether to use custom setting; float: Custom setting
# Custom NMS Threshold
False, .3142, # bool: Whether to use custom setting; float: Custom setting
# Custom Number of Sample Points
False, 128, # bool: Whether to use custom setting; int: Custom setting
# Overlapping objects
True, # bool: Whether to allow overlapping objects
# API name (keep as is)
api_name="/predict"
)
# Example usage: Code below only shows how to use the results
from matplotlib import pyplot as plt
import celldetection as cd
import pandas as pd
# Read results from local temporary files
img = imread(img_filename)
overlay = imread(overlay_filename) # random colors per instance; transparent overlap
properties = pd.read_csv(csv_filename)
contours, scores, label_image = cd.from_h5(h5_filename, 'contours', 'scores', 'labels')
# Optionally display overlay
cd.imshow_row(img, img, figsize=(16, 9))
cd.imshow(overlay)
plt.show()
# Optionally display contours with text
cd.imshow_row(img, img, figsize=(16, 9))
cd.plot_contours(contours, texts=['score: %d%%\narea: %d' % s for s in zip((scores * 100).round(), properties.area)])
plt.show()
Javascript
import { client } from "@gradio/client";
const response_0 = await fetch("https://raw.githubusercontent.com/scikit-image/scikit-image/main/skimage/data/coins.png");
const exampleImage = await response_0.blob();
const app = await client("ericup/celldetection");
const result = await app.predict("/predict", [
exampleImage, // blob: Your input image
// Model name (hosted model or URL)
"ginoro_CpnResNeXt101UNet-fbe875f1a3e5ce2c",
// Custom Score Threshold (numeric value between 0 and 1)
false, .9, // bool: Whether to use custom setting; float: Custom setting
// Custom NMS Threshold
false, .3142, // bool: Whether to use custom setting; float: Custom setting
// Custom Number of Sample Points
false, 128, // bool: Whether to use custom setting; int: Custom setting
// Overlapping objects
true, // bool: Whether to allow overlapping objects
// API name (keep as is)
api_name="/predict"
]);
🧑💻 Napari Plugin
Find our Napari Plugin here: https://github.com/FZJ-INM1-BDA/celldetection-napari
Find out more about Napari here: https://napari.org
You can install it via pip:
pip install git+https://github.com/FZJ-INM1-BDA/celldetection-napari.git
🏆 Awards
- NeurIPS 2022 Cell Segmentation Challenge: Winner Finalist Award
📝 Citing
If you find this work useful, please consider giving a star ⭐️ and citation:
@article{UPSCHULTE2022102371,
title = {Contour proposal networks for biomedical instance segmentation},
journal = {Medical Image Analysis},
volume = {77},
pages = {102371},
year = {2022},
issn = {1361-8415},
doi = {https://doi.org/10.1016/j.media.2022.102371},
url = {https://www.sciencedirect.com/science/article/pii/S136184152200024X},
author = {Eric Upschulte and Stefan Harmeling and Katrin Amunts and Timo Dickscheid},
keywords = {Cell detection, Cell segmentation, Object detection, CPN},
}