[Core feature] Improve flytekitplugins-kfpytorch user experience with default pod template and other reasonable defaults
Motivation: Why do you think this is important?
Currently, to use it with pytorch distributed data parallel with multiple nodees, you need to manually specify a custom pod template like so:
custom_pod_template = PodTemplate(
primary_container_name="flytesnacks-pytorch-lightning",
pod_spec=V1PodSpec(
containers=[
V1Container(
name="flytesnacks-pytorch-lightning",
volume_mounts=[V1VolumeMount(mount_path="/dev/shm", name="dshm")]
)
],
volumes=[
V1Volume(
name="dshm",
empty_dir=V1EmptyDirVolumeSource(medium="", size_limit="200Gi")
)
]
),
)
@task(
container_image=custom_image,
task_config=Elastic(
nnodes=NUM_NODES,
nproc_per_node=NUM_DEVICES,
rdzv_configs={"timeout": 36000, "join_timeout": 36000},
max_restarts=3,
),
accelerator=T4,
requests=Resources(mem="32Gi", cpu="48", gpu="8", ephemeral_storage="100Gi"),
pod_template=custom_pod_template,
)
def train_model() -> FlyteDirectory:
Needing to know about adding a shared memory volume and timeout needed for nodes to connect with each other at task startup adds a lot of burden to using this plugin.
Goal: What should the final outcome look like, ideally?
If the Elastic task config could expose some options with reasonable defaults that help the user understand the following:
- timeouts should be set to some reasonable amount of time (15 minutes?)
- DDP requires a shared memory volume
An example might be:
@task(,
task_config=Elastic(
nnodes=NUM_NODES,
nproc_per_node=NUM_DEVICES,
shared_memory="64Gi", # or increase_shared_memory flag, which would be set to some hard-coded value
rdzv_configs={"timeout": 900, "join_timeout": 900}, # default
max_restarts=3,
),
)
def train_model() -> FlyteDirectory:
Where the Elastic class would be initialized with some default pod template:
default_pod_template = PodTemplate(
primary_container_name="pytorch",
pod_spec=V1PodSpec(
containers=[
V1Container(
name="pytorch",
volume_mounts=[V1VolumeMount(mount_path="/dev/shm", name="dshm")]
)
],
volumes=[V1Volume(name="dshm", empty_dir=V1EmptyDirVolumeSource(medium="Memory"))]
),
)
Describe alternatives you've considered
Another way to solve this problem is with documentation, but this burdens the user to discover the docs and add boilerplate to their code.
Propose: Link/Inline OR Additional context
No response
Are you sure this issue hasn't been raised already?
- [X] Yes
Have you read the Code of Conduct?
- [X] Yes
Fully agree that this should be simplified.
Questions to discuss:
- Shared memory:
-
Do we need to specify an amount? We've had this volume configured in our default pod template and never had any issues:
volumeMounts: - mountPath: /dev/shm name: dshm volumes: - name: dshm emptyDir: medium: Memory -
Do we try to merge this into the pod template a user might have provided to the task or should the shared memory volume only be added if the user doesn’t provide a pod template?
-
- Timeouts:
- For the join timeout I feel we should consider the scenario that some workers have a hot start (node is up and image is cached) while other workers have a cold start, i.e. node needs to be scaled up and image has to be pulled. I feel 15 minutes, as you specified, is a good value here. Are there other opinions?
- Clarify whether the
timeoutin the rdzv config is the same timeout as intorch.distributed.init_process_groupand decide on a reasonable default value.
Just to circle back to this: we opted to:
- Initialize the
Elastictask config with a default pod template:
PodTemplate(
primary_container_name="pytorch",
pod_spec=V1PodSpec(
containers=[
V1Container(
name="pytorch",
volume_mounts=[V1VolumeMount(mount_path="/dev/shm", name="dshm")]
)
],
volumes=[V1Volume(name="dshm", empty_dir=V1EmptyDirVolumeSource(medium="Memory"))]
),
)
This would not be exposed to the end user, but they could still override this by specifying pod_template in the @task decorator.
- Set the default
rdvz_configsjoin_timeoutto900(15 minutes). Digging into the pytorch docs/code, it looks liketimeoutandjoin_timeoutare the same, I thinktimeoutis a legacy argument for the `EtcdRendezvousHandler:- https://pytorch.org/docs/stable/_modules/torch/distributed/elastic/rendezvous/dynamic_rendezvous.html#create_handler
- https://pytorch.org/docs/stable/_modules/torch/distributed/elastic/rendezvous/dynamic_rendezvous.html#create_handler
- https://pytorch.org/docs/stable/elastic/rendezvous.html#etcd-rendezvous-legacy