stable-baselines3 icon indicating copy to clipboard operation
stable-baselines3 copied to clipboard

[Feature Request] request title

Open Mahsarnzh opened this issue 6 months ago • 0 comments
trafficstars

🚀 Feature

Optional BatchNorm integration in NatureCNN

Motivation

Motivation

Batch Normalization helps stabilize and accelerate training by reducing internal covariate shift, which is especially important in high-variance pixel‐based environments like Atari games. By normalizing the activations after each convolutional layer, we expect smoother gradient flow, improved convergence speed, and reduced sensitivity to hyperparameters.

Alternatives Considered

  • LayerNorm: Normalizes across channels for each sample, but doesn’t leverage batch statistics—proved slower to converge in our early trials.

  • GroupNorm: Trades off between BatchNorm and LayerNorm by normalizing over groups of channels; improved stability but added implementation complexity and similar runtime overhead.

BatchNorm offered the best trade-off of simplicity, runtime efficiency, and empirical performance.

Early Results

We ran PPO with NatureCNN + BatchNorm on Breakout (A.L.E.) for ~200 K timesteps:

Iteration | Total Timesteps | Mean Episode Reward -- | -- | -- 1 | 8 192 | 1.85 4 | 32 768 | 6.87 8 | 65 536 | 10.10 16 | 131 072 | 14.70 25 | 204 800 | 15.10 — | — | 18.40 ± 6.45

By 200 K timesteps, the agent achieves an average reward of 18.4 ± 6.5, demonstrating both faster early learning and higher final performance compared to the baseline without BatchNorm.

Proposed Implementation

  • Introduce a new use_batch_norm: bool = False argument in NatureCNN.__init__.

  • When use_batch_norm=True, insert nn.BatchNorm2d immediately after each convolutional layer:

    python
    layers = [] layers.append(nn.Conv2d(...)) if use_batch_norm: layers.append(nn.BatchNorm2d(...)) layers.append(nn.ReLU()) # repeat for each conv block
  • Default behavior remains unchanged (use_batch_norm=False), ensuring full backward compatibility.

Pitch

Enable an optional BatchNorm toggle in the NatureCNN feature extractor so users can easily turn on/off batch normalization after each convolutional layer, improving training stability and convergence in high-variance, image-based environments.

Alternatives

Alternatives By default, use_batch_norm is set to False, so there is zero performance or behavioral impact unless the flag is explicitly turned on. When enabled, BatchNorm leverages batch-level statistics to stabilize and accelerate learning in high-variance, image-based inputs.

Other normalization strategies I evaluated:

LayerNorm: Normalizes per sample across channels—does not use batch statistics, led to slower convergence in our Atari benchmarks.

GroupNorm: Splits channels into groups for normalization—more stable than LayerNorm but incurs extra complexity and similar runtime overhead.

Neither alternative matched the simplicity, efficiency, and empirical gains of toggled-on BatchNorm, so we opted for a boolean flag that keeps it completely off by default.

Additional context

No response

Checklist

  • [x] I have checked that there is no similar issue in the repo
  • [x] If I'm requesting a new feature, I have proposed alternatives

Mahsarnzh avatar May 06 '25 21:05 Mahsarnzh