rlkit
rlkit copied to clipboard
Investigate super-convergence on RL algorithms
I have been using these two routines to figure out the best learning rate to apply with awesome results on SAC. However, the changes in the temperature
alter those values along the way. Probably would be a good idea to extend it further to do some sort of 'automatic' discovery of LR after x
amount of epochs. This version will also mess up the gradients, so you cannot use the policy after you run this.
def find_policy_lr_step(self, loss):
self.find_lr_batch_num += 1
if self.find_lr_batch_num == 1:
self.find_lr_avg_loss = 0.0
self.find_lr_worst_loss = loss.item()
self.find_lr_best_loss = loss.item()
self.find_lr_best_lr = self.policy_optimizer.param_groups[0]['lr']
self.find_lr_worst_lr = self.policy_optimizer.param_groups[0]['lr']
self.find_lr_avg_loss = self.find_lr_beta * self.find_lr_avg_loss + (1-self.find_lr_beta) * loss.item()
smoothed_loss = self.find_lr_avg_loss / (1 - self.find_lr_beta ** self.find_lr_batch_num)
# Record the best and worst loss
if self.find_lr_batch_num > self.find_lr_batches // 10 and smoothed_loss < self.find_lr_best_loss:
self.find_lr_best_lr = self.find_lr_current_lr
self.find_lr_best_loss = smoothed_loss
# We only record at the start (we dont care about the divergent part)
if self.find_lr_batch_num < self.find_lr_batches // 5:
self.find_lr_worst_loss = max(smoothed_loss, self.find_lr_worst_loss)
# Stop if the loss is exploding
if self.find_lr_batch_num > self.find_lr_batches:
import matplotlib.pyplot as plt
plt.plot(self.find_lr_log_lrs,self.find_lr_losses)
plt.show()
# TODO: This is a simplistic heuristic until we do it properly doing gradient analysis.
printout(f'The best learning rate for network could be around: {self.find_lr_best_lr / 10}')
printout(f'Process will exit because finding the learning rate will make your gradients to degenerate')
exit(0)
# Store the values unless we are already diverging
if smoothed_loss <= self.find_lr_worst_loss:
self.find_lr_losses.append(smoothed_loss)
self.find_lr_log_lrs.append(math.log10(self.policy_optimizer.param_groups[0]['lr']))
# Update with the new learning rate.
self.find_lr_current_lr *= self.find_lr_multiplier
self.policy_optimizer.param_groups[0]['lr'] = self.find_lr_current_lr
def find_qfunc_lr_step(self, qf1_loss, qf2_loss):
self.find_lr_batch_num += 1
if self.find_lr_batch_num == 1:
self.find_lr_avg_loss = 0.0
self.find_lr_worst_loss = min( qf1_loss.item(), qf2_loss.item() )
self.find_lr_best_loss = min( qf1_loss.item(), qf2_loss.item() )
self.find_lr_best_lr = self.qf1_optimizer.param_groups[0]['lr']
self.find_lr_worst_lr = self.qf1_optimizer.param_groups[0]['lr']
self.find_lr_avg_loss = self.find_lr_beta * self.find_lr_avg_loss + (1-self.find_lr_beta) * min( qf1_loss.item(), qf2_loss.item() )
smoothed_loss = self.find_lr_avg_loss / (1 - self.find_lr_beta ** self.find_lr_batch_num)
# Record the best and worst loss
if self.find_lr_batch_num > self.find_lr_batches // 10 and smoothed_loss < self.find_lr_best_loss:
self.find_lr_best_lr = self.find_lr_current_lr
self.find_lr_best_loss = smoothed_loss
# We only record at the start (we dont care about the divergent part)
if self.find_lr_batch_num < self.find_lr_batches // 5:
self.find_lr_worst_loss = max(smoothed_loss, self.find_lr_worst_loss)
# Stop if the loss is exploding
if self.find_lr_batch_num > self.find_lr_batches:
import matplotlib.pyplot as plt
plt.plot(self.find_lr_log_lrs,self.find_lr_losses)
plt.show()
# TODO: This is a simplistic heuristic until we do it properly doing gradient analysis.
printout(f'The best learning rate for q function approximator could be around: {self.find_lr_best_lr / 10}')
printout(f'Process will exit because finding the learning rate will make your gradients to degenerate')
exit(0)
# Store the values unless we are already diverging
if smoothed_loss <= self.find_lr_worst_loss:
self.find_lr_losses.append(smoothed_loss)
self.find_lr_log_lrs.append(math.log10(self.qf1_optimizer.param_groups[0]['lr']))
# Update with the new learning rate.
self.find_lr_current_lr *= self.find_lr_multiplier
self.qf1_optimizer.param_groups[0]['lr'] = self.find_lr_current_lr
self.qf2_optimizer.param_groups[0]['lr'] = self.find_lr_current_lr
Thanks for the post! I'm a bit unsure what you are asking. Are you asking that others or I try this out, or merge this code in? Or were you asking for feedback?
Also, if you have example plots for the performance of this on specific environments, it would help.
I don't do research, there is probably a lot of things to do to achieve something worth publishing. Just letting you know that it shows promising result in my limited trials. So this is kind of an observation
I see. Thanks for sharing! Would you mind posting your results here?
On Tue, Jun 11, 2019, 4:25 PM Federico Andres Lois [email protected] wrote:
I don't do research, there is probably a lot of things to do to achieve something worth publishing. Just letting you know that it shows promising result in my limited trials. So this is kind of an observation
— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/vitchyr/rlkit/issues/53?email_source=notifications&email_token=AAJ4VZMFYRBBOV46NOCCDGTP2AX7XA5CNFSM4HJEARI2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGODXOZJ2A#issuecomment-501060840, or mute the thread https://github.com/notifications/unsubscribe-auth/AAJ4VZJQEABC6HCHLHPI6SDP2AX7XANCNFSM4HJEARIQ .
Sorry, I would really like but I am under NDA for this stuff. What I can say (which is general enough) is that even though the source data is very difficult to make it converge using general methods (in the supervised case too); with super convergence effects I was able to steer the policy quite rapidly (in the same way I am able to do on the supervised case). I am training supervised neural networks in under 100 minutes what it took multiple days 6 months ago to the same accuracy.