acme icon indicating copy to clipboard operation
acme copied to clipboard

Rigid assumption about Observations in acme.agents.tf.mcts

Open MJHamar opened this issue 11 months ago • 3 comments

In acme.agents.tf.mcts.types, Observation is hinted to be a numpy ndarray

# Assumption: observations are array-like.
Observation = np.ndarray

based on this, in acme.agents.tf.mcts.acting, the MCTSActor._forward() method hard-codes a tf.expand_dims call on the observation. This makes it impossible to pass nested structures as observations to MCTS.

To fix it, we can re-define this _forward method as:

import tree
(...)
def _forward(self, observation):
        # this is all they should have done in the first place
        logits, value = self._network(tree.map_structure(lambda o: tf.expand_dims(o,axis=0), observation))
       (...)

and solve the problem.

Is there anything I am missing?

MJHamar avatar May 10 '25 09:05 MJHamar

Hey @MJHamar

I'm exploring the repository and finding an alternative solution for my dummy network test cases.

Here’s the implementation I used:

def _forward(network, observation):
    """ Converts nested observation to tensors, runs network, and returns policy and value."""
    # Convert observation leaves to tf.Tensors of type float32. (explicitly casting to float32 to ensure dtype consistency)
    obs_tensor = tree.map_structure(lambda x: tf.convert_to_tensor(x, dtype=tf.float32), observation)
    # Add a batch dimension of 1 to each leaf.
    batched_obs = tree.map_structure(lambda x: tf.expand_dims(x, axis=0), obs_tensor)
    # Run the network: get logits (1,3) and value (1,).
    logits, value = network(batched_obs)
    # Convert to numpy and remove batch dimension.
    logits_np = logits.numpy().squeeze()
    # Numerically stable softmax using logsumexp.
    policy = np.exp(logits_np - logsumexp(logits_np))
    value_np = float(value.numpy())  # convert to Python float scalar
    return policy, value_np

Happy to hear your thoughts!

yuvraajnarula avatar May 11 '25 08:05 yuvraajnarula

This approach helps your agent's _forward method deal gracefully with complex observation structures (like dicts or lists of sensor data) by making sure each piece of data is properly tensored and batched.

Here's the leaner code snippet:

import tree
import tensorflow as tf
import numpy as np

# Assuming this is a method within your Actor class
# e.g., MyCustomActor(object):
# def __init__(self, neural_network):
#     self.network = neural_network

def _forward(self, observation):
    # Helper to process each individual data array in the observation
    def process_leaf_data(obs_item):
        tensor_item = tf.convert_to_tensor(obs_item, dtype=tf.float32)
        return tf.expand_dims(tensor_item, axis=0)  # Add batch dimension

    # Apply the processing to each leaf in the (potentially nested) observation
    batched_observation_tree = tree.map_structure(process_leaf_data, observation)

    # Feed to the network
    raw_logits, raw_value = self.network(batched_observation_tree)

    # Convert outputs to NumPy and remove batch dimension
    logits_np = np.squeeze(raw_logits.numpy(), axis=0)
    value_np = np.squeeze(raw_value.numpy(), axis=0)
    
    # Ensure value is a scalar float
    final_value = float(value_np.item()) if value_np.size == 1 else float(value_np)

    return logits_np, final_value

Kenxpx avatar May 19 '25 12:05 Kenxpx