tensorrec
tensorrec copied to clipboard
Getting error when loading on flask | Session issue
I am getting session management issue when try to use it in flask application, can you please help me
ValueError: Tensor("TensorSliceDataset:0", shape=(), dtype=variant) must be from the same graph as Tensor("Iterator:0", shape=(), dtype=resource).
Sample code structure
from io import StringIO
import flask
import pandas as pd
import tensorrec as tr
import tensorflow as tf
class ScoringService(object):
model = None
model_path = 'model path'
@classmethod
def get_model(cls):
if cls.model == None:
cls.model = tr.TensorRec.load_model(model_path)
return cls.model
@classmethod
def get_reco(cls, model,use_ft, ite_ft):
tf.reset_default_graph()
predictions = model.predict(use_ft, ite_ft)
return predictions
@classmethod
def predict(cls, input):
clf = cls.get_model()
n_reco_test = cls.get_reco(clf, input,use_ft, ite_ft)
return n_reco_test
# The flask app for serving predictions
app = flask.Flask(__name__)
@app.route('/invocations', methods=['POST'])
def transformation():
data = None
# Convert from CSV to pandas
if flask.request.content_type == 'text/csv':
data = flask.request.data.decode('utf-8')
s = StringIO(data)
data = pd.read_csv(s, header=None)
else:
return flask.Response(response='This predictor only supports CSV data', status=415, mimetype='text/plain')
# Do the prediction
predictions = ScoringService.predict(data)
return flask.Response(response=predictions, status=200,mimetype='text/csv')
Hey @jaiswalvineet ! Thanks for reporting this. What is the purpose of the
tf.reset_default_graph()
inside of get_reco()
? This is likely the culprit. Is the call to model.predict()
the line that is failing?
Hey @jfkirk , thanks for replying, So i just want to override the graphs so I will not face this error, but it does not work. I trained model and stored pickle file and getting this error when want to predict using pickle file, Yes, model.predict is failing to predict
"POST /invocations HTTP/1.1" 500 291 "-" "AHC/2.0" File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/op_def_library.py", line 350, in _apply_op_helper g = ops._get_graph_from_inputs(_Flatten(keywords.values())) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 5637, in _get_graph_from_inputs _assert_same_graph(original_graph_element, graph_element) File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 5573, in _assert_same_graph original_item)
Seems tensorrec do something about session, which override default one, here the code says that tensorrec predictor do some magic
Traceback (most recent call last):
File "/usr/local/lib/python3.5/dist-packages/flask/app.py", line 2292, in wsgi_app
response = self.full_dispatch_request()
File "/usr/local/lib/python3.5/dist-packages/flask/app.py", line 1815, in full_dispatch_request
rv = self.handle_user_exception(e)
File "/usr/local/lib/python3.5/dist-packages/flask/app.py", line 1718, in handle_user_exception
reraise(exc_type, exc_value, tb)
File "/usr/local/lib/python3.5/dist-packages/flask/_compat.py", line 35, in reraise
raise value
File "/usr/local/lib/python3.5/dist-packages/flask/app.py", line 1813, in full_dispatch_request
rv = self.dispatch_request()
File "/usr/local/lib/python3.5/dist-packages/flask/app.py", line 1799, in dispatch_request
return self.view_functionsrule.endpoint
File "/opt/ml/code/predictor.py", line 143, in transformation
predictions = ScoringService.predict(data)
File "/opt/ml/code/predictor.py", line 105, in predict
n_reco_test = cls.get_reco(clf, input, n_rec, user_features, ite_ft, u_map, i_map, i_order)
File "/opt/ml/code/predictor.py", line 71, in get_reco
predictions = model.predict(use_ft, ite_ft)
File "/usr/local/lib/python3.5/dist-packages/tensorrec/tensorrec.py", line 663, in predict
item_features=item_features)
File "/usr/local/lib/python3.5/dist-packages/tensorrec/tensorrec.py", line 256, in _create_datasets_and_initializers
for dataset in user_features_datasets]
File "/usr/local/lib/python3.5/dist-packages/tensorrec/tensorrec.py", line 256, in
Hey @jaiswalvineet -- Calling tf.reset_default_graph()
after loading the model will blow away the graph and create a new one. I've reproduced locally and this will then cause the error you're seeing when you call predict() because the reset graph does not have the model.
If you remove the call to tf.reset_default_graph()
do you see the same error?
I removed the tf.reset_default_graph() bus still same error, seems the flask app does persist the graph, it's strange that its working if I directly run it but does not work if I call it from flask app, if its running one time for you from flask then refresh it on browser, you will get the error, if not then please share your flask app ...thanks in advance
@jaiswalvineet I got the same problem. The tf graph is not thread safe and needs to be globally defined and always reused when doing the prediction. Here is how I solved it (actually solution is coming from https://github.com/keras-team/keras/issues/2397#issuecomment-306687500):
I inserted a global model variable, loaded the model and stored the graph after model loading:
import tensorflow as tf
from tensorrec import TensorRec
model = TensorRec.load_model(directory_path=model_path)
graph = tf.get_default_graph()
and then in every call, I used the same graph as loaded in graph
with a new session:
with graph.as_default():
session = tf.Session()
with session.as_default():
# Now do the prediction
predictions = ScoringService.predict(data)