learning-to-drive-in-5-minutes
learning-to-drive-in-5-minutes copied to clipboard
OOM error when training VAE
Describe the bug
I'm running python -m vae.train --n-epochs 50 --verbose 0 --z-size 64 -f logs/images_generated_road_single_colour/
and getting an error Allocator (GPU_0_bfc) ran out of memory trying to allocate 1.17GiB (rounded to 1251090432).
after so many training iterations.
I'v added checkpoint saving using the save_checkpoint()
function in vae/model.py. The training run was crashing after so many iterations when it tried to create the .meta file for the checkpoint but I got around that by adding write_meta_graph=False
to the saver.save
function
When I was still getting an OOM error I added self.sess.graph.finalize()
to the _init_session()
function in vae/model.py to make the graph read only and catch any changes to the graph. An exception was raised from the line vae_controller.set_target_params()
in vae/train.py, which in turn calls assign_ops.append(param.assign(loaded_p))
from within set_params()
in vae/model.py
Was reading this article https://riptutorial.com/tensorflow/example/13426/use-graph-finalize---to-catch-nodes-being-added-to-the-graph and the memory leak I am getting sounds most like their third example .. "subtle (e.g. a call to an overloaded operator on a tf.Tensor and a NumPy array, which implicitly calls tf.convert_to_tensor() and adds a new tf.constant() to the graph)."
Did you run into any OOM errors from graph growth when you were running these scripts or do you have any insights? Cheers Antonin
Code example This is my training loop section from vae/train.py (validation been added). The last line is the problem line ..
for epoch in range(args.n_epochs):
print("Training ...")
pbar = tqdm(total=len(train_minibatchlist))
for obs in train_data_loader:
feed = {vae.input_tensor: obs}
(train_loss, r_loss, kl_loss, train_step, _) = vae.sess.run([
vae.loss,
vae.r_loss,
vae.kl_loss,
vae.global_step,
vae.train_op
], feed)
pbar.update(1)
pbar.close()
print("Evaluating ...")
pbar = tqdm(total=len(val_minibatchlist))
for obs in val_data_loader:
feed = {vae.input_tensor: obs}
(val_loss, val_r_loss, val_kl_loss) = vae.sess.run([
vae.loss,
vae.r_loss,
vae.kl_loss
], feed)
pbar.update(1)
pbar.close()
print("Epoch {:3}/{}".format(epoch + 1, args.n_epochs))
print("Optimization Step: ", (train_step + 1), ", Loss: ", train_loss, " Validation Loss: ", val_loss)
# Update params
vae_controller.set_target_params()
This is the edited _init_session() from vae/model.py ..
def _init_session(self):
"""Launch tensorflow session and initialize variables"""
self.sess = tf.Session(graph=self.graph)
self.sess.run(self.init)
self.sess.graph.finalize()
And this is the supposed source of memory leak, within vae/model.py ..
def set_params(self, params):
assign_ops = []
for param, loaded_p in zip(self.params, params):
assign_ops.append(param.assign(loaded_p))
self.sess.run(assign_ops)
Let me know if you want full scripts
System Info Describe the characteristic of your environment:
- Commit of your version of the repo -
- GPU models and configuration - this is printed to terminal when I set the training run off
Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 15210 MB memory) -> physical GPU (device: 0, name: NVIDIA Tesla P100-PCIE-16GB, pci bus id: 0001:00:00.0, compute capability: 6.0)
- Python version - 3.6.9
- Tensorflow version - 1.14.0 (using tensorflow-gpu)
- Versions of any other relevant libraries
Additional context This is the full error message
File "/home/b3024896/dolphinsstorage/dolphin-tracking/DonkeyTrack/vae/train.py", line 157, in <module>
vae_controller.set_target_params()
File "/home/b3024896/dolphinsstorage/dolphin-tracking/DonkeyTrack/vae/controller.py", line 105, in set_target_params
self.target_vae.set_params(params)
File "/home/b3024896/dolphinsstorage/dolphin-tracking/DonkeyTrack/vae/model.py", line 208, in set_params
assign_ops.append(param.assign(loaded_p))
File "/home/b3024896/.local/share/virtualenvs/b3024896-7tinQHxi/lib/python3.6/site-packages/tensorflow/python/ops/variables.py", line 1952, in assign
name=name)
File "/home/b3024896/.local/share/virtualenvs/b3024896-7tinQHxi/lib/python3.6/site-packages/tensorflow/python/ops/state_ops.py", line 227, in assign
validate_shape=validate_shape)
File "/home/b3024896/.local/share/virtualenvs/b3024896-7tinQHxi/lib/python3.6/site-packages/tensorflow/python/ops/gen_state_ops.py", line 66, in assign
use_locking=use_locking, name=name)
File "/home/b3024896/.local/share/virtualenvs/b3024896-7tinQHxi/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 527, in _apply_op_helper
preferred_dtype=default_dtype)
File "/home/b3024896/.local/share/virtualenvs/b3024896-7tinQHxi/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1224, in internal_convert_to_tensor
ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
File "/home/b3024896/.local/share/virtualenvs/b3024896-7tinQHxi/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py", line 305, in _constant_tensor_conversion_function
return constant(v, dtype=dtype, name=name)
File "/home/b3024896/.local/share/virtualenvs/b3024896-7tinQHxi/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py", line 246, in constant
allow_broadcast=True)
File "/home/b3024896/.local/share/virtualenvs/b3024896-7tinQHxi/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py", line 290, in _constant_impl
name=name).outputs[0]
File "/home/b3024896/.local/share/virtualenvs/b3024896-7tinQHxi/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "/home/b3024896/.local/share/virtualenvs/b3024896-7tinQHxi/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3588, in create_op
self._check_not_finalized()
File "/home/b3024896/.local/share/virtualenvs/b3024896-7tinQHxi/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3225, in _check_not_finalized
raise RuntimeError("Graph is finalized and cannot be modified.")
RuntimeError: Graph is finalized and cannot be modified.
Hello,
Did you run into any OOM errors from graph growth when you were running these scripts or do you have any insights?
I've never experienced any OOM (I was using Google Colab to train the VAE) and I haven't use that code for a while now (I've switched to PyTorch).
The only thing I can provide you is the code I'm now using to train the AE (just made it public, I should open source the rest too): https://github.com/araffin/aae-train-donkeycar
I've been intending to switch over to PyTorch for months. Ok thank you will check it out!