MARL-Algorithms icon indicating copy to clipboard operation
MARL-Algorithms copied to clipboard

优化代码

Open harnvo opened this issue 2 years ago • 0 comments

Hi, it's me again. I have found some redundancy in your code. These codes would not cause any error or bugs, but are not necessary and may reduce the run-time performance. I modified them so that the codes look neater.

Where do I make modifications?

  • common.replay_buffer.ReplayBuffer.store_episode()
  • RolloutWorker() and CommRolloutWorker in common.rollout

Details

1.common.replay_buffer.ReplayBuffer.store_episode()

Your code looks something like this:

def store_episode(self, episode_batch: dict):
    batch_size = episode_batch['o'].shape[0]  # episode_number
    with self.lock:
        idxs = self._get_storage_idx(inc=batch_size)
        # store the informations
        self.buffers['o'][idxs] = episode_batch['o']
        self.buffers['u'][idxs] = episode_batch['u']
        self.buffers['s'][idxs] = episode_batch['s']
        self.buffers['r'][idxs] = episode_batch['r']
        self.buffers['o_next'][idxs] = episode_batch['o_next']
        self.buffers['s_next'][idxs] = episode_batch['s_next']
        self.buffers['avail_u'][idxs] = episode_batch['avail_u']
        self.buffers['avail_u_next'][idxs] = episode_batch['avail_u_next']
        self.buffers['u_onehot'][idxs] = episode_batch['u_onehot']
        self.buffers['padded'][idxs] = episode_batch['padded']
        self.buffers['terminated'][idxs] = episode_batch['terminated']
        if self.args.alg == 'maven':
            self.buffers['z'][idxs] = episode_batch['z']

Such code looks awkward since we have to manually write out every single key, and that we have to check if there are exceptions like 'maven'. The better option is to iterate through every key in episode_batch.

def store_episode(self, episode_batch: dict):
    batch_size = episode_batch['o'].shape[0]  # episode_number
    with self.lock:
        idxs = self._get_storage_idx(inc=batch_size)
        for key in episode_batch.keys():
            self.buffers[key][idxs] = episode_batch[key]

2. How to choose action for each agent in RolloutWorker

Here is how your code implement this:

        obs = self.env.get_obs()
        state = self.env.get_state()
        actions, avail_actions, actions_onehot = [], [], []
        for agent_id in range(self.n_agents):
            avail_action = self.env.get_avail_agent_actions(agent_id)
            if self.args.alg == 'maven':
                action = self.agents.choose_action(obs[agent_id], last_action[agent_id], agent_id,
                                                   avail_action, epsilon, maven_z, evaluate)
            else:
                action = self.agents.choose_action(obs[agent_id], last_action[agent_id], agent_id,
                                                   avail_action, epsilon, evaluate)
            # generate onehot vector of th action
            action_onehot = np.zeros(self.args.n_actions)
            action_onehot[action] = 1
            actions.append(np.int(action))
            actions_onehot.append(action_onehot)
            avail_actions.append(avail_action)
            last_action[agent_id] = action_onehot

There are however three problems in your code:

- redundant checking for 'maven' algorithm

To remove the redundant checking for 'maven', I set maven_z = None if algorithm is not maven. By doing this, it is safe to put maven_z inside the arguments of self.agents.choose_action (Because that function would again check if the algorithm is maven! )

    # sample z for maven
    if self.args.alg == 'maven':
        state = self.env.get_state()
        state = torch.tensor(state, dtype=torch.float32)
        if self.args.cuda:
            state = state.cuda()
        z_prob = self.agents.policy.z_policy(state)
        maven_z = one_hot_categorical.OneHotCategorical(z_prob).sample()
        maven_z = list(maven_z.cpu())
    else:
        maven_z = None

- this should be written as an RolloutWorker method for neat coding.

This is just to make your code looks better organized. Actually, I plan to write BaseRolloutWorker as an RolloutWorker framework. There are so many similarities between RolloutWorker and CommRolloutWorker, and I believe writing a basic class is a good idea.

- actions_onehot is a python list while u_onehot in episode buffer is a numpy array.

The third problem does not raise any error because numpy has already converted your list into numpy array for you. I fix this just because this can be fixed together with the fourth problem. (Note that avail_actions is still a python list... )

- There are better ways to generate onehot vector with numpy

For generating onehot representations for actions, StackOverflow offers a better solution:

import numpy as np
nb_classes = 6
targets = np.array([[2, 3, 4, 0]]).reshape(-1)
one_hot_targets = np.eye(nb_classes)[targets]

>>>array([[[ 0.,  0.,  1.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  1.,  0.,  0.],
           [ 0.,  0.,  0.,  0.,  1.,  0.],
           [ 1.,  0.,  0.,  0.,  0.,  0.]]])

So now the code looks like this:

