open_spiel
open_spiel copied to clipboard
BestResponsePolicy requires concrete policy probability values, fails with JAX abstract tracer values
best_response.BestResponsePolicy
(which is being called by my code via exploitability.nash_conv
) does not work when policy probabilities are JAX abstract tracer values:
Traceback (most recent call last):
[...]
File "/usr/local/lib/python3.10/site-packages/open_spiel/python/algorithms/best_response.py", line 181, in <genexpr>
return sum(p * self.q_value(state, a)
jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<BatchTrace(level=1/1)> with
val = Traced<ShapedArray(bool[11])>with<DynamicJaxprTrace(level=0/1)>
batch_dim = 0
The problem arose with the `bool` function.
This Tracer was created on line /usr/local/lib/python3.10/site-packages/open_spiel/python/algorithms/best_response.py:182 (<genexpr>)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
This is due to the if p > self._cut_threshold
clauses, which require p
to be concrete. Is there a recommended workaround? If not, perhaps a threshold
flag could be added to disable the threshold check completely.
Hmm, not sure I understand. I don't know what an abstract tracer value is... But don't all the probabilities need to be concrete for the best to compute a best response in the first place?
Does removing the threshold logic locally work for you?
Is it because there is a jitted function that calls BestResponsePolicy? If it is then I think it is normally not solvable because the execution flow of a jitted JAX function cannot be dependent on the concrete values of the input (e.g., if conditions. see this link).
@lanctot A tracer value is JAX's internal representation of intermediate values during transformations (jit, vmap, etc.). It keeps track of all operations (additions, multiplications, maximums, etc.) that created it from the inputs so as to compile the resulting computation, which can then be run quickly many times, differentiated, etc. See here for more details.
I tried removing the threshold logic, but it hit another snag here in BestResponsePolicy.best_response_action
's call to max
, which uses Python conditionals internally. I think the algorithm's logic would have to be rewritten to be tracer-friendly (see below).
@rezunli96 Yes. Technically, JAX can handle control flow that depends on tracer values via these functions, but rewriting plain Python control flow to use these requires care and is not always a trivial transformation.
I wrote my own implementation below (alongside OpenSpiel's, for comparison):
def get_best_response(root, policy, i):
reach_prob = dict(get_cf_reach_prob(root, policy, i))
states = get_infoset_states(root, i)
value_cache = {}
def get_value(state):
state_str = state.serialize()
if state_str not in value_cache:
if state.is_terminal():
value_cache[state_str] = state.returns()[i]
elif state.is_chance_node():
value_cache[state_str] = sum(
prob * get_value(state.child(action))
for action, prob in state.chance_outcomes()
)
else:
if state.current_player() == i:
move = get_move(state.information_state_string())
# value_cache[state_str] = get_value(state.child(move))
value_cache[state_str] = jnp.stack([
get_value(state.child(action))
for action in state.legal_actions()
])[move]
else:
value_cache[state_str] = sum(
prob * get_value(state.child(action))
for action, prob in zip(state.legal_actions(), policy(state))
)
return value_cache[state_str]
move_cache = {}
def get_move(infoset):
if infoset not in move_cache:
move_cache[infoset] = sum(
jnp.stack([
get_value(state.child(action))
for action in state.legal_actions()
]) * reach_prob[state.serialize()]
for state in states[infoset]
).argmax()
return move_cache[infoset]
get_value(root)
return {
'root_value': value_cache[root.serialize()],
'value_cache': value_cache,
'move_cache': move_cache,
}
def get_best_response_spiel(root, policy, i):
def action_probabilities(state):
return dict(zip(state.legal_actions(), policy(state)))
policy_obj = SimpleNamespace(action_probabilities=action_probabilities)
br = BestResponsePolicy(root.get_game(), i, policy_obj)
return {
'root_value': br.value(root)
}
Note the commented line, which has to be rewritten as the statement after it.
@lanctot Is this wontfix? Just wondering.
Hi @carlosgmartin, yeah sorry-- this code was not made with your use case in mind. We can updates the docs to make that more clear, but I am surprised that you expected it to work, so you would be the best person to add clarification to the docs?
I would be happy if you or anyone else volunteered to fix this but I don't see any easy solution that would keep this code clean.. do you? So mu advice is fork it, make it work for your use case, and maybe contribute it if you are so inclined?
@lanctot No problem. I wrote a self-contained module of JAX-compatible helper functions for OpenSpiel games.