pomegranate
pomegranate copied to clipboard
[BUG] GeneralMixtureModel of ICD of DiscreteDistributions and numeric distributions does not fit
First of all, thanks for all the work on pomegranate!
Possibly related to #402 there seems to be an issue (or a failure of understanding on my part -- also likely) with a GeneralMixtureModel
of IndependentComponentDistribution
s which DiscreteDistribution
with another continuous/numeric distribution (PoissonDistribution
, GammaDistribution
etc.). Initialization of the model works fine, but upon trying to fit
there are complaints about KeyError
s. Specifically I get the error:
Traceback (most recent call last):
File "pomegranate/utils.pyx", line 429, in pomegranate.utils._check_input
ValueError: could not convert string to float: 'A'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "pomegranate/gmm.pyx", line 220, in pomegranate.gmm.GeneralMixtureModel.fit
File "pomegranate/gmm.pyx", line 231, in pomegranate.gmm.GeneralMixtureModel.fit
File "/home/leland/anaconda3/envs/pomegranate/lib/python3.9/site-packages/joblib/parallel.py", line 1043, in __call__
if self.dispatch_one_batch(iterator):
File "/home/leland/anaconda3/envs/pomegranate/lib/python3.9/site-packages/joblib/parallel.py", line 861, in dispatch_one_batch
self._dispatch(tasks)
File "/home/leland/anaconda3/envs/pomegranate/lib/python3.9/site-packages/joblib/parallel.py", line 779, in _dispatch
job = self._backend.apply_async(batch, callback=cb)
File "/home/leland/anaconda3/envs/pomegranate/lib/python3.9/site-packages/joblib/_parallel_backends.py", line 208, in apply_async
result = ImmediateResult(func)
File "/home/leland/anaconda3/envs/pomegranate/lib/python3.9/site-packages/joblib/_parallel_backends.py", line 572, in __init__
self.results = batch()
File "/home/leland/anaconda3/envs/pomegranate/lib/python3.9/site-packages/joblib/parallel.py", line 262, in __call__
return [func(*args, **kwargs)
File "/home/leland/anaconda3/envs/pomegranate/lib/python3.9/site-packages/joblib/parallel.py", line 262, in <listcomp>
return [func(*args, **kwargs)
File "pomegranate/gmm.pyx", line 332, in pomegranate.gmm.GeneralMixtureModel.summarize
File "pomegranate/utils.pyx", line 445, in pomegranate.utils._check_input
KeyError: 0
To Reproduce
The following seems to be a relatively simple reproducer:
import numpy
import pomegranate
d1 = pomegranate.IndependentComponentsDistribution([
pomegranate.PoissonDistribution(4.0),
pomegranate.DiscreteDistribution({'A':0.5, 'B':0.5}),
pomegranate.DiscreteDistribution({'0':0.2, '1':0.2, '2':0.2, '3':0.4})
]
)
d2 = pomegranate.IndependentComponentsDistribution([
pomegranate.PoissonDistribution(1.0),
pomegranate.DiscreteDistribution({'A':0.1, 'B':0.9}),
pomegranate.DiscreteDistribution({'0':0.1, '1':0.6, '2':0.1, '3':0.2})
]
)
gmm = pomegranate.GeneralMixtureModel([d1, d2], weights=numpy.array([0.4,0.6]))
tmp = numpy.array(
[
[0, 'A', '0'],
[5, 'A', '1'],
[4, 'A', '3'],
[5, 'B', '0']
],
dtype=object
)
gmm.fit(tmp)
Any guidance as to why this was not the right way to use the GeneralMixtureModel
, or how to work around this would be greatly appreciated.
This should be resolved by your PR and incorporated into 0.14.7, but LITERALLY every time I go to release a new version, something new is failing in my wheel building script.
@jmschrei Was this ever merged?
Thank you for opening an issue. pomegranate has recently been rewritten from the ground up to use PyTorch instead of Cython (v1.0.0), and so all issues are being closed as they are likely out of date. Please re-open or start a new issue if a related issue is still present in the new codebase.