geomstats icon indicating copy to clipboard operation
geomstats copied to clipboard

Implement a safe divide function for torch and tf

Open ninamiolane opened this issue 5 years ago • 1 comments

The new gs.assignment function (thanks @pchauchat !) allows us:

  • to clean code such as: https://github.com/geomstats/geomstats/blob/1e5f7cea94a7884b798f07723285b04592ac09ab/geomstats/geometry/special_orthogonal.py#L183
  • to translate functions/methods in order for them to run with tf and torch.

We should use it in the whole codebase.

ninamiolane avatar Apr 27 '20 22:04 ninamiolane

What about writing a divide function for pytorch and tf with similar behavior as np? It is similar to using where to avoid dividing by 0, but unlike where it doesn't compute the division where the condition is false. In most cases, we use:

theta_safe = gs.where(gs.abs(theta) < gs.atol, gs.atol, theta)
safe_division = gs.einsum('...,...i->...i', theta_safe, vectors)

nguigs avatar Apr 28 '20 11:04 nguigs