torchtyping icon indicating copy to clipboard operation
torchtyping copied to clipboard

Support for an or condition, or other way to accomplish this pattern?

Open SeanEaster opened this issue 2 years ago • 1 comments

n00b to this very cool project, looking to enforce a broadcast-ability pattern where a dimension in one tensor either matches or can be broadcast to (i.e. equals 1) a dimension in another tensor.

@typeguard.typechecked
def mwe(
    x: torchtyping.TensorType[
        ...,
        "foo",
        "bar", # How do we make this "match bar from arg_b or equal 1"?
    ],
    y: torchtyping.TensorType[
        "bar",
    ]) -> torch typing.TensorType[...,"foo","bar"]:
    return x * y

Am I missing an existing way to do this in torchtyping out of the box? Would this need an extension?

SeanEaster avatar Nov 24 '21 14:11 SeanEaster

Yep, this is possible: it can be done with Union[TensorType[..., "foo", 1], TensorType[..., "foo", "bar"]].

One caveat -- switching the order of the elements of the Union will cause a spurious failure. (The 1 case has to go before the "bar" case.) That's a bug, really, but probably a thorny one to fix.


Incidentally broadcasting is a common enough operation that I'd be willing to accept a PR making this neater than the Union solution. Essentially all that's needed is some syntax like TensorType["foo": OrOne] which TensorType.__class_getitem__ expands out into a Union of the form given above.

This should be pretty simple so it'd be a good first issue for anyone looking to contribute.

patrick-kidger avatar Nov 24 '21 15:11 patrick-kidger