agents icon indicating copy to clipboard operation
agents copied to clipboard

Incompatibility between PPO collect policy and SquashToSpecNormal Distribution

Open CHEyy-85 opened this issue 1 year ago • 0 comments

Hello, I am a student research assistant at the Creative Machines Lab at Columbia, contributing to the Smart Building project. During my work training a PPO agent, I encountered a series of TypeError messages when using tf_agents.agents.ppo.ppo_policy.PPOPolicy with an actor network that outputs a tfp.distributions.SquashToSpecNormal distribution.

The actor is based on tf_agents.agents.ppo.ppo_actor_network but includes one modification to transform MultivariateNormalDiag to SquashToSpecNormal to suit our bounded action spec. For context, the value network is defined with tf_agents.networks.value_network.ValueNetwork as recommended by the PPO Agent documentation.

        def create_dist(loc_and_scale):
            loc = loc_and_scale['loc']
            loc = tanh_and_scale_to_spec(loc, action_tensor_spec)
            
            scale = loc_and_scale['scale']
            scale = tf.nn.softplus(scale)
            
            dist = output_spec.build_distribution(loc=loc, scale=scale)
            # change here
            return distribution_utils.scale_distribution_to_spec(
                    dist, action_tensor_spec
            )

Then there will be errors such as

  1. TypeError: Expected binary or unicode string, got BoundedTensorSpec(shape=(2,), dtype=tf.float32, name='action', minimum=array(-1., dtype=float32), maximum=array(1., dtype=float32))
  2. TypeError: Failed to convert elements of BoundedTensorSpec(shape=(2,), dtype=tf.float32, name='action', minimum=array(-1., dtype=float32), maximum=array(1., dtype=float32)) to Tensor. Consider casting elements to a supported type. See https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.
  3. TypeError: To be compatible with tf.function, Python functions must return zero or more Tensors or ExtensionTypes or None values; in compilation of <function TFPolicy.action at 0x701338615120>, found return value of type BoundedTensorSpec, which is not a Tensor or ExtensionType.

These errors seem related to a BoundedTensorSpec type unexpectedly being passed into certain functions. However, I haven’t been able to identify the exact cause. One can easily reproduce the issues by calling the collect actor / driver of a PPO agent that uses such an actor net to run.

Note:

  1. The errors only arise when an actor / driver using PPOPolicy is running. An eval actor / driver (in my case using tf_agents.policies.greedy_policy.GreedyPolicy) runs without any issue.
  2. The SquashToSpecNormal distribution works with SAC agent, as shown in the Google open sourced smart building notebook https://github.com/google/sbsim/blob/copybara_push/smart_control/notebooks/SAC_Demo.ipynb.

CHEyy-85 avatar Oct 31 '24 05:10 CHEyy-85