tsai
tsai copied to clipboard
Split plot refinement (accurate labeling)
I used get_splits(train_labels, valid_size=.2, stratify=True, random_state=23, shuffle=True).
In this case, I would expect the second label to be "Valid" instead of "Test". I'm not specifying any test split and by default it is zero. I am, however, specifying a valid_size, which is why the labels should be "Train" and "Validation", not "Test", but it looks like this:
I made a small change to plot_splits() to change the behavior to my needs. For some reason, it assumed that one split, i.e., two lists, always means (train, test). I realized that validation data is not optional in the split generation function, so I assumed it as mandatory. So the combination of only "Train" and "Test" is not possible.
Behavior now:
- specify
valid_sizeandtest_size-> three lists and labels "Train", "Valid" and "Test": - specify
valid_size-> two lists and labels "Train" and "Valid": - specify
test_size-> three lists and labels "Train", "Valid" and "Test" (since valid data is mandatory): - specify
test_sizeand setvalid_sizeto 0 -> two lists and labels "Valid" and "Test" (in this case valid == train):
This is a reasonable labeling behavior in my opinion (under the assumption that validation data is mandatory).
I also set a default value for the new parameter in plot_splits() so that it doesn't cause any compatibility issues.
Check out this pull request on ![]()
See visual diffs & provide feedback on Jupyter Notebooks.
Powered by ReviewNB