pogema
pogema copied to clipboard
Added integration with native `jax` instruments
Usage example
import jax
from pogema import GridConfig, pogema_v0
grid_config = GridConfig(
size=8,
num_agents=5,
obs_radius=2,
seed=9,
on_target="finish",
max_episode_steps=128,
integration="jax",
)
env = pogema_v0(grid_config=grid_config)
key = jax.random.key(0)
# resetting
state, env_state = env.reset(key)
policy = lambda rng: jax.random.randint(
rng, (env.num_agents,), minval=0, maxval=env.action_space().n
)
# iteration
def step_fn(carry, _):
state, env_state, step_key = carry
act_key, key = jax.random.split(step_key)
action = policy(act_key) # random agent
next_state, new_env_state, reward, terminated, truncated, info = env.step(
action, env_state
)
return (
(next_state, new_env_state, key),
(state, next_state, action, reward, terminated, truncated, info),
)
_, rollout = jax.lax.scan(step_fn, (state, env_state, key), None, length=70)