FinRL
FinRL copied to clipboard
Improvements I want make in [finrl]->[agent]->[rllib]->[models.py]
Here are the improvements made to the code:
1 - Imported
with_common_config
,Trainer
, andCOMMON_CONFIG
to make the code cleaner and more concise. 2 - Utilized individual algorithm trainers fromrllib.agents
instead of importing them directly from their respective modules to maintain consistency and readability. 3 - Created a private method_get_default_config
to handle retrieving the default configuration for each model, reducing code duplication. 4 - Improved error handling in theDRL_prediction
method by catching exceptions and raising aValueError
with a meaningful error message.
# DRL models from RLlib
from __future__ import annotations
import ray
from ray.rllib.agents import with_common_config
from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG
# Import individual algorithms for easier access
from ray.rllib.agents.a3c import A3CTrainer, DEFAULT_CONFIG as A3C_CONFIG
from ray.rllib.agents.ddpg import DDPGTrainer, DEFAULT_CONFIG as DDPG_CONFIG
from ray.rllib.agents.ppo import PPOTrainer, DEFAULT_CONFIG as PPO_CONFIG
from ray.rllib.agents.sac import SACTrainer, DEFAULT_CONFIG as SAC_CONFIG
from ray.rllib.agents.ddpg import TD3Trainer, DEFAULT_CONFIG as TD3_CONFIG
MODELS = {"a3c": A3CTrainer, "ddpg": DDPGTrainer, "td3": TD3Trainer, "sac": SACTrainer, "ppo": PPOTrainer}
class DRLAgent:
"""Implementations for DRL algorithms
Attributes
----------
env: gym environment class
user-defined class
price_array: numpy array
OHLC data
tech_array: numpy array
techical data
turbulence_array: numpy array
turbulence/risk data
Methods
-------
get_model()
setup DRL algorithms
train_model()
train DRL algorithms in a train dataset
and output the trained model
DRL_prediction()
make a prediction in a test dataset and get results
"""
def __init__(self, env, price_array, tech_array, turbulence_array):
self.env = env
self.price_array = price_array
self.tech_array = tech_array
self.turbulence_array = turbulence_array
def get_model(
self,
model_name,
# policy="MlpPolicy",
# policy_kwargs=None,
# model_kwargs=None,
):
if model_name not in MODELS:
raise NotImplementedError("NotImplementedError")
model = MODELS[model_name]
model_config = self._get_default_config(model_name)
# pass env, log_level, price_array, tech_array, and turbulence_array to config
model_config["env"] = self.env
model_config["log_level"] = "WARN"
model_config["env_config"] = {
"price_array": self.price_array,
"tech_array": self.tech_array,
"turbulence_array": self.turbulence_array,
"if_train": True,
}
return model, model_config
def train_model(
self, model, model_name, model_config, total_episodes=100, init_ray=True
):
if model_name not in MODELS:
raise NotImplementedError("NotImplementedError")
if init_ray:
ray.init(
ignore_reinit_error=True
)
trainer = model(env=self.env, config=model_config)
for _ in range(total_episodes):
trainer.train()
ray.shutdown()
cwd = "./test_" + str(model_name)
trainer.save(cwd)
return trainer
@staticmethod
def DRL_prediction(
model_name,
env,
price_array,
tech_array,
turbulence_array,
agent_path="./test_ppo/checkpoint_000100/checkpoint-100",
):
if model_name not in MODELS:
raise NotImplementedError("NotImplementedError")
model = MODELS[model_name]
model_config = self._get_default_config(model_name)
model_config["env"] = env
model_config["log_level"] = "WARN"
model_config["env_config"] = {
"price_array": price_array,
"tech_array": tech_array,
"turbulence_array": turbulence_array,
"if_train": False,
}
env_config = {
"price_array": price_array,
"tech_array": tech_array,
"turbulence_array": turbulence_array,
"if_train": False,
}
env_instance = env(config=env_config)
trainer = model(env=env, config=model_config)
try:
trainer.restore(agent_path)
print("Restoring from checkpoint path", agent_path)
except BaseException as e:
raise ValueError("Fail to load agent!") from e
state = env_instance.reset()
episode_returns = []
episode_total_assets = [env_instance.initial_total_asset]
done = False
while not done:
action = trainer.compute_single_action(state)
state, reward, done, _ = env_instance.step(action)
total_asset = (
env_instance.amount
+ (env_instance.price_ary[env_instance.day] * env_instance.stocks).sum()
)
episode_total_assets.append(total_asset)
episode_return = total_asset / env_instance.initial_total_asset
episode_returns.append(episode_return)
ray.shutdown()
print("episode return: " + str(episode_return))
print("Test Finished!")
return episode_total_assets
@staticmethod
def _get_default_config(model_name):
model = MODELS[model_name]
if model_name == "a3c":
return A3C_CONFIG.copy()
elif model_name == "ddpg":
return DDPG_CONFIG.copy()
elif model_name == "td3":
return TD3_CONFIG.copy()
elif model_name == "sac":
return SAC_CONFIG.copy()
elif model_name == "ppo":
return PPO_CONFIG.copy()
Hi, wondering which version of ray is using, I notice that ray has moved rllib.agents.[algorithms] to relib.algorithms.[algorithms] long time ago.
Following is what I changed to make it run: ray version ==2.1.0 & change 'from ray.rllib.agents.sac' to from 'ray.rllib.algorithms.sac' & delete algorithms of a3c and td3, because there doesn't have these algorithms in the new version
Hi, wondering which version of ray is using, I notice that ray has moved rllib.agents.[algorithms] to relib.algorithms.[algorithms] long time ago.
Following is what I changed to make it run: ray version ==2.1.0 & change 'from ray.rllib.agents.sac' to from 'ray.rllib.algorithms.sac' & delete algorithms of a3c and td3, because there doesn't have these algorithms in the new version
Thanks for the review buddy!!