poke-env copied to clipboard
Rllib custom models
Hey Haris,
I am having trouble with making Rllib work with a custom model. Here is a link to their documentation: https://docs.ray.io/en/master/rllib-models.html?highlight=tfmodelv2#custom-models-implementing-your-own-forward-logic
They also provide an example: https://github.com/ray-project/ray/blob/master/rllib/examples/custom_keras_model.py
I am completely lost in this haha so I can't really provide any useful info on what precisely is not working. If you could give it a look it would be greatly appreciated :)
Edit: It is working properly now, but what I am having trouble is passing multiple inputs to the model. For example, I would like to pass each pokemon to an embedding layer and then concatenating them with the rest of the observation to be sent to the dense layers. for that I suppose need to do something like: return all_obs, mon1, mon2, ... , mon12, where all_obs is a concat of all the information from the state apart from the mons. I got this working properly with keras-rl, but I can't manage to make it work with rllib.
Hey @mancho2000,
Could you share the version of your code that works? If you don't want to make it public, we can talk via discord.
In this past days I updated my RLlib and now its not even starting any battles, they must have changed something on their end. I can't figure out how to get this working again, which I need to do before taking on the custom model stuff mentioned in my above comment. This is the code you provided us more than a year ago, which I have been using to build upon all this time, and now does not work:
import asyncio
import numpy as np
import ray
import ray.rllib.agents.ppo as ppo
import tensorflow as tf
from poke_env.player.player import Player
from asyncio import ensure_future, new_event_loop, set_event_loop
from gym.spaces import Box, Discrete
from poke_env.player.env_player import Gen8EnvSinglePlayer
from poke_env.player.random_player import RandomPlayer
class SimpleRLPlayer(Gen8EnvSinglePlayer):
def __init__(self, *args, **kwargs):
self.observation_space = Box(low=-10, high=10, shape=(10,))
def action_space(self):
return Discrete(22)
# We define our RL player
# It needs a state embedder and a reward computer, hence these two methods
def embed_battle(self, battle):
# -1 indicates that the move does not have a base power
# or is not available
moves_base_power = -np.ones(4)
moves_dmg_multiplier = np.ones(4)
for i, move in enumerate(battle.available_moves):
moves_base_power[i] = (
move.base_power / 100
) # Simple rescaling to facilitate learning
if move.type:
moves_dmg_multiplier[i] = move.type.damage_multiplier(
# We count how many pokemons have not fainted in each team
remaining_mon_team = (
len([mon for mon in battle.team.values() if mon.fainted]) / 6
remaining_mon_opponent = (
len([mon for mon in battle.opponent_team.values() if mon.fainted]) / 6
# Final vector with 10 components
return np.concatenate(
[remaining_mon_team, remaining_mon_opponent],
def compute_reward(self, battle) -> float:
return self.reward_computing_helper(
battle, fainted_value=2, hp_value=1, victory_value=30,
def observation_space(self):
return np.array
class MaxDamagePlayer(RandomPlayer):
def choose_move(self, battle):
# If the player can attack, it will
if battle.available_moves:
# Finds the best move among available ones
best_move = max(battle.available_moves, key=lambda move: move.base_power)
return self.create_order(best_move)
# If no attack is available, a random switch will be made
return self.choose_random_move(battle)
config = ppo.DEFAULT_CONFIG.copy()
config["num_workers"] = 0 # Training will not work with poke-env if this value != 0
config["framework"] = "tfe"
trainer = ppo.PPOTrainer(config=config, env=SimpleRLPlayer)
def ray_training_function(player):
for i in range(1000):
result = trainer.train()
checkpoint = trainer.save()
print("checkpoint saved at", checkpoint)
def ray_evaluating_function(player):
for _ in range(100):
done = False
obs = player.reset()
while not done:
action = trainer.compute_action(obs)
obs, _, done, _ = player.step(action)
"PPO Evaluation: %d victories out of %d episodes"
% (player.n_won_battles, 100)
env_player = trainer.workers.local_worker().env
first_opponent = RandomPlayer()
second_opponent = MaxDamagePlayer(battle_format="gen8randombattle")
print("\nTRAINING against random player:")
print("\nResults against random player:")
print("\nResults against max player:")
Could you please look into this first? :)