dask-xgboost
dask-xgboost copied to clipboard
Aligns training and testing data
This PR is to ensure that training and testing data have balance partitions
Closes #32
I think that this is a good start. However I think we've seen cases where the divisions are the same and yet the number of rows in each partition still differ. I think that in that case we still raise a non-informative error.
Thanks for the feedback @mrocklin!
I've added a new align_training_data function to rechunk/repartition labels so it has the same number of rows per partition as data. Since we can load all the training data into distributed memory, we can compute the chunk sizes for data and labels. If they're different, then .rechunk is called on labels accordingly.
I've also added some tests, but am running into issues with test failures. Some failures seem to be related to changes in this PR, while other failures are also in master.
For example, running pytest dask_xgboost/tests/test_core.py::test_classifier fails with a ChildProcessError in both this PR and master.
test_classifier traceback
[gw0] darwin -- Python 3.6.6 /Users/jbourbeau/miniconda/envs/quansight/bin/python
loop = <tornado.platform.asyncio.AsyncIOLoop object at 0x1c26981048>
def test_classifier(loop): # noqa
> with cluster() as (s, [a, b]):
dask_xgboost/tests/test_core.py:38:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../miniconda/envs/quansight/lib/python3.6/contextlib.py:81: in __enter__
return next(self.gen)
../../miniconda/envs/quansight/lib/python3.6/site-packages/distributed/utils_test.py:626: in cluster
scheduler_q = mp_context.Queue()
../../miniconda/envs/quansight/lib/python3.6/multiprocessing/context.py:102: in Queue
return Queue(maxsize, ctx=self.get_context())
../../miniconda/envs/quansight/lib/python3.6/multiprocessing/queues.py:42: in __init__
self._rlock = ctx.Lock()
../../miniconda/envs/quansight/lib/python3.6/multiprocessing/context.py:67: in Lock
return Lock(ctx=self.get_context())
../../miniconda/envs/quansight/lib/python3.6/multiprocessing/synchronize.py:163: in __init__
SemLock.__init__(self, SEMAPHORE, 1, 1, ctx=ctx)
../../miniconda/envs/quansight/lib/python3.6/multiprocessing/synchronize.py:81: in __init__
register(self._semlock.name)
../../miniconda/envs/quansight/lib/python3.6/multiprocessing/semaphore_tracker.py:83: in register
self._send('REGISTER', name)
../../miniconda/envs/quansight/lib/python3.6/multiprocessing/semaphore_tracker.py:90: in _send
self.ensure_running()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <multiprocessing.semaphore_tracker.SemaphoreTracker object at 0xa221cf390>
def ensure_running(self):
'''Make sure that semaphore tracker process is running.
This can be run from any process. Usually a child process will use
the semaphore created by its parent.'''
with self._lock:
if self._pid is not None:
# semaphore tracker was launched before, is it still running?
> pid, status = os.waitpid(self._pid, os.WNOHANG)
E ChildProcessError: [Errno 10] No child processes
../../miniconda/envs/quansight/lib/python3.6/multiprocessing/semaphore_tracker.py:46: ChildProcessError
While pytest dask_xgboost/tests/test_core.py::test_basic passes on master, but fails in this PR with an AssertionError: yield from wasn't used with future error. Clearly I'm doing something wrong involving the futures interface, but I'm not sure where I'm going wrong.
test_basic traceback
[gw0] darwin -- Python 3.6.6 /Users/jbourbeau/miniconda/envs/quansight/bin/python
def test_func():
del _global_workers[:]
_global_clients.clear()
active_threads_start = set(threading._active)
reset_config()
dask.config.set({'distributed.comm.timeouts.connect': '5s'})
# Restore default logging levels
# XXX use pytest hooks/fixtures instead?
for name, level in logging_levels.items():
logging.getLogger(name).setLevel(level)
result = None
workers = []
with pristine_loop() as loop:
with check_active_rpc(loop, active_rpc_timeout):
@gen.coroutine
def coro():
with dask.config.set(config):
s = False
for i in range(5):
try:
s, ws = yield start_cluster(
ncores, scheduler, loop, security=security,
Worker=Worker, scheduler_kwargs=scheduler_kwargs,
worker_kwargs=worker_kwargs)
except Exception as e:
logger.error("Failed to start gen_cluster, retryng", exc_info=True)
else:
workers[:] = ws
args = [s] + workers
break
if s is False:
raise Exception("Could not start cluster")
if client:
c = yield Client(s.address, loop=loop, security=security,
asynchronous=True, **client_kwargs)
args = [c] + args
try:
future = func(*args)
if timeout:
future = gen.with_timeout(timedelta(seconds=timeout),
future)
result = yield future
if s.validate:
s.validate_state()
finally:
if client:
yield c._close(fast=s.status == 'closed')
yield end_cluster(s, workers)
yield gen.with_timeout(timedelta(seconds=1),
cleanup_global_workers())
try:
c = yield default_client()
except ValueError:
pass
else:
yield c._close(fast=True)
raise gen.Return(result)
> result = loop.run_sync(coro, timeout=timeout * 2 if timeout else timeout)
../../miniconda/envs/quansight/lib/python3.6/site-packages/distributed/utils_test.py:909:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../miniconda/envs/quansight/lib/python3.6/site-packages/tornado/ioloop.py:576: in run_sync
return future_cell[0].result()
../../miniconda/envs/quansight/lib/python3.6/site-packages/tornado/gen.py:1147: in run
yielded = self.gen.send(value)
../../miniconda/envs/quansight/lib/python3.6/site-packages/distributed/utils_test.py:890: in coro
result = yield future
../../miniconda/envs/quansight/lib/python3.6/site-packages/tornado/gen.py:1133: in run
value = future.result()
../../miniconda/envs/quansight/lib/python3.6/site-packages/tornado/gen.py:326: in wrapper
yielded = next(result)
dask_xgboost/tests/test_core.py:144: in test_basic
dbst = yield dxgb.train(c, param, ddf, dlabels)
dask_xgboost/core.py:244: in train
data, labels = align_training_data(client, data, labels)
dask_xgboost/core.py:191: in align_training_data
data_chunks = tuple(data.map_partitions(len).compute())
../dask/dask/base.py:156: in compute
(result,) = compute(self, traverse=False, **kwargs)
../dask/dask/base.py:398: in compute
return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
../dask/dask/base.py:398: in <listcomp>
return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
../dask/dask/dataframe/core.py:74: in finalize
return _concat(results)
../dask/dask/dataframe/core.py:58: in _concat
if isinstance(first(core.flatten(args)), np.ndarray):
../../miniconda/envs/quansight/lib/python3.6/site-packages/toolz/itertoolz.py:368: in first
return next(iter(seq))
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
seq = <Future finished exception=CancelledError("('len-ebcfd41fecadb4b7f7d33c5221f4960b', 2)",)>
container = <class 'list'>
def flatten(seq, container=list):
"""
>>> list(flatten([1]))
[1]
>>> list(flatten([[1, 2], [1, 2]]))
[1, 2, 1, 2]
>>> list(flatten([[[1], [2]], [[1], [2]]]))
[1, 2, 1, 2]
>>> list(flatten(((1, 2), (1, 2)))) # Don't flatten tuples
[(1, 2), (1, 2)]
>>> list(flatten((1, 2, [3, 4]))) # support heterogeneous
[1, 2, 3, 4]
"""
if isinstance(seq, str):
yield seq
else:
> for item in seq:
E AssertionError: yield from wasn't used with future
../dask/dask/core.py:272: AssertionError
Any thoughts you may have here would be very appreciated
It would be good to verify that we compute things only once, otherwise we may load and preprocess our data many times. In practice this can be annoying. There are currently two issues stopping this:
- Within
align_training_datawe call compute on the shape twice, once for data and once for labels. In the common case where these have a common history that common history will be recomputed unnecessarily. - We then call
client.compute(which is more like persist today) within_train
We have to persist the data in memory in the _train function. Ideally we would verify alignment only after this stage when we know that it's cheap and won't result in any additional recomputation.
Generally I find things like this by trying them out on a small problem and watching the diagnostic dashboard.