funsor
funsor copied to clipboard
Implement optimized GEMM-like ops for Tensor and Gaussian
trafficstars
- [x] #281 For
Tensoruse a GEMM op, fixing the bugs in https://github.com/pyro-ppl/funsor/blob/d18c78e/funsor/torch.py#L495 (which only works for scalar event shape) - [ ] For
Gaussiansee pyro.ops.gaussian.gaussian_tensordot() pyro-ppl/pyro#1980 . This is lower priority thanTensorbecause the naive(x+y).reduce()version is only slightly more expensive.
@eb8680 can you confirm that I can implement these as something like
@eager.register(Contraction, AddOp, MulOp, frozenset, Variadic[Tensor])
def eager_contract_tensor(red_op, bin_op, reduced_vars, *operands):
equation = "TODO define an einsum string"
data = pyro.ops.einsum.contract(equation, *(x.data for x in operands))
inputs = "TODO"
output = "TODO"
return Tensor(data, inputs, output)
@eager.register(Contraction, LogaddexpOp, AddOp, frozenset, Variadic[Tensor])
def eager_contract_tensor(red_op, bin_op, reduced_vars, *operands):
...backend='pyro.ops.einsum.torch_log'...
can you confirm
Yep, this should work as expected.
@fehiepsi did you already complete this?