Mava
Mava copied to clipboard
[BUG] Remove nested tf.function
Describe the bug
Nested tf.function
decorators are causes TF to constantly retrace which is could cause significant performance and memory issues.
Additional context I think this bug creeped in when we refactored our code to separate forward and backward passes.
Possible Solution
Remove tf.function
decorator from the backward pass.
This bug is also related to #77 and #346
Already fixed in #362.
Having tf.function
over the _policy
function in the executor causes retracing as well because _policy
is called inside a for loop
in the select_actions
function. It is better to create a new function called _select_functions
and put the for loop
inside there and put the tf.function
decorator on that function. See MADQN executor in #362
Having
tf.function
over the_policy
function in the executor causes retracing as well because_policy
is called inside afor loop
in theselect_actions
function. It is better to create a new function called_select_functions
and put thefor loop
inside there and put thetf.function
decorator on that function. See MADQN executor in #362
Thanks @jcformanek! Can you please make an issue for this with the above advice that we look at fixing it for other systems as well.
Reopening this issue until we fix the problem in all other systems.
https://github.com/instadeepai/Mava/pull/353 handles this for ppo.
Closing all TF issues as we are depreciating our TF systems.