dask-sql
dask-sql copied to clipboard
[BUG] Predictions failing to compute on results of `LIMIT` & `OFFSET`
What happened:
When attempting to make a prediction on the resulting table of a LIMIT / OFFSET, I get a failure during computation:
ValueError Traceback (most recent call last)
Input In [5], in <cell line: 1>()
----> 1 c.sql("""
2 SELECT * FROM PREDICT (
3 MODEL my_model,
4 SELECT x, y FROM timeseries LIMIT 100 OFFSET 100
5 )
6 """).compute()
File /raid/charlesb/mambaforge/envs/dask-sql/lib/python3.9/site-packages/dask/base.py:315, in DaskMethodsMixin.compute(self, **kwargs)
291 def compute(self, **kwargs):
292 """Compute this dask collection
293
294 This turns a lazy Dask collection into its in-memory equivalent.
(...)
313 dask.base.compute
314 """
--> 315 (result,) = compute(self, traverse=False, **kwargs)
316 return result
File /raid/charlesb/mambaforge/envs/dask-sql/lib/python3.9/site-packages/dask/base.py:598, in compute(traverse, optimize_graph, scheduler, get, *args, **kwargs)
595 keys.append(x.__dask_keys__())
596 postcomputes.append(x.__dask_postcompute__())
--> 598 results = schedule(dsk, keys, **kwargs)
599 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
File /raid/charlesb/mambaforge/envs/dask-sql/lib/python3.9/site-packages/dask/threaded.py:89, in get(dsk, keys, cache, num_workers, pool, **kwargs)
86 elif isinstance(pool, multiprocessing.pool.Pool):
87 pool = MultiprocessingPoolExecutor(pool)
---> 89 results = get_async(
90 pool.submit,
91 pool._max_workers,
92 dsk,
93 keys,
94 cache=cache,
95 get_id=_thread_get_id,
96 pack_exception=pack_exception,
97 **kwargs,
98 )
100 # Cleanup pools associated to dead threads
101 with pools_lock:
File /raid/charlesb/mambaforge/envs/dask-sql/lib/python3.9/site-packages/dask/local.py:511, in get_async(submit, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, chunksize, **kwargs)
509 _execute_task(task, data) # Re-execute locally
510 else:
--> 511 raise_exception(exc, tb)
512 res, worker_id = loads(res_info)
513 state["cache"][key] = res
File /raid/charlesb/mambaforge/envs/dask-sql/lib/python3.9/site-packages/dask/local.py:319, in reraise(exc, tb)
317 if exc.__traceback__ is not tb:
318 raise exc.with_traceback(tb)
--> 319 raise exc
File /raid/charlesb/mambaforge/envs/dask-sql/lib/python3.9/site-packages/dask/local.py:224, in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
222 try:
223 task, data = loads(task_info)
--> 224 result = _execute_task(task, data)
225 id = get_id()
226 result = dumps((result, id))
File /raid/charlesb/mambaforge/envs/dask-sql/lib/python3.9/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
115 func, args = arg[0], arg[1:]
116 # Note: Don't assign the subtask results to a variable. numpy detects
117 # temporaries by their reference count and can execute certain
118 # operations in-place.
--> 119 return func(*(_execute_task(a, cache) for a in args))
120 elif not ishashable(arg):
121 return arg
File /raid/charlesb/mambaforge/envs/dask-sql/lib/python3.9/site-packages/dask/optimization.py:990, in SubgraphCallable.__call__(self, *args)
988 if not len(args) == len(self.inkeys):
989 raise ValueError("Expected %d args, got %d" % (len(self.inkeys), len(args)))
--> 990 return core.get(self.dsk, self.outkey, dict(zip(self.inkeys, args)))
File /raid/charlesb/mambaforge/envs/dask-sql/lib/python3.9/site-packages/dask/core.py:149, in get(dsk, out, cache)
147 for key in toposort(dsk):
148 task = dsk[key]
--> 149 result = _execute_task(task, cache)
150 cache[key] = result
151 result = _execute_task(out, cache)
File /raid/charlesb/mambaforge/envs/dask-sql/lib/python3.9/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
115 func, args = arg[0], arg[1:]
116 # Note: Don't assign the subtask results to a variable. numpy detects
117 # temporaries by their reference count and can execute certain
118 # operations in-place.
--> 119 return func(*(_execute_task(a, cache) for a in args))
120 elif not ishashable(arg):
121 return arg
File /raid/charlesb/mambaforge/envs/dask-sql/lib/python3.9/site-packages/dask/utils.py:41, in apply(func, args, kwargs)
39 def apply(func, args, kwargs=None):
40 if kwargs:
---> 41 return func(*args, **kwargs)
42 else:
43 return func(*args)
File /raid/charlesb/mambaforge/envs/dask-sql/lib/python3.9/site-packages/dask/dataframe/core.py:6626, in apply_and_enforce(*args, **kwargs)
6624 func = kwargs.pop("_func")
6625 meta = kwargs.pop("_meta")
-> 6626 df = func(*args, **kwargs)
6627 if is_dataframe_like(df) or is_series_like(df) or is_index_like(df):
6628 if not len(df):
File /raid/charlesb/mambaforge/envs/dask-sql/lib/python3.9/site-packages/dask_ml/wrappers.py:630, in _predict(part, estimator)
629 def _predict(part, estimator):
--> 630 return estimator.predict(part)
File /raid/charlesb/mambaforge/envs/dask-sql/lib/python3.9/site-packages/sklearn/ensemble/_gb.py:1452, in GradientBoostingClassifier.predict(self, X)
1437 def predict(self, X):
1438 """Predict class for X.
1439
1440 Parameters
(...)
1450 The predicted values.
1451 """
-> 1452 raw_predictions = self.decision_function(X)
1453 encoded_labels = self._loss._raw_prediction_to_decision(raw_predictions)
1454 return self.classes_.take(encoded_labels, axis=0)
File /raid/charlesb/mambaforge/envs/dask-sql/lib/python3.9/site-packages/sklearn/ensemble/_gb.py:1405, in GradientBoostingClassifier.decision_function(self, X)
1386 def decision_function(self, X):
1387 """Compute the decision function of ``X``.
1388
1389 Parameters
(...)
1403 array of shape (n_samples,).
1404 """
-> 1405 X = self._validate_data(
1406 X, dtype=DTYPE, order="C", accept_sparse="csr", reset=False
1407 )
1408 raw_predictions = self._raw_predict(X)
1409 if raw_predictions.shape[1] == 1:
File /raid/charlesb/mambaforge/envs/dask-sql/lib/python3.9/site-packages/sklearn/base.py:577, in BaseEstimator._validate_data(self, X, y, reset, validate_separately, **check_params)
575 raise ValueError("Validation should be done on X, y or both.")
576 elif not no_val_X and no_val_y:
--> 577 X = check_array(X, input_name="X", **check_params)
578 out = X
579 elif no_val_X and not no_val_y:
File /raid/charlesb/mambaforge/envs/dask-sql/lib/python3.9/site-packages/sklearn/utils/validation.py:909, in check_array(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator, input_name)
907 n_samples = _num_samples(array)
908 if n_samples < ensure_min_samples:
--> 909 raise ValueError(
910 "Found array with %d sample(s) (shape=%s) while a"
911 " minimum of %d is required%s."
912 % (n_samples, array.shape, ensure_min_samples, context)
913 )
915 if ensure_min_features > 0 and array.ndim == 2:
916 n_features = array.shape[1]
ValueError: Found array with 0 sample(s) (shape=(0, 2)) while a minimum of 1 is required by GradientBoostingClassifier.
What you expected to happen:
Would expect this prediction to succeed, as it succeeds without the LIMIT / OFFSET and glancing at the input table in the predict plugin, seems like the limit is working as expected.
Minimal Complete Verifiable Example:
from dask_sql import Context
from dask.datasets import timeseries
ddf = timeseries()
c = Context()
c.create_table("timeseries", ddf)
c.sql(
"""
CREATE MODEL my_model WITH (
model_class = 'sklearn.ensemble.GradientBoostingClassifier',
wrap_predict = True,
target_column = 'target'
) AS (
SELECT x, y, x*y > 0 AS target
FROM timeseries
LIMIT 100
)
"""
)
res = c.sql("""
SELECT * FROM PREDICT (
MODEL my_model,
SELECT x, y FROM timeseries LIMIT 100 OFFSET 100
)
""")
res.compute()
Anything else we need to know?:
Environment:
- dask-sql version: latest
main - Python version: 3.9
- Operating System: ubuntu20.04
- Install method (conda, pip, source): source