distributed
distributed copied to clipboard
Simplify decide_worker
This came up during review of https://github.com/dask/distributed/pull/6614#discussion_r958287053
Bottom line is that this code path is only there for performance optimization and it approximates the decision performed by decide worker (it neglects held memory / ws.nbytes and ignores inhomogeneous nthreads), i.e. the decision quality is strictly better when using worker_objective
. In these situations, poor decision are typically not a big deal, though.
This code path is in a real world scenario actually pretty difficult to hit since we introduced the root task logic above. Most tasks that do not hold dependencies will follow the root task decision path, unless the group is too small to properly utilize the cluster, i.e. #tasks < #total_threads
I performed a couple of micro benchmarks on my machine (basically i extracted the methods to be a function and ran it on a couple of dicts)
This is the measurement I got on my machine. This is the time it takes to make the worker decision for 1k Tasks.
N Workers | main | This PR | This PR + plain dict |
---|---|---|---|
10k | const; see below | 979 ms | (-51%) 484 ms |
1k | 1.3 ms | 112 ms | (-47%) 63.7 ms |
100 | 1.21 ms | 26.4 ms | (-18%) 21.6 ms |
19 | 3.37 ms | 6.48 ms | (-21%) 5.1 ms |
Basically, we'd slow down embarrassingly parallel submissions, e.g. client.map(inc, range(1000))
by ~100ms if scheduled on a 1k workers cluster.
Is this worth the optimization? As I said, most real world workloads that resemble this will very likely go down the root task path anyhow
Code to reproduce
from functools import partial
return (start_time, ws.nbytes)
def decide_worker(workers, ts, idle=set(),total_nthreads=10000, n_tasks=0):
"""
Decide on a worker for task *ts*. Return a WorkerState.
If it's a root or root-like task, we place it with its relatives to
reduce future data tansfer.
If it has dependencies or restrictions, we use
`decide_worker_from_deps_and_restrictions`.
Otherwise, we pick the least occupied worker, or pick from all workers
in a round-robin fashion.
"""
from distributed.scheduler import decide_worker as decide_worker_scheduler
tg = ts.group
valid_workers = None #set(workers.values())
# Group is larger than cluster with few dependencies?
# Minimize future data transfers.
if (
valid_workers is None
and len(tg) > total_nthreads * 2
and len(tg.dependencies) < 5
and sum(map(len, tg.dependencies)) < 5
):
ws = tg.last_worker
if not (ws and tg.last_worker_tasks_left and ws.address in workers):
# Last-used worker is full or unknown; pick a new worker for the next few tasks
ws = min(
(idle or workers).values(),
key=partial(worker_objective, ts),
)
assert ws
tg.last_worker_tasks_left = math.floor(
(len(tg) / total_nthreads) * ws.nthreads
)
# Record `last_worker`, or clear it on the final task
tg.last_worker = (
ws if tg.states["released"] + tg.states["waiting"] > 1 else None
)
tg.last_worker_tasks_left -= 1
return ws
if ts.dependencies or valid_workers is not None:
ws = decide_worker_scheduler(
ts,
workers.values(),
valid_workers,
partial(worker_objective, ts),
)
else:
# Fastpath when there are no related tasks or restrictions
worker_pool = idle or workers
wp_vals = worker_pool.values()
n_workers: int = len(wp_vals)
if n_workers < 20: # smart but linear in small case
ws = min(wp_vals, key=operator.attrgetter("occupancy"))
assert ws
if ws.occupancy == 0:
# special case to use round-robin; linear search
# for next worker with zero occupancy (or just
# land back where we started).
wp_i: WorkerState
start: int = n_tasks % n_workers
i: int
for i in range(n_workers):
wp_i = wp_vals[(i + start) % n_workers]
if wp_i.occupancy == 0:
ws = wp_i
break
else: # dumb but fast in large case
ws = wp_vals[n_tasks % n_workers]
return ws
def worker_objective(ts, ws) -> tuple:
"""
Objective function to determine which worker should get the task
Minimize expected start time. If a tie then break with data storage.
"""
dts: TaskState
comm_bytes: int = 0
for dts in ts.dependencies:
if ws not in dts.who_has:
nbytes = dts.get_nbytes()
comm_bytes += nbytes
stack_time: float = ws.occupancy / ws.nthreads
start_time: float = stack_time + comm_bytes / 100_000_000
if ts.actor:
return (len(ws.actors), start_time, ws.nbytes)
else:
return (start_time, ws.nbytes)
def decide_worker_simplified(workers, ts, idle=set(),total_nthreads=100, n_tasks=0):
"""
Decide on a worker for task *ts*. Return a WorkerState.
If it's a root or root-like task, we place it with its relatives to
reduce future data tansfer.
If it has dependencies or restrictions, we use
`decide_worker_from_deps_and_restrictions`.
Otherwise, we pick the least occupied worker, or pick from all workers
in a round-robin fashion.
"""
from distributed.scheduler import decide_worker as decide_worker_scheduler
tg = ts.group
valid_workers = set(workers.values())
# Group is larger than cluster with few dependencies?
# Minimize future data transfers.
if (
valid_workers is None
and len(tg) > total_nthreads * 2
and len(tg.dependencies) < 5
and sum(map(len, tg.dependencies)) < 5
):
print("Root task stuff!!")
ws = tg.last_worker
if not (ws and tg.last_worker_tasks_left and ws.address in workers):
# Last-used worker is full or unknown; pick a new worker for the next few tasks
ws = min(
(idle or workers).values(),
key=partial(worker_objective, ts),
)
assert ws
tg.last_worker_tasks_left = math.floor(
(len(tg) / total_nthreads) * ws.nthreads
)
# Record `last_worker`, or clear it on the final task
tg.last_worker = (
ws if tg.states["released"] + tg.states["waiting"] > 1 else None
)
tg.last_worker_tasks_left -= 1
return ws
ws = decide_worker_scheduler(
ts,
workers.values(),
valid_workers,
partial(worker_objective, ts),
)
return ws
cc @gjoseph92
There is one related test failure distributed/tests/test_client_executor.py::test_retries
Unit Test Results
See test report for an extended history of previous test failures. This is useful for diagnosing flaky tests.
15 files ±0 15 suites ±0 7h 7m 55s :stopwatch: + 38m 47s 3 052 tests ±0 2 966 :heavy_check_mark: - 1 83 :zzz: ±0 3 :x: +1 22 577 runs ±0 21 592 :heavy_check_mark: - 9 976 :zzz: +2 9 :x: +7
For more details on these failures, see this check.
Results for commit 52e0a887. ± Comparison against base commit 6a1b0894.
Basically, we'd slow down embarrassingly parallel submissions, e.g.
client.map(inc, range(1000))
by ~100ms if scheduled on a 1k workers cluster
A plot of your data table:
1k workers isn't unreasonable. Even 10k happens sometimes. Another thing to note is that in really large clusters like that, you'd have to have a ton of root tasks for them to exceed the total_nthreads * 2
limit. So perhaps with larger clusters, it would actually be more likely to end up in this fastpath, which is also the case where the fastpath would make the most difference?
This code path is in a real world scenario actually pretty difficult to hit since we introduced the root task logic above
I think this is the more important point. n_tasks < total_nthreads * 2
puts an upper bound on how bad performance can be (with tons of workers, that could still be a high bound though, see above). The other way a task could avoid the root task check is if the TaskGroup has >5 dependencies, in which case it wouldn't get to use the zero-deps fastpath anyway.
The main thing that worries me is that TaskGroups are a very brittle way of inferring graph structure. All you have to do is submit a bunch of tasks with UUIDs as keys, and they'll bypass the root task logic since they'll all belong to different TaskGroups.
That again points to the importance of determining root-ish-ness from the graph itself, not task names: https://github.com/dask/distributed/issues/6922.
Basically I'd be a little hesitant to just remove this logic. Instead, I think we should first work on the definition of is_rootish
so that all tasks with 0 deps fall within it, regardless of TaskGroup/cluster size. Then, we could even add a similar fastpath to decide_worker_rootish_*
. I'm imagining something like:
if len(self.idle) > 100:
ws = random.choice(self.idle.values())
else:
ws = min(self.idle.values(), key=lambda ws: len(ws.processing) / ws.nthreads)
A plot of your data table:
I know this plot looks dramatic but it is not surprising. What we do here, sorting, scales with W Log(W)
where W is the number of workers. So, yes, for 10k we're already in a ~1s range for a thousand tasks (I added a data point above).
My point is basically that when we normalize these numbers to overhead per task this doesn't feel to be too dramatic anymore since, for 10k workers we are at 1ms/task and I am doubtful that real world workloads would really notice. after all, if you are running on 10k workers I expect you're processing a lot of data. Would 1s overhead for the initial dispatch really matter?
Basically, I think having fewer branches in the decide_worker logic would make reasoning about what happens much easier and that might be worth the overhead
If we dropped this fast path we could also demote the workers dictionary to a plain dict
. It is currently a sortedcontainers.SortedDict
. I guess this would yield other hard to estimate improvements.
The measurements significantly improve, specifically in the many worker range. We still have the non-linear scaling but as I mentioned above, I don't think we should really care since the absolute values are really small