TextWorld
TextWorld copied to clipboard
please add observation and action spaces
this is a key component of the gym API
here's a custom space for strings (license: MIT, author: Bion Howard)
class String(gym.Space):
def __init__(
self,
length=None,
letters=LETTERS,
min_length=1,
max_length=280):
self.min_length = min_length
self.max_length = max_length
self.letters = letters
self.length = length
def sample(self):
length = self.length if self.length else random.randint(self.min_length, self.max_length)
s = ''
for i in range(length):
letter = random.choice(self.letters)
s += letter
return s
def contains(self, x):
is_a_string = isinstance(x, str)
correct_length = self.min_length <= len(x) <= self.max_length
correct_letters = all([l in self.letters for l in x])
return is_a_string and correct_length and correct_letters
def __repr__(self):
return f"String(min_length={self.min_length},length={self.length},max_length={self.max_length},letters={self.letters})"
Hi @bionicles, I'd be happy to integrate it into TextWorld. Can you make a PR to add it to https://github.com/microsoft/TextWorld/blob/master/textworld/gym/spaces/text_spaces.py ? Also, note the existing textworld.gym.spaces.Char
.
What space would be a reasonable default? It might be better to use an existing one if it's built for this (less code)
(just curious because spaces help with random actions and normalization)
Sorry for the delay in getting back to you, I just got back from paternity leave.
What default are you referring to? If you are talking about TextworldGymEnv, I've set it to None to force the user to think of what makes sense in their case. The main reason being I wasn't sure how to pick good values for LETTERS
(or WORDS/VOCAB
) and
max_length
. If you have some ideas, I'm all hears.
congrats on being a new dad! here's what i wound up doing so far: self.observation_space = String()
and posted the updated string space below
however, that's gonna pass raw strings to the agent, so the agent needs a string sensor to handle string observations
another option which plays better with frameworks would be to convert the text into a numpy array of UTF-8 bytes (uint8) then cast to float32 and normalize... this could go in a wrapper, and the observation could then just be a float32 gym.spaces.Box
# within nature/sense.py
sense_str = lambda mystring: jnp.array(list(bytes(mystring, "utf-8")), dtype=jnp.float32) / 255
here's a utf-8 actuator ( also could go in a wrapper ) ... it stops writing at the first non-UTF8 byte
# within nature/actuate.py
non_unicode_bytes = jnp.array([0, 247, 248, 249, 250, 251, 252, 253, 254, 255])
def _actuate_string(space, values, xmin=-1.0, xmax=1.0):
decimal = rescale(xmin, xmax, 0, 255, values.flatten()).astype("int")
bad_bytes = jnp.where(jnp.isin(decimal, non_unicode_bytes))[0]
if bad_bytes.size > 0:
decimal = decimal[: bad_bytes[0]]
result = bytearray(decimal).decode("utf-8", errors="ignore")
if space.letters:
result = "".join(c for c in result if c in space.letterset)
return result
then the rescale function is this:
# tricks/rescale.py
"min max scale function"
import jax.numpy as jnp
from jax import jit
@jit
def maybe_replace(z):
return (
jnp.nan_to_num(z, neginf=-1000.0, posinf=1000.0)
if jnp.issubsctype(z, jnp.inexact)
else z
)
@jit
def rescale(xmin, xmax, ymin, ymax, inputs):
"rescales inputs from [xmin, xmax] to [ymin, ymax]"
xmin, xmax, ymin, ymax = [maybe_replace(z) for z in [xmin, xmax, ymin, ymax]]
return jnp.nan_to_num(
jnp.clip(
((((inputs - xmin) * (ymax - ymin)) / (xmax - xmin)) + ymin), ymin, ymax,
)
)
here's some tinkering with the string space
# nurture/spaces/string.py
import re
from jax import random
from gym import Space
from tricks import RNG # just an iterator over jax.random.PRNGKey
LETTERS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!#$%&()*+,-./:;<=>?@[]^_`{|}~' "
regex = re.compile(
r"[^abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!#$%&()*+,-./:;<=>?@[\]^_`{\|}~' ]"
)
def sanitize_string(string, letters=None):
"filter string so it only contains letters in 'letters' kwarg"
if letters is None:
return string
else:
return "".join(char for char in string if char in letters)
class String(Space):
"a space of potential strings from min to max length with certain set of letters"
def __init__(self, length=None, letters=LETTERS, min_len=0, max_len=4096):
self.min_len = min_len
self.max_len = max_len
self.letters = letters
self.len = length
self.rng = None
self.seed()
def seed(self, initial=None):
_seed = 420 if initial is None else initial
assert _seed is not None
self.rng = RNG(_seed)
def sample(self):
letters = self.letters if self.letters else LETTERS
key = next(self.rng)
length = self.len if self.len else self.max_len
string = []
for _ in range(length):
key = next(self.rng)
string.append(letters[random.choice(key, len(letters))])
return "".join(string)
def contains(self, x):
if not isinstance(x, str):
return False
if not self.min_len <= len(x) <= self.max_len:
return False
if self.letters:
if not all([l in self.letters for l in x]):
return False
return True
def __repr__(self):
letters = "DEFAULT" if self.letters == LETTERS else self.letters
return f"String(min_len={self.min_len},len={self.len},max_len={self.max_len},letters={letters})"
@property
def letterset(self):
if self.letters is None:
return set()
return set(self.letters)
here's a wrapper to make various difficulty levels: easy difficulty provides more information and uses a Discrete action space harder difficulties get into String action space
#nurture/textworld/env.py
"wrap the microsoft textworld MUD-style game"
from random import randint
import shutil
import os
import gym
import textworld
import textworld.gym
from nurture import String, sanitize_string
def _get_easy_options():
options = textworld.GameOptions()
options.nb_objects = randint(1, 8)
options.quest_length = randint(1, 8)
options.nb_rooms = randint(1, 8)
return options
def _get_hard_options():
options = textworld.GameOptions()
options.nb_rooms = randint(8, 10)
options.nb_objects = randint(8, 10)
options.nb_parallel_quests = randint(1, 2)
options.quest_length = randint(8, 10)
options.quest_breadth = randint(1, 2)
options.quest_depth = randint(2, 8)
return options
GET_OPTIONS = dict(
easy=_get_easy_options, hard=_get_hard_options, expert=_get_hard_options
)
INFOS = textworld.EnvInfos(
objective=True,
inventory=True,
description=True,
admissible_commands=True,
feedback=True,
)
MAX_STEPS = dict(easy=50, hard=200, expert=420)
fp = os.path.dirname(__file__)
def _make_textworld(difficulty):
try:
shutil.rmtree(os.path.join(fp, "tw_games", difficulty))
except Exception as _:
pass
options = GET_OPTIONS[difficulty]()
options.path = os.path.join(fp, "tw_games", difficulty, "game.ulx")
game_file, _ = textworld.make(options)
return textworld.gym.register_game(
game_file, INFOS, max_episode_steps=MAX_STEPS[difficulty]
)
def _parse_inventory(inv):
return inv.replace("You are carrying", "You carry").replace(":", "") + "."
class TextWorldWrapper(gym.Env):
"a custom env to adjust the textworld env observation space and difficulty levels"
def __init__(self, difficulty="easy"):
self.observation_space = String()
self.action_space = String()
self.difficulty = difficulty
if difficulty in ["easy", "hard"]:
self.stringify = self._stringify_easy
else:
self.stringify = _stringify_hard
self.commands = None
self._env = None
def reset(self):
env_id = _make_textworld(self.difficulty)
self._env = gym.make(env_id)
_, i = self._env.reset()
return self.stringify(i, stepping=False)
def step(self, action):
if isinstance(action, int):
action = self.commands[action]
_, r, d, i = self._env.step(action)
return self.stringify(i), r, d, {}
def render(self):
self._env.render()
def _stringify_easy(self, i, stepping=True):
inventory = _parse_inventory(i["inventory"])
self.commands = [
c for c in i["admissible_commands"] if c not in ["inventory", "look"]
]
self.action_space = gym.spaces.Discrete(len(self.commands))
obs = (
f'{i["objective"]} {i["description"]} {inventory} Commands: {self.commands}'
)
if stepping:
obs += f" Feedback: {i['feedback']}"
return sanitize_string(
obs.replace("\n", " ")
.replace(" ", " ")
.replace(" ", " ")
.replace("\\", "")
.strip()
)
def _stringify_hard(i, stepping=True):
obs = i["feedback"] if stepping else i["objective"]
obs = (
obs.strip()
.replace("\n", " ")
.replace(" ", " ")
.replace(" ", " ")
.replace("\\", "")
)
if stepping and "You are carrying:" in obs:
obs = _parse_inventory(obs)
return sanitize_string(obs)
just for completeness, here is the rng class
# tricks/rng.py
import jax
@jax.tree_util.register_pytree_node_class
class RNG:
"PRNGKey iterator"
def __init__(self, seed):
self.seed = seed
self.key = jax.random.PRNGKey(seed)
def __iter__(self):
return self
def __next__(self):
self.key, output = jax.random.split(self.key)
return output
def tree_flatten(self):
return ((self.seed, self.key), None)
@classmethod
def tree_unflatten(cls, _, rng_state):
seed, key = rng_state
new = cls(seed)
new.key = key
return new
def __eq__(self, other):
if not isinstance(other, RNG):
return False
same_key = jax.numpy.all(other.key == self.key)
same_seed = other.seed == self.seed
if same_seed and same_key:
return True
return False
Thanks for sharing your code. I like the style and it is very insightful.
I never thought of changing the env.action_space at every step (i.e. choice-based setting) but that might not play well with some existing algorithms, e.g. in the OpenAI's baselines repo: PolicyWithValue where .n
will change throughout the episode.
I'd be happy to integrate your String
space to textworld.gym.spaces. Or, maybe, it could be added into the Gym codebase directly?
@MarcCote @bionicles any updates on the above. It seems like there are a few warnings these days from gym
.
Maybe it is now related to: https://github.com/microsoft/TextWorld/issues/324
This was not integrated in TextWorld yet. I'd be happy to review any PR though.
Note that gym
is no longer under development, it has been replaced with gymnasium
which seems to have a Text space.
https://gymnasium.farama.org/api/spaces/fundamental/#gymnasium.spaces.Text
Dependency on gym
has been dropped. No need for those spaces anymore.
See #341