seml icon indicating copy to clipboard operation
seml copied to clipboard

Add support for multi node/task jobs

Open n-gao opened this issue 10 months ago • 0 comments

This PR adds support for multi-task/node jobs.

Changes

I made the following changes to support this:

  • The Python script is now executed via srun bash -c rather directly via eval in our bash template
  • When preparing the experiment, it is now checked whether the slurm array and task id match the database. If that is the case, we also accept the attempt (which otherwise would be rejected because the experiment is already running/ran)
  • The first process will log from INFO by default while all other processes will only log ERROR in logging.
  • Observers are only attached to the first process.
  • The random seed is fixed in the experiment preparation and is the same for all processes.

Additional updates

  • The collection cache is automatically refreshed when calling add, delete or drop.
  • SSH connections are now handled in a separate process that tracks its health.
  • Fixed SSH connections in a multi user setting.
  • We now support src-layouts by converting them to a flat-layout at runtime.
  • When encountering non-unicode symbols in reading the file output, we replace them with the Unicode replacement character.
  • Fixed multi-user ssh port forwarding.
  • Updated CI to use pre-commit and uv instead of pip.
  • Added seml <col> restore-sources <path> to restore source files form the MongoDB.
  • Added seml <col> hold and seml <col> release to hold and release slurm jobs.
  • Added seml queue to print the collection of slurm jobs (only works for jobs submitted with this version or newer)
  • Refresh collection cache for autocompletion directly when adding/removing collections.
  • Cancel experiments by default when deleting them.
  • Seml also suggests commands that do not need a collection like drop, queue or list in autocompletion.

Example JAX experiment:

import jax
import socket
from seml import Experiment
from jax._src.clusters.slurm_cluster import SlurmCluster


ex = Experiment()


@ex.automain
def main(seed=0):
    pid = SlurmCluster.get_process_id()
    address = SlurmCluster.get_coordinator_address()
    host, port = address.split(":")
    host = socket.gethostbyname(host)
    jax.distributed.initialize(f"{host}:{port}")
    return jax.process_index()

YAML file

seml:
  executable: test.py
  name: distributed_test
  output_dir: logs
  project_root_dir: .

slurm:
  experiments_per_job: 1
  sbatch_options:
    mem: 16G
    cpus-per-task: 1
    time: 0-01:00
    partition: gpu_all
    gres: gpu:2
    nodes: 1
    ntasks: 2

fixed:
  seed: 0

This example starts to tasks on the same node each with 1 GPU.

n-gao avatar Apr 16 '24 13:04 n-gao