pomegranate icon indicating copy to clipboard operation
pomegranate copied to clipboard

[BUG] Incremental training of HMM with GMM crashes

Open shawrkbait opened this issue 4 years ago • 2 comments

Describe the bug I am attempting to train a Hidden Markov Model with GMM incrementally based on the last 100 datapoints. To do this, I am resetting the start weights before each iteration. After a few iterations, the parameters end up as NaN. I'm probably doing something wrong, so a little guidance will help.

Thanks!

NOTE: I've also tried max_iterations=1 and it runs ~135 times before crashing.

To Reproduce

from pomegranate import *
import numpy as np

np.random.seed(100)

total_points = 1000
nstates = 4
nmix = 2
window = 100

def reset_starts(old):
  js = json.loads(old.to_json())
  for e in js["edges"]:
    if e[0] == js["start_index"]:
      e[2] = 0.25
      e[3] = 0.25
  #print(js)
  return HiddenMarkovModel.from_json(json.dumps(js))

# Generate sequence
xs = np.linspace(0, 100, total_points)
sequence = []
prev = 0
for x in range(total_points):
  a = 0.7 * prev + 0.2 * np.random.random() + .0003
  sequence.append(a)
  prev = a
##

# Create model
model = HiddenMarkovModel()
states = []
for i in range(nstates):
  states.append(State(GeneralMixtureModel([NormalDistribution(-1,1), NormalDistribution(1,1)])))

model.add_states(states)
for x in range(nstates):
  model.add_transition(model.start, states[x], 0.25)
  sample = np.random.random_sample(nstates)
  for y in range(nstates):
    model.add_transition(states[y], states[x], sample[y])
model.bake(verbose=True)
##

# Train incrementally from the previous 100 points
for i in range(window,total_points):
  seq = sequence[i-window:i]
  model.fit( [seq], max_iterations=100)
  print("Iter %d\n%s" %(i-window,model.predict(seq)))
  model = reset_starts(model)

