optax
optax copied to clipboard
Add example of reading learning rate from optimizer state
Closes #312.
I extended the Optax 101 notebook to also include an example of extracting the learning rate, based on #206.
This example uses the inject_hyperparams wrapper.
Some additional changes include:
- better formatting of print outputs
- extracting the
step()function, so it can reused later on - made param generation deterministic, so outputs stay the same
I have also tried to do it by passing in the count from inner_state to the scheduler. I found this method quite hacky due to having to index into inner_state[0], so I have not included the following way of doing it:
def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
opt_state = optimizer.init(params)
for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
params, opt_state, loss_value = step(params, opt_state, batch, labels)
if i % 100 == 0:
# not clear what's going on here
count = opt_state.inner_state[0].count
lr = schedule(count)
print(f'Step {i:3}, Loss: {loss_value:.3f}, Learning rate: {lr:.9f}')
return params
params = fit(initial_params, optimizer)
Please let me know what you think. Happy to make any changes!
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).
View this failed invocation of the CLA check for more information.
For the most up to date status, view the checks section at the bottom of the pull request.
Huge apologies for the long delay here, if you can sync we get this submitted?
No worries, synced!
@mtthss: this looks good to me. Do we have your green light to merge?