seml
seml copied to clipboard
Add support for multi node/task jobs
This PR adds support for multi-task/node jobs.
Changes
I made the following changes to support this:
- The Python script is now executed via
srun bash -c
rather directly viaeval
in our bash template - When preparing the experiment, it is now checked whether the slurm array and task id match the database. If that is the case, we also accept the attempt (which otherwise would be rejected because the experiment is already running/ran)
- The first process will log from
INFO
by default while all other processes will only logERROR
inlogging
. - Observers are only attached to the first process.
- The random seed is fixed in the experiment preparation and is the same for all processes.
Additional updates
- The collection cache is automatically refreshed when calling
add
,delete
ordrop
. - SSH connections are now handled in a separate process that tracks its health.
- Fixed SSH connections in a multi user setting.
- We now support src-layouts by converting them to a flat-layout at runtime.
- When encountering non-unicode symbols in reading the file output, we replace them with the Unicode replacement character.
- Fixed multi-user ssh port forwarding.
- Updated CI to use pre-commit and
uv
instead ofpip
. - Added
seml <col> restore-sources <path>
to restore source files form the MongoDB. - Added
seml <col> hold
andseml <col> release
to hold and release slurm jobs. - Added
seml queue
to print the collection of slurm jobs (only works for jobs submitted with this version or newer) - Refresh collection cache for autocompletion directly when adding/removing collections.
- Cancel experiments by default when deleting them.
- Seml also suggests commands that do not need a collection like
drop
,queue
orlist
in autocompletion.
Example JAX experiment:
import jax
import socket
from seml import Experiment
from jax._src.clusters.slurm_cluster import SlurmCluster
ex = Experiment()
@ex.automain
def main(seed=0):
pid = SlurmCluster.get_process_id()
address = SlurmCluster.get_coordinator_address()
host, port = address.split(":")
host = socket.gethostbyname(host)
jax.distributed.initialize(f"{host}:{port}")
return jax.process_index()
YAML file
seml:
executable: test.py
name: distributed_test
output_dir: logs
project_root_dir: .
slurm:
experiments_per_job: 1
sbatch_options:
mem: 16G
cpus-per-task: 1
time: 0-01:00
partition: gpu_all
gres: gpu:2
nodes: 1
ntasks: 2
fixed:
seed: 0
This example starts to tasks on the same node each with 1 GPU.