prefect icon indicating copy to clipboard operation
prefect copied to clipboard

Slow dask task creation for complex functions

Open bnaul opened this issue 7 months ago • 2 comments

Bug summary

I've been seeing very slow task creation (~1/s) when .maping a classmethod from a not particularly crazy class. Using plain old Dask client.map() on the same function creates the tasks several hundred times faster. For a reproducible example I created a more monstrous method that is somewhat slow even for Dask's .map but incredibly slow for Prefect's. Obviously this is a fairly pathological case but hopefully it's useful for diagnosing where the gap between the pure Dask and Prefect+Dask performance is coming from.

import time
from dataclasses import dataclass
from typing import Any, Dict, List

from prefect import flow, task
from prefect.context import FlowRunContext
from prefect_dask import DaskTaskRunner


@dataclass
class Node:
    """A node in our complex object tree"""

    id: str
    data: Dict[str, Any]
    children: List['Node']
    parent_ref: Any
    closures: List[Any]


class ComplexDataProcessor:
    def __init__(self, depth=5, breadth=5):
        # Create deeply nested tree with circular references
        self.root = self._build_tree(depth, breadth)
        self.root.parent_ref = self

        # Add lots of closures that capture self
        self.closures = []
        for i in range(100):
            captured_data = f"data_{i}" * 100

            def make_closure(n, obj=self):
                def inner(x):
                    return x + n + len(obj.root.id) + len(captured_data)
                return inner
            self.closures.append(make_closure(i))

        # Add lambdas with captured state
        self.lambdas = {
            f"lambda_{i}": lambda x, i=i, obj=self: x + i + len(obj.root.data)
            for i in range(100)
        }

        # Add recursive/circular references
        self.recursive_dict = {'self': self, 'root': self.root}
        for i in range(50):
            self.recursive_dict[f'level_{i}'] = {
                'parent': self.recursive_dict,
                'processor': self,
                'data': {'nested': self.recursive_dict},
                'more_refs': [self.recursive_dict for _ in range(10)]
            }

        # Add regular classes with complex initialization
        self.complex_objects = []
        for i in range(50):
            obj = type('ComplexObj', (), {
                'value': i,
                'get_processor': lambda self, p=self: p,
                'compute': lambda self, x, p=self: x + len(p.root.id)
            })()
            obj.processor_ref = self
            obj.recursive_ref = self.recursive_dict
            self.complex_objects.append(obj)

        # Add custom objects with methods
        class CustomObj:
            def __init__(self, processor, index):
                self.processor = processor
                self.index = index
                self.data = {'proc': processor, 'idx': index}

            def compute(self, x):
                return x + self.index + len(self.processor.root.id)

        self.custom_objects = [CustomObj(self, i) for i in range(50)]

        # Add nested functions
        def outer_func(x):
            def middle_func(y):
                def inner_func(z):
                    return x + y + z + len(self.root.id)
                return inner_func
            return middle_func

        self.nested_funcs = [outer_func(i) for i in range(50)]

    def _build_tree(self, depth, breadth, parent=None):
        """Build a tree with lots of complex references"""
        node_id = f"node_depth{depth}_{'x' * 50}"

        # Create closures specific to this node
        node_closures = []
        for i in range(10):
            def make_node_closure(n, node_id=node_id):
                def closure(x):
                    return x + n + len(node_id)
                return closure
            node_closures.append(make_node_closure(i))

        node = Node(
            id=node_id,
            data={
                'level': depth,
                'static_data': {f'key_{i}': i * 0.1 for i in range(100)},
                'text': 'x' * 5000,
                'nested_dict': {str(i): {str(j): i*j for j in range(50)} for i in range(50)}
            },
            children=[],
            parent_ref=parent,
            closures=node_closures
        )

        if depth > 0:
            node.children = [
                self._build_tree(depth - 1, breadth, parent=node)
                for _ in range(breadth)
            ]

        return node

    def process_item(self, item):
        """Method without @task decorator for direct client.map comparison"""
        return item * 2

    @task
    def process_item_task(self, item):
        """Same method but with @task decorator"""
        return item * 2


@flow(task_runner=DaskTaskRunner())
def test_task_creation_methods():
    # Get the Dask client from the task runner
    context = FlowRunContext.get()
    client = context.task_runner.client

    # Create a complex processor instance
    print("Creating VERY complex processor...")
    processor = ComplexDataProcessor(depth=4, breadth=4)

    # Generate items to process
    items = list(range(10))  # Fewer items because it's so slow

    print("\n" + "="*60)
    print("COMPARISON: Complex class method")
    print("="*60)

    # Method 1: Direct client.map (FAST)
    print("\n1. Using client.map directly...")
    start_time = time.time()

    futures = client.map(processor.process_item, items)

    client_map_time = time.time() - start_time
    print(f"   Time: {client_map_time:.3f} seconds")

    # Method 2: Using @task decorator (VERY SLOW)
    print("\n2. Using @task decorator...")
    start_time = time.time()

    task_results = processor.process_item_task.map(items)

    task_map_time = time.time() - start_time
    print(f"   Time: {task_map_time:.3f} seconds")
    print(f"   Slowdown: {task_map_time / client_map_time:.1f}x")

    # Wait for results
    print("\nGathering results...")
    results1 = client.gather(futures)
    results2 = task_results.result()

    print(f"Results match: {results1 == results2}")


