agents icon indicating copy to clipboard operation
agents copied to clipboard

GaussianPolicy applies the same noise to all actions in the batch

Open zhezherun opened this issue 1 year ago • 0 comments

If GaussianPolicy receives a batched TimeStep, it applies the same noise to all actions returned by the wrapped policy. Instead, it should sample a different noise term per batch element. Here is the code to reproduce the issue.

import tensorflow as tf
from tf_agents.networks import network
from tf_agents.policies import actor_policy, gaussian_policy
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import time_step as ts
from tf_agents.utils import nest_utils


class SimpleActor(network.Network):
    def call(self, observations, step_type=(), network_state=()):
        outer_shape = nest_utils.get_outer_shape(observations, self.input_tensor_spec)
        return tf.fill(tf.concat([outer_shape, [1]], axis=0), 0.0), network_state


def main():
    observation_spec = tensor_spec.BoundedTensorSpec(shape=[1], dtype=tf.float32, minimum=0.0, maximum=1.0)
    action_spec = tensor_spec.BoundedTensorSpec(shape=[1], dtype=tf.float32, minimum=-1, maximum=1)
    time_step_spec = ts.time_step_spec(observation_spec)

    actor_network = SimpleActor(input_tensor_spec=observation_spec)
    policy = actor_policy.ActorPolicy(
        time_step_spec=time_step_spec, action_spec=action_spec, actor_network=actor_network
    )
    noisy_policy = gaussian_policy.GaussianPolicy(policy, scale=0.1)

    time_step = ts.TimeStep(
        observation=tf.constant([[0.0], [0.0], [0.0]]),
        step_type=tf.constant([ts.StepType.FIRST, ts.StepType.FIRST, ts.StepType.FIRST]),
        discount=tf.constant([1.0, 1.0, 1.0]),
        reward=tf.constant([0.0, 0.0, 0.0]),
    )
    policy_step = policy.action(time_step)
    print("Original actions:", policy_step.action)

    policy_step = noisy_policy.action(time_step)
    print("Noisy actions:", policy_step.action)


if __name__ == "__main__":
    main()

zhezherun avatar Aug 19 '24 19:08 zhezherun