Slow dask task creation for complex functions
Bug summary
I've been seeing very slow task creation (~1/s) when .maping a classmethod from a not particularly crazy class. Using plain old Dask client.map() on the same function creates the tasks several hundred times faster. For a reproducible example I created a more monstrous method that is somewhat slow even for Dask's .map but incredibly slow for Prefect's. Obviously this is a fairly pathological case but hopefully it's useful for diagnosing where the gap between the pure Dask and Prefect+Dask performance is coming from.
import time
from dataclasses import dataclass
from typing import Any, Dict, List
from prefect import flow, task
from prefect.context import FlowRunContext
from prefect_dask import DaskTaskRunner
@dataclass
class Node:
"""A node in our complex object tree"""
id: str
data: Dict[str, Any]
children: List['Node']
parent_ref: Any
closures: List[Any]
class ComplexDataProcessor:
def __init__(self, depth=5, breadth=5):
# Create deeply nested tree with circular references
self.root = self._build_tree(depth, breadth)
self.root.parent_ref = self
# Add lots of closures that capture self
self.closures = []
for i in range(100):
captured_data = f"data_{i}" * 100
def make_closure(n, obj=self):
def inner(x):
return x + n + len(obj.root.id) + len(captured_data)
return inner
self.closures.append(make_closure(i))
# Add lambdas with captured state
self.lambdas = {
f"lambda_{i}": lambda x, i=i, obj=self: x + i + len(obj.root.data)
for i in range(100)
}
# Add recursive/circular references
self.recursive_dict = {'self': self, 'root': self.root}
for i in range(50):
self.recursive_dict[f'level_{i}'] = {
'parent': self.recursive_dict,
'processor': self,
'data': {'nested': self.recursive_dict},
'more_refs': [self.recursive_dict for _ in range(10)]
}
# Add regular classes with complex initialization
self.complex_objects = []
for i in range(50):
obj = type('ComplexObj', (), {
'value': i,
'get_processor': lambda self, p=self: p,
'compute': lambda self, x, p=self: x + len(p.root.id)
})()
obj.processor_ref = self
obj.recursive_ref = self.recursive_dict
self.complex_objects.append(obj)
# Add custom objects with methods
class CustomObj:
def __init__(self, processor, index):
self.processor = processor
self.index = index
self.data = {'proc': processor, 'idx': index}
def compute(self, x):
return x + self.index + len(self.processor.root.id)
self.custom_objects = [CustomObj(self, i) for i in range(50)]
# Add nested functions
def outer_func(x):
def middle_func(y):
def inner_func(z):
return x + y + z + len(self.root.id)
return inner_func
return middle_func
self.nested_funcs = [outer_func(i) for i in range(50)]
def _build_tree(self, depth, breadth, parent=None):
"""Build a tree with lots of complex references"""
node_id = f"node_depth{depth}_{'x' * 50}"
# Create closures specific to this node
node_closures = []
for i in range(10):
def make_node_closure(n, node_id=node_id):
def closure(x):
return x + n + len(node_id)
return closure
node_closures.append(make_node_closure(i))
node = Node(
id=node_id,
data={
'level': depth,
'static_data': {f'key_{i}': i * 0.1 for i in range(100)},
'text': 'x' * 5000,
'nested_dict': {str(i): {str(j): i*j for j in range(50)} for i in range(50)}
},
children=[],
parent_ref=parent,
closures=node_closures
)
if depth > 0:
node.children = [
self._build_tree(depth - 1, breadth, parent=node)
for _ in range(breadth)
]
return node
def process_item(self, item):
"""Method without @task decorator for direct client.map comparison"""
return item * 2
@task
def process_item_task(self, item):
"""Same method but with @task decorator"""
return item * 2
@flow(task_runner=DaskTaskRunner())
def test_task_creation_methods():
# Get the Dask client from the task runner
context = FlowRunContext.get()
client = context.task_runner.client
# Create a complex processor instance
print("Creating VERY complex processor...")
processor = ComplexDataProcessor(depth=4, breadth=4)
# Generate items to process
items = list(range(10)) # Fewer items because it's so slow
print("\n" + "="*60)
print("COMPARISON: Complex class method")
print("="*60)
# Method 1: Direct client.map (FAST)
print("\n1. Using client.map directly...")
start_time = time.time()
futures = client.map(processor.process_item, items)
client_map_time = time.time() - start_time
print(f" Time: {client_map_time:.3f} seconds")
# Method 2: Using @task decorator (VERY SLOW)
print("\n2. Using @task decorator...")
start_time = time.time()
task_results = processor.process_item_task.map(items)
task_map_time = time.time() - start_time
print(f" Time: {task_map_time:.3f} seconds")
print(f" Slowdown: {task_map_time / client_map_time:.1f}x")
# Wait for results
print("\nGathering results...")
results1 = client.gather(futures)
results2 = task_results.result()
print(f"Results match: {results1 == results2}")
# Show the serialization overhead directly
def demonstrate_serialization_issue():
import cloudpickle
print("Demonstrating why it's slow...\n")
processor = ComplexDataProcessor(depth=3, breadth=3)
print("Serializing the complex processor object:")
start = time.time()
serialized = cloudpickle.dumps(processor)
print(f" Serialization time: {time.time() - start:.3f} seconds")
print(f" Serialized size: {len(serialized) / (1024*1024):.1f} MB")
if __name__ == "__main__":
# First show why it's slow
demonstrate_serialization_issue()
print("\n" + "="*60 + "\n")
# Run the comparison
test_task_creation_methods()
Output:
Demonstrating why it's slow...
Serializing the complex processor object:
Serialization time: 0.017 seconds
Serialized size: 1.1 MB
============================================================
11:37:50.076 | INFO | Flow run 'archetypal-quoll' - Beginning flow run 'archetypal-quoll' for flow 'test-task-creation-methods'
11:37:50.085 | INFO | Flow run 'archetypal-quoll' - View at https://app.prefect.cloud/account/1e4d7e04-0fb7-4aa3-8ef5-746e9f404f4f/workspace/849a6829-4afb-48c3-9cc2-2dc6b262fd9c/runs/flow-run/068372dc-dd1c-7a75-8000-0f83470e14de
11:37:50.086 | INFO | prefect.task_runner.dask - Creating a new Dask cluster with `distributed.deploy.local.LocalCluster`
11:37:50.364 | INFO | distributed.http.proxy - To route to workers diagnostics web server please install jupyter-server-proxy: python -m pip install jupyter-server-proxy
11:37:50.380 | INFO | distributed.scheduler - State start
11:37:50.383 | INFO | distributed.scheduler - Scheduler at: tcp://127.0.0.1:51524
11:37:50.383 | INFO | distributed.scheduler - dashboard at: http://127.0.0.1:8787/status
...
11:37:51.698 | INFO | distributed.core - Starting established connection to tcp://127.0.0.1:51548
11:37:51.699 | INFO | prefect.task_runner.dask - The Dask dashboard is available at http://127.0.0.1:8787/status
Creating VERY complex processor...
============================================================
COMPARISON: Complex class method
============================================================
1. Using client.map directly...
Time: 3.206 seconds
2. Using @task decorator...
11:38:18.586 | INFO | Task run 'process_item_task-2ba' - Finished in state Completed()
11:38:32.258 | INFO | Task run 'process_item_task-ed5' - Finished in state Completed()
...
Time: 274.437 seconds
Slowdown: 85.6x
Gathering results...
Results match: True
11:42:45.356 | INFO | Flow run 'archetypal-quoll' - Finished in state Completed()
Version info
Version: 3.4.2
API version: 0.8.4
Python version: 3.10.16
Git commit: c3c1c119
Built: Mon, May 19, 2025 04:04 PM
OS/Arch: darwin/arm64
Profile: default
Server type: cloud
Pydantic version: 2.9.2
Integrations:
prefect-kubernetes: 0.6.1
prefect-dask: 0.3.5
prefect-gcp: 0.6.4
dask, version 2025.5.1
Additional context
No response
I tried registering a custom tokenizer function with dask and that seemed to speed things up by about 3x (not claiming this is necessarily a legit implementation, just a first attempt)
@dask.tokenize.normalize_token.register(prefect.tasks.Task)
def tokenize_task(task):
"""Custom tokenizer for Prefect tasks. Speeds up creation of high volume mapped tasks."""
return (
dask.tokenize.normalize_token(prefect.tasks.Task),
task.task_key,
task.name,
task.version or "",
getattr(task.fn, "__code__", None) and task.fn.__code__.co_code.hex(),
frozenset(task.tags or set()),
)
Same flow:
Time: 84.910 seconds
Slowdown: 28.6x
(previously 274.437 seconds, 85.6x)
thank you @bnaul ! this is useful and we'll explore this as time allows