pytorch-forecasting icon indicating copy to clipboard operation
pytorch-forecasting copied to clipboard

Unable to use WeightedRandomSampler in TimeSeriesDataset (classification)

Open adityamtr opened this issue 2 years ago • 0 comments

  • PyTorch-Forecasting version: 0.10.3
  • Operating System: linux

Expected behavior

I am using pytorch forecasting TFT model for classification problem, so my target is categorical here. I am trying to generate Timeseries Dataset with WeightedRandomSampler as batch_sampler since dataset is imbalanced (10 % positive classes).

Actual behavior

After generating dataloader i am nigher able to iterate over it not able to train model with it. It gives following Error: TypeError: 'int' object is not iterable

Code to reproduce the problem

from torch.utils.data.sampler import WeightedRandomSampler

targets = dataset['targets'].tolist() ## here targets are float (0.0, 1.0) class_counts = np.bincount(targets) class_weights = 1.0/class_counts class_weights = class_weights.tolist() sample_weights = [class_weights[int(x)] for x in targets]

dataset['targets'] = dataset['targets'].astype('str') dataset_options=dict() dataloader_options = dict(batch_size=128, num_workers=0, batch_sampler=WeightedRandomSampler(weights=sample_weights, num_samples=len(targets), replacement=True))

train_data = TimeSeriesDataSet(dataset, **dataset_options) train_dataloader = train_data.to_dataloader(train=True, **dataloader_options )

for i, x in enumerate(train_dataloader): ## This gives TypeError, also model training shows the same pass

adityamtr avatar Oct 03 '23 14:10 adityamtr