nni
nni copied to clipboard
Fix for using Pruner trainer methods on GPU
Description
Resolves #4911
@J-shang Does this PR make sense?
Checklist
- [ ] test case
- [ ] doc
How to test
Thank you for your submission, we really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
:x: shenoynikhil sign now
You have signed the CLA already but the status is still pending? Let us recheck it.
Hello @shenoynikhil , could you show us an example to reproduce your problem? I think weight_mask
should on cuda
if you use gpu.
you could check this after pruner initializd by:
pruner = TaylorFOWeightPruner(model, config_list, ...)
for name, buffer in model.named_buffers():
print(name, buffer.device)
Hi @J-shang
I use pytorch lightning for modelling. My setup is as follows,
pruner = TaylorFOWeightPruner(
None,
None,
training_batches=training_batches,
criterion=None,
traced_optimizer=self._get_traced_optimizer(), # will define this below
trainer=self.taylorfo_trainer, # will define this below
)
task_generator = AGPTaskGenerator(
self.total_iteration,
self.model, # pl.LightningModule
oc.to_container(self.config_list),
log_dir=self.log_dir,
keep_intermediate_result=self.keep_intermediate_result,
)
scheduler = PruningScheduler(
pruner,
task_generator,
finetuner=self.finetuner,
speedup=False,
dummy_input=self.dummy_input,
evaluator=self.evaluator if self.use_evaluator else None,
reset_weight=False,
)
scheduler.compress() # to run compression
# Helper Function: all these are part of a class used for pruning, ignore the self
def taylorfo_trainer(self, model, optimizer, criterion):
"""Helper trainer to be used with TaylorFOWeightPruner"""
training_batches = self.additional_pruner_args.get("training_batches", 20)
trainer = pl.Trainer(
gpus=self.gpus, # 1 because I want to use GPU
max_epochs=1,
default_root_dir=None,
logger=False,
limit_train_batches=training_batches,
num_sanity_val_steps=0,
)
log.info(f"Running Taylor Optimization with training batch(es) : {training_batches}")
trainer.fit_loop.epoch_loop.batch_loop.connect(
optimizer_loop=TaylorOptimizationLoop(optimizer, self.start_ckpt_path)
)
trainer.fit(model, self.train_dataloader)
class TaylorOptimizationLoop(OptimizerLoop):
"""TaylorOptimizationLoop using the nni traced optimizer"""
def __init__(
self, traced_optimizer: OptimizerConstructHelper, ckpt_path, device="cpu", *args, **kwargs
):
super().__init__(*args, **kwargs)
self.optimizer = traced_optimizer
self.ckpt_path = ckpt_path
self.device = device
self._load_optimizer_state_dict()
def _load_optimizer_state_dict(self):
try:
checkpoint = torch.load(self.ckpt_path)
optimizer_states = checkpoint["optimizer_states"][0]
self.optimizer.load_state_dict(optimizer_states)
except Exception as e:
print(f"Error loading optimizer state dict: {e}")
raise e
def advance(self, batch: Any, *args: Any, **kwargs: Any):
loss = self.trainer.lightning_module.step(batch)["loss"]
# Manual Optimizeation Step
self.optimizer.zero_grad()
loss.backward()
# self.optimizer.step() Not updating weights
# Update progress
self.optim_progress.optimizer_position += 1
def _get_traced_optimizer(self):
dict_without_target = oc.to_container(self.model.model_config.optimizer.copy()) # dict containing the class (I use Hydra :p)
del dict_without_target["_target_"]
return OptimizerConstructHelper.from_trace(
self.model,
nni.trace(eval(self.model.model_config.optimizer["_target_"]))(
params=self.model.parameters(), **dict_without_target
),
)
hello @shenoynikhil , thanks for your pr, we systematically support lightning model in nni v2.9. I think this issue have been fixed in v2.9, please have a try and I will close this pr.