Deep-Learning-TensorFlow icon indicating copy to clipboard operation
Deep-Learning-TensorFlow copied to clipboard

Problem about using trained model to predict

Open yangpanroy opened this issue 8 years ago • 0 comments

I trained my own network model with custom data and supervised sDAE. When I finished training, I got good results on the test set. So I want to use a well-trained model in the later process to predict large amounts of data. But according to the documentation, I use the command line python run_stacked_autoencoder_supervised.py --restore_previous_model sdae --do_pretrain True --dae_num_epochs 0 --finetune_num_epochs 0 , only got a 24% accuracy.

So I checked the code carefully and found in /yadlt/core/supervised_model.py that you're going to build the model, initialize tensorflow stuff, train the model, and save the model in the fit () method.

Your code works well in pretraining - finetuning - prediction process. But if I want to directly restore the already trained model to predict, I will encounter the above situation.

So I made a change in the score () method, joined the building model, initialize tensorflow stuff, the specific code is as follows:

def score(self, test_X, test_Y):
        if len(test_Y.shape) != 1:
            num_classes = test_Y.shape[1]
        else:
            raise Exception("Please convert the labels with one-hot encoding.")
        with self.tf_graph.as_default():
            # Build model
            self.build_model(test_X.shape[1], num_classes)
            with tf.Session() as self.tf_session:
                # Initialize tf stuff
                summary_objs = tf_utils.init_tf_ops(self.tf_session)
                self.tf_merged_summaries = summary_objs[0]
                self.tf_summary_writer = summary_objs[1]
                self.tf_saver = summary_objs[2]
                # Restore the model
                self.tf_saver.restore(self.tf_session, self.model_path)
                feed = {
                    self.input_data: test_X,
                    self.input_labels: test_Y,
                    self.keep_prob: 1
                }
                return self.accuracy.eval(feed)

Then I can directly restore the model to predict. Hope this helps others.

yangpanroy avatar Dec 19 '17 03:12 yangpanroy