nni icon indicating copy to clipboard operation
nni copied to clipboard

Fix for using Pruner trainer methods on GPU

Open shenoynikhil opened this issue 2 years ago • 3 comments

Description

Resolves #4911

@J-shang Does this PR make sense?

Checklist

  • [ ] test case
  • [ ] doc

How to test

shenoynikhil avatar Jun 03 '22 19:06 shenoynikhil

CLA assistant check
Thank you for your submission, we really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.

:x: shenoynikhil sign now
You have signed the CLA already but the status is still pending? Let us recheck it.

Hello @shenoynikhil , could you show us an example to reproduce your problem? I think weight_mask should on cuda if you use gpu.

you could check this after pruner initializd by:

pruner = TaylorFOWeightPruner(model, config_list, ...)
for name, buffer in model.named_buffers():
    print(name, buffer.device)

J-shang avatar Jun 04 '22 02:06 J-shang

Hi @J-shang

I use pytorch lightning for modelling. My setup is as follows,

pruner = TaylorFOWeightPruner(
                None,
                None,
                training_batches=training_batches,
                criterion=None,
                traced_optimizer=self._get_traced_optimizer(), # will define this below
                trainer=self.taylorfo_trainer, # will define this below
)
task_generator = AGPTaskGenerator(
            self.total_iteration,
            self.model, # pl.LightningModule
            oc.to_container(self.config_list),
            log_dir=self.log_dir,
            keep_intermediate_result=self.keep_intermediate_result,
)
scheduler = PruningScheduler(
            pruner,
            task_generator,
            finetuner=self.finetuner,
            speedup=False,
            dummy_input=self.dummy_input,
            evaluator=self.evaluator if self.use_evaluator else None,
            reset_weight=False,
)

scheduler.compress() # to run compression

# Helper Function: all these are part of a class used for pruning, ignore the self
    def taylorfo_trainer(self, model, optimizer, criterion):
        """Helper trainer to be used with TaylorFOWeightPruner"""
        training_batches = self.additional_pruner_args.get("training_batches", 20)
        trainer = pl.Trainer(
            gpus=self.gpus, # 1 because I want to use GPU
            max_epochs=1,
            default_root_dir=None,
            logger=False,
            limit_train_batches=training_batches,
            num_sanity_val_steps=0,
        )
        log.info(f"Running Taylor Optimization with training batch(es) : {training_batches}")
        trainer.fit_loop.epoch_loop.batch_loop.connect(
            optimizer_loop=TaylorOptimizationLoop(optimizer, self.start_ckpt_path)
        )
        trainer.fit(model, self.train_dataloader)

class TaylorOptimizationLoop(OptimizerLoop):
    """TaylorOptimizationLoop using the nni traced optimizer"""

    def __init__(
        self, traced_optimizer: OptimizerConstructHelper, ckpt_path, device="cpu", *args, **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.optimizer = traced_optimizer
        self.ckpt_path = ckpt_path
        self.device = device
        self._load_optimizer_state_dict()

    def _load_optimizer_state_dict(self):
        try:
            checkpoint = torch.load(self.ckpt_path)
            optimizer_states = checkpoint["optimizer_states"][0]
            self.optimizer.load_state_dict(optimizer_states)
        except Exception as e:
            print(f"Error loading optimizer state dict: {e}")
            raise e

    def advance(self, batch: Any, *args: Any, **kwargs: Any):
        loss = self.trainer.lightning_module.step(batch)["loss"]

        # Manual Optimizeation Step
        self.optimizer.zero_grad()
        loss.backward()
        # self.optimizer.step() Not updating weights

        # Update progress
        self.optim_progress.optimizer_position += 1

def _get_traced_optimizer(self):
    dict_without_target = oc.to_container(self.model.model_config.optimizer.copy()) # dict containing the class (I use Hydra :p)
    del dict_without_target["_target_"] 
    return OptimizerConstructHelper.from_trace(
            self.model,
            nni.trace(eval(self.model.model_config.optimizer["_target_"]))(
                params=self.model.parameters(), **dict_without_target
         ),
     )

shenoynikhil98 avatar Jun 23 '22 14:06 shenoynikhil98

hello @shenoynikhil , thanks for your pr, we systematically support lightning model in nni v2.9. I think this issue have been fixed in v2.9, please have a try and I will close this pr.

J-shang avatar Sep 07 '22 07:09 J-shang