dask-ml icon indicating copy to clipboard operation
dask-ml copied to clipboard

GridSearch Error: KeyError: 'data'

Open rileyhun opened this issue 5 years ago • 17 comments

I am getting the following error when running a gridsearch on dask distributed back-end. This error is nonexistent when just running sklearn gridsearch on single core local machine. I don't know where that KeyError is coming from; I don't have anything in my pipeline that references the key 'data'.

Here is the full error traceback I am getting:

[Parallel(n_jobs=-1)]: Using backend DaskDistributedBackend with 38 concurrent workers.
distributed.client - ERROR - Error in callback <function DaskDistributedBackend.apply_async.<locals>.callback_wrapper at 0x11c4a8f28> of <Future: finished, type: builtins.list, key: _fit_and_score-batch-7c58d371c94649d0a8ed3a11682660d9>:
Traceback (most recent call last):
  File "/Users/rihun/anaconda3/envs/dask_env/lib/python3.7/site-packages/distributed/client.py", line 285, in execute_callback
    fn(fut)
  File "/Users/rihun/anaconda3/envs/dask_env/lib/python3.7/site-packages/joblib/_dask.py", line 260, in callback_wrapper
    result = future.result()
  File "/Users/rihun/anaconda3/envs/dask_env/lib/python3.7/site-packages/distributed/client.py", line 217, in result
    result = self.client.sync(self._result, callback_timeout=timeout, raiseit=False)
  File "/Users/rihun/anaconda3/envs/dask_env/lib/python3.7/site-packages/distributed/client.py", line 780, in sync
    self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
  File "/Users/rihun/anaconda3/envs/dask_env/lib/python3.7/site-packages/distributed/utils.py", line 348, in sync
    raise exc.with_traceback(tb)
  File "/Users/rihun/anaconda3/envs/dask_env/lib/python3.7/site-packages/distributed/utils.py", line 332, in f
    result[0] = yield future
  File "/Users/rihun/anaconda3/envs/dask_env/lib/python3.7/site-packages/tornado/gen.py", line 735, in run
    value = future.result()
  File "/Users/rihun/anaconda3/envs/dask_env/lib/python3.7/site-packages/distributed/client.py", line 242, in _result
    result = await self.client._gather([self])
  File "/Users/rihun/anaconda3/envs/dask_env/lib/python3.7/site-packages/distributed/client.py", line 1781, in _gather
    response = await future
  File "/Users/rihun/anaconda3/envs/dask_env/lib/python3.7/site-packages/distributed/client.py", line 1832, in _gather_remote
    response = await retry_operation(self.scheduler.gather, keys=keys)
  File "/Users/rihun/anaconda3/envs/dask_env/lib/python3.7/site-packages/distributed/utils_comm.py", line 391, in retry_operation
    operation=operation,
  File "/Users/rihun/anaconda3/envs/dask_env/lib/python3.7/site-packages/distributed/utils_comm.py", line 379, in retry
    return await coro()
  File "/Users/rihun/anaconda3/envs/dask_env/lib/python3.7/site-packages/distributed/core.py", line 757, in send_recv_from_rpc
    result = await send_recv(comm=comm, op=key, **kwargs)
  File "/Users/rihun/anaconda3/envs/dask_env/lib/python3.7/site-packages/distributed/core.py", line 556, in send_recv
    raise exc.with_traceback(tb)
  File "/opt/conda/lib/python3.7/site-packages/distributed/core.py", line 412, in handle_comm
  File "/opt/conda/lib/python3.7/site-packages/distributed/scheduler.py", line 2792, in gather
  File "/opt/conda/lib/python3.7/site-packages/distributed/utils_comm.py", line 87, in gather_from_workers
KeyError: 'data'

Sample Dataset

entity_name classification
great tech other
xfone communication ltd other
coventrys other
pt invensys indonesia other
massillon cable tv inc other
city of New York government
police department government
ministry of finance government
US Navy military
US Army military
AFB military

Code Example

import sys
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.pipeline import FeatureUnion
from sklearn.linear_model import LogisticRegression
import pandas as pd
import numpy as np
from sklearn.pipeline import Pipeline
from sklearn.svm import SVC
import re
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
import time
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
import string
from dask.distributed import Client
import joblib
import logging
from sklearn.model_selection import StratifiedKFold

data = pd.read_csv('https://raw.githubusercontent.com/rileyhun/dask/master/training_data_03_23_u.csv')
X_train, X_test, y_train, y_test = train_test_split(data['entity_name'], data['classification'], test_size=0.3, random_state=123)

