acme
acme copied to clipboard
Rigid assumption about Observations in acme.agents.tf.mcts
trafficstars
In acme.agents.tf.mcts.types, Observation is hinted to be a numpy ndarray
# Assumption: observations are array-like.
Observation = np.ndarray
based on this, in acme.agents.tf.mcts.acting, the MCTSActor._forward() method hard-codes a tf.expand_dims call on the observation. This makes it impossible to pass nested structures as observations to MCTS.
To fix it, we can re-define this _forward method as:
import tree
(...)
def _forward(self, observation):
# this is all they should have done in the first place
logits, value = self._network(tree.map_structure(lambda o: tf.expand_dims(o,axis=0), observation))
(...)
and solve the problem.
Is there anything I am missing?