nni icon indicating copy to clipboard operation
nni copied to clipboard

Trainer with GPU based Model fails while creating Masks

Open shenoynikhil opened this issue 3 years ago • 1 comments

Bug with GPU Model

Currently, while using pruning methods like TaylorFOWeight Pruner, If I use a model on GPU for getting the metrics (as calculated for getting masks), it fails on line while creating masks. The reason why it fails is,

metrics - Dict[str, Tensor] where the tensor is on cuda
wrapper.weight_mask - Tensor on cpu

What would you like to be added:

 # something like this, not necessarily the right answer
metric = metric.to(wrapper.weight_mask.device)

Why is this needed: Allow us to use GPU based fitting (therefore much faster) while calculating metrics for generating masks.

Without this feature, how does current nni work: Fails when model is gpu based

Components that may involve changes: All Sparsity Allocators under nni/algorithms/compression/v2/pytorch/pruning/tools/sparsity_allocator.py

shenoynikhil avatar Jun 03 '22 07:06 shenoynikhil

Hi @J-shang

I use pytorch lightning for modelling. My setup is as follows,

pruner = TaylorFOWeightPruner(
                None,
                None,
                training_batches=training_batches,
                criterion=None,
                traced_optimizer=self._get_traced_optimizer(), # will define this below
                trainer=self.taylorfo_trainer, # will define this below
)
task_generator = AGPTaskGenerator(
            self.total_iteration,
            self.model, # pl.LightningModule
            oc.to_container(self.config_list),
            log_dir=self.log_dir,
            keep_intermediate_result=self.keep_intermediate_result,
)
scheduler = PruningScheduler(
            pruner,
            task_generator,
            finetuner=self.finetuner,
            speedup=False,
            dummy_input=self.dummy_input,
            evaluator=self.evaluator if self.use_evaluator else None,
            reset_weight=False,
)

scheduler.compress() # to run compression

# Helper Function: all these are part of a class used for pruning, ignore the self
    def taylorfo_trainer(self, model, optimizer, criterion):
        """Helper trainer to be used with TaylorFOWeightPruner"""
        training_batches = self.additional_pruner_args.get("training_batches", 20)
        trainer = pl.Trainer(
            gpus=self.gpus, # 1 because I want to use GPU
            max_epochs=1,
            default_root_dir=None,
            logger=False,
            limit_train_batches=training_batches,
            num_sanity_val_steps=0,
        )
        log.info(f"Running Taylor Optimization with training batch(es) : {training_batches}")
        trainer.fit_loop.epoch_loop.batch_loop.connect(
            optimizer_loop=TaylorOptimizationLoop(optimizer, self.start_ckpt_path)
        )
        trainer.fit(model, self.train_dataloader)

class TaylorOptimizationLoop(OptimizerLoop):
    """TaylorOptimizationLoop using the nni traced optimizer"""

    def __init__(
        self, traced_optimizer: OptimizerConstructHelper, ckpt_path, device="cpu", *args, **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.optimizer = traced_optimizer
        self.ckpt_path = ckpt_path
        self.device = device
        self._load_optimizer_state_dict()

    def _load_optimizer_state_dict(self):
        try:
            checkpoint = torch.load(self.ckpt_path)
            optimizer_states = checkpoint["optimizer_states"][0]
            self.optimizer.load_state_dict(optimizer_states)
        except Exception as e:
            print(f"Error loading optimizer state dict: {e}")
            raise e

    def advance(self, batch: Any, *args: Any, **kwargs: Any):
        loss = self.trainer.lightning_module.step(batch)["loss"]

        # Manual Optimizeation Step
        self.optimizer.zero_grad()
        loss.backward()
        # self.optimizer.step() Not updating weights

        # Update progress
        self.optim_progress.optimizer_position += 1

def _get_traced_optimizer(self):
    dict_without_target = oc.to_container(self.model.model_config.optimizer.copy()) # dict containing the class (I use Hydra :p)
    del dict_without_target["_target_"] 
    return OptimizerConstructHelper.from_trace(
            self.model,
            nni.trace(eval(self.model.model_config.optimizer["_target_"]))(
                params=self.model.parameters(), **dict_without_target
         ),
     )

shenoynikhil98 avatar Jun 23 '22 14:06 shenoynikhil98