first-order-model-tf
first-order-model-tf copied to clipboard
Performance issues about tf.function
Hello! Our static bug checker has found a performance issue in blob/main/run.py and blob/main/animate.py: animate
is repeatedly called in a for loop, but there are tf.function decorated functions slice_driving
and next_batch
defined and called in animate
.
In that case, when animate
is called in a loop, the functions slice_driving
and next_batch
will create a new graph every time, and that can trigger tf.function retracing warning.
Here is the tensorflow document to support it.
Briefly, for better efficiency, it's better to use:
@tf.function
def inner():
pass
def outer():
inner()
than:
def outer():
@tf.function
def inner():
pass
inner()
Looking forward to your reply.
We are investigating this kind of issues, and your answer will be of great help to our work. Can you take a look? Thank you in advance! @lshug
Hello @DLPerf. Apologies for the very late reply. I no longer actively work on this repo, and haven't done so in over three years.
The logic outlined is correct. Kudos to your static checker. The reason for creating those tf.functions inside the outer function is that they depend on variables defined within the outer function's scope. The alternative would be to build them outside and pass the needed variables as arguments, but I remember that I needed both of those tf.functions to have an input signature consisting of a single TensorSpec for performance reasons (although I don't remember the details, as I made that decision years ago). In practice, animate.py is not meant to be run in a loop and is only ran once by run.py, so retracing never happens anyway.