def _agents_choose_action(self, obs, last_action,
                          epsilon, evaluate, maven_z=None):
    """
    returns:
        actions: actions agent choose
        actions_onehot: the onehot representation for actions agent choose
        avail_actions: the available actions for all agents
    """

    actions, avail_actions = [], []

    for agent_id in range(self.n_agents):
        avail_action = self.env.get_avail_agent_actions(agent_id)

        # maven_z is None when algorithm is not maven.
        action = self.agents.choose_action(obs[agent_id], last_action[agent_id], agent_id,
                                           avail_action, epsilon, maven_z, evaluate)

        actions.append(np.int(action))
        avail_actions.append(avail_action)

    # generate one-hot vector of the action
    actions_onehot = np.eye(self.n_actions)[actions]

    return actions, actions_onehot, avail_actions

This method would be called in generate_episode here:

    while not terminated and step < self.episode_limit:
        # time.sleep(0.2)
        obs = self.env.get_obs()
        state = self.env.get_state()

        # choose action for each agent
        # Note that maven_z is None if algorithm is not maven.

        actions, actions_onehot, avail_actions = self._agents_choose_action(obs, last_action,
                                                                            epsilon, evaluate,
                                                                            maven_z)
        last_action = np.copy(actions_onehot)

Note that I use deep copy for last_action merely because I am not sure if I really should. I am doing this it play it safe. Feel free to change it to shallow copy should you know it is fine.

3.documents

Few docs are added to replay buffer and rollout worker. Do tell me if my documentation is wrong.

Doc for ReplayBuffer:

"""
The buffer is stored in self.buffers.

Keys:
    o: observation of all agents
    u: actions agents chose
    s: state of environment
    r: reward
    o_next: next observation
    s_next: next state of environment
    avail_u: actions that were available for all agents
    avail_u_next: next actions that are available for all agents
    u_onehot: onehot representation of actions agents chose
    padded: whether this step is a padded data
    terminated: whether game terminates in this step

*Key for maven:
    z: hidden state
"""

Doc for RolloutWorker.generate_episode():

@torch.no_grad()
def generate_episode(self, episode_num=None, evaluate=False):
    """
    returns:
        episode: the replay buffer containing experience of an episode
        episode_reward: the total reward for the entire episode
        win_tag: whether the policy won in this episode
        step: the number of steps in episode (Note that episode might terminate before episode limit! e.g. the AI won)
    """

Other minor changes

RolloutWorker.generate_episode()

We do not need to let epsilon = epsilon.

    if self.args.epsilon_anneal_scale == 'episode':
        epsilon = epsilon - self.anneal_epsilon if epsilon > self.min_epsilon else epsilon

So the code now looks like:

    if self.args.epsilon_anneal_scale == 'episode' and epsilon > self.min_epsilon:
        epsilon = epsilon - self.anneal_epsilon

clearer argument

self.agents.policy.init_hidden(1) is chaged to self.agents.policy.init_hidden(episode_num=1)

Type hint in ReplayBuffer.store_episode()

Now episode_batch has type hint dict.

def store_episode(self, episode_batch: dict):

Testing

I have tested my new code in RolloutWorker via the following code.

def _choose_action_old(self, obs, last_action,
                       epsilon, evaluate, maven_z=None):
    """
    This function exists just to make sure that the new code is the equivalent of the old implementation.

    returns:
        actions: actions agent choose
        actions_onehot: the onehot representation for actions agent choose
        avail_actions: the available actions for all agents
    """
    actions, avail_actions, actions_onehot = [], [], []

    for agent_id in range(self.n_agents):
        avail_action = self.env.get_avail_agent_actions(agent_id)
        if self.args.alg == 'maven':
            action = self.agents.choose_action(obs[agent_id], last_action[agent_id], agent_id,
                                               avail_action, epsilon, maven_z, evaluate)
        else:
            action = self.agents.choose_action(obs[agent_id], last_action[agent_id], agent_id,
                                               avail_action, epsilon, evaluate)
        # generate onehot vector of the action
        actions.append(np.int(action))
        avail_actions.append(avail_action)

        # unnecessary
        action_onehot = np.zeros(self.args.n_actions)
        action_onehot[action] = 1
        actions_onehot.append(action_onehot)
        last_action[agent_id] = action_onehot

    actions_onehot = np.array(actions_onehot)
    # generate one-hot vector of the action
    actions_onehot_new = np.eye(self.n_actions)[actions]
    if (actions_onehot_new != action_onehot).all():
        raise Exception(f"actions_onehot and actions_onehot_new are not the same:"
                        f"{actions_onehot}!={actions_onehot_new}")

    return actions, actions_onehot, avail_actions

No exceptions are raised when running main.py, meaning my new code in _agents_choose_action is the equivalent of the old code.

I have also tested if agents are able to learn properly under my modification under args.alg=maven and args.map='3m'. So far so good. (Let me know if I should upload the pic for you to verify.)

harnvo avatar Aug 25 '22 08:08 harnvo