DLIO hangs during first epoch when running with certain accelerators counts
Seeing DLIO hang during the first epoch when running certain accelerators counts. Running with 6x a100 or 7x a100 will cause the test to hang after printing the summary of epoch 1 and CTR+C is needed to stop the test. The strange thing is, running with 5x a100 or 8x a100 passes. I see this same behavior with other clients and when running multiple clients. The output below was collected with openmpi 4.1.2 and python 3.10.12 but I've also used openmpi 5.0.8 but still get the same behavior. Is there any additional logging that can be enabled to help narrow down and fix the issue?
5x a100
[OUTPUT] 2025-06-21T15:32:56.403469 Saved outputs in /local/tmp/mlperf_storage_results/training/unet3d/run/20250621_151706 [/root/.venvs/myenv/lib/python3.10/site-packages/dlio_benchmark/utils/statscounter.py:204]
[OUTPUT] Averaged metric over all steps/epochs
[METRIC] ==========================================================
[METRIC] Number of Simulated Accelerators: 5
[METRIC] Training Accelerator Utilization [AU] (%): 97.7676 (0.1091)
[METRIC] Training Throughput (samples/second): 53.5577 (0.0599)
[METRIC] Training I/O Throughput (MB/second): 7487.8664 (8.3689)
[METRIC] train_au_meet_expectation: success
[METRIC] ==========================================================
[/root/.venvs/myenv/lib/python3.10/site-packages/dlio_benchmark/utils/statscounter.py:226]
[OUTPUT] 2025-06-21T15:32:56.410701 outputs saved in RANKID_output.json [/root/.venvs/myenv/lib/python3.10/site-packages/dlio_benchmark/utils/statscounter.py:456]
2025-06-21 15:33:27|STATUS: Writing metadata for benchmark to: /local/tmp/mlperf_storage_results/training/unet3d/run/20250621_151706/training_20250621_151706_metadata.json
Setting attr from num_accelerators to 5
Hosts is: ['192.168.20.21']
Hosts is: ['192.168.20.21']
6x a100
[INFO] 2025-06-21T14:39:05.657766 Maximum number of steps reached [/root/.venvs/myenv/lib/python3.10/site-packages/dlio_benchmark/main.py:350]
[DEBUG] 2025-06-21T14:39:05.667202 Rank 5 returned after 238 steps. [/root/.venvs/myenv/lib/python3.10/site-packages/dlio_benchmark/main.py:407]
[DEBUG] my_rank: 5, start_sample: 8335, end_sample: 9999 [/root/.venvs/myenv/lib/python3.10/site-packages/dlio_benchmark/utils/config.py:441]
[OUTPUT] 2025-06-21T14:39:05.658318 Ending block 1 - 238 steps completed in 158.11 s [/root/.venvs/myenv/lib/python3.10/site-packages/dlio_benchmark/utils/statscounter.py:338]
[OUTPUT] 2025-06-21T14:39:05.668577 Epoch 1 - Block 1 [Training] Accelerator Utilization [AU] (%): 98.2719 [/root/.venvs/myenv/lib/python3.10/site-packages/dlio_benchmark/utils/statscounter.py:340]
[OUTPUT] 2025-06-21T14:39:05.668684 Epoch 1 - Block 1 [Training] Throughput (samples/second): 64.5567 [/root/.venvs/myenv/lib/python3.10/site-packages/dlio_benchmark/utils/statscounter.py:341]
[OUTPUT] 2025-06-21T14:39:05.668760 Epoch 1 - Block 1 [Training] Computation time per step (second): 0.6366+/-0.0001 (set value: {'mean': 0.636}) [/root/.venvs/myenv/lib/python3.10/site-packages/dlio_benchmark/utils/statscounter.py:342]
2025-06-21 14:52:48|WARNING: Received signal SIGINT (2)
2025-06-21 14:52:48|INFO: Exiting immediately due to SIGTERM
2025-06-21 14:52:48|ERROR: Error occurred while executing command: INTERRUPTED
2025-06-21 14:52:48|STATUS: Writing metadata for benchmark to: /local/tmp/mlperf_storage_results/training/unet3d/run/20250621_143620/training_20250621_143620_metadata.json
2025-06-21 14:52:48|ERROR: Benchmark failed after 1 iterations
Setting attr from num_accelerators to 6
Hosts is: ['192.168.20.21']
Hosts is: ['192.168.20.21']
7x a100 -
[INFO] 2025-06-21T15:04:00.500743 Maximum number of steps reached [/root/.venvs/myenv/lib/python3.10/site-packages/dlio_benchmark/main.py:350]
[DEBUG] 2025-06-21T15:04:00.510045 Rank 6 returned after 204 steps. [/root/.venvs/myenv/lib/python3.10/site-packages/dlio_benchmark/main.py:407]
[DEBUG] my_rank: 6, start_sample: 8574, end_sample: 9999 [/root/.venvs/myenv/lib/python3.10/site-packages/dlio_benchmark/utils/config.py:441]
[OUTPUT] 2025-06-21T15:04:00.501313 Ending block 1 - 204 steps completed in 136.95 s [/root/.venvs/myenv/lib/python3.10/site-packages/dlio_benchmark/utils/statscounter.py:338]
[OUTPUT] 2025-06-21T15:04:00.511633 Epoch 1 - Block 1 [Training] Accelerator Utilization [AU] (%): 97.9379 [/root/.venvs/myenv/lib/python3.10/site-packages/dlio_benchmark/utils/statscounter.py:340]
[OUTPUT] 2025-06-21T15:04:00.511746 Epoch 1 - Block 1 [Training] Throughput (samples/second): 75.0075 [/root/.venvs/myenv/lib/python3.10/site-packages/dlio_benchmark/utils/statscounter.py:341]
[OUTPUT] 2025-06-21T15:04:00.511821 Epoch 1 - Block 1 [Training] Computation time per step (second): 0.6366+/-0.0001 (set value: {'mean': 0.636}) [/root/.venvs/myenv/lib/python3.10/site-packages/dlio_benchmark/utils/statscounter.py:342]
2025-06-21 15:16:42|WARNING: Received signal SIGINT (2)
2025-06-21 15:16:42|INFO: Exiting immediately due to SIGTERM
2025-06-21 15:16:42|ERROR: Error occurred while executing command: INTERRUPTED
2025-06-21 15:16:42|STATUS: Writing metadata for benchmark to: /local/tmp/mlperf_storage_results/training/unet3d/run/20250621_150135/training_20250621_150135_metadata.json
2025-06-21 15:16:42|ERROR: Benchmark failed after 1 iterations
Setting attr from num_accelerators to 7
Hosts is: ['192.168.20.21']
Hosts is: ['192.168.20.21']
8x a100
[OUTPUT] 2025-06-21T15:53:27.458607 Saved outputs in /local/tmp/mlperf_storage_results/training/unet3d/run/20250621_154318 [/root/.venvs/myenv/lib/python3.10/site-packages/dlio_benchmark/utils/statscounter.py:204]
[OUTPUT] Averaged metric over all steps/epochs
[METRIC] ==========================================================
[METRIC] Number of Simulated Accelerators: 8
[METRIC] Training Accelerator Utilization [AU] (%): 97.6416 (0.1626)
[METRIC] Training Throughput (samples/second): 85.3993 (0.1422)
[METRIC] Training I/O Throughput (MB/second): 11939.6082 (19.8859)
[METRIC] train_au_meet_expectation: success
[METRIC] ==========================================================
[/root/.venvs/myenv/lib/python3.10/site-packages/dlio_benchmark/utils/statscounter.py:226]
[OUTPUT] 2025-06-21T15:53:27.469267 outputs saved in RANKID_output.json [/root/.venvs/myenv/lib/python3.10/site-packages/dlio_benchmark/utils/statscounter.py:456]
2025-06-21 15:53:46|STATUS: Writing metadata for benchmark to: /local/tmp/mlperf_storage_results/training/unet3d/run/20250621_154318/training_20250621_154318_metadata.json
Setting attr from num_accelerators to 8
Hosts is: ['192.168.20.21']
Hosts is: ['192.168.20.21']
cmd line I'm using is:
mlpstorage training run --model unet3d --data-dir /mnt/mlperf_train/ --oversubscribe --allow-run-as-root --param dataset.num_files_train=10000 --param reader.odirect=True --param reader.read_threads=12 --client-host-memory-in-gb 256 --exec-type mpi --accelerator-type a100 --hosts 192.168.20.21 --num-accelerators 8
I have the exact same issue
I'm hitting a deadlock using main.
1 process in AllReduce
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib64/libthread_db.so.1".
0x00007f8791658967 in MPIDI_CH3I_Progress () from /usr/lib64/mpich/lib/libmpi.so.12
#0 0x00007f8791658967 in MPIDI_CH3I_Progress () from /usr/lib64/mpich/lib/libmpi.so.12
#1 0x00007f87916ff592 in MPIR_Wait_impl.isra.0 () from /usr/lib64/mpich/lib/libmpi.so.12
#2 0x00007f87915e4ce3 in MPIC_Wait () from /usr/lib64/mpich/lib/libmpi.so.12
#3 0x00007f87915e516a in MPIC_Recv () from /usr/lib64/mpich/lib/libmpi.so.12
#4 0x00007f879157703 in MPIR_Bcast_intra_binomial () from /usr/lib64/mpich/lib/libmpi.so.12
#5 0x00007f87915ce99c in MPIR_Bcast_allcomm_auto () from /usr/lib64/mpich/lib/libmpi.so.12
#6 0x00007f87915cea9 in MPIR_Bcast_impl () from /usr/lib64/mpich/lib/libmpi.so.12
#7 0x00007f8791454a5b in PMPI_Bcast_c () from /usr/lib64/mpich/lib/libmpi.so.12
#8 0x00007f8791b8978b in __pyx_f_6mpi4py_3MPI_PyMPI_bcast_p2p.constprop.0 (__pyx_v_obj=0x7f88140370 <_Py_NoneStruct>, __pyx_v_comm=-2080374782, __pyx_v_root=0) at src/mpi4py/MPI.c:107777
#9 0x00007f8791b20a80 in __pyx_f_6mpi4py_3MPI_PyMPI_allreduce_intra (__pyx_v_comm=<optimized out>, __pyx_v_op=<optimized out>, __pyx_v_sendobj=0x7f88140370 <_Py_NoneStruct>) at src/mpi4py/MPI.c:110447
#10 __pyx_f_6mpi4py_3MPI_PyMPI_allreduce (__pyx_v_comm=<optimized out>, __pyx_v_op=<optimized out>, __pyx_v_sendobj=<optimized out>) at src/mpi4py/MPI.c:111025
#11 __pyx_pf_6mpi4py_3MPI_4Comm_288allreduce (__pyx_v_self=<optimized out>, __pyx_v_self=<optimized out>, __pyx_v_op=<optimized out>, __pyx_v_sendobj=<optimized out>) at src/mpi4py/MPI.c:190298
#12 __pyx_pw_6mpi4py_3MPI_4Comm_289allreduce (__pyx_v_self=<optimized out>, __pyx_args=<optimized out>, __pyx_nargs=<optimized out>, __pyx_kwds=<optimized out>) at src/mpi4py/MPI.c:190259
#13 0x00007f88132b9d5 in PyObject_Vectorcall () from /lib64/libpython3.12.so.1.0
#14 0x00007f8813104ad in _PyEval_EvalFrameDefault () from /lib64/libpython3.12.so.1.0
#15 0x00007f8813cd6fd0 in PyEval_EvalCode () from /lib64/libpython3.12.so.1.0
#16 0x00007f8813cf868 in run_eval_code_obj () from /lib64/libpython3.12.so.1.0
#17 0x00007f8813cf793b in run_mod () from /lib64/libpython3.12.so.1.0
#18 0x00007f8813d1181 in pyrun_file () from /lib64/libpython3.12.so.1.0
#19 0x00007f8813d11559 in _PyRun_SimpleFileObject () from /lib64/libpython3.12.so.1.0
#20 0x00007f8813d1067 in _PyRun_AnyFileObject () from /lib64/libpython3.12.so.1.0
#21 0x00007f8813d09369 in Py_RunMain () from /lib64/libpython3.12.so.1.0
#22 0x00007f8813c0e0d in Py_BytesMain () from /lib64/libpython3.12.so.1.0
#23 0x00007f88136295d0 in __libc_start_call_main () from /lib64/libc.so.6
#24 0x00007f8813629680 in __libc_start_main_impl () from /lib64/libc.so.6
#25 0x00005604a072b095 in _start ()
[Inferior 1 (process 184291) detached]
All other processes in Barrier
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib64/libthread_db.so.1".
0x00007feecba909b0 in pthread_self () from /lib64/libc.so.6
#0 0x00007feecba909b0 in pthread_self () from /lib64/libc.so.6
#1 0x00007fee4558d40 in MPIDI_CH3I_Progress () from /usr/lib64/mpich/lib/libmpi.so.12
#2 0x00007fee45cff592 in MPIR_Wait_impl.isra.0 () from /usr/lib64/mpich/lib/libmpi.so.12
#3 0x00007fee45be4ce3 in MPIC_Wait () from /usr/lib64/mpich/lib/libmpi.so.12
#4 0x00007fee45be5577 in MPIC_Sendrecv () from /usr/lib64/mpich/lib/libmpi.so.12
#5 0x00007fee45b7634f in MPIR_Barrier_intra_k_dissemination () from /usr/lib64/mpich/lib/libmpi.so.12
#6 0x00007fee45bcdcce in MPIR_Barrier_allcomm_auto () from /usr/lib64/mpich/lib/libmpi.so.12
#7 0x00007fee45bcddbb in MPIR_Barrier_impl () from /usr/lib64/mpich/lib/libmpi.so.12
#8 0x00007fee45b766b4 in MPIR_Barrier_intra_smp () from /usr/lib64/mpich/lib/libmpi.so.12
#9 0x00007fee45bcdcbb in MPIR_Barrier_allcomm_auto () from /usr/lib64/mpich/lib/libmpi.so.12
#10 0x00007fee45bcddbb in MPIR_Barrier_impl () from /usr/lib64/mpich/lib/libmpi.so.12
#11 0x00007fee45a52ce0 in PMPI_Barrier () from /usr/lib64/mpich/lib/libmpi.so.12
#12 0x00007fee46008e9f in __pyx_f_6mpi4py_3MPI_PyMPI_barrier (__pyx_v_comm=1140850688) at src/mpi4py/MPI.c:102280
#13 __pyx_pf_6mpi4py_3MPI_4Comm_274barrier (__pyx_v_self=<optimized out>, __pyx_v_self=<optimized out>) at src/mpi4py/MPI.c:189167
#14 __pyx_pw_6mpi4py_3MPI_4Comm_275barrier (__pyx_v_self=<optimized out>, __pyx_args=<optimized out>, __pyx_nargs=<optimized out>, __pyx_kwds=<optimized out>) at src/mpi4py/MPI.c:189131
#15 0x00007feec02b9d5 in PyObject_Vectorcall () from /lib64/libpython3.12.so.1.0
#16 0x00007feec0104ad in _PyEval_EvalFrameDefault () from /lib64/libpython3.12.so.1.0
#17 0x00007feec0d6fd0 in PyEval_EvalCode () from /lib64/libpython3.12.so.1.0
#18 0x00007feec0f868 in run_eval_code_obj () from /lib64/libpython3.12.so.1.0
#19 0x00007feec0f793b in run_mod () from /lib64/libpython3.12.so.1.0
#20 0x00007feec11181 in pyrun_file () from /lib64/libpython3.12.so.1.0
#21 0x00007feec111559 in _PyRun_SimpleFileObject () from /lib64/libpython3.12.so.1.0
#22 0x00007feec11067 in _PyRun_AnyFileObject () from /lib64/libpython3.12.so.1.0
#23 0x00007feec109369 in Py_RunMain () from /lib64/libpython3.12.so.1.0
#24 0x00007feec00e0d in Py_BytesMain () from /lib64/libpython3.12.so.1.0
#25 0x00007feecba295d0 in __libc_start_call_main () from /lib64/libc.so.6
#26 0x00007feecba29680 in __libc_start_main_impl () from /lib64/libc.so.6
#27 0x000055938fdf095 in _start ()
[Inferior 1 (process 344158) detached]
Can you provide the steps you're using to get the stack trace above so I can verify that I'm also seeing the same hang?
Thanks!
I'm hitting a deadlock using main.
1 process in AllReduce
[Thread debugging using libthread_db enabled] Using host libthread_db library "/lib64/libthread_db.so.1". 0x00007f8791658967 in MPIDI_CH3I_Progress () from /usr/lib64/mpich/lib/libmpi.so.12 #0 0x00007f8791658967 in MPIDI_CH3I_Progress () from /usr/lib64/mpich/lib/libmpi.so.12 #1 0x00007f87916ff592 in MPIR_Wait_impl.isra.0 () from /usr/lib64/mpich/lib/libmpi.so.12 #2 0x00007f87915e4ce3 in MPIC_Wait () from /usr/lib64/mpich/lib/libmpi.so.12 #3 0x00007f87915e516a in MPIC_Recv () from /usr/lib64/mpich/lib/libmpi.so.12 #4 0x00007f879157703 in MPIR_Bcast_intra_binomial () from /usr/lib64/mpich/lib/libmpi.so.12 #5 0x00007f87915ce99c in MPIR_Bcast_allcomm_auto () from /usr/lib64/mpich/lib/libmpi.so.12 #6 0x00007f87915cea9 in MPIR_Bcast_impl () from /usr/lib64/mpich/lib/libmpi.so.12 #7 0x00007f8791454a5b in PMPI_Bcast_c () from /usr/lib64/mpich/lib/libmpi.so.12 #8 0x00007f8791b8978b in __pyx_f_6mpi4py_3MPI_PyMPI_bcast_p2p.constprop.0 (__pyx_v_obj=0x7f88140370 <_Py_NoneStruct>, __pyx_v_comm=-2080374782, __pyx_v_root=0) at src/mpi4py/MPI.c:107777 #9 0x00007f8791b20a80 in __pyx_f_6mpi4py_3MPI_PyMPI_allreduce_intra (__pyx_v_comm=
, __pyx_v_op= , __pyx_v_sendobj=0x7f88140370 <_Py_NoneStruct>) at src/mpi4py/MPI.c:110447 #10 __pyx_f_6mpi4py_3MPI_PyMPI_allreduce (__pyx_v_comm= , __pyx_v_op= , __pyx_v_sendobj= ) at src/mpi4py/MPI.c:111025 #11 __pyx_pf_6mpi4py_3MPI_4Comm_288allreduce (__pyx_v_self= , __pyx_v_self= , __pyx_v_op= , __pyx_v_sendobj= ) at src/mpi4py/MPI.c:190298 #12 __pyx_pw_6mpi4py_3MPI_4Comm_289allreduce (__pyx_v_self= , __pyx_args= , __pyx_nargs= , __pyx_kwds= ) at src/mpi4py/MPI.c:190259 #13 0x00007f88132b9d5 in PyObject_Vectorcall () from /lib64/libpython3.12.so.1.0 #14 0x00007f8813104ad in _PyEval_EvalFrameDefault () from /lib64/libpython3.12.so.1.0 #15 0x00007f8813cd6fd0 in PyEval_EvalCode () from /lib64/libpython3.12.so.1.0 #16 0x00007f8813cf868 in run_eval_code_obj () from /lib64/libpython3.12.so.1.0 #17 0x00007f8813cf793b in run_mod () from /lib64/libpython3.12.so.1.0 #18 0x00007f8813d1181 in pyrun_file () from /lib64/libpython3.12.so.1.0 #19 0x00007f8813d11559 in _PyRun_SimpleFileObject () from /lib64/libpython3.12.so.1.0 #20 0x00007f8813d1067 in _PyRun_AnyFileObject () from /lib64/libpython3.12.so.1.0 #21 0x00007f8813d09369 in Py_RunMain () from /lib64/libpython3.12.so.1.0 #22 0x00007f8813c0e0d in Py_BytesMain () from /lib64/libpython3.12.so.1.0 #23 0x00007f88136295d0 in __libc_start_call_main () from /lib64/libc.so.6 #24 0x00007f8813629680 in __libc_start_main_impl () from /lib64/libc.so.6 #25 0x00005604a072b095 in _start () [Inferior 1 (process 184291) detached] All other processes in Barrier [Thread debugging using libthread_db enabled] Using host libthread_db library "/lib64/libthread_db.so.1". 0x00007feecba909b0 in pthread_self () from /lib64/libc.so.6 #0 0x00007feecba909b0 in pthread_self () from /lib64/libc.so.6 #1 0x00007fee4558d40 in MPIDI_CH3I_Progress () from /usr/lib64/mpich/lib/libmpi.so.12 #2 0x00007fee45cff592 in MPIR_Wait_impl.isra.0 () from /usr/lib64/mpich/lib/libmpi.so.12 #3 0x00007fee45be4ce3 in MPIC_Wait () from /usr/lib64/mpich/lib/libmpi.so.12 #4 0x00007fee45be5577 in MPIC_Sendrecv () from /usr/lib64/mpich/lib/libmpi.so.12 #5 0x00007fee45b7634f in MPIR_Barrier_intra_k_dissemination () from /usr/lib64/mpich/lib/libmpi.so.12 #6 0x00007fee45bcdcce in MPIR_Barrier_allcomm_auto () from /usr/lib64/mpich/lib/libmpi.so.12 #7 0x00007fee45bcddbb in MPIR_Barrier_impl () from /usr/lib64/mpich/lib/libmpi.so.12 #8 0x00007fee45b766b4 in MPIR_Barrier_intra_smp () from /usr/lib64/mpich/lib/libmpi.so.12 #9 0x00007fee45bcdcbb in MPIR_Barrier_allcomm_auto () from /usr/lib64/mpich/lib/libmpi.so.12 #10 0x00007fee45bcddbb in MPIR_Barrier_impl () from /usr/lib64/mpich/lib/libmpi.so.12 #11 0x00007fee45a52ce0 in PMPI_Barrier () from /usr/lib64/mpich/lib/libmpi.so.12 #12 0x00007fee46008e9f in __pyx_f_6mpi4py_3MPI_PyMPI_barrier (__pyx_v_comm=1140850688) at src/mpi4py/MPI.c:102280 #13 __pyx_pf_6mpi4py_3MPI_4Comm_274barrier (__pyx_v_self=
, __pyx_v_self= ) at src/mpi4py/MPI.c:189167 #14 __pyx_pw_6mpi4py_3MPI_4Comm_275barrier (__pyx_v_self= , __pyx_args= , __pyx_nargs= , __pyx_kwds= ) at src/mpi4py/MPI.c:189131 #15 0x00007feec02b9d5 in PyObject_Vectorcall () from /lib64/libpython3.12.so.1.0 #16 0x00007feec0104ad in _PyEval_EvalFrameDefault () from /lib64/libpython3.12.so.1.0 #17 0x00007feec0d6fd0 in PyEval_EvalCode () from /lib64/libpython3.12.so.1.0 #18 0x00007feec0f868 in run_eval_code_obj () from /lib64/libpython3.12.so.1.0 #19 0x00007feec0f793b in run_mod () from /lib64/libpython3.12.so.1.0 #20 0x00007feec11181 in pyrun_file () from /lib64/libpython3.12.so.1.0 #21 0x00007feec111559 in _PyRun_SimpleFileObject () from /lib64/libpython3.12.so.1.0 #22 0x00007feec11067 in _PyRun_AnyFileObject () from /lib64/libpython3.12.so.1.0 #23 0x00007feec109369 in Py_RunMain () from /lib64/libpython3.12.so.1.0 #24 0x00007feec00e0d in Py_BytesMain () from /lib64/libpython3.12.so.1.0 #25 0x00007feecba295d0 in __libc_start_call_main () from /lib64/libc.so.6 #26 0x00007feecba29680 in __libc_start_main_impl () from /lib64/libc.so.6 #27 0x000055938fdf095 in _start () [Inferior 1 (process 344158) detached]
gdb -batch -ex "bt" -p $dliopid
I think the issue is because the # of batches is not divisible by the # of processes. This might happen in _train in dlio. One proc is still in the compute phase, while all other have exited. They might have exited because of no more samples in the for loop, or because of the if condition. I actually hit both scenarios.
I wrote a workaround which is just to find a new number of files that is divisible by the number of processes, and that is above or equal the one returned by mlpstorage training datasize.
workaround_for_mlperfv2_issue172() {
if [ "$#" -ne 3 ]; then
echo "Usage: workaround_for_mlperfv2_issue172 <num_files> <model_name> <num_processes>" >&2
exit 1
fi
local num_files=$1
local model_name=$2
local num_processes=$3
local samples_per_file=0
case "$model_name" in
"unet3d")
# From unet3d_*.yaml
samples_per_file=1
;;
"cosmoflow")
# From cosmoflow_*.yaml
samples_per_file=1
;;
"resnet50")
# From resnet50_*.yaml
samples_per_file=1251
;;
*)
echo "Error: Invalid model name '$model_name'. Supported models are: unet3d, cosmoflow, resnet50." >&2
exit 1
;;
esac
local x_total_samples=$(( num_files * samples_per_file ))
local y_divisible_samples=$(( (x_total_samples + num_processes - 1) / num_processes * num_processes ))
local new_num_files=$(( (y_divisible_samples + samples_per_file - 1) / samples_per_file ))
echo "$new_num_files"
}
So it can be used like this before the data generation:
# beforehand, get num_files from mlpstorage training datasize
export num_files=$(workaround_for_mlperfv2_issue172 ${num_files} ${model} ${numgpus})
It does fix the problem for me.
Thanks @LouisDDN! Using your workaround, I was able to get 5x a100 and 6x a100 to pass now.