ray_lightning
ray_lightning copied to clipboard
[Tune] PBT/PB2 doesn't work correctly with Ray Lightning
When using PBT/PB2, I received the following error:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
This issue happens after the trial is paused and resumed. I was able to reproduce this issue with some modifications on the example provided by ray lightning:
"""Simple example using RayAccelerator and Ray Tune"""
import os
import tempfile
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule
import pytorch_lightning as pl
import ray
from ray import tune
from ray_lightning.tune import TuneReportCheckpointCallback, get_tune_resources
from ray_lightning import RayPlugin
from ray_lightning.tests.utils import LightningMNISTClassifier
from ray.tune.schedulers.pb2 import PB2
def train_mnist(config,
checkpoint_dir=None,
data_dir=None,
num_epochs=10,
num_workers=1,
use_gpu=False,
callbacks=None):
# Make sure data is downloaded on all nodes.
def download_data():
from filelock import FileLock
with FileLock(os.path.join(data_dir, ".lock")):
MNISTDataModule(data_dir=data_dir).prepare_data()
model = LightningMNISTClassifier(config, data_dir)
callbacks = callbacks or []
checkpoint_path = None
if checkpoint_dir is not None:
checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint')
trainer = pl.Trainer(
max_epochs=num_epochs,
callbacks=callbacks,
progress_bar_refresh_rate=0,
plugins=[
RayPlugin(
num_workers=num_workers,
use_gpu=use_gpu,
init_hook=download_data)
])
dm = MNISTDataModule(
data_dir=data_dir, num_workers=1, batch_size=config["batch_size"])
trainer.fit(model, dm, ckpt_path=checkpoint_path)
def tune_mnist(data_dir,
num_samples=10,
num_epochs=10,
num_workers=1,
use_gpu=False):
config = {
"layer_1": tune.choice([32, 64, 128]),
"layer_2": tune.choice([64, 128, 256]),
"lr": tune.loguniform(1e-4, 1e-1),
"batch_size": tune.choice([32, 64, 128]),
}
scheduler = PB2(
hyperparam_bounds= {
"lr": [1e-4, 1e-1]
}
)
# Add Tune callback.
metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}
callbacks = [TuneReportCheckpointCallback(metrics, on="validation_end", filename="checkpoint")]
trainable = tune.with_parameters(
train_mnist,
data_dir=data_dir,
num_epochs=num_epochs,
num_workers=num_workers,
use_gpu=use_gpu,
callbacks=callbacks)
analysis = tune.run(
trainable,
scheduler=scheduler,
metric="loss",
mode="min",
config=config,
num_samples=num_samples,
resources_per_trial=get_tune_resources(
num_workers=num_workers, use_gpu=use_gpu),
name="tune_mnist")
print("Best hyperparameters found were: ", analysis.best_config)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-workers",
type=int,
help="Number of training workers to use.",
default=1)
parser.add_argument(
"--use-gpu", action="store_true", help="Use GPU for training.")
parser.add_argument(
"--num-samples",
type=int,
default=10,
help="Number of samples to tune.")
parser.add_argument(
"--num-epochs",
type=int,
default=10,
help="Number of epochs to train for.")
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
parser.add_argument(
"--address",
required=False,
type=str,
help="the address to use for Ray")
args, _ = parser.parse_known_args()
num_epochs = 1 if args.smoke_test else args.num_epochs
num_workers = 1 if args.smoke_test else args.num_workers
use_gpu = False if args.smoke_test else args.use_gpu
num_samples = 1 if args.smoke_test else args.num_samples
if args.smoke_test:
ray.init(num_cpus=2)
else:
ray.init(address=args.address)
data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_")
tune_mnist(data_dir, num_samples, num_epochs, num_workers, use_gpu)
The args I passed in: python3 test_ray_lightning.py --use-gpu --num-workers 2 --num-samples 4
Versions:
pytorch-lightning==1.5.10
torch==1.10.2
ray==1.12.0
ray-lightning==0.2.0
Hi @yinweisu Thanks for reporting. I believe this is a valid issue. Was able to reproduce it on my set up as well. After a bit digging, it seems this is a known issue with ptl 1.5: https://github.com/PyTorchLightning/pytorch-lightning/discussions/11435 https://github.com/PyTorchLightning/pytorch-lightning/issues/12327
The solution is basically to upgrade to ptl 1.6. @amogkam Another datapoint that we should do the upgrade sooner than later.
Thanks! And yes, upgrade to ptl 1.6 soon would be awesome!
+1 for upgrading to PyTorch Lightning 1.6! Is there an estimate for when that work might occur?