torchtyping icon indicating copy to clipboard operation
torchtyping copied to clipboard

TensorType detail: grad_enabled

Open SimpleConjugate opened this issue 3 years ago • 1 comments

Is it possible to perform type checking for tensors with grad enabled? I myself am not sure of all the cases necessary to test against to confirm this as I don't fully understand how runtime type checking operates.

class _AutoGradTensorDetail(TensorDetail):
    def check(self, tensor: torch.Tensor)  -> bool:
        return tensor.requires_grad()

SimpleConjugate avatar May 27 '21 20:05 SimpleConjugate

Ah, that's a nice idea for a tensor detail.

Yes, that should be completely possible. Quick mock-up (untested):

class _RequiresGradDetail(TensorDetail):
    def check(self, tensor: Tensor) -> bool:
        return tensor.requires_grad

    def __repr__(self) -> str:
        return "requires_grad"

    @classmethod
    def tensor_repr(cls, tensor: Tensor) -> str:
        if tensor.requires_grad:
            return "requires_grad"
       	else:
            return ""

requires_grad = _RequiresGradDetail()

patrick-kidger avatar May 27 '21 20:05 patrick-kidger