tsai
tsai copied to clipboard
Incorrect Behavior of `get_splits()` with `balance=True` and `test_size` Specified
Description:
Hello everyone,
I've encountered an issue with the get_splits() function when using balance=True and specifying the test_size. It appears that the function does not behave correctly in this scenario.
Reproduction Steps:
os : Linux-5.4.0-100-generic-x86_64-with-glibc2.31 python : 3.10.10 tsai : 0.3.6 fastai : 2.7.12 fastcore : 1.5.29 torch : 2.0.0
Here's a code snippet to reproduce the issue:
import numpy as np
from tsai.all import *
# Define the target array
y = np.array([0]*10 + [1]*10 + [2]*30)
# Call `get_splits()` with `balance=True` and `valid_size` specified
splits = get_splits(y, valid_size=0.2, balance=True, random_state=2023)
print(np.unique(y[splits[0]], return_counts=True))
It worked correctly:
(array([0, 1, 2]), array([24, 24, 24]))
However, when the test_size parameter is also specified, the behavior is incorrect:
y = np.array([0]*10 + [1]*10 + [2]*30)
splits = get_splits(y, valid_size=0.2, test_size=0.2, balance=True, random_state=2023)
print(np.unique(y[splits[0]], return_counts=True))
The output is:
(array([0, 1, 2]), array([8, 7, 27]))
The output should still contain the same number of items per class when test_size is specified. However, it seems that something is going wrong when the test_size is used in conjunction with balance=True.
I would appreciate any help in resolving this issue. Thank you!