flint icon indicating copy to clipboard operation
flint copied to clipboard

Slow modular arithmetic

Open fredrik-johansson opened this issue 4 years ago • 14 comments

Here is a simple function (bernsum_powg) taken from David Harvey's bernmm library (using NTL), and two FLINT versions using the n_precomp and n_preinvert interfaces.

#include <NTL/ZZ.h>
#include "flint/ulong_extras.h"
#include "flint/profiler.h"

NTL_CLIENT;

long PowerMod(long a, long ee, long n, mulmod_t ninv)
{
   long x, y;

   unsigned long e;

   if (ee < 0)
      e = - ((unsigned long) ee);
   else
      e = ee;

   x = 1;
   y = a;
   while (e) {
      if (e & 1) x = MulMod(x, y, n, ninv);
      y = MulMod(y, y, n, ninv);
      e = e >> 1;
   }

   if (ee < 0) x = InvMod(x, n);

   return x;
}

long bernsum_powg(long p, mulmod_t pinv, long k, long g)
{
   long half_gm1 = (g + ((g & 1) ? 0 : p) - 1) / 2;    // (g-1)/2 mod p
   long g_to_jm1 = 1;
   long g_to_km1 = PowerMod(g, k-1, p, pinv);
   long g_to_km1_to_j = g_to_km1;
   long sum = 0;
   muldivrem_t g_pinv = PrepMulDivRem(g, p);
   mulmod_precon_t g_to_km1_pinv = PrepMulModPrecon(g_to_km1, p, pinv);

   for (long j = 1; j <= (p-1)/2; j++)
   {
      // at this point,
      //    g_to_jm1 holds g^(j-1) mod p
      //    g_to_km1_to_j holds (g^(k-1))^j mod p

      // update g_to_jm1 and compute q = (g*(g^(j-1) mod p) - (g^j mod p)) / p
      long q;
      g_to_jm1 = MulDivRem(q, g_to_jm1, g, p, g_pinv);

      // compute h = -h_g(g^j) = q - (g-1)/2
      long h = SubMod(q, half_gm1, p);

      // add h_g(g^j) * (g^(k-1))^j to running total
      sum = SubMod(sum, MulMod(h, g_to_km1_to_j, p, pinv), p);

      // update g_to_km1_to_j
      g_to_km1_to_j = MulModPrecon(g_to_km1_to_j, g_to_km1, p, g_to_km1_pinv);
   }

   return sum;
}

long bernsum_powg_flint2(ulong p, double pinv, ulong k, ulong g)
{
   ulong half_gm1 = (g + ((g & 1) ? 0 : p) - 1) / 2;    // (g-1)/2 mod p
   ulong g_to_jm1 = 1;
   ulong g_to_km1 = n_powmod_precomp(g, k-1, p, pinv);
   ulong g_to_km1_to_j = g_to_km1;
   ulong sum = 0;
   ulong g_to_km1_pinv = n_mulmod_precomp_shoup(g_to_km1, p);

   for (long j = 1; j <= (p-1)/2; j++)
   {
      ulong q;

      g_to_jm1 = n_divrem2_precomp(&q, g_to_jm1 * g, p, pinv);
      ulong h = n_submod(q, half_gm1, p);
      sum = n_submod(sum, n_mulmod_precomp(h, g_to_km1_to_j, p, pinv), p);
      g_to_km1_to_j = n_mulmod_shoup(g_to_km1, g_to_km1_to_j, g_to_km1_pinv, p);
   }

   return sum;
}

long bernsum_powg_flint2b(ulong p, ulong pinv, ulong k, ulong g)
{
   ulong half_gm1 = (g + ((g & 1) ? 0 : p) - 1) / 2;    // (g-1)/2 mod p
   ulong g_to_jm1 = 1;
   ulong g_to_km1 = n_powmod2_preinv(g, k-1, p, pinv);
   ulong g_to_km1_to_j = g_to_km1;
   ulong sum = 0;

   for (long j = 1; j <= (p-1)/2; j++)
   {
      ulong q;

      g_to_jm1 = n_divrem2_preinv(&q, g_to_jm1 * g, p, pinv);
      ulong h = n_submod(q, half_gm1, p);
      sum = n_submod(sum, n_mulmod2_preinv(h, g_to_km1_to_j, p, pinv), p);
      g_to_km1_to_j = n_mulmod2_preinv(g_to_km1_to_j, g_to_km1, p, pinv);
   }

   return sum;
}

int main()
{
    long s, v;

    s = 0;
    TIMEIT_START
    s |= bernsum_powg(10007, PrepMulMod(10007), 9406, 5);
    TIMEIT_STOP
    printf("%ld\n", s);

    s = 0;
    TIMEIT_START
    s |= bernsum_powg_flint2(10007, n_precompute_inverse(10007), 9406, 5);
    TIMEIT_STOP
    printf("%ld\n", s);

    s = 0;
    TIMEIT_START
    s |= bernsum_powg_flint2b(10007, n_preinvert_limb(10007), 9406, 5);
    TIMEIT_STOP
    printf("%ld\n", s);
}

Timings on my machine:

