dask-ml
dask-ml copied to clipboard
Provide wrappers for popular ML libraries
It'd be convenient to provide support for use of Keras or PyTorch models in model selection. There are two issues:
- Keras/PyTorch models don't conform to the Scikit-learn API.
- Keras models are not pickle-able.
I'm imaging this interface:
from torchvision.models import resnet18
import torch.optim as optim
from dask_ml.wrappers import PyTorchClassifier
pytorch_model = resnet18()
sklearn_model = SkorchClassifier(
model=pytorch_model,
model__alpha=1e-2, # if resnet18 had a kwarg `alpha`
optimizer=optim.SGD,
optimizer__lr=0.1,
)
Related issues/PRs Same complaint in dask/distributed: https://github.com/dask/distributed/issues/3873
I think this is possible with these wrappers:
- skorch, a library that brings the Scikit-learn API to PyTorch. This is what's used in the "Better and faster hyperparameter optimization with Dask ."
- There's an implementation of
__reduce_ex__
for Keras models in https://github.com/tensorflow/tensorflow/pull/39609/
edit these libraries are discussed below:
- adadamp, which provides distributed training and a Scikit-learn API for PyTorch models.
- saturncloud/dask-pytorch-ddp, which allows usage of Dask clusters with native PyTorch distributed code.
- SciKeras, which provides a Scikit-learn API to Keras.
SciKeras and Skorch are now mentioned in Dask-ML's documentation on wrappers (see 1 and 2).
If you're looking for a way to make Keras models conform to the scikit-learn API, check out SciKeras (full disclosure: I'm the author)
Thanks for the link Adrian. To minimize our maintenance burden, I'd aim for the goal that our model_selection
estimators work with any model implementing the scikit-learn interface, and encourage the development / use of wrappers like skorch and SciKeras.
On top of that, we have the additional burden of these models needing to work well with distributed's serialization. To the extent possible, that functionality should be in the projects themselves (making Keras models picklable) or in distributed.
How hard is it to support PyTorch/Keras fit/predict APIs? If this is as simple as making a function like the following, then I would be in favor
def fit(estimator, X, y=None):
if hasattr(estimator, "fit"):
return estimator.fit(X, y)
elif hasattr(estimator, ...): # pytorch-like
return ...
elif hasattr(estimator, ...): # keras-like
return ...
For serialization I think that we have a decent Pytorch serializer in distributed (early work from @stsievert if I recall correctly). I don't think that we have anything for Keras today.
Serialization is also maybe something that we could ask for help from the RAPIDS folks like @quasiben @jakirkham @pentschev . It's not RAPIDS obviously, but these are often GPU related and that team is familiar with these sorts of issues.
I'm not 100% sure what the goal is here (I just came from the discussion in tensorflow/tensorflow#39609) but SciKeras adds serialization support to the Keras models it wraps. Ex:
from scikeras import KerasClassifier
keras_model = ... # some keras model object, can be Sequential or Functional
wrapped_model = KerasRegressor(keras_model) # a serializable, scikit-learn api compliant estimator
So I guess you could just tell your users to wrap their Keras models before using them with dask-ml?
Serialization is definitely something RAPIDS cares about. scikeras
looks interesting -- @adriangb do you know if it forces a host to device transfer ? Does it support the __cuda_array_interface__
? If so, I believe things are a lot easier for us.
cc @JohnZed maybe pytorch serialization is something cuML would also care about
Current pytorch serialization is here: https://github.com/dask/distributed/blob/master/distributed/protocol/torch.py
It looks like it forces things to numpy though, and so may not be GPU-optimized.
Rather than scikeras I'm still curious if we can make things more torch/tf/keras-native cheaply
On Mon, Jul 13, 2020 at 8:35 AM Benjamin Zaitlen [email protected] wrote:
Serialization is definitely something RAPIDS cares about. scikeras looks interesting -- @adriangb https://github.com/adriangb do you know if it forces a host to device transfer ? Does it support the cuda_array_interface ? If so, I believe things are a lot easier for us.
cc @JohnZed https://github.com/JohnZed maybe pytorch serialization is something cuML would also care about
— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/dask/dask-ml/issues/696#issuecomment-657631883, or unsubscribe https://github.com/notifications/unsubscribe-auth/AACKZTBFNULDR66W3MEA6XDR3MSVHANCNFSM4OX6CD2Q .
So I guess you could just tell your users to wrap their Keras models before using them with dask-ml?
That's the plan. To do that, models need to support serialization and implement partial_fit
(see https://github.com/adriangb/scikeras/pull/17).
pytorch serialization is something cuML would also care about
may not be GPU-optimized.
PyTorch has serialization support, even though they recently tried to remove it! https://github.com/pytorch/pytorch/issues/38597 Skorch wraps the PyTorch, and it looks like the support GPUs: skorch/net.py#L1608.
do you know if it forces a host to device transfer ? Does it support the
__cuda_array_interface__
? If so, I believe things are a lot easier for us.
To be honest, I am not familiar with these terms. All SciKeras does is implement copy.deepcopy
and pickle
compatible serialization. It does not have a __cuda_array_interface__
method, so I think the answer is no.
models need to support serialization and implement
partial_fit
(see adriangb/scikeras#17).
Will take a look tonight!
How difficult would it be to implement pickling (like Matt did for PyTorch ( https://github.com/pytorch/pytorch/pull/9184 )) for Keras as well? There's a lot of value gained by supporting standard Python protocols. Not to say there may not be additional gains with Dask serialization. Just that having this standard protocol working would make interop with various distributed computing libraries (including Dask) easier.
How difficult would it be to implement pickling ... for Keras as well?
SciKeras has an implementation at scikeras/wrappers.py#L87. There's currently an open PR to merge this into Tensorflow/Keras master: https://github.com/tensorflow/tensorflow/pull/39609
Does that answer your question?
Here's two more PyTorch wrappers:
- adadamp, which provides usage of Dask clusters with PyTorch models and presents a Scikit-learn interface
- saturncloud/dask-pytorch-ddp, which allows use of a Dask cluster with PyTorch distributed code.