deeplearning4j-examples
deeplearning4j-examples copied to clipboard
PredictGender test investigate
Looks like predictions are very skewed towards one class even though training metrics look okay. Assign me.
Also adding note on reworking example to better align to best practices
The reason it predicts only one class is because it needs tuning. Keeping issue open for reworking to better align with best practices.
The maxNameLength = 88; in the code is the root of the issue. It affects the number of units in the network, and most of these have a 0 input value (because of padding). In the data I used, the average name length is 10 characters. If we cap input name to have varying lengths == 88, 50, 25, and 10, you can see the confusion matrix is affected (88 causes the network to predict only one category because most of the input values are 0). The best predictions came from a maxNameLength of 10, which was the average in the data I used. To cap the input name length, I set maxNameLength and two lines of code in the method nameToBinary() in class GenderRecordReader. Here are the two lines (one added, one changed): int nameLength = maxLengthName < name.length() ? maxLengthName: name.length(); //rgw cap name length for (int j = 0; j < nameLength; j++) {//rgw name.length() maxNameLength = 88; =========================Confusion Matrix========================= 0 1
0 13142 | 0 = 0
0 13142 | 1 = 1
Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
maxNameLength = 50; =========================Confusion Matrix========================= 0 1
9907 3235 | 0 = 0 2910 10232 | 1 = 1
Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
maxNameLength = 25; =========================Confusion Matrix========================= 0 1
10132 3010 | 0 = 0 3477 9665 | 1 = 1
Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
maxNameLength = 10; //average name length in the data I used =========================Confusion Matrix========================= 0 1
10784 2358 | 0 = 0 3512 9630 | 1 = 1
Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times