Feature Request: Add learning_rate_schedule_fn argument to ppo.train
Hi Brax team,
JummerCloth again. Thanks for adding the parametric action distribution! I’d like to suggest another small but useful feature enhancement to the ppo.train function: adding a new argument learning_rate_schedule_fn to allow more flexible control over the learning rate schedule. This would make it easier to plug in schedules like optax.cosine_decay_schedule or any other custom schedule without modifying the core training logic.
Motivation
Right now, ppo.train uses the simple adam optimizer. While theoretically adaptive, this empirically limits experimentation and fine-tuning, especially for users who want to explore different scheduling strategies (e.g., cosine decay, piecewise constant, exponential decay, etc.). We benchmarked performance between brax and rsl_rl, and found that the adaptive learning rate in rsl_rl helped a lot for our tasks and made rsl_rl outperform brax. While the adaptive learning rate may be challenging to implement, we found that a simple learning_rate_schedule_fn, such as cosine decay, is sufficient for significant improvement in our tasks' performance using brax.
Proposed Change
Add a new optional argument to ppo.train:
def train(..., learning_rate_schedule_fn: Optional[Callable[[int], float]] = None, ...):
Then modify the learning rate schedule setup as follows:
if learning_rate_schedule_fn is None:
# Default to a constant schedule if no custom schedule is provided
learning_rate_schedule_fn = optax.constant_schedule(value=learning_rate)
optimizer = optax.adam(learning_rate=learning_rate_schedule_fn)
This change would be backwards compatible and would not affect existing users.
Example Usage
learning_rate_schedule_fn = optax.cosine_decay_schedule(
init_value=0.001,
decay_steps=100_000,
alpha=0.1,
)
ppo.train(..., learning_rate_schedule_fn=learning_rate_schedule_fn)
Let me know if this would be a welcome addition — I’d be happy to contribute a PR if that’s helpful!
Thanks for the awesome work on Brax!
JummerCloth