returnn
returnn copied to clipboard
PyTorch debug_add_check_numerics_ops
We could use such code:
import torch
from torch.utils._pytree import tree_all
from torch.utils._python_dispatch import TorchDispatchMode
class NaNInfMode(TorchDispatchMode):
enabled: bool
def __init__(self, enabled=True):
super().__init__()
self.enabled = enabled
@staticmethod
def check_finite(pytree):
return tree_all(lambda x: not isinstance(x, torch.Tensor) or torch.isfinite(x).all(), pytree)
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs if kwargs else {}
if self.enabled:
assert self.check_finite((args, kwargs)), f"input to {func} contains NaN or Inf: {args}, {kwargs}"
out = func(*args, **kwargs)
if self.enabled:
assert self.check_finite(out), f"output to {func} contains NaN or Inf: {out}"
return out