vec_transformer = FeatureUnion([
                ('word_name',
                  Pipeline([
                            ('tfidf',
                              TfidfVectorizer(sublinear_tf=False,
                                            smooth_idf=False,
                                            use_idf=1,
                                            min_df=2,
                                            preprocessor=lambda x: re.sub("[" + string.punctuation + " +" + "]+", " ", x.lower()),
                                            analyzer='char_wb',
                                            token_pattern=r'\S+',
                                            ngram_range=(2,10),
                                            dtype=np.float32))]))
])

pipeline = Pipeline([
    ('vectorizer', vec_transformer),
    ('model', LogisticRegression())
])

client = Client('<IP Address>:<Port>')

param_grid = {
    "model__C": [1, 3],
    "model__tol": [0.001, 0.01]
}

clf = GridSearchCV(pipeline,
                           param_grid,
                           verbose=8,
                           cv=3,
                           scoring='f1_weighted',
                           refit=True)

with joblib.parallel_backend('dask'):
    clf.fit(X_train, y_train)

There are no conflicts between scheduler, client and the dask workers.

rileyhun avatar Apr 06 '20 23:04 rileyhun

@rileyhun looks like there's some missing imports. Can you fill those out?

And is this a minimal example? Do you need the timing stuff, print statements, etc? See http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports

It also looks like X_train isn't defined.

TomAugspurger avatar Apr 07 '20 01:04 TomAugspurger

@TomAugspurger Added more details to the original post

rileyhun avatar Apr 07 '20 15:04 rileyhun

Thanks @rileyhun. It seems like data is undefined.

TomAugspurger avatar Apr 07 '20 15:04 TomAugspurger

@TomAugspurger, I made one more edit to the original comment -- I am defining data

rileyhun avatar Apr 07 '20 15:04 rileyhun

@rileyhun see http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports for writing bug reports. I don't have that CSV file. Since the issue isn't with reading a CSV you could ideally create the dataset in memory.

TomAugspurger avatar Apr 07 '20 17:04 TomAugspurger

As mentioned in the original post, the grid search works without dask as the back-end. I am now getting this error when I run it again using dask:

ValueError: X has 205757 features per sample; expecting 206501

Here is a snippet of the dataset:

entity_name classification
great tech other
xfone communication ltd other
coventrys other
pt invensys indonesia other
massillon cable tv inc other
city of New York government
police department government
ministry of finance government
US Navy military
US Army military
AFB military

rileyhun avatar Apr 07 '20 19:04 rileyhun

Let me know you can provide a reproducible example.

TomAugspurger avatar Apr 07 '20 19:04 TomAugspurger

Okay re-ran a third time, and getting the same error.

ValueError: X has 207586 features per sample; expecting 205996

The search space I am using is just 2 params:

param_grid = {
    "model__tol": [0.001, 0.01]
}

I am using Python 3.7.3 and Dask 2.14

Is Dask Grid Search always supposed to outperform Loky Backend? It's also noticeably slower even though I'm using a cluster with 5 dask workers, each with 12 cpus

rileyhun avatar Apr 07 '20 20:04 rileyhun

I won't be able to help until you provide a minimal, reproducible example.

On Tue, Apr 7, 2020 at 3:38 PM Riley Hun [email protected] wrote:

Okay re-ran a third time, and getting the same error.

ValueError: X has 207586 features per sample; expecting 205996

The search space I am using is just 2 params:

param_grid = { "model__tol": [0.001, 0.01] }

I am using Python 3.7.3 and Dask 2.14

Is Dask Grid Search always supposed to outperform Loky Backend? It's also noticeably slower even though I'm using a cluster with 5 dask workers, each with 12 cpus

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/dask/dask-ml/issues/636#issuecomment-610608263, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAKAOIVSG6BX7E3OMCZQXO3RLOFMXANCNFSM4MCVHXKA .

TomAugspurger avatar Apr 07 '20 20:04 TomAugspurger

The code under Code Example is copy-pasteable. You just need to change the cluster IP endpoint.

rileyhun avatar Apr 07 '20 21:04 rileyhun

@rileyhun Why ngram_range=(2, 10)? That's a ton of n-grams, and results in a large memory and computation cost. I think ngram_range=(1, 4) is typical (or some number smaller than 4). When I set ngram_range=(2, 4) the error disappears.

It looks like the number of features are changing, which is alarming. I'm not sure why.

In a distributed context, a HashingVectorizer is often preferred over CountVectorizer/TfIdfVectorizer because it's stateless.

