multinerf
multinerf copied to clipboard
OOM with only ~12MB memory allocated / requested on GPU
Hi, I tried to run ref-nerf but no matter how small the batch size is, the OOM problem would arise. I am not sure if it is the bug of jax or multinerf or tf. I've tried jax v0.3.24/25 but got the same problem.
I used:
python 3.9
ubuntu 20.04.5
RTX3090
CUDA 11.6 CUDNN 8.6
NV driver 510
jax v0.3.24/25
flax v0.6.1/2
2022-11-29 21:04:38.780955: W external/org_tensorflow/tensorflow/tsl/framework/bfc_allocator.cc:479] Allocator (GPU_0_bfc) ran out of memory trying to allocate 6.00MiB (rounded to 6291456)requested by op
2022-11-29 21:04:39.275116: W external/org_tensorflow/tensorflow/tsl/framework/bfc_allocator.cc:492] ****************************************************************************************************
2022-11-29 21:04:39.275246: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2153] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 6291456 bytes.
2022-11-29 21:04:39.275246: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2153] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 6291456 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 6.00MiB
constant allocation: 0B
maybe_live_out allocation: 6.00MiB
preallocated temp allocation: 0B
total allocation: 12.00MiB
total fragmentation: 0B (0.00%)
Peak buffers:
Buffer 1:
Size: 6.00MiB
Operator: op_name="jit(concatenate)/jit(main)/concatenate[dimension=0]" source_file="/media/gccrcv/Data/Opensources/multinerf/internal/models.py" source_line=689
XLA Label: concatenate
Shape: f32[4096,128,3]
==========================
Buffer 2:
Size: 384.0KiB
Entry Parameter Subshape: f32[256,128,3]
==========================
Buffer 3:
Size: 384.0KiB
Entry Parameter Subshape: f32[256,128,3]
==========================
Buffer 4:
Size: 384.0KiB
Entry Parameter Subshape: f32[256,128,3]
==========================
Buffer 5:
Size: 384.0KiB
Entry Parameter Subshape: f32[256,128,3]
==========================
Buffer 6:
Size: 384.0KiB
Entry Parameter Subshape: f32[256,128,3]
==========================
Buffer 7:
Size: 384.0KiB
Entry Parameter Subshape: f32[256,128,3]
==========================
Buffer 8:
Size: 384.0KiB
Entry Parameter Subshape: f32[256,128,3]
==========================
Buffer 9:
Size: 384.0KiB
Entry Parameter Subshape: f32[256,128,3]
==========================
Buffer 10:
Size: 384.0KiB
Entry Parameter Subshape: f32[256,128,3]
==========================
Buffer 11:
Size: 384.0KiB
Entry Parameter Subshape: f32[256,128,3]
==========================
Buffer 12:
Size: 384.0KiB
Entry Parameter Subshape: f32[256,128,3]
==========================
Buffer 13:
Size: 384.0KiB
Entry Parameter Subshape: f32[256,128,3]
==========================
Buffer 14:
Size: 384.0KiB
Entry Parameter Subshape: f32[256,128,3]
==========================
Buffer 15:
Size: 384.0KiB
Entry Parameter Subshape: f32[256,128,3]
==========================
Traceback (most recent call last):
File "/home/gccrcv/anaconda3/envs/multinerf/lib/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/gccrcv/anaconda3/envs/multinerf/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/media/gccrcv/Data/Opensources/multinerf/train.py", line 288, in <module>
app.run(main)
File "/home/gccrcv/anaconda3/envs/multinerf/lib/python3.9/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/gccrcv/anaconda3/envs/multinerf/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/media/gccrcv/Data/Opensources/multinerf/train.py", line 229, in main
rendering = models.render_image(
File "/media/gccrcv/Data/Opensources/multinerf/internal/models.py", line 689, in render_image
jax.tree_util.tree_map(lambda *args: jnp.concatenate(args), *chunks))
File "/home/gccrcv/anaconda3/envs/multinerf/lib/python3.9/site-packages/jax/_src/tree_util.py", line 207, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/home/gccrcv/anaconda3/envs/multinerf/lib/python3.9/site-packages/jax/_src/tree_util.py", line 207, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/media/gccrcv/Data/Opensources/multinerf/internal/models.py", line 689, in <lambda>
jax.tree_util.tree_map(lambda *args: jnp.concatenate(args), *chunks))
File "/home/gccrcv/anaconda3/envs/multinerf/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 1791, in concatenate
arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)
File "/home/gccrcv/anaconda3/envs/multinerf/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 1791, in <listcomp>
arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)
File "/home/gccrcv/anaconda3/envs/multinerf/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 648, in concatenate
return concatenate_p.bind(*operands, dimension=dimension)
File "/home/gccrcv/anaconda3/envs/multinerf/lib/python3.9/site-packages/jax/core.py", line 329, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/gccrcv/anaconda3/envs/multinerf/lib/python3.9/site-packages/jax/core.py", line 332, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/gccrcv/anaconda3/envs/multinerf/lib/python3.9/site-packages/jax/core.py", line 712, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/gccrcv/anaconda3/envs/multinerf/lib/python3.9/site-packages/jax/_src/dispatch.py", line 115, in apply_primitive
return compiled_fun(*args)
File "/home/gccrcv/anaconda3/envs/multinerf/lib/python3.9/site-packages/jax/_src/dispatch.py", line 200, in <lambda>
return lambda *args, **kw: compiled(*args, **kw)[0]
File "/home/gccrcv/anaconda3/envs/multinerf/lib/python3.9/site-packages/jax/_src/dispatch.py", line 895, in _execute_compiled
out_flat = compiled.execute(in_flat)
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 6291456 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 6.00MiB
constant allocation: 0B
maybe_live_out allocation: 6.00MiB
preallocated temp allocation: 0B
total allocation: 12.00MiB
total fragmentation: 0B (0.00%)
Peak buffers:
Buffer 1:
Size: 6.00MiB
Operator: op_name="jit(concatenate)/jit(main)/concatenate[dimension=0]" source_file="/media/gccrcv/Data/Opensources/multinerf/internal/models.py" source_line=689
XLA Label: concatenate
Shape: f32[4096,128,3]
==========================
Buffer 2:
Size: 384.0KiB
Entry Parameter Subshape: f32[256,128,3]
==========================
Buffer 3:
Size: 384.0KiB
Entry Parameter Subshape: f32[256,128,3]
==========================
Buffer 4:
Size: 384.0KiB
Entry Parameter Subshape: f32[256,128,3]
==========================
Buffer 5:
Size: 384.0KiB
Entry Parameter Subshape: f32[256,128,3]
==========================
Buffer 6:
Size: 384.0KiB
Entry Parameter Subshape: f32[256,128,3]
==========================
Buffer 7:
Size: 384.0KiB
Entry Parameter Subshape: f32[256,128,3]
==========================
Buffer 8:
Size: 384.0KiB
Entry Parameter Subshape: f32[256,128,3]
==========================
Buffer 9:
Size: 384.0KiB
Entry Parameter Subshape: f32[256,128,3]
==========================
Buffer 10:
Size: 384.0KiB
Entry Parameter Subshape: f32[256,128,3]
==========================
Buffer 11:
Size: 384.0KiB
Entry Parameter Subshape: f32[256,128,3]
==========================
Buffer 12:
Size: 384.0KiB
Entry Parameter Subshape: f32[256,128,3]
==========================
Buffer 13:
Size: 384.0KiB
Entry Parameter Subshape: f32[256,128,3]
==========================
Buffer 14:
Size: 384.0KiB
Entry Parameter Subshape: f32[256,128,3]
==========================
Buffer 15:
Size: 384.0KiB
Entry Parameter Subshape: f32[256,128,3]
==========================
Even if I tried to limit the GPU ram allocation as mentioned in some issues, the script would still use 90% of the ram and crashed when rendering an image. Just can not understand that jax leaves this problem alone.
export XLA_PYTHON_CLIENT_MEM_FRACTION="0.5"
export XLA_PYTHON_CLIENT_PREALLOCATE=false
export XLA_FLAGS="--xla_gpu_strict_conv_algorithm_picker=false --xla_gpu_force_compilation_parallelism=1"
What does nvidia-smi look like?
Hi, @jonbarron , this is what I am trying now.
if __name__ == '__main__':
import os
# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="0.5"
# print("XLA_PYTHON_CLIENT_MEM_FRACTION=0.5")
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
import tensorflow as tf
gpus = tf.config.list_physical_devices('GPU')
if gpus:
# Restrict TensorFlow to only allocate 1GB of memory on the first GPU
try:
tf.config.set_logical_device_configuration(
gpus[0],
[tf.config.LogicalDeviceConfiguration(memory_limit=2048)])
logical_gpus = tf.config.list_logical_devices('GPU')
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
except RuntimeError as e:
# Virtual devices must be set before GPUs have been initialized
print(e)
with gin.config_scope('train'):
app.run(main)
2/250000: loss=0.03680, psnr=14.793, lr=3.21e-05 | data=0.03646, orie=1.2e-05, pred=0.00033, 167 r/s
100/250000: loss=0.08784, psnr=11.866, lr=6.17e-04 | data=0.07865, orie=0.00886, pred=0.00033, 2898 r/s
Rendering chunk 0/129599
Rendering chunk 12960/129599
Rendering chunk 25920/129599
Rendering chunk 38880/129599
Rendering chunk 51840/129599
Rendering chunk 64800/129599
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.85.02 Driver Version: 510.85.02 CUDA Version: 11.6 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... Off | 00000000:01:00.0 On | N/A |
| 54% 64C P2 220W / 350W | 11815MiB / 24576MiB | 54% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 1181 G /usr/lib/xorg/Xorg 53MiB |
| 0 N/A N/A 1757 G /usr/lib/xorg/Xorg 307MiB |
| 0 N/A N/A 1893 G /usr/bin/gnome-shell 50MiB |
| 0 N/A N/A 5143 G ...nlogin/bin/sunloginclient 12MiB |
| 0 N/A N/A 32718 G /usr/lib/firefox/firefox 117MiB |
| 0 N/A N/A 1223976 G ...RendererForSitePerProcess 171MiB |
| 0 N/A N/A 1277408 C python 11081MiB |
+-----------------------------------------------------------------------------+
OOM still arised.
Rendering chunk 0/129599
Rendering chunk 12960/129599
Rendering chunk 25920/129599
Rendering chunk 38880/129599
Rendering chunk 51840/129599
Rendering chunk 64800/129599
Rendering chunk 77760/129599
Rendering chunk 90720/129599
Rendering chunk 103680/129599
Rendering chunk 116640/129599
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.85.02 Driver Version: 510.85.02 CUDA Version: 11.6 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... Off | 00000000:01:00.0 On | N/A |
| 52% 57C P2 120W / 350W | 23919MiB / 24576MiB | 7% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 1181 G /usr/lib/xorg/Xorg 53MiB |
| 0 N/A N/A 1757 G /usr/lib/xorg/Xorg 307MiB |
| 0 N/A N/A 1893 G /usr/bin/gnome-shell 50MiB |
| 0 N/A N/A 5143 G ...nlogin/bin/sunloginclient 12MiB |
| 0 N/A N/A 32718 G /usr/lib/firefox/firefox 117MiB |
| 0 N/A N/A 1223976 G ...RendererForSitePerProcess 161MiB |
| 0 N/A N/A 1277408 C python 23195MiB |
+-----------------------------------------------------------------------------+
I've tried different configuration of ram limitation. It seems that the OOM is nothing to do with ram limitation because it would crash finally.
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="0.7"
print("XLA_PYTHON_CLIENT_MEM_FRACTION=0.7")
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
import tensorflow as tf
gpus = tf.config.list_physical_devices('GPU')
# if gpus:
# # Restrict TensorFlow to only allocate 1GB of memory on the first GPU
# try:
# tf.config.set_logical_device_configuration(
# gpus[0],
# [tf.config.LogicalDeviceConfiguration(memory_limit=4096)])
# logical_gpus = tf.config.list_logical_devices('GPU')
# print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
# except RuntimeError as e:
# # Virtual devices must be set before GPUs have been initialized
# print(e)
if gpus:
try:
# Currently, memory growth needs to be the same across GPUs
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
logical_gpus = tf.config.list_logical_devices('GPU')
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
except RuntimeError as e:
# Memory growth must be set before GPUs have been initialized
print(e)
with gin.config_scope('train'):
app.run(main)
I bypassed the OOM by simply transfering the chunks in render_image
to CPU.
I bypassed the OOM by simply transfering the chunks in
render_image
to CPU. Hi, @GCChen97 , I have encountered the same problem as you. Could you share your code for this step? Thanks!
Hi, @Riga27527 , here is my solution:
def move_chunks_to_cpu(chunks):
chunks_new = []
device_cpu = jax.devices("cpu")[0]
for chunk in chunks:
chunk_new = {}
chunk_new["acc"] = jax.device_put(chunk["acc"], device_cpu)
chunk_new["distance_mean"] = \
jax.device_put(chunk["distance_mean"], device_cpu)
chunk_new["distance_median"] = \
jax.device_put(chunk["distance_median"], device_cpu)
chunk_new["distance_percentile_5"] = \
jax.device_put(chunk["distance_percentile_5"], device_cpu)
chunk_new["distance_percentile_95"] = \
jax.device_put(chunk["distance_percentile_95"], device_cpu)
chunk_new["normals"] = \
jax.device_put(chunk["normals"], device_cpu)
chunk_new["normals_pred"] = \
jax.device_put(chunk["normals_pred"], device_cpu)
chunk_new["ray_rgbs"] = [ jax.device_put(data, device_cpu)
for data in chunk["ray_rgbs"] ]
chunk_new["ray_sdist"] = [ jax.device_put(data, device_cpu)
for data in chunk["ray_sdist"] ]
chunk_new["ray_weights"] = [ jax.device_put(data, device_cpu)
for data in chunk["ray_weights"] ]
chunk_new["rgb"] = jax.device_put(chunk["rgb"], device_cpu)
chunk_new["roughness"] = jax.device_put(chunk["roughness"], device_cpu)
chunks_new.append(chunk_new)
return chunks_new
def render_image(...
...
chunks = move_chunks_to_cpu(chunks)
# Concatenate all chunks within each leaf of a single pytree.
rendering = (
jax.tree_util.tree_map(lambda *args: jnp.concatenate(args), *chunks))
...
Hi, @Riga27527 , here is my solution:
def move_chunks_to_cpu(chunks): chunks_new = [] device_cpu = jax.devices("cpu")[0] for chunk in chunks: chunk_new = {} chunk_new["acc"] = jax.device_put(chunk["acc"], device_cpu) chunk_new["distance_mean"] = \ jax.device_put(chunk["distance_mean"], device_cpu) chunk_new["distance_median"] = \ jax.device_put(chunk["distance_median"], device_cpu) chunk_new["distance_percentile_5"] = \ jax.device_put(chunk["distance_percentile_5"], device_cpu) chunk_new["distance_percentile_95"] = \ jax.device_put(chunk["distance_percentile_95"], device_cpu) chunk_new["normals"] = \ jax.device_put(chunk["normals"], device_cpu) chunk_new["normals_pred"] = \ jax.device_put(chunk["normals_pred"], device_cpu) chunk_new["ray_rgbs"] = [ jax.device_put(data, device_cpu) for data in chunk["ray_rgbs"] ] chunk_new["ray_sdist"] = [ jax.device_put(data, device_cpu) for data in chunk["ray_sdist"] ] chunk_new["ray_weights"] = [ jax.device_put(data, device_cpu) for data in chunk["ray_weights"] ] chunk_new["rgb"] = jax.device_put(chunk["rgb"], device_cpu) chunk_new["roughness"] = jax.device_put(chunk["roughness"], device_cpu) chunks_new.append(chunk_new) return chunks_new def render_image(... ... chunks = move_chunks_to_cpu(chunks) # Concatenate all chunks within each leaf of a single pytree. rendering = ( jax.tree_util.tree_map(lambda *args: jnp.concatenate(args), *chunks)) ...
Which file(s) did you modify? Train.py? Looks like render_image is referenced from internal/model.py so I tried to add your code there but I ended up getting a 'tree' parameter related error.
"line 686, in render_image jax.tree_util.tree_map(lambda *args: jnp.concatenate(args), *chunks)) TypeError: tree_map() missing 1 required positional argument: 'tree'"
I have been stuck trying to get this to work for several days. Originally I tried windows.. then wsl ubuntu.. and now finally dual boot ubuntu with several different versions of cuda, cudnn in each system. Error after error... I finally am able to start training and getting now getting OOM. I have 4090rtx, I reduced batch size to 4096, but at 5000 iter it starts to chunk and at end of chunk same error.
Hi, @Riga27527 , here is my solution:
def move_chunks_to_cpu(chunks): chunks_new = [] device_cpu = jax.devices("cpu")[0] for chunk in chunks: chunk_new = {} chunk_new["acc"] = jax.device_put(chunk["acc"], device_cpu) chunk_new["distance_mean"] = \ jax.device_put(chunk["distance_mean"], device_cpu) chunk_new["distance_median"] = \ jax.device_put(chunk["distance_median"], device_cpu) chunk_new["distance_percentile_5"] = \ jax.device_put(chunk["distance_percentile_5"], device_cpu) chunk_new["distance_percentile_95"] = \ jax.device_put(chunk["distance_percentile_95"], device_cpu) chunk_new["normals"] = \ jax.device_put(chunk["normals"], device_cpu) chunk_new["normals_pred"] = \ jax.device_put(chunk["normals_pred"], device_cpu) chunk_new["ray_rgbs"] = [ jax.device_put(data, device_cpu) for data in chunk["ray_rgbs"] ] chunk_new["ray_sdist"] = [ jax.device_put(data, device_cpu) for data in chunk["ray_sdist"] ] chunk_new["ray_weights"] = [ jax.device_put(data, device_cpu) for data in chunk["ray_weights"] ] chunk_new["rgb"] = jax.device_put(chunk["rgb"], device_cpu) chunk_new["roughness"] = jax.device_put(chunk["roughness"], device_cpu) chunks_new.append(chunk_new) return chunks_new def render_image(... ... chunks = move_chunks_to_cpu(chunks) # Concatenate all chunks within each leaf of a single pytree. rendering = ( jax.tree_util.tree_map(lambda *args: jnp.concatenate(args), *chunks)) ...
Which file(s) did you modify? Train.py? Looks like render_image is referenced from internal/model.py so I tried to add your code there but I ended up getting a 'tree' parameter related error.
"line 686, in render_image jax.tree_util.tree_map(lambda *args: jnp.concatenate(args), *chunks)) TypeError: tree_map() missing 1 required positional argument: 'tree'"
I have been stuck trying to get this to work for several days. Originally I tried windows.. then wsl ubuntu.. and now finally dual boot ubuntu with several different versions of cuda, cudnn in each system. Error after error... I finally am able to start training and getting now getting OOM. I have 4090rtx, I reduced batch size to 4096, but at 5000 iter it starts to chunk and at end of chunk same error.
I didn't use the above code, I just reduced the 'render_chunk_size' in internal/config.py to 2048.
Hi, @Riga27527 , here is my solution:
def move_chunks_to_cpu(chunks): chunks_new = [] device_cpu = jax.devices("cpu")[0] for chunk in chunks: chunk_new = {} chunk_new["acc"] = jax.device_put(chunk["acc"], device_cpu) chunk_new["distance_mean"] = \ jax.device_put(chunk["distance_mean"], device_cpu) chunk_new["distance_median"] = \ jax.device_put(chunk["distance_median"], device_cpu) chunk_new["distance_percentile_5"] = \ jax.device_put(chunk["distance_percentile_5"], device_cpu) chunk_new["distance_percentile_95"] = \ jax.device_put(chunk["distance_percentile_95"], device_cpu) chunk_new["normals"] = \ jax.device_put(chunk["normals"], device_cpu) chunk_new["normals_pred"] = \ jax.device_put(chunk["normals_pred"], device_cpu) chunk_new["ray_rgbs"] = [ jax.device_put(data, device_cpu) for data in chunk["ray_rgbs"] ] chunk_new["ray_sdist"] = [ jax.device_put(data, device_cpu) for data in chunk["ray_sdist"] ] chunk_new["ray_weights"] = [ jax.device_put(data, device_cpu) for data in chunk["ray_weights"] ] chunk_new["rgb"] = jax.device_put(chunk["rgb"], device_cpu) chunk_new["roughness"] = jax.device_put(chunk["roughness"], device_cpu) chunks_new.append(chunk_new) return chunks_new def render_image(... ... chunks = move_chunks_to_cpu(chunks) # Concatenate all chunks within each leaf of a single pytree. rendering = ( jax.tree_util.tree_map(lambda *args: jnp.concatenate(args), *chunks)) ...
Which file(s) did you modify? Train.py? Looks like render_image is referenced from internal/model.py so I tried to add your code there but I ended up getting a 'tree' parameter related error. "line 686, in render_image jax.tree_util.tree_map(lambda *args: jnp.concatenate(args), *chunks)) TypeError: tree_map() missing 1 required positional argument: 'tree'" I have been stuck trying to get this to work for several days. Originally I tried windows.. then wsl ubuntu.. and now finally dual boot ubuntu with several different versions of cuda, cudnn in each system. Error after error... I finally am able to start training and getting now getting OOM. I have 4090rtx, I reduced batch size to 4096, but at 5000 iter it starts to chunk and at end of chunk same error.
I didn't use the above code, I just reduced the 'render_chunk_size' in internal/config.py to 2048.
Thank you, I lowered the render_chunk_size and tested...now I am getting this error:
Profiling failure on cuDNN engine eng28{}: UNKNOWN: CUDNN_STATUS_ALLOC_FAILED in external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc(4640): 'status'
Another thread mentions metric_harness is the problem and to just comment out but doesn't mention where. I am testing commenting out the following in train.py lines 242 - 248:
#metric = metric_harness(
# postprocess_fn(rendering['rgb']), postprocess_fn(test_case.rgb))
#print(f'Metrics computed in {(time.time() - metric_start_time):0.3f}s')
#for name, val in metric.items():
# if not np.isnan(val):
# print(f'{name} = {val:.4f}')
# summary_writer.scalar('train_metrics/' + name, val, step)
Also line 80 in train.py: #metric_harness = image.MetricHarness()
So far this is the first time I have trained past 5000 iter. I will provide an update later.
Hi, @Riga27527 , here is my solution:
def move_chunks_to_cpu(chunks): chunks_new = [] device_cpu = jax.devices("cpu")[0] for chunk in chunks: chunk_new = {} chunk_new["acc"] = jax.device_put(chunk["acc"], device_cpu) chunk_new["distance_mean"] = \ jax.device_put(chunk["distance_mean"], device_cpu) chunk_new["distance_median"] = \ jax.device_put(chunk["distance_median"], device_cpu) chunk_new["distance_percentile_5"] = \ jax.device_put(chunk["distance_percentile_5"], device_cpu) chunk_new["distance_percentile_95"] = \ jax.device_put(chunk["distance_percentile_95"], device_cpu) chunk_new["normals"] = \ jax.device_put(chunk["normals"], device_cpu) chunk_new["normals_pred"] = \ jax.device_put(chunk["normals_pred"], device_cpu) chunk_new["ray_rgbs"] = [ jax.device_put(data, device_cpu) for data in chunk["ray_rgbs"] ] chunk_new["ray_sdist"] = [ jax.device_put(data, device_cpu) for data in chunk["ray_sdist"] ] chunk_new["ray_weights"] = [ jax.device_put(data, device_cpu) for data in chunk["ray_weights"] ] chunk_new["rgb"] = jax.device_put(chunk["rgb"], device_cpu) chunk_new["roughness"] = jax.device_put(chunk["roughness"], device_cpu) chunks_new.append(chunk_new) return chunks_new def render_image(... ... chunks = move_chunks_to_cpu(chunks) # Concatenate all chunks within each leaf of a single pytree. rendering = ( jax.tree_util.tree_map(lambda *args: jnp.concatenate(args), *chunks)) ...
Which file(s) did you modify? Train.py? Looks like render_image is referenced from internal/model.py so I tried to add your code there but I ended up getting a 'tree' parameter related error.
"line 686, in render_image jax.tree_util.tree_map(lambda *args: jnp.concatenate(args), *chunks)) TypeError: tree_map() missing 1 required positional argument: 'tree'"
I have been stuck trying to get this to work for several days. Originally I tried windows.. then wsl ubuntu.. and now finally dual boot ubuntu with several different versions of cuda, cudnn in each system. Error after error... I finally am able to start training and getting now getting OOM. I have 4090rtx, I reduced batch size to 4096, but at 5000 iter it starts to chunk and at end of chunk same error.
Yes, the code is for internal/models.py
. The key is that the chunks
need to be moved to cpu as the modification code does.