funsor
funsor copied to clipboard
Add array protocol dispatch methods to top-level Funsor class
Now that PyTorch supports tensor subtyping and function overloading with __torch_function__, should we add __array_function__ and __torch_function__ methods to funsor.terms.Funsor to allow evaluation of (some) PyTorch/Numpy code on Funsors?
Here is the meat of a Funsor.__torch_function__ implementation, modulo handling of edge cases; __array_function__ for the Numpy backend would be very similar:
class Funsor:
...
def __torch_function__(self, func, types, args=(), kwargs=None):
# exploit our op registry: ops should know how to handle and convert their arguments
try:
op = getattr(funsor.ops, func.__name__)
except AttributeError:
op = funsor.ops.make_op(func). # handle e.g. nn.Module or dist.Transform instances
return op(*args, **kwargs)
The motivating application is as a much simpler and more general alternative to the dimension tracking via effectful to_data/to_funsor primitives in pyro.contrib.funsor, which is somewhat confusing. This would also simplify @ordabayevy's work in #543 and elsewhere by removing the need for special torch.Tensor subclasses that duplicate Funsor broadcasting semantics.
I like the idea quite a lot! It might simplify things in funsor and make look cleaner. My current understanding is that __torch_function__ will replace all Funsor.ops (such as Funsor.__add__, Funsor.sum, etc)? And contrib.funsor will calculate everything as Funsors during model execution instead of delegating it to TraceMessenger and converting it the last moment?
I don't think it would replace the basic Python operator overloads, but array-specific methods like sum() could probably be removed in favor of these generic methods.