funsor
funsor copied to clipboard
Importance funsor
Importance sampling is represented by an Importance funsor.
- Signature -
Importance(model, guide, sampled_vars). - When
guideis aDeltait eagerly evaluates toguide + model - guide. Importance.reduceis delegated toImportance.model.reduce.- (not implemented) consider implementing
MonteCarlointerpretation whenguideis not aDelta.
Dice factor as an importance weight
model = Delta(name, point, log_prob)
guide = Delta(name, point, ops.detach(log_prob))
Importance(model, guide, name)
== guide + model - guide
== guide + log_prob - ops.detach(log_prob)
== guide + dice_factor
Lazy interpretation
lazy_importance = DispatchedInterpretation("lazy_importance")
@lazy_importance.register(Importance, Funsor, Delta, frozenset)
def _lazy_importance(model, guide, sampled_vars):
return reflect.interpret(Importance, model, guide, sampled_vars)
It is used for a lazy importance sampling:
with lazy_importance:
sampled_dist = dist.sample(msg["name"], sample_inputs)
and for adjoint algorithm:
with lazy_importance:
marginals = adjoint(funsor.ops.logaddexp, funsor.ops.add, logzq)
Separability
model_a = Delta(“a”, point_a[“i”], log_prob_a)
guide_a = Delta(“a”, point_a[“i”], ops.detach(log_prob_a))
q_a = Importance(model_a, guide_a, {“a”})
model_b = Delta(“b”, point_b[“i”], log_prob_b)
guide_b = Delta(“b”, point_b[“i”], ops.detach(log_prob_b))
q_b = Importance(model_b, guide_b, {“b”})
with lazy_importance:
(q_a.exp() * q_b.exp() * cost_b).reduce(add, {“a”, “b”, “i”})
== [q_a.exp().reduce(add, “a”) * (q_b.exp() * cost_b).reduce(add, {“b”})].reduce(add, “i”)
== [1(“i”) * (q_b.exp().reduce(add, {“b”}) + cost_b(b=point_b))].reduce(add, “i”)
== [1(“i”) * 1("i") * cost_b(b=point_b)].reduce(add, “i”)
== cost_b(b=point_b).reduce(add, “i”)