returnn icon indicating copy to clipboard operation
returnn copied to clipboard

PyTorch debug_add_check_numerics_ops

Open albertz opened this issue 1 year ago • 0 comments

We could use such code:

import torch
from torch.utils._pytree import tree_all
from torch.utils._python_dispatch import TorchDispatchMode
 

class NaNInfMode(TorchDispatchMode):
    enabled: bool
    
    def __init__(self, enabled=True):
        super().__init__()
        self.enabled = enabled
    
    @staticmethod
    def check_finite(pytree):
        return tree_all(lambda x: not isinstance(x, torch.Tensor) or torch.isfinite(x).all(), pytree)
    
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        kwargs = kwargs if kwargs else {}
        if self.enabled:
            assert self.check_finite((args, kwargs)), f"input to {func} contains NaN or Inf: {args}, {kwargs}"
        out = func(*args, **kwargs)
        if self.enabled:
            assert self.check_finite(out), f"output to {func} contains NaN or Inf: {out}"
        return out

Via, via.

albertz avatar May 31 '24 08:05 albertz