ESEm icon indicating copy to clipboard operation
ESEm copied to clipboard

Save a trained model for future use?

Open EricKeenan opened this issue 3 years ago • 8 comments

Wow! Great project - thanks for your hard work.

Is there a way to save a trained model, load it into a new notebook, and run inference? Apologies if this is documented somewhere.

e.g.

rf_model = rf_model(X_train, Y_train)
rf_model.train()
rf_model.save("rf_model.pb") # <---- Is there anything like this?

and then in a new notebook

rf_model = esem.open("rf_model.pb") # <---- Is there anything like this?

EricKeenan avatar Feb 17 '22 19:02 EricKeenan

Thanks!

Currently there isn't a way to do this, no. It seems a very sensible thing to allow though. I think it would require each type of model (GPFlow, sckit-learn and keras) having save and load methods that the Emulator can then just use.

I will flag this as a feature request and try to implement it when I have a chance but would also be very happy to review pull-requests that implement it.

In the meantime, for the random forest model (only) I think you should be able to just use pickle:

import pickle
pickle.dump(rf_model, 'rf_model.pb')
rf_model2 = pickle.load('rf_model.pb')
rf_model2.predict(...)

duncanwp avatar Feb 18 '22 10:02 duncanwp

Thanks for the reply.

That solution doesn't appear to work in my case

pickle.dump(rf_model, 'rf_model.pkl')

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_14674/1446057301.py in <module>
----> 1 pickle.dump(rf_model, 'rf_model.pkl')

TypeError: file must have a 'write' attribute
b

Likewise with

with open("rf_model.pkl","wb") as f:
    pickle.dump(rf_model, f)

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_14674/851769721.py in <module>
      1 with open("rf_model.pkl","wb") as f:
----> 2     pickle.dump(rf_model, f)

TypeError: can't pickle tensorflow.python._pywrap_tfe.EagerContextThreadLocalData objects
    
  

EricKeenan avatar Feb 18 '22 17:02 EricKeenan

Ah OK, that's because of some of the tensorflow functions on the Emulator... This will need a bit more thought sorry.

duncanwp avatar Feb 18 '22 17:02 duncanwp

@duncanwp It seems like the esem random forest is a simple implementation of the scikit-learn random forest. In the meantime, would you recommend training my emulator with scikit-learn and saving the model using pickle?

EricKeenan avatar Feb 22 '22 23:02 EricKeenan

Yes, you're absolutely right.

It's also possible (but a little convoluted currently) to wrap the loaded model back in to ESEm:

# Save the sklearn model held internally in the esem wrapper
with open("rf_model.pkl","wb") as f:
    pickle.dump(esem_rf_model.model.model, f)

# Load it again
with open("rf_model.pkl","rb") as f:
    skmodel=pickle.load(f)

# Wrap the loaded model

from esem.wrappers import wrap_data
from esem.data_processors import Flatten
from esem.model_adaptor import SKLearnModel
from esem.emulator import Emulator
from sklearn.ensemble import RandomForestRegressor

wrapped_skmodel = SKLearnModel(skmodel)

# Note that we need to reload the data seperately (this is used internally for post-processing)
data = wrap_data(y_train, data_processors=[Flatten()])
loaded_esem_rf_model = Emulator(wrapped_skmodel, x_train, data)

loaded_esem_rf_model.predict(...)

I made a full example here: https://gist.github.com/duncanwp/e4b96690da5bb0bf2505bb94d5450001

duncanwp avatar Feb 23 '22 10:02 duncanwp

Thanks @duncanwp ! I managed to save a RF model. I'll leave this issue open in case you want this as documentation for a feature request. Otherwise, feel free to close. Thanks again!

EricKeenan avatar Mar 01 '22 20:03 EricKeenan