pycave icon indicating copy to clipboard operation
pycave copied to clipboard

GMM Training with Mini-Batch

Open hashim19 opened this issue 6 months ago • 0 comments

Hi, first of all thank you for the amazing repository.

I am trying to do mini-batch training of GMM. After going over #51 #19 and #7, I realized that I need to create my own dataset loader. Here is an sample of my custom dataset loader (Each file of my dataset is stored in .pkl files, so I wrote a pkl_dataset class).

` Dataset Class

class PKL_dataset(Dataset):

def __init__(self, dataset_pth, data_label):
    self.data_dir = dataset_pth
    self.files_ls = os.listdir(dataset_pth)
    self.len = len(self.files_ls)
    self.label = data_label

def __len__(self):
    return self.len

def transform(self, data):

    if data.shape[0] < 2000:

        return np.pad(data, [(0, 2000 - data.shape[0]), (0,0)], 'mean')

    else:

        return data[:2000]

def __getitem__(self, idx):

    file_path = os.path.join(self.data_dir, self.files_ls[idx])
    
    pkl_data = open_pkl(file_path)
    transformed_pkl_data = self.transform(pkl_data)
    
    return transformed_pkl_data`

Now I am calling the Gaussian Mixture class like this,

gmm = GaussianMixture(num_components=ncomp, covariance_type='diag', batch_size=32, covariance_regularization=0.1, init_strategy='kmeans++', trainer_params=dict(accelerator='gpu', devices=1, max_epochs=100))

and passing the my dataset to the fit function like this,

history = gmm.fit(pkl_dataloader)

It gives me the following error,

Traceback (most recent call last): File "asvspoof2021_baseline.py", line 65, in <module> gmm_bona = train_gmm(data_label=data_labels[0], features=features, File "/home/hashim/PhD/Audio_Spoof_Detection/Baseline-CQCC-GMM/python/gmm.py", line 218, in train_gmm history = gmm.fit(pkl_dataloader) File "/home/hashim/PhD/AsvSpoof2021/asvspoof_venv/lib/python3.8/site-packages/pycave/bayes/gmm/estimator.py", line 128, in fit num_features = len(data[0]) TypeError: 'DataLoader' object is not subscriptable

It looks like the fit routine does not accept data as a dataloader.

However, if I do not use a dataloader, the training gets killed because of the memory issues.

Here is a snapshot of my system, My GPU is NVIDIA RTX A2000 12GB system

hashim19 avatar Jan 02 '24 23:01 hashim19