agents icon indicating copy to clipboard operation
agents copied to clipboard

remove tensorflow warning around tf.function on per_field_where

Open cmarlin opened this issue 9 months ago • 0 comments

Hello, The tf.function decorator on function per_field_where (in file tf_agents/utils/nest_utils.py) generates a tensorflow warning:

WARNING:tensorflow:5 out of the last 8 calls to <function where.<locals>.per_field_where at 0x7f07a6484b80> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.

A simple python example could reproduce this issue:

import tensorflow as tf
import tf_agents

def test(condition, true_outputs, false_outputs):
    return tf_agents.utils.nest_utils.where(condition, true_outputs, false_outputs)

if __name__ == "__main__":
  for _ in range(10):
    condition = tf.convert_to_tensor([True, True], dtype=tf.bool)
    true_outputs = tf.convert_to_tensor([0, 1], dtype=tf.int32)
    false_outputs = tf.convert_to_tensor([2, 3], dtype=tf.int32)
    test(condition, true_outputs, false_outputs)

The issue seems related to the inner/local function definition, so it could be easily fixed by either:

  • removing the tf.function decorator
  • moving the per_field_where function outside of where function. Note we need to provide 'condition' and 'condition_rank' variables, so we could replace the statement with 'return tf.nest.map_structure(lambda t, f : _per_field_where(t, f, condition, condition_rank), true_outputs, false_outputs)'

Thanks

cmarlin avatar Dec 06 '23 12:12 cmarlin