stsievert avatar Apr 08 '20 00:04 stsievert

@rileyhun Why ngram_range=(2, 10)? That's a ton of n-grams, and results in a large memory and computation cost. I think ngram_range=(1, 4) is typical (or some number smaller than 4). When I set ngram_range=(2, 4) the error disappears.

It looks like the number of features are changing, which is alarming. I'm not sure why.

In a distributed context, a HashingVectorizer is often preferred over CountVectorizer/TfIdfVectorizer because it's stateless.

Keep in mind that I'm using character n-grams, not word n-grams. As such, I've found that the (2, 10) range is good at picking up deviations in spelling. I could try a smaller range though and re-run and see if that impacts the accuracy.

I am not an expert, but during cross validation, would the number of features change due different assortment of entity names?

I'll also look into HashingVectorizer.

Thanks!

rileyhun avatar Apr 08 '20 00:04 rileyhun

character n-grams, not word n-grams.

Whoops, I missed that. Never mind.

during cross validation, would the number of features change due different assortment of entity names?

I would expect that because different words will be given to different cv splits, but I'm not seeing why that's an issue. The code runs when fine with joblib.parallel_backend('dask') is commented out.

I think the next steps will come down finding a single representative example. http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports has some tips. I'd start by commenting various things out and seeing how far I can go.

stsievert avatar Apr 08 '20 01:04 stsievert

I re-ran using a smaller n-gram range and also using hashvectorizer instead, and I didn't run into this error, thus far.

Thanks for these tips! Appreciate it!

rileyhun avatar Apr 08 '20 02:04 rileyhun

I ran into a similar bug with HyperbandSearchCV. It starts with client.compute(fit_params) and ends in the same error (KeyError: 'data'). Here's the traceback:

----------------------------------------------------------------
KeyError                       Traceback (most recent call last)
<ipython-input-15-d6b9e588cd8f> in async-def-wrapper()

~/anaconda3/envs/skorch/lib/python3.7/site-packages/dask_ml/model_selection/_incremental.py in fit(self, X, y, **fit_params)
    981                 "Specify fits_per_score={} instead".format(self.scores_per_fit)
    982             )
--> 983         return super(IncrementalSearchCV, self).fit(X, y=y, **fit_params)
    984 
    985     def _get_params(self):

~/anaconda3/envs/skorch/lib/python3.7/site-packages/dask_ml/model_selection/_incremental.py in fit(self, X, y, **fit_params)
    671 
    672         with context:
--> 673             return default_client().sync(self._fit, X, y, **fit_params)
    674 
    675     @if_delegate_has_method(delegate=("best_estimator_", "estimator"))

~/anaconda3/envs/skorch/lib/python3.7/site-packages/distributed/client.py in sync(self, func, asynchronous, callback_timeout, *args, **kwargs)
    814         else:
    815             return sync(
--> 816                 self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
    817             )
    818 

~/anaconda3/envs/skorch/lib/python3.7/site-packages/distributed/utils.py in sync(loop, func, callback_timeout, *args, **kwargs)
    345     if error[0]:
    346         typ, exc, tb = error[0]
--> 347         raise exc.with_traceback(tb)
    348     else:
    349         return result[0]

~/anaconda3/envs/skorch/lib/python3.7/site-packages/distributed/utils.py in f()
    329             if callback_timeout is not None:
    330                 future = asyncio.wait_for(future, callback_timeout)
--> 331             result[0] = yield future
    332         except Exception as exc:
    333             error[0] = sys.exc_info()

~/anaconda3/envs/skorch/lib/python3.7/site-packages/tornado/gen.py in run(self)
    733 
    734                     try:
--> 735                         value = future.result()
    736                     except Exception:
    737                         exc_info = sys.exc_info()

~/anaconda3/envs/skorch/lib/python3.7/site-packages/tornado/gen.py in run(self)
    740                     if exc_info is not None:
    741                         try:
--> 742                             yielded = self.gen.throw(*exc_info)  # type: ignore
    743                         finally:
    744                             # Break up a reference to itself

~/anaconda3/envs/skorch/lib/python3.7/site-packages/dask_ml/model_selection/_incremental.py in _fit(self, X, y, **fit_params)
    623             random_state=self.random_state,
    624             verbose=self.verbose,
--> 625             prefix=self.prefix,
    626         )
    627         results = self._process_results(results)

~/anaconda3/envs/skorch/lib/python3.7/site-packages/tornado/gen.py in run(self)
    733 
    734                     try:
--> 735                         value = future.result()
    736                     except Exception:
    737                         exc_info = sys.exc_info()

