pycave icon indicating copy to clipboard operation
pycave copied to clipboard

Mini-batch training on GMM

Open Daisy-GENG opened this issue 2 years ago • 3 comments

Hi,

I want to implement mini-batching training on GMM as discussed in #7 . However, I am little bit confused by the code gmm.reset_parameters(torch.Tensor(fvectors[:500].astype(np.float32))). I am not sure whether it is related to my version of pycave, or maybe my understanding to the code in #7 is wrong. My code doesn't work.

My code are as follows:

from pycave.bayes.gmm import GaussianMixture as GM
from dataloader.gmm_dataset import gmm_dataset

train_gmm_dataset = gmm_dataset(data_path)
train_dataset_loader = torch.utils.data.DataLoader(dataset=train_gmm_dataset,
                                                        batch_size=train_dataloader_config["batch_size"],
                                                        shuffle=train_dataloader_config["shuffle"],
                                                        num_workers=train_dataloader_config["num_workers"])

for i, data in enumerate(train_dataset_loader):  # data:[1, pt, 3]
    data = torch.squeeze(data, 0)
    gmm = GM(num_components=2, covariance_type="diag", init_strategy="kmeans")
    gmm.model_.reset_parameters(data)  
    history = gmm.fit(train_dataset_loader)

And the error is:

`GaussianMixture` has not been fitted yet

Thank you so much!

Best regards, Daisy

Daisy-GENG avatar May 24 '22 16:05 Daisy-GENG

Issue #7 still referred to PyCave version 2. In PyCave v3, you don't need to call gmm.model_.reset_parameters: the model_ attribute will only be available once fit has returned without error.

I believe that this should be the line that causes your error.

borchero avatar May 24 '22 16:05 borchero

So is there a similar way to implement batch training in PyCave version 3 using dataloader? My whole dataset is large, so I cannot load all the data into the memory once.

Thank you so much!

Best regards, Daisy

Daisy-GENG avatar May 24 '22 16:05 Daisy-GENG

Ah, sorry! Yes, you can simply set the batch size when initializing the GMM. In your case, you might, for example, use:

gmm = GM(..., batch_size=8192)

This will automatically take care to load data in batches, both for initialization and GMM training. Note that you might be better off with init_strategy='kmeans++' since kmeans is quite costly to run. You'll need PyCave 3.1.3 for that, though (there was a bug for kmeans++ initialization before).

borchero avatar May 24 '22 16:05 borchero