pogema icon indicating copy to clipboard operation
pogema copied to clipboard

Added integration with native `jax` instruments

Open alexunderch opened this issue 9 months ago • 0 comments

Usage example

import jax

from pogema import GridConfig, pogema_v0

grid_config = GridConfig(
    size=8,
    num_agents=5,
    obs_radius=2,
    seed=9,
    on_target="finish",
    max_episode_steps=128,
    integration="jax",
)

env = pogema_v0(grid_config=grid_config)


key = jax.random.key(0)

# resetting
state, env_state = env.reset(key)

policy = lambda rng: jax.random.randint(
    rng, (env.num_agents,), minval=0, maxval=env.action_space().n
)

# iteration


def step_fn(carry, _):
    state, env_state, step_key = carry
    act_key, key = jax.random.split(step_key)
    action = policy(act_key)  # random agent
    next_state, new_env_state, reward, terminated, truncated, info = env.step(
        action, env_state
    )
    return (
        (next_state, new_env_state, key),
        (state, next_state, action, reward, terminated, truncated, info),
    )


_, rollout = jax.lax.scan(step_fn, (state, env_state, key), None, length=70)

alexunderch avatar May 14 '24 13:05 alexunderch