funsor icon indicating copy to clipboard operation
funsor copied to clipboard

Add array protocol dispatch methods to top-level Funsor class

Open eb8680 opened this issue 4 years ago • 2 comments

Now that PyTorch supports tensor subtyping and function overloading with __torch_function__, should we add __array_function__ and __torch_function__ methods to funsor.terms.Funsor to allow evaluation of (some) PyTorch/Numpy code on Funsors?

Here is the meat of a Funsor.__torch_function__ implementation, modulo handling of edge cases; __array_function__ for the Numpy backend would be very similar:

class Funsor:
    ...
    def __torch_function__(self, func, types, args=(), kwargs=None):
        # exploit our op registry: ops should know how to handle and convert their arguments
        try:
            op = getattr(funsor.ops, func.__name__)
        except AttributeError:
            op = funsor.ops.make_op(func). # handle e.g. nn.Module or dist.Transform instances
        return op(*args, **kwargs)

The motivating application is as a much simpler and more general alternative to the dimension tracking via effectful to_data/to_funsor primitives in pyro.contrib.funsor, which is somewhat confusing. This would also simplify @ordabayevy's work in #543 and elsewhere by removing the need for special torch.Tensor subclasses that duplicate Funsor broadcasting semantics.

eb8680 avatar Aug 03 '21 14:08 eb8680

I like the idea quite a lot! It might simplify things in funsor and make look cleaner. My current understanding is that __torch_function__ will replace all Funsor.ops (such as Funsor.__add__, Funsor.sum, etc)? And contrib.funsor will calculate everything as Funsors during model execution instead of delegating it to TraceMessenger and converting it the last moment?

ordabayevy avatar Aug 04 '21 00:08 ordabayevy

I don't think it would replace the basic Python operator overloads, but array-specific methods like sum() could probably be removed in favor of these generic methods.

eb8680 avatar Aug 04 '21 00:08 eb8680