first-order-model-tf icon indicating copy to clipboard operation
first-order-model-tf copied to clipboard

Performance issues about tf.function

Open DLPerf opened this issue 1 year ago • 1 comments

Hello! Our static bug checker has found a performance issue in blob/main/run.py and blob/main/animate.py: animate is repeatedly called in a for loop, but there are tf.function decorated functions slice_driving and next_batch defined and called in animate.

In that case, when animate is called in a loop, the functions slice_driving and next_batch will create a new graph every time, and that can trigger tf.function retracing warning.

Here is the tensorflow document to support it.

Briefly, for better efficiency, it's better to use:

@tf.function
def inner():
    pass

def outer():
    inner()  

than:

def outer():
    @tf.function
    def inner():
        pass
    inner()

Looking forward to your reply.

DLPerf avatar Mar 03 '23 02:03 DLPerf

We are investigating this kind of issues, and your answer will be of great help to our work. Can you take a look? Thank you in advance! @lshug

DLPerf avatar Mar 06 '23 02:03 DLPerf

Hello @DLPerf. Apologies for the very late reply. I no longer actively work on this repo, and haven't done so in over three years.

The logic outlined is correct. Kudos to your static checker. The reason for creating those tf.functions inside the outer function is that they depend on variables defined within the outer function's scope. The alternative would be to build them outside and pass the needed variables as arguments, but I remember that I needed both of those tf.functions to have an input signature consisting of a single TensorSpec for performance reasons (although I don't remember the details, as I made that decision years ago). In practice, animate.py is not meant to be run in a loop and is only ran once by run.py, so retracing never happens anyway.

lshug avatar May 15 '24 15:05 lshug