agents
agents copied to clipboard
remove tensorflow warning around tf.function on per_field_where
Hello, The tf.function decorator on function per_field_where (in file tf_agents/utils/nest_utils.py) generates a tensorflow warning:
WARNING:tensorflow:5 out of the last 8 calls to <function where.<locals>.per_field_where at 0x7f07a6484b80> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
A simple python example could reproduce this issue:
import tensorflow as tf
import tf_agents
def test(condition, true_outputs, false_outputs):
return tf_agents.utils.nest_utils.where(condition, true_outputs, false_outputs)
if __name__ == "__main__":
for _ in range(10):
condition = tf.convert_to_tensor([True, True], dtype=tf.bool)
true_outputs = tf.convert_to_tensor([0, 1], dtype=tf.int32)
false_outputs = tf.convert_to_tensor([2, 3], dtype=tf.int32)
test(condition, true_outputs, false_outputs)
The issue seems related to the inner/local function definition, so it could be easily fixed by either:
- removing the tf.function decorator
- moving the per_field_where function outside of where function. Note we need to provide 'condition' and 'condition_rank' variables, so we could replace the statement with 'return tf.nest.map_structure(lambda t, f : _per_field_where(t, f, condition, condition_rank), true_outputs, false_outputs)'
Thanks