acme icon indicating copy to clipboard operation
acme copied to clipboard

Support nested observation structures in MCTS agent

Open natinew77-creator opened this issue 1 month ago • 0 comments

Summary

Fixes #341

The MCTSActor._forward() method previously hard-coded tf.expand_dims() directly on the observation, which only works for array-like observations (np.ndarray). This prevented using nested structures (dicts, tuples) as observations.

Problem

As described in #341, when passing nested observation structures to the MCTS agent:

# This fails with nested observations
logits, value = self._network(tf.expand_dims(observation, axis=0))

The tf.expand_dims call assumes the observation is a single array, but many environments use nested observation spaces (dictionaries, tuples, etc.).

Solution

Modified acme/agents/tf/mcts/acting.py to use the existing tf2_utils.add_batch_dim() utility, which internally uses tree.map_structure() to apply tf.expand_dims to each leaf of the observation structure:

from acme.tf import utils as tf2_utils
# ...
batched_observation = tf2_utils.add_batch_dim(observation)
logits, value = self._network(batched_observation)

Also updated acme/agents/tf/mcts/types.py to change the Observation type from np.ndarray to Any to properly reflect that nested structures are now supported.

Changes

  • acme/agents/tf/mcts/acting.py: Use tf2_utils.add_batch_dim() instead of direct tf.expand_dims()
  • acme/agents/tf/mcts/types.py: Update Observation type alias to Any

Testing

  • Verified syntax is valid with python3 -m py_compile
  • The fix follows the existing pattern used in acme/tf/utils.py (see add_batch_dim function, line 28-30)

natinew77-creator avatar Dec 08 '25 00:12 natinew77-creator