xlearn icon indicating copy to clipboard operation
xlearn copied to clipboard

How to save/load scikit-learn API model

Open ethen8181 opened this issue 6 years ago • 3 comments

Hi, team. What's the best approach right now to save and load a scikit-learn like model? Pickling doesn't seem to work. Thanks!

# the example from the tutorial
import numpy as np
import xlearn as xl
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

# Load dataset
iris_data = load_iris()
X = iris_data['data']
y = (iris_data['target'] == 2)

X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=0.3, random_state=0)

# param:
#  0. binary classification
#  1. model scale: 0.1
#  2. epoch number: 10 (auto early-stop)
#  3. learning rate: 0.1
#  4. regular lambda: 1.0
#  5. use sgd optimization method
linear_model = xl.LRModel(task='binary', init=0.1,
                          epoch=10, lr=0.1,
                          reg_lambda=1.0, opt='sgd')

# Start to train
linear_model.fit(X_train, y_train,
                 eval_set=[X_val, y_val],
                 is_lock_free=False)

# attempting to save the model
from joblib import dump
dump(linear_model, 'temp.pkl')
# ValueError: ctypes objects containing pointers cannot be pickled

ethen8181 avatar May 28 '18 16:05 ethen8181

@randxie Can you check out this issue?

aksnzhy avatar Jun 03 '18 20:06 aksnzhy

@aksnzhy It should be the issue of _XLearnModel that can not be pickled. We could either add getstate and setstate methods to XLearn class, or provide save_model and load_model methods to the sklearn interface. What do you think?

randxie avatar Jun 05 '18 23:06 randxie

I have the same problem, is there any progress on this topic?

Superhzf avatar Jun 12 '19 16:06 Superhzf