Output: None : 011e197d-dabc-4fb3-a528-52b477b19aad summed to 1.81130714, normalized to 1.0 None : a77297b3-d9c5-4ac1-bac0-3b2323ed1378 summed to 1.68065266, normalized to 1.0 None : cdf60bf0-db0a-473e-ae8f-324d57431a0c summed to 2.52873565, normalized to 1.0 None : 35ab1de7-98a4-4685-a99a-381a2ad7aef6 summed to 0.26911914, normalized to 1.0 Iter 0 [1, 1, 1, 0, 1, 1, 0, 3, 0, 3, 2, 3, 0, 1, 0, 3, 2, 3, 2, 2, 3, 2, 2, 2, 3, 0, 1, 1, 0, 1, 0, 3, 0, 1, 0, 3, 2, 3, 2, 2, 2, 2, 2, 3, 0, 3, 2, 3, 0, 3, 2, 2, 2, 2, 2, 3, 0, 1, 1, 1, 1, 0, 3, 0, 3, 2, 2, 2, 2, 3, 0, 1, 0, 3, 2, 2, 2, 2, 3, 0, 3, 0, 3, 2, 2, 3, 2, 2, 2, 2, 2, 3, 0, 1, 0, 3, 0, 3, 0, 3] Iter 1 [1, 1, 0, 1, 1, 0, 3, 0, 3, 2, 3, 0, 1, 0, 3, 2, 3, 2, 2, 3, 2, 2, 2, 3, 0, 1, 1, 0, 1, 0, 3, 0, 1, 0, 3, 2, 3, 2, 2, 2, 2, 2, 3, 0, 3, 2, 3, 0, 3, 2, 2, 2, 2, 2, 3, 0, 1, 1, 1, 1, 0, 3, 0, 3, 2, 2, 2, 2, 3, 0, 1, 0, 3, 2, 2, 2, 2, 3, 0, 3, 0, 3, 2, 2, 3, 2, 2, 2, 2, 2, 3, 0, 1, 0, 3, 0, 3, 0, 3, 2] Iter 2 [1, 0, 1, 1, 0, 3, 0, 3, 2, 3, 0, 1, 0, 3, 2, 3, 2, 2, 3, 2, 2, 2, 3, 0, 1, 1, 0, 1, 0, 3, 0, 1, 0, 3, 2, 3, 2, 2, 2, 2, 2, 3, 0, 3, 2, 3, 0, 3, 2, 2, 2, 2, 2, 3, 0, 1, 1, 1, 1, 0, 3, 0, 3, 2, 2, 2, 2, 3, 0, 1, 0, 3, 2, 2, 2, 2, 3, 0, 3, 0, 3, 2, 2, 3, 2, 2, 2, 2, 2, 3, 0, 1, 0, 3, 0, 3, 0, 3, 2, 2] Iter 3 [0, 1, 1, 0, 3, 0, 3, 2, 3, 0, 1, 0, 3, 2, 3, 2, 2, 3, 2, 2, 2, 3, 0, 1, 1, 0, 1, 0, 3, 0, 1, 0, 3, 2, 3, 2, 2, 2, 2, 2, 3, 0, 3, 2, 3, 0, 3, 2, 2, 2, 2, 2, 3, 0, 1, 1, 1, 1, 0, 3, 0, 3, 2, 2, 2, 3, 0, 1, 1, 0, 3, 2, 2, 2, 2, 3, 0, 3, 0, 3, 2, 2, 3, 2, 2, 2, 2, 2, 3, 0, 1, 0, 3, 0, 3, 0, 3, 2, 2, 2] Iter 4 [1, 1, 0, 3, 0, 3, 2, 3, 0, 1, 0, 3, 2, 3, 2, 2, 3, 2, 2, 2, 3, 0, 1, 1, 0, 1, 0, 3, 0, 1, 0, 3, 2, 3, 2, 2, 2, 2, 2, 3, 0, 3, 2, 3, 0, 3, 2, 2, 2, 2, 2, 3, 0, 1, 1, 1, 1, 0, 3, 0, 3, 2, 2, 2, 2, 3, 0, 1, 0, 3, 2, 2, 2, 2, 3, 0, 3, 0, 3, 2, 2, 3, 2, 2, 2, 2, 2, 3, 0, 1, 0, 3, 0, 3, 0, 3, 2, 2, 2, 2] Iter 5 [1, 0, 3, 0, 3, 2, 3, 0, 1, 0, 3, 2, 3, 2, 2, 3, 2, 2, 2, 3, 0, 1, 1, 0, 1, 0, 3, 0, 1, 0, 3, 2, 3, 2, 2, 2, 2, 2, 3, 0, 3, 2, 3, 0, 3, 2, 2, 2, 2, 2, 3, 0, 1, 1, 1, 1, 0, 3, 0, 3, 2, 2, 2, 2, 3, 0, 1, 0, 3, 2, 2, 2, 2, 3, 0, 3, 0, 3, 2, 2, 3, 2, 2, 2, 2, 2, 3, 0, 1, 0, 3, 0, 3, 0, 3, 2, 2, 2, 2, 2] Iter 6 [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1] Traceback (most recent call last): File "broken.py", line 46, in model = reset_starts(model) File "broken.py", line 18, in reset_starts return HiddenMarkovModel.from_json(json.dumps(js)) File "pomegranate/hmm.pyx", line 3283, in pomegranate.hmm.HiddenMarkovModel.from_json File "pomegranate/base.pyx", line 480, in pomegranate.base.State.from_json File "pomegranate/gmm.pyx", line 416, in pomegranate.gmm.GeneralMixtureModel.from_json File "pomegranate/distributions/distributions.pyx", line 337, in pomegranate.distributions.distributions.Distribution.from_json File "", line 1, in NameError: name 'nan' is not defined

shawrkbait avatar Jun 13 '20 21:06 shawrkbait

That's weird. I'll try to look at it soon. Do you see the same thing when you use a normal distribution instead of a mixture?

jmschrei avatar Jun 16 '20 04:06 jmschrei

