LightningCLI support for external accelerators
🚀 Feature
LightningCLI support for external accelerators
Motivation
LightningCLI helps avoid boilerplate code for command line tools. The current implementation does not seem to support external accelerators, and it only accepts the accelerators present in lightning source.
Pitch
Extend support for external accelerators in LightningCLI.
Alternatives
Additional context
First mentioned in #54
To reproduce:
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.cli import LightningCLI
from lightning_habana import HPUAccelerator
class BMAccelerator(BoringModel):
def on_fit_start(self):
assert isinstance(self.trainer.accelerator, HPUAccelerator), self.trainer.accelerator
model = BMAccelerator
accelerator = HPUAccelerator()
if __name__ == "__main__":
# Method 1, Passing supported accelerator class instance from an external library
cli = LightningCLI(model, trainer_defaults={'accelerator': accelerator}
# Method 2, passing accelerator as string
cli = LightningCLI(model, trainer_defaults={'accelerator': 'hpu'}
Gives the following tracebacks:
Method 1, passing supported accelerator class instance from an external library
Traceback (most recent call last):
File "temp.py", line 34, in <module>
cli = LightningCLI(model, trainer_defaults={'accelerator': HPUAccelerator()})
File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/cli.py", line 353, in __init__
self._run_subcommand(self.subcommand)
File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/cli.py", line 642, in _run_subcommand
fn(**fn_kwargs)
File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 520, in fit
call._call_and_handle_interrupt(
File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 559, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 893, in _run
self.strategy.setup_environment()
File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/strategies/strategy.py", line 127, in setup_environment
self.accelerator.setup_device(self.root_device)
File "/home/agola/lightning-habana-fork/src/lightning_habana/pytorch/accelerator.py", line 50, in setup_device
raise MisconfigurationException(f"Device should be HPU, got {device} instead.")
lightning.fabric.utilities.exceptions.MisconfigurationException: Device should be HPU, got cpu instead.
Method 2, passing accelerator as string
Traceback (most recent call last):
File "temp.py", line 33, in <module>
cli = LightningCLI(model, trainer_defaults={'accelerator': "hpu"})
File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/cli.py", line 353, in __init__
self._run_subcommand(self.subcommand)
File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/cli.py", line 642, in _run_subcommand
fn(**fn_kwargs)
File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 520, in fit
call._call_and_handle_interrupt(
File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 559, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 916, in _run
call._call_lightning_module_hook(self, "on_fit_start")
File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 142, in _call_lightning_module_hook
output = fn(*args, **kwargs)
File "temp.py", line 15, in on_fit_start
assert isinstance(self.trainer.accelerator,
AssertionError: <lightning.pytorch.accelerators.hpu.HPUAccelerator object at 0x7f37f62917c0>
Env
lightning 2.0.0
lightning-fabric 2.0.3
lightning-habana 1.0.0
lightning-utilities 0.9.0
pytorch-lightning 2.0.5
Hi! thanks for your contribution!, great first issue!
cc @Borda
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.