litdata icon indicating copy to clipboard operation
litdata copied to clipboard

GCSFuse mount + Vertex AI custom training jobs support

Open miguelalba96 opened this issue 1 year ago • 1 comments
trafficstars

🚀 Feature

DDP on models in GCP with data stored in GCS

Question

Has litdata being tested in the case you want to train on GCP (Vertex AI) instead of just storing and streaming the data from GCS?

Motivation

I've been trying to set up ddp on GCP (Vertex AI) using lighting Fabric with my data being stored in GCS, when you use GCS and Vertex AI you can "mount" a bucket to the instance(s)/containers running in their infrastructure using GCSFuse, in such case the only thing you need to do is replacing gs:// by /gcs/ and the bucket acts as a file system, does litdata have been tested under such setting? will it work ?

I tried already with mosaicml-streaming and I have ran into lots of throughput issues, that lead to data starvation in the multi-node setting

I wrote this cluster environment to use configure the cluster on vertex AI + Fabric:

import os
import json

from lightning.fabric.plugins.environments.lightning import LightningEnvironment


class VertexAICluster(LightningEnvironment):
    """
    Configures distributed training on a vertex ai custom training job,
    ex:
        Consider a cluster with 3 nodes, each composed of 2 gpus

        The "cluster" key in CLUSTER_SPEC will be:
            {
                'workerpool0': ['cmle-training-workerpool0-d604929a6a-0:2222'],
                'workerpool1': [
                                'cmle-training-workerpool1-d604929a6a-0:2222',
                                'cmle-training-workerpool1-d604929a6a-1:2222'
                              ]
            }

        and each process scheduled will be under the "task" key, following the same example
        the three tasks will look like this:
            task0 ("chief" spawn process) -> node 0:
            {'type': 'workerpool0', 'index': 0}
            task 1 (on first node on workerpool1) -> node 1:
            {'type': 'workerpool1', 'index': 0}
            task 2 (on second node on workerpool1) -> node 2:
            {'type': 'workerpool1', 'index': 1}
    """

    def __init__(self):
        super().__init__()
        self.cluster_spec = json.loads(os.environ['CLUSTER_SPEC'])

    @property
    def main_address(self) -> str:
        return self.cluster_spec["cluster"]["workerpool0"][0].split(':')[0]

    @property
    def main_port(self) -> int:
        """Set common fixed MASTER_PORT port across processes
        """
        return int(self.cluster_spec["cluster"]["workerpool0"][0].split(':')[1])

    def node_rank(self) -> int:
        task = self.cluster_spec["task"]
        if task["type"] == "workerpool0":
            return 0
        else:
            return task["index"] + 1

do I need to set up some other env variables if I wanted to test litdata?, do the ones defined here in this cluster environment are used by litdata, in mosaicml you had to configure these

miguelalba96 avatar Apr 07 '24 22:04 miguelalba96

Hey @miguelalba96,

It is possible but I strongly recommend to not use fuse mount if you are trying to get the fastest data loading possible on the cloud.

When doing my benchmarking, streaming directly from the storage server instead of a mount can be up to 20x times faster.

Concerning the env variables question, LitData infers the ranks from torch.distributed directly, so as long as you have already a defined progress group, we can infer the rank directly from it. Your cluster env looks fine to me cc @awaelchli @carmocca

Best, T.C

tchaton avatar Apr 08 '24 13:04 tchaton