geomstats
geomstats copied to clipboard
Implement a safe divide function for torch and tf
The new gs.assignment function (thanks @pchauchat !) allows us:
- to clean code such as: https://github.com/geomstats/geomstats/blob/1e5f7cea94a7884b798f07723285b04592ac09ab/geomstats/geometry/special_orthogonal.py#L183
- to translate functions/methods in order for them to run with tf and torch.
We should use it in the whole codebase.
What about writing a divide function for pytorch and tf with similar behavior as np? It is similar to using where to avoid dividing by 0, but unlike where it doesn't compute the division where the condition is false. In most cases, we use:
theta_safe = gs.where(gs.abs(theta) < gs.atol, gs.atol, theta)
safe_division = gs.einsum('...,...i->...i', theta_safe, vectors)