pytorch-forecasting
pytorch-forecasting copied to clipboard
Unable to use WeightedRandomSampler in TimeSeriesDataset (classification)
- PyTorch-Forecasting version: 0.10.3
- Operating System: linux
Expected behavior
I am using pytorch forecasting TFT model for classification problem, so my target is categorical here. I am trying to generate Timeseries Dataset with WeightedRandomSampler as batch_sampler since dataset is imbalanced (10 % positive classes).
Actual behavior
After generating dataloader i am nigher able to iterate over it not able to train model with it. It gives following Error: TypeError: 'int' object is not iterable
Code to reproduce the problem
from torch.utils.data.sampler import WeightedRandomSampler
targets = dataset['targets'].tolist() ## here targets are float (0.0, 1.0) class_counts = np.bincount(targets) class_weights = 1.0/class_counts class_weights = class_weights.tolist() sample_weights = [class_weights[int(x)] for x in targets]
dataset['targets'] = dataset['targets'].astype('str') dataset_options=dict() dataloader_options = dict(batch_size=128, num_workers=0, batch_sampler=WeightedRandomSampler(weights=sample_weights, num_samples=len(targets), replacement=True))
train_data = TimeSeriesDataSet(dataset, **dataset_options) train_dataloader = train_data.to_dataloader(train=True, **dataloader_options )
for i, x in enumerate(train_dataloader): ## This gives TypeError, also model training shows the same pass