Rigid assumption about Observations in acme.agents.tf.mcts
In acme.agents.tf.mcts.types, Observation is hinted to be a numpy ndarray
# Assumption: observations are array-like.
Observation = np.ndarray
based on this, in acme.agents.tf.mcts.acting, the MCTSActor._forward() method hard-codes a tf.expand_dims call on the observation. This makes it impossible to pass nested structures as observations to MCTS.
To fix it, we can re-define this _forward method as:
import tree
(...)
def _forward(self, observation):
# this is all they should have done in the first place
logits, value = self._network(tree.map_structure(lambda o: tf.expand_dims(o,axis=0), observation))
(...)
and solve the problem.
Is there anything I am missing?
Hey @MJHamar
I'm exploring the repository and finding an alternative solution for my dummy network test cases.
Here’s the implementation I used:
def _forward(network, observation):
""" Converts nested observation to tensors, runs network, and returns policy and value."""
# Convert observation leaves to tf.Tensors of type float32. (explicitly casting to float32 to ensure dtype consistency)
obs_tensor = tree.map_structure(lambda x: tf.convert_to_tensor(x, dtype=tf.float32), observation)
# Add a batch dimension of 1 to each leaf.
batched_obs = tree.map_structure(lambda x: tf.expand_dims(x, axis=0), obs_tensor)
# Run the network: get logits (1,3) and value (1,).
logits, value = network(batched_obs)
# Convert to numpy and remove batch dimension.
logits_np = logits.numpy().squeeze()
# Numerically stable softmax using logsumexp.
policy = np.exp(logits_np - logsumexp(logits_np))
value_np = float(value.numpy()) # convert to Python float scalar
return policy, value_np
Happy to hear your thoughts!
This approach helps your agent's _forward method deal gracefully with complex observation structures (like dicts or lists of sensor data) by making sure each piece of data is properly tensored and batched.
Here's the leaner code snippet:
import tree
import tensorflow as tf
import numpy as np
# Assuming this is a method within your Actor class
# e.g., MyCustomActor(object):
# def __init__(self, neural_network):
# self.network = neural_network
def _forward(self, observation):
# Helper to process each individual data array in the observation
def process_leaf_data(obs_item):
tensor_item = tf.convert_to_tensor(obs_item, dtype=tf.float32)
return tf.expand_dims(tensor_item, axis=0) # Add batch dimension
# Apply the processing to each leaf in the (potentially nested) observation
batched_observation_tree = tree.map_structure(process_leaf_data, observation)
# Feed to the network
raw_logits, raw_value = self.network(batched_observation_tree)
# Convert outputs to NumPy and remove batch dimension
logits_np = np.squeeze(raw_logits.numpy(), axis=0)
value_np = np.squeeze(raw_value.numpy(), axis=0)
# Ensure value is a scalar float
final_value = float(value_np.item()) if value_np.size == 1 else float(value_np)
return logits_np, final_value