It also dies with normal distributions, although when training max_iterations=1, I now see a ZeroDivisionError. Printing out the model, I see that this time one of the stdev converged to 0. (This didn't happen with the GMM)

Iter 253 [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1] {'class': 'HiddenMarkovModel', 'name': 'None', 'start': {'class': 'State', 'distribution': None, 'name': 'None-start', 'weight': 1.0}, 'end': {'class': 'State', 'distribution': None, 'name': 'None-end', 'weight': 1.0}, 'states': [{'class': 'State', 'distribution': {'class': 'Distribution', 'name': 'NormalDistribution', 'parameters': [0.2662727171605095, 0.061107540771398065], 'frozen': False}, 'name': '1fe12462-6437-4888-bfa2-a7f2bdd1ca4b', 'weight': 1.0}, {'class': 'State', 'distribution': {'class': 'Distribution', 'name': 'NormalDistribution', 'parameters': [0.36716326683682204, 0.0711085093316527], 'frozen': False}, 'name': '82dfd2be-f57e-44d1-b61a-3fdb03d45f3c', 'weight': 1.0}, {'class': 'State', 'distribution': {'class': 'Distribution', 'name': 'NormalDistribution', 'parameters': [0.3997194162239017, 0.0], 'frozen': False}, 'name': 'de26c027-bcb7-4216-9e31-d4b4b223178a', 'weight': 1.0}, {'class': 'State', 'distribution': {'class': 'Distribution', 'name': 'NormalDistribution', 'parameters': [0.3114115376978282, 0.00019355405244530423], 'frozen': False}, 'name': 'f1c7e842-4d03-4ebd-863e-cc972ddf298d', 'weight': 1.0}, {'class': 'State', 'distribution': None, 'name': 'None-start', 'weight': 1.0}, {'class': 'State', 'distribution': None, 'name': 'None-end', 'weight': 1.0}], 'end_index': 5, 'start_index': 4, 'silent_index': 4, 'edges': [[0, 2, 0.03030301088373653, 0.22219852722772315, None], [0, 1, 3.214115095853174e-101, 0.027617184503319403, None], [0, 3, 0.0, 0.016469626666622994, None], [0, 0, 0.9696969891162635, 0.002833802567491883, None], [1, 2, 3.473786840805535e-129, 0.38223726411047276, None], [1, 1, 0.9717355600928262, 0.8193198019179616, None], [1, 3, 0.02826443990717378, 0.31213882462326004, None], [1, 0, 2.0737210449513234e-177, 0.1669567665617232, None], [2, 2, 0.0, 0.027732277525769122, None], [2, 1, 1.0, 0.3051261201740513, None], [2, 3, 0.0, 0.9543877717509077, None], [2, 0, 0.0, 0.5240609701557668, None], [3, 2, 0.0, 0.953250912037412, None], [3, 1, 0.44030526978656925, 0.5784698159078305, None], [3, 3, 0.0, 0.18407287867767852, None], [3, 0, 0.5596947302134309, 0.8129420468290551, None], [4, 2, 0.25, 0.25, None], [4, 1, 0.25, 0.25, None], [4, 3, 0.25, 0.25, None], [4, 0, 0.25, 0.25, None]], 'distribution ties': []} Traceback (most recent call last): File "broken.py", line 50, in model = reset_starts(model) File "broken.py", line 18, in reset_starts return HiddenMarkovModel.from_json(json.dumps(js)) File "pomegranate/hmm.pyx", line 3283, in pomegranate.hmm.HiddenMarkovModel.from_json File "pomegranate/base.pyx", line 480, in pomegranate.base.State.from_json File "pomegranate/distributions/distributions.pyx", line 337, in pomegranate.distributions.distributions.Distribution.from_json File "", line 1, in File "pomegranate/distributions/NormalDistribution.pyx", line 37, in pomegranate.distributions.NormalDistribution.NormalDistribution.init ZeroDivisionError: float division by zero

shawrkbait avatar Jun 16 '20 13:06 shawrkbait

Thank you for opening an issue. pomegranate has recently been rewritten from the ground up to use PyTorch instead of Cython (v1.0.0), and so all issues are being closed as they are likely out of date. Please re-open or start a new issue if a related issue is still present in the new codebase.

jmschrei avatar Apr 16 '23 06:04 jmschrei