prefect icon indicating copy to clipboard operation
prefect copied to clipboard

Dataclasses cannot be passed to a task if run with DaskTaskRunner

Open binste opened this issue 3 years ago • 0 comments

First check

  • [X] I added a descriptive title to this issue.
  • [X] I used the GitHub search to find a similar issue and didn't find it.
  • [X] I searched the Prefect documentation for this issue.
  • [X] I checked that this issue is related to Prefect and not one of its dependencies.

Bug summary

If a Python dataclass is passed to a task which is submitted to a DaskTaskRunner, the task fails with a TypeError stating that the relevant parameters are not provided to the __init__ method of the class. Potentially this happens due to some serialization inbetween? The code below works when you pass dictionaries or normal Python classes. To test this, you can use the other get_numbers* functions.

Reproduction

from dataclasses import dataclass

from prefect import flow, task
from prefect_dask.task_runners import DaskTaskRunner


class Number:
    def __init__(self, value):
        self.value = value


@dataclass
class NumberDataClass:
    value: int


@task
def get_numbers():
    return [Number(value=1), Number(value=2)]


@task
def get_numbers_dict():
    return [{"value": 1}, {"value": 2}]


@task
def get_numbers_dataclass():
    return [NumberDataClass(value=1), NumberDataClass(value=2)]

@task
def print_number(x):
    print(x)


@flow(
    name="test_dask_flow",
    task_runner=DaskTaskRunner(
        cluster_kwargs={"n_workers": 2, "threads_per_worker": 1}
    ),
)
def test_dask_flow(new_parameter: int = 4):
    print("Dask flow")

    # These two work:
    # numbers = get_numbers()
    # numbers = get_numbers_dict()

    # This one does not
    numbers = get_numbers_dataclass()
    print_numbers_task = print_number.map(numbers)


if __name__ == "__main__":
    test_dask_flow()

Error