cpu/wall(s): 4.19e-05 4.19e-05
5444
cpu/wall(s): 6.7e-05 6.7e-05
5444
cpu/wall(s): 0.000131 0.000131
5444

So our n_precomp arithmetic is 1.6x slower than NTL, and our n_preinvert arithmetic is 3x slower. I even cheated here -- NTL has a muldivrem function which we don't, so I put in a plain multiplication which of course will overflow if p is large.

  • Something is missing in the n_precomp arithmetic, making it slower than NTL for small-modulus arithmetic.
  • We currently use n_preinvert methods almost everywhere in Flint since they work up to 64 bits. We should investigate how to cleanly use n_precomp arithmetic (or something better) whenever the modulus is small.
  • n_mulmod_shoup is not very intuitively named, by the way (and takes its arguments in a weird order).

fredrik-johansson avatar Oct 23 '21 10:10 fredrik-johansson

Very surprising. I can't imagine for the life of me what we could have overlooked there. So much care went into making that efficient. Is NTL using doubles?

wbhart avatar Oct 28 '21 12:10 wbhart

We could do the following, where they make sense (conversion costs/representation need to be considered):

  • use a double mul and fused multiply and add to get a full 106 bit product of two 53 bit values
  • vectorisation
  • re-optimise our n_mulmod_precomp for modern CPUs
  • use 32 bit integers

wbhart avatar Oct 29 '21 15:10 wbhart

If n is odd and small enough and we use a balanced representation -n/2...n/2, this does a correctly reduced multiplication:

double dmod_mul(double a, double b, double n, double ninv)
{
    double magic = 6755399441055744.0;
    double r = a * b;

    return r - ((r * ninv + magic) - magic) * n;
}

This should be good for doing lots of multiplications in parallel with SIMD. How quickly can we add in this representation?

fredrik-johansson avatar Oct 29 '21 15:10 fredrik-johansson

  • Montgomery reduction when appropriate

fredrik-johansson avatar Oct 29 '21 16:10 fredrik-johansson

Another idea: in multimodular algorithms, we generally use primes of the form 2^n + c or 2^n - c where c is small. For multi-word moduli, this can certainly be exploited, but what about nmods?

fredrik-johansson avatar Nov 06 '21 11:11 fredrik-johansson

I have at least part of the answer: n_mulmod_preinv requires the inputs to be reduced and does something fast, n_mulmod2_preinv does not require the inputs to be reduced and does something slow.

Our nmod_mul is stupidly doing the same thing as n_mulmod2_preinv. We should basically just change it to do an n_mulmod_preinv instead; this is 2x faster on my machine.

We can also replace many other uses of n_mulmod2_preinv with n_mulmod_preinv throughout Flint.

Some shifts can be avoided when the modulus has exactly FLINT_BITS bits; maybe this is worth optimizing for in various places.

fredrik-johansson avatar Nov 10 '21 17:11 fredrik-johansson

Ditto for nmod_addmul / NMOD_ADDMUL.

fredrik-johansson avatar Nov 10 '21 17:11 fredrik-johansson

If n is odd and small enough and we use a balanced representation -n/2...n/2, this does a correctly reduced multiplication

how small is small enough? This is assuming no fmadd/fmsub? What about with fmadd/fmsub?

tthsqe12 avatar Nov 10 '21 17:11 tthsqe12

If n is odd and small enough and we use a balanced representation -n/2...n/2, this does a correctly reduced multiplication

how small is small enough? This is assuming no fmadd/fmsub? What about with fmadd/fmsub?

Up to sqrt(2^53) I guess, but I did not check or prove this. You might want to design an entirely different algorithm around fma.

fredrik-johansson avatar Nov 10 '21 17:11 fredrik-johansson

FWIW, flint does better if one uses nmod_mul and n_mulmod_shoup.

After that, the remaining difference seems to be due to n_divrem2_preinv being much slower than MulDivRem.

fredrik-johansson avatar Jan 07 '24 00:01 fredrik-johansson

Several functions for modular arithmetic like n_mulmod2_preinv, n_mod2_preinv, n_ll_mod_preinv, n_lll_mod_preinv don't take a norm as input and therefore need an flint_clz operation which is redundant in situations where we already have an nmod_t containing this data. There are also probably many places (though not all) where these operations should actually be inlined. Replacing them with nmod_mul, NMOD2_RED2 etc. would be an improvement.

We should think about ways to redesign these interfaces so that they are more obvious.

fredrik-johansson avatar May 09 '24 20:05 fredrik-johansson

It would be nice to separate those that need normalization and those who does not. Not sure how to do that user friendly though.

albinahlback avatar May 09 '24 22:05 albinahlback

another point of comparison should be https://math.mit.edu/~drew/ffpoly.html I think @andrewvsutherland (the author) did a comparison at some point...

edgarcosta avatar May 10 '24 14:05 edgarcosta

Better/simpler to test against b32.h (fast implementation of Barrett modular arithmetic for 32-bit integers) and m64.h (fast implementation of 64-bit Montgomery arithmetic); the latter is what ffpoly uses.

AndrewVSutherland avatar May 10 '24 14:05 AndrewVSutherland