DeepSpeed
DeepSpeed copied to clipboard
[BUG] lr scheduler get_last_lr() does work with fp16 enabled
Describe the bug I am trying to get the current learning rate value during training for logging purposes using this function: lr_scheduler.get_last_lr() It works well when I am not using fp16 training, but when I enable fp16 I get this error: Traceback (most recent call last): ... aml_log_metrics({"training_loss":loss.item(), "learning_rate":lr_scheduler.get_last_lr()[0]}, distributed_args[0]) File "/azureml-envs/azureml_5c60934469eb9a6d2cc091063a233732/lib/python3.8/site-packages/deepspeed/runtime/lr_schedules.py", line 766, in get_last_lr assert getattr(self, '_last_lr', None) is not None, "need to call step() first" AssertionError: need to call step() first
i am calling step() before get_last_lr(), and it works when disabling fp16 training.
To Reproduce Steps to reproduce the behavior: I am using this type of loop, with fp16 enabled: model_engine, optimizer, trainloader, lr_scheduler = deepspeed.initialize(...) for epoch in range(config['epochs']): for step, data in enumerate(loop): logits = model_engine(...) loss = compute_loss(...) model_engine.backward(loss) model_engine.step() print("learning_rate" + str(lr_scheduler.get_last_lr()[0]))
Expected behavior get_last_lr() should work with or without fp16
System info (please complete the following information):
- OS: Ubuntu 18.04
- GPU count and types: two machines with x8 A100s each
- Python version: 3.8
- Any other relevant info about your setup openmpi4.1.0 cuda11.1 cudnn8 torch1.9
Launcher context Launching experiment with azure MPI
Docker context mcr.microsoft.com/azureml/openmpi4.1.0-cuda11.1-cudnn8-ubuntu18.04
** Deepspeed config file ** distillation_deepspeed_config_fp16.txt
@mbetser, thanks for reporting this error. Can you please share a simple script and steps to reproduce this issue?
@tjruwase
I have the same issue when I used this repo to train the model. The only change is set the precision to FP16
Hi @mbetser I was unable to repro the issue occurring only for fp16 I ran into the error regardless of data type. The solution is to either pass in an lr scheduler during initialization or through the ds_config as shown here.
Closing because it's such an old issue.