`recompute_grad` does not save memory and is incompatible with graph mode
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): No.
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 16.04 and Windows 10.
- TensorFlow installed from (source or binary): from binary (pip install)
- TensorFlow version (use command below): 2.1.0
- Python version: 3.7
- CUDA/cuDNN version: CUDA10.2+CuDNN7.6.5 (Windows), CUDA10.1+CuDNN7.6.5+TensorRT 6 (Ubuntu),
- GPU model and memory: GeForce GTX 1060 with Max-Q Design, 6GB (Windows) and GeForce GTX 1080 Ti, 12GB (Ubuntu)
Describe the current behavior
Using tf.recompute_grad to wrap keras layers does not take any effect. I build a DenseNet model and wrap each "-bn-relu-conv1x1-bn-relu-conv" block by the function. But I have not seen any GPU memory reduction on both the Windows and Ubuntu platforms. When eager mode is disabled, it throws "ValueError: Variable <tf.Variable 'batch_normalization/gamma:0' shape=(32,) dtype=float32> has None for gradient.", indicating that using compute_grad blocks the gradient backpropagation in graph mode.
Describe the expected behavior
The function seems to originate from OpenAI's gradient checkpointing (https://github.com/cybertronai/gradient-checkpointing) and is expected to save GPU memory during training. Recently, a tensorflow implementation of efficient DenseNets (https://github.com/joeyearsley/efficient_densenet_tensorflow) also uses this function to perform the gradient checkpointing (they used tf.contrib.layers.recompute_grad in tf1 graph mode, not exactly the same environment as our case.)
Please fix the incompatibility bug so that the function can still work with the graph mode. If the function is designed to perform gradient checkpointing, please verify its effectiveness. If it is not supposed to implement efficient DenseNets, please provide the correct and effective implementation.
Standalone code to reproduce the issue
import os
import tensorflow as tf
import tensorflow_datasets as tfds
from absl import app, flags
from absl.flags import FLAGS
from tensorflow import keras
flags.DEFINE_list("gpu",
default=None,
help="index of GPU")
flags.DEFINE_bool("recompute_grad",
default=False,
help="whether to recompute gradients to save GPU RAM")
flags.DEFINE_integer("batch_size",
default=1024,
help="batch size")
flags.DEFINE_bool("graph",
default=False,
help="use graph mode instead of eager mode")
def dense_lenet(inputs):
net = keras.layers.Conv2D(32, 5, strides=2, use_bias=False, padding="SAME")(inputs)
for _ in range(5):
def _block(x):
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU()(x)
x = keras.layers.Conv2D(16, 1, use_bias=False, padding="SAME")(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU()(x)
x = keras.layers.Conv2D(4, 3, use_bias=False, padding="SAME")(x)
return x
if FLAGS.recompute_grad:
_block = tf.recompute_grad(_block)
net = keras.layers.concatenate([net, _block(net)])
net = keras.layers.BatchNormalization()(net)
net = keras.layers.ReLU()(net)
net = keras.layers.Conv2D(64, 1, use_bias=False, padding="SAME")(net)
net = keras.layers.AveragePooling2D()(net)
for _ in range(10):
def _block(x):
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU()(x)
x = keras.layers.Conv2D(32, 1, use_bias=False, padding="SAME")(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU()(x)
x = keras.layers.Conv2D(8, 3, use_bias=False, padding="SAME")(x)
return x
if FLAGS.recompute_grad:
_block = tf.recompute_grad(_block)
net = keras.layers.concatenate([net, _block(net)])
net = keras.layers.BatchNormalization()(net)
net = keras.layers.ReLU()(net)
net = keras.layers.Conv2D(128, 1, use_bias=False, padding="SAME")(net)
net = keras.layers.AveragePooling2D()(net)
for _ in range(10):
def _block(x):
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU()(x)
x = keras.layers.Conv2D(32, 1, use_bias=False, padding="SAME")(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.ReLU()(x)
x = keras.layers.Conv2D(8, 3, use_bias=False, padding="SAME")(x)
return x
if FLAGS.recompute_grad:
_block = tf.recompute_grad(_block)
net = keras.layers.concatenate([net, _block(net)])
net = keras.layers.BatchNormalization()(net)
net = keras.layers.ReLU()(net)
net = keras.layers.GlobalAveragePooling2D()(net)
net = keras.layers.Dense(10)(net)
net = keras.layers.Softmax()(net)
return net
def main(_):
if FLAGS.gpu:
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, FLAGS.gpu))
if FLAGS.graph:
tf.compat.v1.disable_eager_execution()
tf.compat.v1.keras.backend.set_session(
session=tf.compat.v1.Session(
config=tf.compat.v1.ConfigProto(
gpu_options=tf.compat.v1.GPUOptions(
allow_growth=True
)
)
)
)
else:
for gpu in tf.config.experimental.list_physical_devices('GPU'):
tf.config.experimental.set_memory_growth(gpu, True)
tfds.core.constants.DATA_DIR = "data"
dataset_builder = tfds.image.FashionMNIST(version="3.*.*")
dataset_builder.download_and_prepare()
dataset = dataset_builder.as_dataset(
split="train",
shuffle_files=True,
as_supervised=True,
).repeat().batch(FLAGS.batch_size)
inputs = keras.layers.Input((28, 28, 1), batch_size=FLAGS.batch_size)
model = keras.Model(inputs, dense_lenet(inputs))
model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
model.summary()
model.fit(
x=dataset,
epochs=3,
steps_per_epoch=60000//FLAGS.batch_size,
)
if __name__ == "__main__":
app.run(main)
@BinyanHu I've got a minimal example showing that it doesn't work for memory reduction over here: https://github.com/tensorflow/tensorflow/issues/30418#issuecomment-589820336
@BinyanHu if you're interested, I've written a simple gradient checkpointing decorator here
@davisyoshida Thanks for doing this! I've been looking at implementing the same thing; will test out yours.
@mathemakitten Happy to help! Do let me know if you run into any issues.
@davisyoshida Thank you for sharing. Will test your implementation as soon as possible!
@davisyoshida does this work with keras? if so, can you provide a small example of how to use it with keras?
@paulter I have a version working with Keras but sequential models only. I have create a pull request as part of TF addons github repo - https://github.com/tensorflow/addons/pull/1600. You can find an example notebook here - https://github.com/pidajay/addons/blob/grad_checkpointing_eager/docs/tutorials/training_gradient_checkpointing.ipynb
Thank you @pidajay
@pidajay thanks for the work on this.
you say this only works on sequential models? unfortunately, my keras model is too complex to be a sequential model, so I can't use your code. Is there a way I can use what you've written?
I don't mind manually checkpointing - in fact, it is probably preferable. I'm writing custom keras lines for research and it would be nice to have something that specifies to recompute the gradient for a particular layer.
Unfortunately, I don't really understand the documentation for recompute_grad here https://www.tensorflow.org/api_docs/python/tf/recompute_grad
The documentation there seems to imply that you go:
my_layer = tf.recompute_grad(keras.layers.Conv2D(...))
but this gives no memory improvements.
Any chance that I can use what you've written, even if it's in a manual way?
@paulter I have posted a small tutorial here https://github.com/pidajay/tf2_gradient_checkpointing/blob/master/tf_recompute_grad_tutorial.ipynb For this to work you need to replace (or just copy the delta) the custom_gradient.py file with this version in my TF fork https://github.com/pidajay/tensorflow/blob/fix_gradient_checkpointing/tensorflow/python/ops/custom_gradient.py I plan to submit this fix as a PR soon but not sure if TF folks would be interested. Unfortunately my example demonstrates how to do this for a keras sequential model in eager mode. But splitting a functional or custom model and invoking recompute_grad should work the same way. Just that I need to check if the graph mode decorator has the same bug as the eager mode decorator (conversation at top of this thread says it has been fixed). Will dig into this week and let you know. Hope this helps.
Any news for the Graph Mode models? I tried to use the code from @pidajay. Still, as long as I passed any keywords like variables to the recomputed grad function, TF raised an error 'The custom_gradient decorator currently supports keywords arguments only when eager execution is enabled".
If you're looking to do gradient checkpointing in graph mode I suggest the implementation tf-slim here, which I've extracted and successfully tested on tf-nightly in graph mode on TPU: https://github.com/google-research/tf-slim/blob/a62dc893de5e46e6f2e9ec24a74b2abce026307a/tf_slim/layers/rev_block_lib.py
If you're looking to do gradient checkpointing in graph mode I suggest the implementation tf-slim here, which I've extracted and successfully tested on tf-nightly in graph mode on TPU: https://github.com/google-research/tf-slim/blob/a62dc893de5e46e6f2e9ec24a74b2abce026307a/tf_slim/layers/rev_block_lib.py
Thanks for your advice. I tried the extracted code from tf-slim. It did work to some degree, but in my case, it just reduced 5% of memory usage. Finally, I just copied the Tensorflow v1.15's contribs library's Graph Editor. With the OpenAI's Gradient Checkpointing, I got the memory reduction of 40% at the cost of 48% longer time.
This still doesn't seem to work... with a custom keras model.
@BinyanHu Did you find any workaround for gradient-checkpointing that indeed works?
Hi,
Thank you for opening this issue. Since this issue has been open for a long time, the code/debug information for this issue may not be relevant with the current state of the code base.
The Tensorflow team is constantly improving the framework by fixing bugs and adding new features. We suggest you try the latest TensorFlow version with the latest compatible hardware configuration which could potentially resolve the issue. If you are still facing the issue, please create a new GitHub issue with your latest findings, with all the debugging information which could help us investigate.
Please follow the release notes to stay up to date with the latest developments which are happening in the Tensorflow space.
This issue is stale because it has been open for 7 days with no activity. It will be closed if no further activity occurs. Thank you.
This issue was closed because it has been inactive for 7 days since being marked as stale. Please reopen if you'd like to work on this further.