acme icon indicating copy to clipboard operation
acme copied to clipboard

Rigid assumption about Observations in acme.agents.tf.mcts

Open MJHamar opened this issue 6 months ago • 2 comments
trafficstars

In acme.agents.tf.mcts.types, Observation is hinted to be a numpy ndarray

# Assumption: observations are array-like.
Observation = np.ndarray

based on this, in acme.agents.tf.mcts.acting, the MCTSActor._forward() method hard-codes a tf.expand_dims call on the observation. This makes it impossible to pass nested structures as observations to MCTS.

To fix it, we can re-define this _forward method as:

import tree
(...)
def _forward(self, observation):
        # this is all they should have done in the first place
        logits, value = self._network(tree.map_structure(lambda o: tf.expand_dims(o,axis=0), observation))
       (...)

and solve the problem.

Is there anything I am missing?

MJHamar avatar May 10 '25 09:05 MJHamar