Support nested observation structures in MCTS agent
Summary
Fixes #341
The MCTSActor._forward() method previously hard-coded tf.expand_dims() directly on the observation, which only works for array-like observations (np.ndarray). This prevented using nested structures (dicts, tuples) as observations.
Problem
As described in #341, when passing nested observation structures to the MCTS agent:
# This fails with nested observations
logits, value = self._network(tf.expand_dims(observation, axis=0))
The tf.expand_dims call assumes the observation is a single array, but many environments use nested observation spaces (dictionaries, tuples, etc.).
Solution
Modified acme/agents/tf/mcts/acting.py to use the existing tf2_utils.add_batch_dim() utility, which internally uses tree.map_structure() to apply tf.expand_dims to each leaf of the observation structure:
from acme.tf import utils as tf2_utils
# ...
batched_observation = tf2_utils.add_batch_dim(observation)
logits, value = self._network(batched_observation)
Also updated acme/agents/tf/mcts/types.py to change the Observation type from np.ndarray to Any to properly reflect that nested structures are now supported.
Changes
-
acme/agents/tf/mcts/acting.py: Usetf2_utils.add_batch_dim()instead of directtf.expand_dims() -
acme/agents/tf/mcts/types.py: UpdateObservationtype alias toAny
Testing
- Verified syntax is valid with
python3 -m py_compile - The fix follows the existing pattern used in
acme/tf/utils.py(seeadd_batch_dimfunction, line 28-30)