multi-class-text-classification-cnn icon indicating copy to clipboard operation
multi-class-text-classification-cnn copied to clipboard

How to get class for given text input?

Open GraphGrailAi opened this issue 8 years ago • 7 comments

Could you provide some example code how to get class output for given text input?

I was able to get all code working with ./data/small_samples.json but output is accuracy percent - i need exact class name for every text

GraphGrailAi avatar Jan 22 '17 20:01 GraphGrailAi

  1. When you run train.py, the labels.json will be saved. labels.json is a list with all labels.

  2. When you run predict.py, take a look at line 63, if you print batch_predictions, it is a list with numbers, and each number is the index of labels.json.

For example, I printed batch_predictions: [6 6 6 4 6 4 3 6 4 6 3 4 1 2 3 2 3 2 4 0 4 4 4 3 4 6 4 4 1 4 0 6 2 4 4 6 3 3 1 3 4 4 3 4 3 6 3 6 6 6] the first number in batch_predictions is 6, so the corresponding label for number 6 is labels.json[6], mortgage.

Hope this will help you find the corresponding labels.

jiegzhan avatar Jan 25 '17 18:01 jiegzhan

Thanks for answer, i have done guess myself, and i tested that list of prediction index labels is all_predictions (not batch_predictions). When printed batch_predictions it return empty list []

predict.py from 63 line:

			for x_test_batch in batches:
				batch_predictions = sess.run(predictions, {input_x: x_test_batch, dropout_keep_prob: 1.0})
				all_predictions = np.concatenate([all_predictions, batch_predictions])

	if y_test is not None:
		y_test = np.argmax(y_test, axis=1)
		correct_predictions = sum(all_predictions == y_test)
		logging.critical('The batch_predictions is: {}'.format(batch_predictions))
		logging.critical('The all_predictions is: {}'.format(all_predictions))
		logging.critical('The y_test is: {}'.format(y_test)) # y_test is label list in labels.json
		logging.critical('The correct_predictions is: {}'.format(correct_predictions))
		logging.critical('The accuracy is: {}'.format(correct_predictions / float(len(y_test))))

output:

d:\Django\multi-class-text-classification-cnn>python predict.py ./trained_model_1485334811/ ./data/small_samples_my.json
CRITICAL:root:Loaded the trained model: d:\Django\multi-class-text-classification-cnn\trained_model_1485334811\checkpoints\model-300
INFO:root:The number of x_test: 5
INFO:root:The number of y_test: 5
CRITICAL:root:The batch_predictions is: []
CRITICAL:root:The all_predictions is: [ 10.  10.  10.   8.  10.]
CRITICAL:root:The y_test is: [10  6 10  8  9]
CRITICAL:root:The correct_predictions is: 3
CRITICAL:root:The accuracy is: 0.6

GraphGrailAi avatar Jan 25 '17 18:01 GraphGrailAi

Actually, for each batch, there will be a batch_predictions list, which will be appended to all_predictions.

Eventually, if you have 100 test examples, all predictions will have 100 numbers. Each number is the corresponding index in labels.json. You can get the actual label by referring to labels.json[index].

jiegzhan avatar Jan 25 '17 18:01 jiegzhan

Thank you! i will create another issue for other question

GraphGrailAi avatar Jan 26 '17 09:01 GraphGrailAi

Has anyone figured this out? I need to predict score for each of the class it predicts. Example: if the text belongs to a single class, I need to know the probability of the text belonging to that class. Any help would keep me moving.

akki2825 avatar Apr 28 '17 09:04 akki2825

@akki2825 Were you able to find a solution for predicting the probability of the classified text ?

@GraphGrailAi did you mean that you got accuracy for each class predict or the accuracy of the whole model ?

vijaysaimutyala avatar Jan 29 '18 06:01 vijaysaimutyala

Has anyone got a solution on printing the probability of each sentence prediction?

Chinguun8 avatar Sep 10 '19 19:09 Chinguun8