sbx icon indicating copy to clipboard operation
sbx copied to clipboard

Different results across HPC and PC while training RL using SBX

Open Deepakgthomas opened this issue 6 months ago • 4 comments

Question

While training my RL algorithm using SBX, I am getting different results across my HPC cluster and PC. However, I did find that results consistently are same within the same machine. They just diverge across machines. I am limiting all computation to CPU.

I created a minimal working code to test my hypothesis. Please let me know if there is any bug in it, such as a forgotten seed.

# simple_sbx_test.py
import jax
import numpy as np
import random
import os
import gymnasium as gym
from sbx import DQN
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.vec_env import DummyVecEnv


def set_seed(seed):
   """Set seed for reproducibility."""
   os.environ['PYTHONHASHSEED'] = str(seed)
   random.seed(seed)
   np.random.seed(seed)


def make_env(env_name, seed):
   """Create environment with fixed seed"""
   def _init():
       env = gym.make(env_name)
       env.reset(seed=seed)
       return env
   return _init


def main():
   # Fixed seeds
   AGENT_SEED = 42
   ENV_SEED = 123
   EVAL_SEED = 456
   set_seed(AGENT_SEED)

   print("=== Simple SBX DQN Cross-Platform Test (JAX) ===")
   print(f"JAX: {jax.__version__}")
   print(f"NumPy: {np.__version__}")
   print(f"JAX devices: {jax.devices()}")
   print(f"Agent seed: {AGENT_SEED}, Env seed: {ENV_SEED}, Eval seed: {EVAL_SEED}")
   print("-" * 50)

   # Create environments
   train_env = DummyVecEnv([make_env("CartPole-v1", ENV_SEED)])
   eval_env = DummyVecEnv([make_env("CartPole-v1", EVAL_SEED)])

   # Create model
   model = DQN(
       "MlpPolicy",
       train_env,
       learning_rate=1e-3,
       buffer_size=10000,
       learning_starts=1000,
       batch_size=32,
       gamma=0.99,
       train_freq=4,
       target_update_interval=1000,
       exploration_initial_eps=1.0,
       exploration_final_eps=0.05,
       exploration_fraction=0.1,
       verbose=0,
       seed=AGENT_SEED
   )

   # Print initial model parameters (JAX uses params instead of weights)
   if hasattr(model, 'qf') and hasattr(model.qf, 'params'):
       print("Initial parameters available")
       # JAX parameters are nested dictionaries, harder to inspect directly
       print("  Model initialized successfully")

   # Evaluation callback
   eval_callback = EvalCallback(
       eval_env,
       best_model_save_path=None,
       log_path=None,
       eval_freq=2000,
       n_eval_episodes=10,
       deterministic=True,
       render=False,
       verbose=1  # Enable to see evaluation results
   )

   # Train
   print("\nTraining...")
   model.learn(total_timesteps=10000, callback=eval_callback)

   print("Training completed")

   # Final evaluation
   print("\nFinal evaluation:")
   rewards = []
   for i in range(10):
       obs = eval_env.reset()
       total_reward = 0
       done = False
       while not done:
           action, _ = model.predict(obs, deterministic=True)
           obs, reward, done, info = eval_env.step(action)
           total_reward += reward[0]
       rewards.append(total_reward)
       print(f"Episode {i + 1}: {total_reward}")

   print(f"\nFinal Results:")
   print(f"Mean reward: {np.mean(rewards):.2f}")
   print(f"Std reward: {np.std(rewards):.2f}")
   print(f"All rewards: {rewards}")


if __name__ == "__main__":
   main()

This is my result from my PC -

Final evaluation:
Episode 1: 208.0
Episode 2: 237.0
Episode 3: 200.0
Episode 4: 242.0
Episode 5: 206.0
Episode 6: 334.0
Episode 7: 278.0
Episode 8: 235.0
Episode 9: 248.0
Episode 10: 206.0

and this is my result from my HPC cluster -

Final evaluation:
Episode 1: 201.0
Episode 2: 256.0
Episode 3: 193.0
Episode 4: 218.0
Episode 5: 192.0
Episode 6: 326.0
Episode 7: 239.0
Episode 8: 226.0
Episode 9: 237.0
Episode 10: 201.0

Checklist

  • [ X] I have read the documentation (required)
  • [X ] I have checked that there is no similar issue in the repo (required)

Deepakgthomas avatar Aug 29 '25 16:08 Deepakgthomas

Also, I read the documentation and tried using set_random_seed(AGENT_SEED). The result did not change.

Deepakgthomas avatar Aug 29 '25 18:08 Deepakgthomas

However, I did find that results consistently are same within the same machine.

Then it might sound like a limitation of Jax, probably similar to PyTorch

araffin avatar Aug 30 '25 11:08 araffin

Hi @Deepakgthomas, may I ask you, what is the point giving different seeds to the agent and to the environment?

abuibaid avatar Sep 24 '25 08:09 abuibaid

@abuibaid Just to get a finer degree of control. Moreover, the environment is obtained from the gymnasium library whereas the agent from sbx. Therefore, it is not like that they follow the same randomness logic.

Deepakgthomas avatar Sep 24 '25 11:09 Deepakgthomas