# Show the serialization overhead directly
def demonstrate_serialization_issue():
    import cloudpickle

    print("Demonstrating why it's slow...\n")

    processor = ComplexDataProcessor(depth=3, breadth=3)

    print("Serializing the complex processor object:")
    start = time.time()
    serialized = cloudpickle.dumps(processor)
    print(f"  Serialization time: {time.time() - start:.3f} seconds")
    print(f"  Serialized size: {len(serialized) / (1024*1024):.1f} MB")


if __name__ == "__main__":
    # First show why it's slow
    demonstrate_serialization_issue()

    print("\n" + "="*60 + "\n")

    # Run the comparison
    test_task_creation_methods()

Output:

Demonstrating why it's slow...

Serializing the complex processor object:
  Serialization time: 0.017 seconds
  Serialized size: 1.1 MB

============================================================

11:37:50.076 | INFO    | Flow run 'archetypal-quoll' - Beginning flow run 'archetypal-quoll' for flow 'test-task-creation-methods'
11:37:50.085 | INFO    | Flow run 'archetypal-quoll' - View at https://app.prefect.cloud/account/1e4d7e04-0fb7-4aa3-8ef5-746e9f404f4f/workspace/849a6829-4afb-48c3-9cc2-2dc6b262fd9c/runs/flow-run/068372dc-dd1c-7a75-8000-0f83470e14de
11:37:50.086 | INFO    | prefect.task_runner.dask - Creating a new Dask cluster with `distributed.deploy.local.LocalCluster`
11:37:50.364 | INFO    | distributed.http.proxy - To route to workers diagnostics web server please install jupyter-server-proxy: python -m pip install jupyter-server-proxy
11:37:50.380 | INFO    | distributed.scheduler - State start
11:37:50.383 | INFO    | distributed.scheduler -   Scheduler at:     tcp://127.0.0.1:51524
11:37:50.383 | INFO    | distributed.scheduler -   dashboard at:  http://127.0.0.1:8787/status
...
11:37:51.698 | INFO    | distributed.core - Starting established connection to tcp://127.0.0.1:51548
11:37:51.699 | INFO    | prefect.task_runner.dask - The Dask dashboard is available at http://127.0.0.1:8787/status
Creating VERY complex processor...

============================================================
COMPARISON: Complex class method
============================================================

1. Using client.map directly...
   Time: 3.206 seconds

2. Using @task decorator...
11:38:18.586 | INFO    | Task run 'process_item_task-2ba' - Finished in state Completed()
11:38:32.258 | INFO    | Task run 'process_item_task-ed5' - Finished in state Completed()
...
   Time: 274.437 seconds
   Slowdown: 85.6x

Gathering results...
Results match: True

11:42:45.356 | INFO    | Flow run 'archetypal-quoll' - Finished in state Completed()

Version info

Version:             3.4.2
API version:         0.8.4
Python version:      3.10.16
Git commit:          c3c1c119
Built:               Mon, May 19, 2025 04:04 PM
OS/Arch:             darwin/arm64
Profile:             default
Server type:         cloud
Pydantic version:    2.9.2
Integrations:
  prefect-kubernetes: 0.6.1
  prefect-dask:      0.3.5
  prefect-gcp:       0.6.4



dask, version 2025.5.1

Additional context

No response

bnaul avatar May 28 '25 15:05 bnaul

I tried registering a custom tokenizer function with dask and that seemed to speed things up by about 3x (not claiming this is necessarily a legit implementation, just a first attempt)

    @dask.tokenize.normalize_token.register(prefect.tasks.Task)
    def tokenize_task(task):
        """Custom tokenizer for Prefect tasks. Speeds up creation of high volume mapped tasks."""
        return (
            dask.tokenize.normalize_token(prefect.tasks.Task),
            task.task_key,
            task.name,
            task.version or "",
            getattr(task.fn, "__code__", None) and task.fn.__code__.co_code.hex(),
            frozenset(task.tags or set()),
        )

Same flow:

   Time: 84.910 seconds
   Slowdown: 28.6x

(previously 274.437 seconds, 85.6x)

bnaul avatar May 28 '25 15:05 bnaul

thank you @bnaul ! this is useful and we'll explore this as time allows

zzstoatzz avatar May 28 '25 16:05 zzstoatzz