pyhsmm
pyhsmm copied to clipboard
SVI plotting broken
Tried to use hmm.plot() in the svi example, and got the following:
Despite showing different emissions, they are all one color because only the latest state seems to be recorded at the end of training. After digging into it a bit, looks like the HMMSVI class uses states_list for each minibatch instead of for storing the global HMMStatesEigen object. I tried changing _get_mb_states_list
in the _HMMSVI definition in models.py to add the last state from states_list instead of popping it off, as follows:
def _get_mb_states_list(self,minibatch,**kwargs):
minibatch = minibatch if isinstance(minibatch,list) else [minibatch]
mb_states_list = []
for mb in minibatch:
self.add_data(mb,generate=False,**kwargs)
mb_states_list.append(self.states_list[-1])
return mb_states_list
This resulted in the following plot, where the left side looks as expect and the right is a mess, likely due to plotting every overlapping state sequences: