brax icon indicating copy to clipboard operation
brax copied to clipboard

Feature Request: Add learning_rate_schedule_fn argument to ppo.train

Open JummerCloth opened this issue 6 months ago • 0 comments

Hi Brax team,

JummerCloth again. Thanks for adding the parametric action distribution! I’d like to suggest another small but useful feature enhancement to the ppo.train function: adding a new argument learning_rate_schedule_fn to allow more flexible control over the learning rate schedule. This would make it easier to plug in schedules like optax.cosine_decay_schedule or any other custom schedule without modifying the core training logic.

Motivation

Right now, ppo.train uses the simple adam optimizer. While theoretically adaptive, this empirically limits experimentation and fine-tuning, especially for users who want to explore different scheduling strategies (e.g., cosine decay, piecewise constant, exponential decay, etc.). We benchmarked performance between brax and rsl_rl, and found that the adaptive learning rate in rsl_rl helped a lot for our tasks and made rsl_rl outperform brax. While the adaptive learning rate may be challenging to implement, we found that a simple learning_rate_schedule_fn, such as cosine decay, is sufficient for significant improvement in our tasks' performance using brax.

Proposed Change

Add a new optional argument to ppo.train:

def train(..., learning_rate_schedule_fn: Optional[Callable[[int], float]] = None, ...):

Then modify the learning rate schedule setup as follows:

if learning_rate_schedule_fn is None:
    # Default to a constant schedule if no custom schedule is provided
    learning_rate_schedule_fn = optax.constant_schedule(value=learning_rate)

optimizer = optax.adam(learning_rate=learning_rate_schedule_fn)

This change would be backwards compatible and would not affect existing users.

Example Usage

learning_rate_schedule_fn = optax.cosine_decay_schedule(
    init_value=0.001,
    decay_steps=100_000,
    alpha=0.1,
)

ppo.train(..., learning_rate_schedule_fn=learning_rate_schedule_fn)

Let me know if this would be a welcome addition — I’d be happy to contribute a PR if that’s helpful!

Thanks for the awesome work on Brax!

JummerCloth

JummerCloth avatar Jun 27 '25 06:06 JummerCloth