pomegranate
pomegranate copied to clipboard
[BUG] Incremental training of HMM with GMM crashes
Describe the bug I am attempting to train a Hidden Markov Model with GMM incrementally based on the last 100 datapoints. To do this, I am resetting the start weights before each iteration. After a few iterations, the parameters end up as NaN. I'm probably doing something wrong, so a little guidance will help.
Thanks!
NOTE: I've also tried max_iterations=1 and it runs ~135 times before crashing.
To Reproduce
from pomegranate import *
import numpy as np
np.random.seed(100)
total_points = 1000
nstates = 4
nmix = 2
window = 100
def reset_starts(old):
js = json.loads(old.to_json())
for e in js["edges"]:
if e[0] == js["start_index"]:
e[2] = 0.25
e[3] = 0.25
#print(js)
return HiddenMarkovModel.from_json(json.dumps(js))
# Generate sequence
xs = np.linspace(0, 100, total_points)
sequence = []
prev = 0
for x in range(total_points):
a = 0.7 * prev + 0.2 * np.random.random() + .0003
sequence.append(a)
prev = a
##
# Create model
model = HiddenMarkovModel()
states = []
for i in range(nstates):
states.append(State(GeneralMixtureModel([NormalDistribution(-1,1), NormalDistribution(1,1)])))
model.add_states(states)
for x in range(nstates):
model.add_transition(model.start, states[x], 0.25)
sample = np.random.random_sample(nstates)
for y in range(nstates):
model.add_transition(states[y], states[x], sample[y])
model.bake(verbose=True)
##
# Train incrementally from the previous 100 points
for i in range(window,total_points):
seq = sequence[i-window:i]
model.fit( [seq], max_iterations=100)
print("Iter %d\n%s" %(i-window,model.predict(seq)))
model = reset_starts(model)
Output:
None : 011e197d-dabc-4fb3-a528-52b477b19aad summed to 1.81130714, normalized to 1.0
None : a77297b3-d9c5-4ac1-bac0-3b2323ed1378 summed to 1.68065266, normalized to 1.0
None : cdf60bf0-db0a-473e-ae8f-324d57431a0c summed to 2.52873565, normalized to 1.0
None : 35ab1de7-98a4-4685-a99a-381a2ad7aef6 summed to 0.26911914, normalized to 1.0
Iter 0
[1, 1, 1, 0, 1, 1, 0, 3, 0, 3, 2, 3, 0, 1, 0, 3, 2, 3, 2, 2, 3, 2, 2, 2, 3, 0, 1, 1, 0, 1, 0, 3, 0, 1, 0, 3, 2, 3, 2, 2, 2, 2, 2, 3, 0, 3, 2, 3, 0, 3, 2, 2, 2, 2, 2, 3, 0, 1, 1, 1, 1, 0, 3, 0, 3, 2, 2, 2, 2, 3, 0, 1, 0, 3, 2, 2, 2, 2, 3, 0, 3, 0, 3, 2, 2, 3, 2, 2, 2, 2, 2, 3, 0, 1, 0, 3, 0, 3, 0, 3]
Iter 1
[1, 1, 0, 1, 1, 0, 3, 0, 3, 2, 3, 0, 1, 0, 3, 2, 3, 2, 2, 3, 2, 2, 2, 3, 0, 1, 1, 0, 1, 0, 3, 0, 1, 0, 3, 2, 3, 2, 2, 2, 2, 2, 3, 0, 3, 2, 3, 0, 3, 2, 2, 2, 2, 2, 3, 0, 1, 1, 1, 1, 0, 3, 0, 3, 2, 2, 2, 2, 3, 0, 1, 0, 3, 2, 2, 2, 2, 3, 0, 3, 0, 3, 2, 2, 3, 2, 2, 2, 2, 2, 3, 0, 1, 0, 3, 0, 3, 0, 3, 2]
Iter 2
[1, 0, 1, 1, 0, 3, 0, 3, 2, 3, 0, 1, 0, 3, 2, 3, 2, 2, 3, 2, 2, 2, 3, 0, 1, 1, 0, 1, 0, 3, 0, 1, 0, 3, 2, 3, 2, 2, 2, 2, 2, 3, 0, 3, 2, 3, 0, 3, 2, 2, 2, 2, 2, 3, 0, 1, 1, 1, 1, 0, 3, 0, 3, 2, 2, 2, 2, 3, 0, 1, 0, 3, 2, 2, 2, 2, 3, 0, 3, 0, 3, 2, 2, 3, 2, 2, 2, 2, 2, 3, 0, 1, 0, 3, 0, 3, 0, 3, 2, 2]
Iter 3
[0, 1, 1, 0, 3, 0, 3, 2, 3, 0, 1, 0, 3, 2, 3, 2, 2, 3, 2, 2, 2, 3, 0, 1, 1, 0, 1, 0, 3, 0, 1, 0, 3, 2, 3, 2, 2, 2, 2, 2, 3, 0, 3, 2, 3, 0, 3, 2, 2, 2, 2, 2, 3, 0, 1, 1, 1, 1, 0, 3, 0, 3, 2, 2, 2, 3, 0, 1, 1, 0, 3, 2, 2, 2, 2, 3, 0, 3, 0, 3, 2, 2, 3, 2, 2, 2, 2, 2, 3, 0, 1, 0, 3, 0, 3, 0, 3, 2, 2, 2]
Iter 4
[1, 1, 0, 3, 0, 3, 2, 3, 0, 1, 0, 3, 2, 3, 2, 2, 3, 2, 2, 2, 3, 0, 1, 1, 0, 1, 0, 3, 0, 1, 0, 3, 2, 3, 2, 2, 2, 2, 2, 3, 0, 3, 2, 3, 0, 3, 2, 2, 2, 2, 2, 3, 0, 1, 1, 1, 1, 0, 3, 0, 3, 2, 2, 2, 2, 3, 0, 1, 0, 3, 2, 2, 2, 2, 3, 0, 3, 0, 3, 2, 2, 3, 2, 2, 2, 2, 2, 3, 0, 1, 0, 3, 0, 3, 0, 3, 2, 2, 2, 2]
Iter 5
[1, 0, 3, 0, 3, 2, 3, 0, 1, 0, 3, 2, 3, 2, 2, 3, 2, 2, 2, 3, 0, 1, 1, 0, 1, 0, 3, 0, 1, 0, 3, 2, 3, 2, 2, 2, 2, 2, 3, 0, 3, 2, 3, 0, 3, 2, 2, 2, 2, 2, 3, 0, 1, 1, 1, 1, 0, 3, 0, 3, 2, 2, 2, 2, 3, 0, 1, 0, 3, 2, 2, 2, 2, 3, 0, 3, 0, 3, 2, 2, 3, 2, 2, 2, 2, 2, 3, 0, 1, 0, 3, 0, 3, 0, 3, 2, 2, 2, 2, 2]
Iter 6
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
Traceback (most recent call last):
File "broken.py", line 46, in
That's weird. I'll try to look at it soon. Do you see the same thing when you use a normal distribution instead of a mixture?
It also dies with normal distributions, although when training max_iterations=1, I now see a ZeroDivisionError. Printing out the model, I see that this time one of the stdev converged to 0. (This didn't happen with the GMM)
Iter 253
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
{'class': 'HiddenMarkovModel', 'name': 'None', 'start': {'class': 'State', 'distribution': None, 'name': 'None-start', 'weight': 1.0}, 'end': {'class': 'State', 'distribution': None, 'name': 'None-end', 'weight': 1.0}, 'states': [{'class': 'State', 'distribution': {'class': 'Distribution', 'name': 'NormalDistribution', 'parameters': [0.2662727171605095, 0.061107540771398065], 'frozen': False}, 'name': '1fe12462-6437-4888-bfa2-a7f2bdd1ca4b', 'weight': 1.0}, {'class': 'State', 'distribution': {'class': 'Distribution', 'name': 'NormalDistribution', 'parameters': [0.36716326683682204, 0.0711085093316527], 'frozen': False}, 'name': '82dfd2be-f57e-44d1-b61a-3fdb03d45f3c', 'weight': 1.0}, {'class': 'State', 'distribution': {'class': 'Distribution', 'name': 'NormalDistribution', 'parameters': [0.3997194162239017, 0.0], 'frozen': False}, 'name': 'de26c027-bcb7-4216-9e31-d4b4b223178a', 'weight': 1.0}, {'class': 'State', 'distribution': {'class': 'Distribution', 'name': 'NormalDistribution', 'parameters': [0.3114115376978282, 0.00019355405244530423], 'frozen': False}, 'name': 'f1c7e842-4d03-4ebd-863e-cc972ddf298d', 'weight': 1.0}, {'class': 'State', 'distribution': None, 'name': 'None-start', 'weight': 1.0}, {'class': 'State', 'distribution': None, 'name': 'None-end', 'weight': 1.0}], 'end_index': 5, 'start_index': 4, 'silent_index': 4, 'edges': [[0, 2, 0.03030301088373653, 0.22219852722772315, None], [0, 1, 3.214115095853174e-101, 0.027617184503319403, None], [0, 3, 0.0, 0.016469626666622994, None], [0, 0, 0.9696969891162635, 0.002833802567491883, None], [1, 2, 3.473786840805535e-129, 0.38223726411047276, None], [1, 1, 0.9717355600928262, 0.8193198019179616, None], [1, 3, 0.02826443990717378, 0.31213882462326004, None], [1, 0, 2.0737210449513234e-177, 0.1669567665617232, None], [2, 2, 0.0, 0.027732277525769122, None], [2, 1, 1.0, 0.3051261201740513, None], [2, 3, 0.0, 0.9543877717509077, None], [2, 0, 0.0, 0.5240609701557668, None], [3, 2, 0.0, 0.953250912037412, None], [3, 1, 0.44030526978656925, 0.5784698159078305, None], [3, 3, 0.0, 0.18407287867767852, None], [3, 0, 0.5596947302134309, 0.8129420468290551, None], [4, 2, 0.25, 0.25, None], [4, 1, 0.25, 0.25, None], [4, 3, 0.25, 0.25, None], [4, 0, 0.25, 0.25, None]], 'distribution ties': []}
Traceback (most recent call last):
File "broken.py", line 50, in
Thank you for opening an issue. pomegranate has recently been rewritten from the ground up to use PyTorch instead of Cython (v1.0.0), and so all issues are being closed as they are likely out of date. Please re-open or start a new issue if a related issue is still present in the new codebase.