anndata icon indicating copy to clipboard operation
anndata copied to clipboard

anndata DataLoader for pyTorch without DistributedSampler

Open mbuttner opened this issue 3 years ago • 5 comments

Hi there, I have been trying to implement an MLP to predict cell type labels using pyTorch Lightning and the AnnLoader function. For the implementation, I followed the AnnLoader tutorial to interface with pyTorch models and the PyTorch Lightning tutorial. I aim to implement the training, test and prediction methods, and run it on a GPU. I tested my code on a Google Colabs instance. The error message is the same for GPU and CPU runtime.
When I try to predict a cell type label using the predict function, pyTorch lightning wants to use the DistributedSampler as sampler, which is not implemented in the AnnLoader and I could not figure out how to disable the sampler.

Here's my code:

import gdown
import pytorch_lightning as pl
import torch
import torch.nn as nn

import numpy as np
import scanpy as sc
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
from torchmetrics.functional import accuracy
from anndata.experimental.pytorch import AnnLoader

#define model class 
class MLP(pl.LightningModule):
  
  def __init__(self, input_dim, hidden_dims, out_dim):
    super().__init__()
    modules = []
    for in_size, out_size in zip([input_dim]+hidden_dims, hidden_dims):
        modules.append(nn.Linear(in_size, out_size))
        modules.append(nn.LayerNorm(out_size))
        modules.append(nn.ReLU())
        modules.append(nn.Dropout(p=0.05))
    modules.append(nn.Linear(hidden_dims[-1], out_dim))
    self.layers = nn.Sequential(*modules)
    
    self.ce = nn.CrossEntropyLoss()
    
  def forward(self, x):
    return self.layers(x)
  
  def training_step(self, batch, batch_idx):
    # here, a batch has data (x) and labels (y). What is returned by
    # batch depends on the __get_item__() implementation in your Dataset
    x = batch.X
    y = batch.obs['cell_type'] #hard coded, please adapt
    x = x.view(x.size(0), -1)
    y_hat = self.layers(x)
    loss = self.ce(y_hat, y)
    self.log('train_loss', loss)
    return loss
  
  def test_step(self, batch, batch_idx):
    # here, a batch has data (x) and labels (y). What is returned by
    # batch depends on the __get_item__() implementation in your Dataset
    x = batch.X
    y = batch.obs['cell_type'] #hard coded, please adapt
    x = x.view(x.size(0), -1)
    y_hat = self.layers(x)
    loss = self.ce(y_hat, y)
    y_hat = torch.argmax(y_hat, dim=1)
    acc = accuracy(y_hat, y)
    metrics = dict({
        'test_loss': loss.clone().detach(),
        'test_acc': acc.clone().detach(),
    })
    self.log_dict(metrics, batch_size=len(y))
    return metrics

  def predict_step(self, batch, batch_idx, dataloader_idx=0):
    # here, a batch has data (x) and labels (y). What is returned by
    x = batch.X
    x = x.view(x.size(0), -1)
    y_hat = self.model(x)
    return y_hat

  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
    return optimizer

#load normalized pancreas data from AnnLoader tutorial 
from google.colab import drive
drive.mount('/content/drive')
file_path = '/content/drive/My Drive/pancreas_normalized.h5ad'

adata = sc.read(file_path)
adata.X = adata.raw.X # put raw counts to .X

#prepare AnnLoader 
encoder_study = OneHotEncoder(sparse=False, dtype=np.float32)
encoder_study.fit(adata.obs['study'].to_numpy()[:, None])

encoder_celltype = LabelEncoder()
encoder_celltype.fit(adata.obs['cell_type'])

use_cuda = torch.cuda.is_available()

encoders = {
    'obs': {
        'study': lambda s: encoder_study.transform(s.to_numpy()[:, None]),
        'cell_type': encoder_celltype.transform
    }
}

# Load data as dataLoader, split in train and test data  
dataloader = AnnLoader(adata[adata.obs['study']!='Pancreas Fluidigm C1'], batch_size=128, shuffle=True, convert=encoders, use_cuda=use_cuda)
dataloader_test = AnnLoader(adata[adata.obs['study']=='Pancreas Fluidigm C1'], batch_size=128, #sampler = sampler,  
                            shuffle=False, convert=encoders, use_cuda=use_cuda)

#create MLP model, configure pytorch lightning trainer 
mlp = MLP(input_dim = adata.n_vars, hidden_dims = [128,128], out_dim=8)
trainer = pl.Trainer(auto_scale_batch_size='power', gpus=1, deterministic=True, 
                     max_epochs=5, replace_sampler_ddp=False) 