~/anaconda3/envs/skorch/lib/python3.7/site-packages/tornado/gen.py in run(self)
    740                     if exc_info is not None:
    741                         try:
--> 742                             yielded = self.gen.throw(*exc_info)  # type: ignore
    743                         finally:
    744                             # Break up a reference to itself

~/anaconda3/envs/skorch/lib/python3.7/site-packages/dask_ml/model_selection/_incremental.py in _fit(model, params, X_train, y_train, X_test, y_test, additional_calls, fit_params, scorer, random_state, verbose, prefix)
    156 
    157     # assume everything in fit_params is small and make it concrete
--> 158     fit_params = yield client.compute(fit_params)
    159 
    160     # Convert testing data into a single element on the cluster

~/anaconda3/envs/skorch/lib/python3.7/site-packages/tornado/gen.py in run(self)
    733 
    734                     try:
--> 735                         value = future.result()
    736                     except Exception:
    737                         exc_info = sys.exc_info()

~/anaconda3/envs/skorch/lib/python3.7/asyncio/tasks.py in _wrap_awaitable(awaitable)
    628     that will later be wrapped in a Task by ensure_future().
    629     """
--> 630     return (yield from awaitable.__await__())
    631 
    632 

~/anaconda3/envs/skorch/lib/python3.7/site-packages/distributed/client.py in _result(self, raiseit)
    238                 return exception
    239         else:
--> 240             result = await self.client._gather([self])
    241             return result[0]
    242 

~/anaconda3/envs/skorch/lib/python3.7/site-packages/distributed/client.py in _gather(self, futures, errors, direct, local_worker)
   1853                 else:
   1854                     self._gather_future = future
-> 1855                 response = await future
   1856 
   1857             if response["status"] == "error":

~/anaconda3/envs/skorch/lib/python3.7/site-packages/distributed/client.py in _gather_remote(self, direct, local_worker)
   1904 
   1905             else:  # ask scheduler to gather data for us
-> 1906                 response = await retry_operation(self.scheduler.gather, keys=keys)
   1907 
   1908         return response

~/anaconda3/envs/skorch/lib/python3.7/site-packages/distributed/utils_comm.py in retry_operation(coro, operation, *args, **kwargs)
    388         delay_min=retry_delay_min,
    389         delay_max=retry_delay_max,
--> 390         operation=operation,
    391     )

~/anaconda3/envs/skorch/lib/python3.7/site-packages/distributed/utils_comm.py in retry(coro, count, delay_min, delay_max, jitter_fraction, retry_on_exceptions, operation)
    368                 delay *= 1 + random.random() * jitter_fraction
    369             await asyncio.sleep(delay)
--> 370     return await coro()
    371 
    372 

~/anaconda3/envs/skorch/lib/python3.7/site-packages/distributed/core.py in send_recv_from_rpc(**kwargs)
    748             name, comm.name = comm.name, "ConnectionPool." + key
    749             try:
--> 750                 result = await send_recv(comm=comm, op=key, **kwargs)
    751             finally:
    752                 self.pool.reuse(self.addr, comm)

~/anaconda3/envs/skorch/lib/python3.7/site-packages/distributed/core.py in send_recv(comm, reply, serializers, deserializers, **kwargs)
    547         if comm.deserialize:
    548             typ, exc, tb = clean_exception(**response)
--> 549             raise exc.with_traceback(tb)
    550         else:
    551             raise Exception(response["text"])

/home/ubuntu/miniconda3/envs/skorch/lib/python3.7/site-packages/distributed/core.py in handle_comm()

/home/ubuntu/miniconda3/envs/skorch/lib/python3.7/site-packages/distributed/scheduler.py in gather()

/home/ubuntu/miniconda3/envs/skorch/lib/python3.7/site-packages/distributed/utils_comm.py in gather_from_workers()

KeyError: 'data'

I've done some debugging, and have resolved some issues (making sure valid parameters are passed, etc). I haven't seen this error since; I'll report again if I do.

stsievert avatar May 28 '20 19:05 stsievert

Ran into this error as well... Have you made progress on getting around this @stsievert or @rileyhun ?

vinodshanbhag avatar Dec 20 '21 19:12 vinodshanbhag

@vinodshanbhag as I mentioned in https://github.com/dask/dask-ml/issues/636#issuecomment-635544754, I got around it by cleaning my workflow "(passing valid parameters, etc)." It'd be great if you have a minimal working example (http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports).

stsievert avatar Dec 20 '21 20:12 stsievert