box-embeddings
box-embeddings copied to clipboard
Exact Bessel Volume
Currently, the volume function is an approximation of the Bessel volume in this repo. However, I have tried to implement an exact version of the Bessel volume in the past. It was numerically not stable. I would like to request you to have a look at the code snippet and see how this could be appended to this repo.
The bessel function wrapper
class Bessel(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
dev = input.device
with torch.no_grad():
x = special.k0(input.detach().cpu()).to(dev)
input.to(dev)
return x
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
dev = grad_output.device
with torch.no_grad():
grad_input = grad_output*(-special.k1(input.detach().cpu())).to(dev)
input.to(dev)
return grad_input
The volume function
def _log_bessel_volume(cls,
z: Tensor,
Z: Tensor,
gumbel_beta: float=1.,
scale: Union[float, Tensor] = 1.) -> Tensor:
eps = torch.finfo(z.dtype).tiny
if isinstance(scale, float):
s = torch.tensor(scale)
else:
s = scale
element = (2*torch.exp((z-Z)/(2*gumbel_beta))).clamp_max(100)
return (torch.sum(
torch.log(2*gumbel_beta*Bessel.apply(element).clamp_min(eps)),
dim=-1) + torch.log(s)
)