lightning-Habana icon indicating copy to clipboard operation
lightning-Habana copied to clipboard

LightningCLI support for external accelerators

Open ankitgola005 opened this issue 2 years ago • 3 comments

🚀 Feature

LightningCLI support for external accelerators

Motivation

LightningCLI helps avoid boilerplate code for command line tools. The current implementation does not seem to support external accelerators, and it only accepts the accelerators present in lightning source.

Pitch

Extend support for external accelerators in LightningCLI.

Alternatives

Additional context

First mentioned in #54

To reproduce:

from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.cli import LightningCLI
from lightning_habana import HPUAccelerator

class BMAccelerator(BoringModel):
    def on_fit_start(self):
        assert isinstance(self.trainer.accelerator, HPUAccelerator), self.trainer.accelerator

model = BMAccelerator
accelerator = HPUAccelerator()

if __name__ == "__main__":

    # Method 1, Passing supported accelerator class instance from an external library
    cli = LightningCLI(model, trainer_defaults={'accelerator': accelerator}

    # Method 2, passing accelerator as string
    cli = LightningCLI(model, trainer_defaults={'accelerator': 'hpu'}

Gives the following tracebacks:

Method 1, passing supported accelerator class instance from an external library

Traceback (most recent call last):
  File "temp.py", line 34, in <module>
    cli = LightningCLI(model, trainer_defaults={'accelerator': HPUAccelerator()})
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/cli.py", line 353, in __init__
    self._run_subcommand(self.subcommand)
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/cli.py", line 642, in _run_subcommand
    fn(**fn_kwargs)
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 520, in fit
    call._call_and_handle_interrupt(
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 559, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 893, in _run
    self.strategy.setup_environment()
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/strategies/strategy.py", line 127, in setup_environment
    self.accelerator.setup_device(self.root_device)
  File "/home/agola/lightning-habana-fork/src/lightning_habana/pytorch/accelerator.py", line 50, in setup_device
    raise MisconfigurationException(f"Device should be HPU, got {device} instead.")
lightning.fabric.utilities.exceptions.MisconfigurationException: Device should be HPU, got cpu instead.

Method 2, passing accelerator as string

Traceback (most recent call last):
  File "temp.py", line 33, in <module>
    cli = LightningCLI(model, trainer_defaults={'accelerator': "hpu"})
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/cli.py", line 353, in __init__
    self._run_subcommand(self.subcommand)
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/cli.py", line 642, in _run_subcommand
    fn(**fn_kwargs)
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 520, in fit
    call._call_and_handle_interrupt(
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 559, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 916, in _run
    call._call_lightning_module_hook(self, "on_fit_start")
  File "/home/agola/anaconda3/envs/plt_3.8/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 142, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "temp.py", line 15, in on_fit_start
    assert isinstance(self.trainer.accelerator,
AssertionError: <lightning.pytorch.accelerators.hpu.HPUAccelerator object at 0x7f37f62917c0>

Env

lightning                     2.0.0
lightning-fabric              2.0.3
lightning-habana              1.0.0
lightning-utilities           0.9.0
pytorch-lightning             2.0.5

ankitgola005 avatar Jul 19 '23 09:07 ankitgola005

Hi! thanks for your contribution!, great first issue!

github-actions[bot] avatar Jul 19 '23 09:07 github-actions[bot]

cc @Borda

jerome-habana avatar Jul 19 '23 09:07 jerome-habana

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

stale[bot] avatar Sep 17 '23 12:09 stale[bot]