08:17:54.715 | INFO    | prefect.engine - Created flow run 'cordial-chimpanzee' for flow 'test_dask_flow'
08:17:54.715 | INFO    | prefect.task_runner.dask - Creating a new Dask cluster with `distributed.deploy.local.LocalCluster`
08:17:58.136 | INFO    | prefect.task_runner.dask - The Dask dashboard is available at http://127.0.0.1:8787/status
Dask flow
08:17:58.383 | INFO    | Flow run 'cordial-chimpanzee' - Created task run 'get_numbers_dataclass-fff3d643-0' for task 'get_numbers_dataclass'
08:17:58.383 | INFO    | Flow run 'cordial-chimpanzee' - Executing 'get_numbers_dataclass-fff3d643-0' immediately...
08:17:58.569 | INFO    | Task run 'get_numbers_dataclass-fff3d643-0' - Finished in state Completed()
08:17:58.630 | INFO    | Flow run 'cordial-chimpanzee' - Created task run 'print_number-6168cb60-0' for task 'print_number'
08:17:59.206 | INFO    | Flow run 'cordial-chimpanzee' - Submitted task run 'print_number-6168cb60-0' for execution.
08:17:59.216 | INFO    | Flow run 'cordial-chimpanzee' - Created task run 'print_number-6168cb60-1' for task 'print_number'
08:17:59.224 | INFO    | Flow run 'cordial-chimpanzee' - Submitted task run 'print_number-6168cb60-1' for execution.
2022-09-21 08:17:59,925 - distributed.worker - WARNING - Compute Failed
Key:       85e1b94d-ca4d-424e-a011-66b21f7dbd6f
Function:  begin_task_run
args:      ()
kwargs:    {'task': <prefect.tasks.Task object at 0x10c56f280>, 'task_run': TaskRun(id=UUID('bfb750b7-f690-4fda-8ca5-3f1282cccc6c'), name='print_number-6168cb60-0', flow_run_id=UUID('bb8e0a27-76b6-4000-a010-344c241cc6e2'), task_key='__main__.print_number', dynamic_key='0', cache_key=None, cache_expiration=None, task_version=None, empirical_policy=TaskRunPolicy(max_retries=0, retry_delay_seconds=0.0, retries=0, retry_delay=0), tags=[], state_id=UUID('28a0e521-240d-422c-a1c5-afb5f7859fc5'), task_inputs={'x': [TaskRunResult(input_type='task_run', id=UUID('7c4f5e58-04f0-406d-95a6-4969a622c0c4'))]}, state_type=StateType.PENDING, state_name='Pending', run_count=0, expected_start_time=DateTime(2022, 9, 21, 6, 17, 58, 578649, tzinfo=Timezone('+00:00')), next_scheduled_start_time=None, start_time=None, end_time=None, total_run_time=datetime.timedelta(0), estimated_run_time=datetime.timedelta(0), estimated_start_time_delta=datetime.timedelta(microseconds=45090), state=Pending(message=None, type=PENDING, re
Exception: 'TypeError("NumberDataClass.__init__() missing 1 required positional argument: \'value\'")'

2022-09-21 08:17:59,929 - distributed.worker - WARNING - Compute Failed
Key:       a22a66cf-10f1-4c03-a046-01f2d78c5c13
Function:  begin_task_run
args:      ()
kwargs:    {'task': <prefect.tasks.Task object at 0x10d26aa10>, 'task_run': TaskRun(id=UUID('e54028b0-3695-4cb4-a522-07c60dfb5a0e'), name='print_number-6168cb60-1', flow_run_id=UUID('bb8e0a27-76b6-4000-a010-344c241cc6e2'), task_key='__main__.print_number', dynamic_key='1', cache_key=None, cache_expiration=None, task_version=None, empirical_policy=TaskRunPolicy(max_retries=0, retry_delay_seconds=0.0, retries=0, retry_delay=0), tags=[], state_id=UUID('f52b11bf-5e00-4ee1-953f-5e423e5b5a2d'), task_inputs={'x': [TaskRunResult(input_type='task_run', id=UUID('7c4f5e58-04f0-406d-95a6-4969a622c0c4'))]}, state_type=StateType.PENDING, state_name='Pending', run_count=0, expected_start_time=DateTime(2022, 9, 21, 6, 17, 58, 595735, tzinfo=Timezone('+00:00')), next_scheduled_start_time=None, start_time=None, end_time=None, total_run_time=datetime.timedelta(0), estimated_run_time=datetime.timedelta(0), estimated_start_time_delta=datetime.timedelta(microseconds=45369), state=Pending(message=None, type=PENDING, re
Exception: 'TypeError("NumberDataClass.__init__() missing 1 required positional argument: \'value\'")'

08:17:59.939 | INFO    | Task run 'print_number-6168cb60-0' - Crash detected! Execution was interrupted by an unexpected exception.
08:17:59.979 | INFO    | Task run 'print_number-6168cb60-1' - Crash detected! Execution was interrupted by an unexpected exception.
08:18:00.546 | ERROR   | Flow run 'cordial-chimpanzee' - Finished in state Failed('2/3 states failed.')
Traceback (most recent call last):
  File "/Users/sbinder/Dropbox/Programming/playground/test_prefect/test_dask.py", line 55, in <module>
    test_dask_flow()
  File "/Users/sbinder/miniforge3/envs/py_prefect/lib/python3.10/site-packages/prefect/flows.py", line 384, in __call__
    return enter_flow_run_engine_from_flow_call(
  File "/Users/sbinder/miniforge3/envs/py_prefect/lib/python3.10/site-packages/prefect/engine.py", line 158, in enter_flow_run_engine_from_flow_call
    return anyio.run(begin_run)
  File "/Users/sbinder/miniforge3/envs/py_prefect/lib/python3.10/site-packages/anyio/_core/_eventloop.py", line 70, in run
    return asynclib.run(func, *args, **backend_options)
  File "/Users/sbinder/miniforge3/envs/py_prefect/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 292, in run
    return native_run(wrapper(), debug=debug)
  File "/Users/sbinder/miniforge3/envs/py_prefect/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/Users/sbinder/miniforge3/envs/py_prefect/lib/python3.10/asyncio/base_events.py", line 646, in run_until_complete
    return future.result()
  File "/Users/sbinder/miniforge3/envs/py_prefect/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 287, in wrapper
    return await func(*args)
  File "/Users/sbinder/miniforge3/envs/py_prefect/lib/python3.10/site-packages/prefect/client.py", line 103, in with_injected_client
    return await fn(*args, **kwargs)
  File "/Users/sbinder/miniforge3/envs/py_prefect/lib/python3.10/site-packages/prefect/engine.py", line 238, in create_then_begin_flow_run
    return state.result()
  File "/Users/sbinder/miniforge3/envs/py_prefect/lib/python3.10/site-packages/prefect/orion/schemas/states.py", line 157, in result
    state.result()
  File "/Users/sbinder/miniforge3/envs/py_prefect/lib/python3.10/site-packages/prefect/orion/schemas/states.py", line 143, in result
    raise data
  File "/Users/sbinder/miniforge3/envs/py_prefect/lib/python3.10/site-packages/prefect_dask/task_runners.py", line 246, in wait
    return await future.result(timeout=timeout)
  File "/Users/sbinder/miniforge3/envs/py_prefect/lib/python3.10/site-packages/distributed/client.py", line 289, in _result
    raise exc.with_traceback(tb)
  File "/Users/sbinder/miniforge3/envs/py_prefect/lib/python3.10/site-packages/prefect/engine.py", line 1103, in begin_task_run
    return await orchestrate_task_run(
  File "/Users/sbinder/miniforge3/envs/py_prefect/lib/python3.10/site-packages/prefect/engine.py", line 1168, in orchestrate_task_run
    resolved_parameters = await resolve_inputs(parameters)
  File "/Users/sbinder/miniforge3/envs/py_prefect/lib/python3.10/site-packages/prefect/engine.py", line 1402, in resolve_inputs
    return await run_sync_in_worker_thread(
  File "/Users/sbinder/miniforge3/envs/py_prefect/lib/python3.10/site-packages/prefect/utilities/asyncutils.py", line 57, in run_sync_in_worker_thread
    return await anyio.to_thread.run_sync(call, cancellable=True)
  File "/Users/sbinder/miniforge3/envs/py_prefect/lib/python3.10/site-packages/anyio/to_thread.py", line 31, in run_sync
    return await get_asynclib().run_sync_in_worker_thread(
  File "/Users/sbinder/miniforge3/envs/py_prefect/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 937, in run_sync_in_worker_thread
    return await future
  File "/Users/sbinder/miniforge3/envs/py_prefect/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 867, in run
    result = context.run(func, *args)
  File "/Users/sbinder/miniforge3/envs/py_prefect/lib/python3.10/site-packages/prefect/utilities/collections.py", line 307, in visit_collection
    items = [(visit_nested(k), visit_nested(v)) for k, v in expr.items()]
  File "/Users/sbinder/miniforge3/envs/py_prefect/lib/python3.10/site-packages/prefect/utilities/collections.py", line 307, in <listcomp>
    items = [(visit_nested(k), visit_nested(v)) for k, v in expr.items()]
  File "/Users/sbinder/miniforge3/envs/py_prefect/lib/python3.10/site-packages/prefect/utilities/collections.py", line 273, in visit_nested
    return visit_collection(
  File "/Users/sbinder/miniforge3/envs/py_prefect/lib/python3.10/site-packages/prefect/utilities/collections.py", line 313, in visit_collection
    result = typ(**items) if return_data else None
TypeError: NumberDataClass.__init__() missing 1 required positional argument: 'value'

Versions

Version:             2.4.0
API version:         0.8.0
Python version:      3.10.6
Git commit:          513639e8
Built:               Tue, Sep 13, 2022 2:15 PM
OS/Arch:             darwin/x86_64
Profile:             default
Server type:         hosted

prefect-dask version is 0.2.0.post1

Additional context

The Dask logs are also not properly propagated to the UI so this error does not show up in the Prefect UI which is probably already tackled in #5850

binste avatar Sep 21 '22 06:09 binste