box-embeddings icon indicating copy to clipboard operation
box-embeddings copied to clipboard

Exact Bessel Volume

Open ssdasgupta opened this issue 3 years ago • 0 comments

Currently, the volume function is an approximation of the Bessel volume in this repo. However, I have tried to implement an exact version of the Bessel volume in the past. It was numerically not stable. I would like to request you to have a look at the code snippet and see how this could be appended to this repo.

The bessel function wrapper

class Bessel(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        dev = input.device
        with torch.no_grad():
            x = special.k0(input.detach().cpu()).to(dev)
            input.to(dev)
        return x

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        dev = grad_output.device
        with torch.no_grad():
            grad_input = grad_output*(-special.k1(input.detach().cpu())).to(dev)
            input.to(dev)

        return grad_input

The volume function

    def _log_bessel_volume(cls,
                           z: Tensor,
                           Z: Tensor,
                           gumbel_beta: float=1.,
                           scale: Union[float, Tensor] = 1.) -> Tensor:
        eps = torch.finfo(z.dtype).tiny
        if isinstance(scale, float):
            s = torch.tensor(scale)
        else:
            s = scale
        element = (2*torch.exp((z-Z)/(2*gumbel_beta))).clamp_max(100)
        return (torch.sum(
            torch.log(2*gumbel_beta*Bessel.apply(element).clamp_min(eps)),
            dim=-1) + torch.log(s)
        )

ssdasgupta avatar Mar 12 '21 06:03 ssdasgupta