tffm icon indicating copy to clipboard operation
tffm copied to clipboard

fix: exception is not imported sklearn.exceptions.NotFittedError

Open kman0 opened this issue 6 years ago • 4 comments

Currently the error is "NameError: name 'sklearn' is not defined" which is not the intended one as sklearn is not in scope. So adding a import statement.

kman0 avatar Feb 17 '18 01:02 kman0

Hi! Thanks for contribution! Will check it carefully on weekends.

geffy avatar Feb 20 '18 13:02 geffy

Hi, @kman0! It seems like as-is this PR doesn't work. This mini-example produces assertion error:

X_tr = np.random.randn(10000, 23)
y_tr = np.zeros(10000)
y_tr[::2] = 1

model = TFFMClassifier(
        order=2,
        rank=10,
        optimizer=tf.train.AdamOptimizer(learning_rate=0.001),
        n_epochs=50,
        batch_size=1024,
        init_std=0.001,
        reg=0.01,
        input_type='dense',
        seed=42
    )
model.fit(X_tr, y_tr, show_progress=True)
model.save_model('./tmp/model')
model.destroy()
del model

model = TFFMClassifier.load_model(TFFMClassifier, './tmp/model')

or it is supposed to be used in other way?

geffy avatar Mar 04 '18 20:03 geffy

@geffy my bad. I wanted to raise a pull req for the first commit alone! I was mainly working on Regression and didn't check it.

The assertion error occurs because of the hard setting that loss_function cannot be set in case of Classifier.

        assert 'loss_function' not in init_params, """Parameter 'loss_function' is
        not supported for TFFMClassifier. For custom loss function, extend the
        base class TFFMBaseModel."""

One way to overcome is to check if it is a TFFMClassifier instance and handle accordingly. I will post the changes shortly

kman0 avatar Mar 18 '18 23:03 kman0

@geffy Your usage is correct. The primary reason why I wanted this was to train once and evaluate later. If you merge this, you sample code could be included in example notebook.

One addition that I would make is a call to predict function in the example for the sake of completeness

model.predict(X_test)

kman0 avatar Mar 19 '18 00:03 kman0