Deep-temporal-clustering
Deep-temporal-clustering copied to clipboard
Whether the centroids have been updated?
When the code is run to cluster in train_ClusterNET() function and compute Q/P in forward(), the similarity computed based on self.centroids which are always stay the same. Will the centroids be updated? I'm also trying to add steps to update in training_function but the loss seems not to descend. Can you help me? Thanks.
def forward(self, x):
z, x_reconstr = self.tae(x)
z_np = z.detach().cpu()
similarity = compute_similarity(
z, self.centroids, similarity=self.similarity
)
## Q (batch_size , n_clusters)
Q = torch.pow((1 + (similarity / self.alpha_)), -(self.alpha_ + 1) / 2)
sum_columns_Q = torch.sum(Q, dim=1).view(-1, 1)
Q = Q / sum_columns_Q
## P : ground truth distribution
P = torch.pow(Q, 2) / torch.sum(Q, dim=0).view(1, -1)
sum_columns_P = torch.sum(P, dim=1).view(-1, 1)
P = P / sum_columns_P
return z, x_reconstr, Q, P
Excuse me, have you solved this problem?