modulus icon indicating copy to clipboard operation
modulus copied to clipboard

🐛[BUG]: MeshGraphNet multiGPU test failure

Open akshaysubr opened this issue 1 year ago • 2 comments

Version

main

On which installation method(s) does this occur?

Docker

Describe the issue

Reported by @azrael417 here, pasting the failure log

Minimum reproducible example

No response

Relevant log output

I can run some of the tests but the meshgraphnet one fails:

`models/meshgraphnet/test_meshgraphnet_snmg.py FFF [100%]

=================================== FAILURES ===================================
____________________ test_distributed_meshgraphnet[dtype0] _____________________

dtype = torch.float32

@pytest.mark.multigpu
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_distributed_meshgraphnet(dtype):
    num_gpus = torch.cuda.device_count()
    assert num_gpus >= 2, "Not enough GPUs available for test"
    world_size = num_gpus
  torch.multiprocessing.spawn(
        run_test_distributed_meshgraphnet,
        args=(world_size, dtype),
        nprocs=world_size,
        start_method="spawn",
    )
models/meshgraphnet/test_meshgraphnet_snmg.py:193:

../../.conda/envs/modulus/lib/python3.10/site-packages/torch/multiprocessing/spawn.py:246: in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
../../.conda/envs/modulus/lib/python3.10/site-packages/torch/multiprocessing/spawn.py:202: in start_processes
while not context.join():

self = <torch.multiprocessing.spawn.ProcessContext object at 0x7f7f41a258a0>
timeout = None

def join(self, timeout=None):
    r"""
    Tries to join one or more processes in this spawn context.
    If one of them exited with a non-zero exit status, this function
    kills the remaining processes and raises an exception with the cause
    of the first process exiting.

    Returns ``True`` if all processes have been joined successfully,
    ``False`` if there are more processes that need to be joined.

    Args:
        timeout (float): Wait this long before giving up on waiting.
    """
    # Ensure this function can be called even when we're done.
    if len(self.sentinels) == 0:
        return True

    # Wait for any process to fail or all of them to succeed.
    ready = multiprocessing.connection.wait(
        self.sentinels.keys(),
        timeout=timeout,
    )

    error_index = None
    for sentinel in ready:
        index = self.sentinels.pop(sentinel)
        process = self.processes[index]
        process.join()
        if process.exitcode != 0:
            error_index = index
            break

    # Return if there was no error.
    if error_index is None:
        # Return whether or not all processes have been joined.
        return len(self.sentinels) == 0

    # Assume failure. Terminate processes that are still alive.
    for process in self.processes:
        if process.is_alive():
            process.terminate()
        process.join()
`

Environment details

No response

akshaysubr avatar Dec 13 '23 16:12 akshaysubr