pyro
pyro copied to clipboard
Inferring discrete variables with funsor HMM
I am trying to extract the discrete hidden states from the funsor HMM example; but am getting an odd error when using this pattern to infer discrete sites. My model exactly matches the example posted here and I added this chunk of code to the end of main()
which is supposed to extract the hidden states
guide_trace = handlers.trace(guide).get_trace(sequences, lengths)
trained_model = handlers.replay(model, trace=guide_trace)
inferred_model = infer.infer_discrete(
trained_model, temperature=0,
first_available_dim=first_available_dim)
trace = handlers.trace(inferred_model).get_trace(sequences, lengths)
The error arises when .get_trace()
is invoked and its downstream calling of forward_backward()
in adjoint.py. This error occurs no matter which model structure I pick within the funsor hmm example (i.e. it’s an issue pertaining to all funsor HMM’s, not just model_7()
with vectorized time dimension). Also, this routine to infer discrete sites works just fine when I’m analyzing Bach Chorales using the standard pyro HMM example.
I continue to get this same error when I use this condition workaround to avoid any compatibility issues between replay()
and funsor's infer_discrete()
# get trace of discrete params
guide_trace = handlers.trace(guide).get_trace(sequences, lengths)
guide_data = {
name: site["value"]
for name, site in guide_trace.nodes.items()
if site["type"] == "sample"
}
# MAP estimate discretes, conditioned on posterior sampled continous latents.
actual_trace = handlers.trace(
infer.infer_discrete(
handlers.condition(infer.config_enumerate(model), guide_data),
temperature=0,
)
).get_trace(sequences, lengths)
I’m using a CPU on a linux cluster with these versions:
funsor==0.4.3
pyro-api==0.1.2
pyro-ppl==1.8.2
torch==1.12.1
The link to my initial discussion board post can be found here
Last, here’s the full error trace:
AssertionError Traceback (most recent call last)
<ipython-input-7-3ccf5e1cbc19> in <module>
8
9 with pyro_backend(PYRO_BACKEND):
---> 10 trace, sequences, lengths = main(args)
<ipython-input-4-8c5832ced74d> in main(args)
169 temperature=0, first_available_dim=first_available_dim
170 )
--> 171 ).get_trace(sequences, lengths)
172
173 # trained_model = handlers.replay(model, trace=guide_trace)
/juno/work/venv3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
196 Calls this poutine and returns its trace instead of the function's return value.
197 """
--> 198 self(*args, **kwargs)
199 return self.msngr.get_trace()
/juno/work/venv3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
172 )
173 try:
--> 174 ret = self.fn(*args, **kwargs)
175 except (ValueError, RuntimeError) as e:
176 exc_type, exc_value, traceback = sys.exc_info()
/juno/work/venv3/lib/python3.7/site-packages/pyro/contrib/funsor/infer/discrete.py in _sample_posterior(model, first_available_dim, temperature, *args, **kwargs)
44
45 with approx:
---> 46 approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob)
47
48 # construct a result trace to replay against the model
/juno/work/venv3/lib/python3.7/site-packages/funsor/adjoint.py in adjoint(sum_op, bin_op, expr)
140
141 def adjoint(sum_op, bin_op, expr):
--> 142 forward, backward = forward_backward(sum_op, bin_op, expr)
143 return backward
144
/juno/work/venv3/lib/python3.7/site-packages/funsor/adjoint.py in forward_backward(sum_op, bin_op, expr, batch_vars)
135 # TODO fix traversal order in AdjointTape instead of using stack_reinterpret
136 forward = stack_reinterpret(expr)
--> 137 backward = tape.adjoint(sum_op, bin_op, forward, batch_vars=batch_vars)
138 return forward, backward
139
/juno/work/venv3/lib/python3.7/site-packages/funsor/adjoint.py in adjoint(self, sum_op, bin_op, root, targets, batch_vars)
113 self._eager_to_lazy[output] = lazy_output
114
--> 115 in_adjs = adjoint_ops(fn, sum_op, bin_op, adjoint_values[output], *inputs)
116 for v, adjv in in_adjs:
117 # Marginalize out message variables that don't appear in recipients.
/juno/work/venv3/lib/python3.7/site-packages/funsor/registry.py in __call__(self, key, *args)
104
105 def __call__(self, key, *args):
--> 106 return self[key](*args)
107
108 def dispatch(self, key, *args):
/juno/work/venv3/lib/python3.7/site-packages/funsor/registry.py in __call__(self, *args)
61
62 def __call__(self, *args):
---> 63 return self.partial_call(*args)(*args)
64
65
/juno/work/venv3/lib/python3.7/site-packages/funsor/adjoint.py in adjoint_contract_generic(adj_sum_op, adj_prod_op, out_adj, sum_op, prod_op, reduced_vars, terms)
215 adj_sum_op, adj_prod_op, out_adj, sum_op, prod_op, reduced_vars, terms
216 ):
--> 217 assert len(terms) == 1 or len(terms) == 2
218 return adjoint_ops(
219 Contraction,
AssertionError:
@ordabayevy any idea what might be happening?