hamilton icon indicating copy to clipboard operation
hamilton copied to clipboard

`MultiProcessingExecutor` pickle'ing error when using `extract_fields`

Open sT0v opened this issue 9 months ago • 2 comments

Current behavior

MultiProcessingExecutor pickle'ing error when using extract_fields.

Stack Traces

  File "c:\codebase\api\venv\Lib\site-packages\hamilton\execution\executors.py", line 204, in get_state
    self.future.result()
  File "C:\Program Files\Python312\Lib\concurrent\futures\_base.py", line 449, in result
    return self.__get_result()
           ^^^^^^^^^^^^^^^^^^^
  File "C:\Program Files\Python312\Lib\concurrent\futures\_base.py", line 401, in __get_result
    raise self._exception
  File "C:\Program Files\Python312\Lib\multiprocessing\queues.py", line 244, in _feed
    obj = _ForkingPickler.dumps(obj)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Program Files\Python312\Lib\multiprocessing\reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'extract_fields.transform_node.<locals>.dict_generator'

Screenshots

Image

Steps to replicate behavior

import pandas as pd
from typing import Dict, List, Tuple

from hamilton.htypes import Parallelizable, Collect
from hamilton.function_modifiers import extract_fields


def _load_full_dataset() -> pd.DataFrame:
    url = "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/penguins.csv"
    return pd.read_csv(url)


def kind(kinds: List[str]) -> Parallelizable[str]:
    for kind in kinds:
        yield kind


def raw_kind(kind: str) -> pd.DataFrame:
    return df[df.species == kind]


@extract_fields(dict(raw_df=pd.DataFrame, agg_df=pd.DataFrame))
def all_kinds(raw_kind: Collect[pd.DataFrame]) -> Dict[str, pd.DataFrame]:
    raw = pd.concat(raw_kind)
    return {
        "raw_df": raw,
        "agg_df": raw.groupby("species")[
            ["bill_length_mm", "bill_depth_mm", "flipper_length_mm", "body_mass_g"]
            ].mean().reset_index(),
    }

def all_data(raw_df: pd.DataFrame, agg_df: pd.DataFrame) -> Tuple[pd.DataFrame]:
    return raw_df, agg_df

if __name__ == "__main__":
    from hamilton import driver, settings
    from hamilton.execution import executors

    import __main__
    
    def debug_parallellism(mode: str) -> Tuple:
        if mode == "local":
            executor = executors.SynchronousLocalTaskExecutor()
        elif mode=="multithreading":
            executor = executors.MultiThreadingExecutor(max_tasks=4)
        elif mode=="multiprocessing":
            executor = executors.MultiProcessingExecutor(max_tasks=6)
        
        config = {settings.ENABLE_POWER_USER_MODE: True, "kinds": species}
        dr = (
            driver.Builder()
            .with_modules(__main__)
            .with_config(config)
            .enable_dynamic_execution(allow_experimental_mode=True)
            .with_remote_executor(executor)
            .build()
        )

        dag = dr.visualize_execution(["all_data"], f"parallel-penguins-{mode}-dag.png", bypass_validation=True)
        results = dr.execute(["all_data"])
        return dag, results
    
    
    df = _load_full_dataset()
    species = df.species.unique().tolist()

    modes = ["local", "multithreading", "multiprocessing"]
    for mode in modes:
        try:
            dag, results = debug_parallellism(mode)
        except Exception as e:
            print(f"Error in {mode=}: {e}")

Library & System Information

Microsoft Windows [Version 10.0.22621.4890] Python 3.12 sf-hamilton==1.87.0

Expected behavior

I was hoping this would work since the SynchronousLocalTaskExecutor works as expected. I tried MultiThreadingExecutor for my use case but it hangs longer than the sync executor which is why I am trying the MultiProcessingExecutor

sT0v avatar Feb 28 '25 14:02 sT0v

Thanks @sT0v for the clean repro! An update -- if I run all_kinds it works, but only if I remove extract_fields. Hmm.

Task failed
concurrent.futures.process._RemoteTraceback:
"""
Traceback (most recent call last):
  File "/Users/elijahbenizzy/.pyenv/versions/3.12.0/lib/python3.12/multiprocessing/queues.py", line 244, in _feed
    obj = _ForkingPickler.dumps(obj)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/elijahbenizzy/.pyenv/versions/3.12.0/lib/python3.12/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'extract_fields.transform_node.<locals>.dict_generator'
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/elijahbenizzy/.pyenv/versions/burr-3-12/lib/python3.12/site-packages/hamilton/execution/executors.py", line 193, in get_state
    self.future.result()
  File "/Users/elijahbenizzy/.pyenv/versions/3.12.0/lib/python3.12/concurrent/futures/_base.py", line 449, in result
    return self.__get_result()
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/elijahbenizzy/.pyenv/versions/3.12.0/lib/python3.12/concurrent/futures/_base.py", line 401, in __get_result
    raise self._exception
  File "/Users/elijahbenizzy/.pyenv/versions/3.12.0/lib/python3.12/multiprocessing/queues.py", line 244, in _feed
    obj = _ForkingPickler.dumps(obj)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/elijahbenizzy/.pyenv/versions/3.12.0/lib/python3.12/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
AttributeError: Can't pickle local object 'extract_fields.transform_node.<locals>.dict_generator'
-------------------------------------------------------------------
Oh no an error! Need help with Hamilton?
Join our slack and ask for help! https://join.slack.com/t/hamilton-opensource/shared_invite/zt-2niepkra8-DGKGf_tTYhXuJWBTXtIs4g
-------------------------------------------------------------------

Error in mode='multiprocessing': Can't pickle local object 'extract_fields.transform_node.<locals>.dict_generator'

elijahbenizzy avatar Feb 28 '25 17:02 elijahbenizzy

This diff gets it to work -- problem is extract_fields (as suggested initially)

< df = _load_full_dataset()
< species = df.species.unique().tolist()
<
<
26c22
< # @extract_fields(dict(raw_df=pd.DataFrame, agg_df=pd.DataFrame))
---
> @extract_fields(dict(raw_df=pd.DataFrame, agg_df=pd.DataFrame))
36,44d31
<
< def raw_df(all_kinds: dict) -> pd.DataFrame:
<     return all_kinds["raw_df"]
<
<
< def agg_df(all_kinds: dict) -> pd.DataFrame:
<     return all_kinds["agg_df"]
<
<
48d34
<
55d40
<
78a64,66
>     df = _load_full_dataset()
>     species = df.species.unique().tolist()
>

elijahbenizzy avatar Feb 28 '25 17:02 elijahbenizzy