PastebinPython icon indicating copy to clipboard operation
PastebinPython copied to clipboard

Compare agents1.py

Open vicks4u opened this issue 3 months ago • 0 comments

""" Compare PPO, A2C, and DQN on MT5 trading environment.

  • Each agent is trained separately
  • Results (mean reward) are logged """

import time import numpy as np import pandas as pd import gym from gym import spaces import MetaTrader5 as mt5

from stable_baselines3 import PPO, A2C, DQN from stable_baselines3.common.vec_env import DummyVecEnv from stable_baselines3.common.evaluation import evaluate_policy

-------------------------

CONFIG

-------------------------

SYMBOL = "EURUSD" TIMEFRAME = mt5.TIMEFRAME_M5 LOOKBACK = 50 TRAIN_TIMESTEPS = 20000 N_BARS = 5000 SEED = 42

-------------------------

-------------------------

MT5 Connection

-------------------------

def mt5_connect(): if not mt5.initialize(): raise RuntimeError(f"MT5 init failed: {mt5.last_error()}") if not mt5.symbol_select(SYMBOL, True): raise RuntimeError(f"Could not select {SYMBOL}")

def mt5_shutdown(): mt5.shutdown()

def fetch_bars(symbol, timeframe, n_bars): rates = mt5.copy_rates_from_pos(symbol, timeframe, 0, n_bars) if rates is None: raise RuntimeError(f"Failed to fetch data: {mt5.last_error()}") df = pd.DataFrame(rates) df['time'] = pd.to_datetime(df['time'], unit='s') return df

-------------------------

Custom Gym Env

-------------------------

class MT5TradingEnv(gym.Env): def init(self, df, lookback=LOOKBACK): super().init() self.df = df.reset_index(drop=True) self.lookback = lookback self.ptr = lookback self.position = 0 self.entry_price = 0 self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(lookback+1,), dtype=np.float32) self.action_space = spaces.Discrete(3) # 0=hold, 1=buy, 2=sell

def _get_obs(self):
    closes = self.df.loc[self.ptr-self.lookback:self.ptr-1, "close"].values.astype(np.float32)
    norm = closes / (closes[-1] + 1e-9) - 1.0
    return np.concatenate([norm, [float(self.position)]], axis=0)

def reset(self):
    self.ptr = self.lookback
    self.position = 0
    self.entry_price = 0
    return self._get_obs()

def step(self, action):
    done, reward = False, 0
    price = float(self.df.loc[self.ptr, "close"])
    if action == 1:  # buy
        if self.position == 0:
            self.position, self.entry_price = 1, price
        elif self.position == -1:
            reward += (self.entry_price - price)
            self.position, self.entry_price = 1, price
    elif action == 2:  # sell
        if self.position == 0:
            self.position, self.entry_price = -1, price
        elif self.position == 1:
            reward += (price - self.entry_price)
            self.position, self.entry_price = -1, price

    self.ptr += 1
    if self.ptr >= len(self.df):
        done = True
    else:
        next_price = float(self.df.loc[self.ptr, "close"])
        if self.position == 1:
            reward += (next_price - self.entry_price) * 0.1
        elif self.position == -1:
            reward += (self.entry_price - next_price) * 0.1

    obs = self._get_obs() if not done else np.zeros(self.observation_space.shape, dtype=np.float32)
    return obs, float(reward), done, {}

-------------------------

Training & Evaluation

-------------------------

def run_comparison(): mt5_connect() df = fetch_bars(SYMBOL, TIMEFRAME, N_BARS) mt5_shutdown()

results = {}
agents = {
    "PPO": PPO,
    "A2C": A2C,
    "DQN": DQN
}

for name, algo in agents.items():
    print(f"\n=== Training {name} ===")
    env = DummyVecEnv([lambda: MT5TradingEnv(df, lookback=LOOKBACK)])
    model = algo("MlpPolicy", env, verbose=0, seed=SEED)
    model.learn(total_timesteps=TRAIN_TIMESTEPS)
    mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=5)
    results[name] = (mean_reward, std_reward)
    print(f"{name} → mean reward: {mean_reward:.2f}, std: {std_reward:.2f}")

print("\n=== Summary ===")
for k, v in results.items():
    print(f"{k}: mean {v[0]:.2f}, std {v[1]:.2f}")

if name == "main": run_comparison()

vicks4u avatar Sep 16 '25 01:09 vicks4u