modulus
modulus copied to clipboard
🐛[BUG]: MeshGraphNet multiGPU test failure
Version
main
On which installation method(s) does this occur?
Docker
Describe the issue
Reported by @azrael417 here, pasting the failure log
Minimum reproducible example
No response
Relevant log output
I can run some of the tests but the meshgraphnet one fails:
`models/meshgraphnet/test_meshgraphnet_snmg.py FFF [100%]
=================================== FAILURES ===================================
____________________ test_distributed_meshgraphnet[dtype0] _____________________
dtype = torch.float32
@pytest.mark.multigpu
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_distributed_meshgraphnet(dtype):
num_gpus = torch.cuda.device_count()
assert num_gpus >= 2, "Not enough GPUs available for test"
world_size = num_gpus
torch.multiprocessing.spawn(
run_test_distributed_meshgraphnet,
args=(world_size, dtype),
nprocs=world_size,
start_method="spawn",
)
models/meshgraphnet/test_meshgraphnet_snmg.py:193:
../../.conda/envs/modulus/lib/python3.10/site-packages/torch/multiprocessing/spawn.py:246: in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
../../.conda/envs/modulus/lib/python3.10/site-packages/torch/multiprocessing/spawn.py:202: in start_processes
while not context.join():
self = <torch.multiprocessing.spawn.ProcessContext object at 0x7f7f41a258a0>
timeout = None
def join(self, timeout=None):
r"""
Tries to join one or more processes in this spawn context.
If one of them exited with a non-zero exit status, this function
kills the remaining processes and raises an exception with the cause
of the first process exiting.
Returns ``True`` if all processes have been joined successfully,
``False`` if there are more processes that need to be joined.
Args:
timeout (float): Wait this long before giving up on waiting.
"""
# Ensure this function can be called even when we're done.
if len(self.sentinels) == 0:
return True
# Wait for any process to fail or all of them to succeed.
ready = multiprocessing.connection.wait(
self.sentinels.keys(),
timeout=timeout,
)
error_index = None
for sentinel in ready:
index = self.sentinels.pop(sentinel)
process = self.processes[index]
process.join()
if process.exitcode != 0:
error_index = index
break
# Return if there was no error.
if error_index is None:
# Return whether or not all processes have been joined.
return len(self.sentinels) == 0
# Assume failure. Terminate processes that are still alive.
for process in self.processes:
if process.is_alive():
process.terminate()
process.join()
`
Environment details
No response