pomegranate icon indicating copy to clipboard operation
pomegranate copied to clipboard

[BUG] GMM.sample(): probabilities do not sum to 1

Open profPlum opened this issue 2 years ago • 0 comments

Describe the bug When fitting a GMM to a dataset using Bernoulli and Normal distributions, and then calling sample() it will through an error: probabilities do not sum to 1

To Reproduce

import pandas as pd
import numpy as np
import pomegranate as pgm
from pomegranate import *
# dataset from here: https://www.kaggle.com/datasets/alexteboul/diabetes-health-indicators-dataset
data = pd.read_csv(dir + 'diabetes_binary_health_indicators_BRFSS2015.csv')

# verified to work! 4/3/22
# prepare df for PGMs which require discrete data only!
def make_df_categorical(data, max_cols=9, required_cols=[], max_vals_for_categorical=15):
  unique_values = {col: len(data[col].unique()) for col in data}
  print(unique_values)
  categorical_candidates = [col for col in data if len(data[col].unique()) < max_vals_for_categorical]
  categorical_candidates = list(set(categorical_candidates) - set(required_cols))
  categorical_candidates = list(np.random.choice(categorical_candidates, max_cols-len(required_cols), replace=False)) + list(required_cols)
  data_down_sample = data[categorical_candidates].apply(lambda x: x.factorize()[0])
  return data_down_sample

def fit_GMM(data, model=None, n_components=5, default_dist=NormalDistribution):
  is_binom = lambda x: np.all(np.isin(x, [1, 0, True, False]))
  distributions = [(BernoulliDistribution if is_binom(data[x]) else default_dist) for x in data.columns]
  GMM = GeneralMixtureModel.from_samples(distributions, n_components=n_components, X=data)
  return GMM
GMM = fit_GMM(data_down_sample)#, default_dist=DiscreteDistribution)
GMM.sample()

Most of the time this code will say: probabilities do not sum to 1. It appears this only happens when I need to use the default_distribution (which is Gaussian), it will not happen if the down sampled columns are all binary for example. NOTE: dataset only has binary and integer data.

profPlum avatar May 10 '22 19:05 profPlum