# Train the model
trainer.fit(mlp, dataloader)

# Perform evaluation
trainer.test(mlp, dataloader_test)
## output [{'test_acc': 0.9620253443717957, 'test_loss': 0.12309074401855469}]

# Return predictions
trainer.predict(mlp, dataloader_test)

Here is the error code from the prediction step:

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:489: PossibleUserWarning: Your `predict_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test/predict dataloaders.
  category=PossibleUserWarning,
---------------------------------------------------------------------------
MisconfigurationException                 Traceback (most recent call last)
[<ipython-input-18-0ce04d8b9a10>](https://localhost:8080/#) in <module>()
      1 # Return predictions
----> 2 trainer.predict(mlp, dataloader_test)

11 frames
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/utilities/data.py](https://localhost:8080/#) in _get_dataloader_init_kwargs(dataloader, sampler, mode)
    240         dataloader_cls_name = dataloader.__class__.__name__
    241         raise MisconfigurationException(
--> 242             f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
    243             "This would fail as some of the `__init__` arguments are not available as instance attributes. "
    244             f"The missing attributes are {required_args}. "

MisconfigurationException: Trying to inject `DistributedSampler` into the `AnnLoader` instance. This would fail as some of the `__init__` arguments are not available as instance attributes. The missing attributes are ['adatas']. HINT: If you wrote the `AnnLoader` class, define `self.missing_arg_name` or manually add the `DistributedSampler` as: `AnnLoader(dataset, sampler=DistributedSampler(dataset))`.

Versions:

pyasn1                                      0.4.8
pyasn1_modules                              0.2.8
pydev_ipython                               NA
pydevconsole                                NA
pydevd                                      2.0.0
pydevd_concurrency_analyser                 NA
pydevd_file_utils                           NA
pydevd_plugins                              NA
pydevd_tracing                              NA
pydot_ng                                    2.0.0
pygments                                    2.6.1
pyparsing                                   3.0.7
pytorch_lightning                           1.6.0
pytz                                        2018.9
regex                                       2.5.72
requests                                    2.23.0
rsa                                         4.8
scipy                                       1.4.1
session_info                                1.0.0
setuptools                                  57.4.0
simplegeneric                               NA
sitecustomize                               NA
six                                         1.15.0
sklearn                                     1.0.2
socks                                       1.7.1
sphinxcontrib                               NA
storemagic                                  NA
tblib                                       1.7.0
tensorboard                                 2.8.0
tensorflow                                  2.8.0
termcolor                                   1.1.0
threadpoolctl                               3.1.0
toolz                                       0.11.2
torch                                       1.10.0+cu111
torchmetrics                                0.7.3
torchtext                                   0.11.0
torchvision                                 0.11.1+cu111
tornado                                     5.1.1
tqdm                                        4.63.0
traitlets                                   5.1.1
typing_extensions                           NA
uritemplate                                 3.0.1
urllib3                                     1.24.3
wcwidth                                     0.2.5
webencodings                                0.5.1
wrapt                                       1.14.0
yaml                                        6.0
zipp                                        NA
zmq                                         22.3.0
-----
IPython             5.5.0
jupyter_client      5.3.5
jupyter_core        4.9.2
notebook            5.3.1
-----
Python 3.7.13 (default, Mar 16 2022, 17:37:17) [GCC 7.5.0]
Linux-5.4.144+-x86_64-with-Ubuntu-18.04-bionic
-----
Session information updated at 2022-04-13 15:04

mbuttner avatar Apr 13 '22 15:04 mbuttner

Yes, the way pytorch lightning tries to inject the sampler is not supported, probably you need to disable this somehow. But i will check what can be done to fix this.

Koncopd avatar Apr 13 '22 15:04 Koncopd

it honestly sounds like a bug in pytorch lightning as the code works for fit. it should remember that you didn't want to replace samplers. I would make an issue there

adamgayoso avatar Apr 14 '22 20:04 adamgayoso

Yes, I thought so, too, because the fit step works just fine. Thanks for linking the issue to the PyTorch lightning repo, too.

mbuttner avatar Apr 15 '22 06:04 mbuttner

Hi there, I have got a reply from the PyTorch lightning developers, who suggested some change in the DataLoader, see issue posted there. Is this something to implement without much hassle? Thank you!

mbuttner avatar May 03 '22 10:05 mbuttner

@mbuttner thank you, i will check their proposed changes.

Koncopd avatar May 03 '22 16:05 Koncopd