Slow modular arithmetic
Here is a simple function (bernsum_powg) taken from David Harvey's bernmm library (using NTL), and two FLINT versions using the n_precomp and n_preinvert interfaces.
#include <NTL/ZZ.h>
#include "flint/ulong_extras.h"
#include "flint/profiler.h"
NTL_CLIENT;
long PowerMod(long a, long ee, long n, mulmod_t ninv)
{
long x, y;
unsigned long e;
if (ee < 0)
e = - ((unsigned long) ee);
else
e = ee;
x = 1;
y = a;
while (e) {
if (e & 1) x = MulMod(x, y, n, ninv);
y = MulMod(y, y, n, ninv);
e = e >> 1;
}
if (ee < 0) x = InvMod(x, n);
return x;
}
long bernsum_powg(long p, mulmod_t pinv, long k, long g)
{
long half_gm1 = (g + ((g & 1) ? 0 : p) - 1) / 2; // (g-1)/2 mod p
long g_to_jm1 = 1;
long g_to_km1 = PowerMod(g, k-1, p, pinv);
long g_to_km1_to_j = g_to_km1;
long sum = 0;
muldivrem_t g_pinv = PrepMulDivRem(g, p);
mulmod_precon_t g_to_km1_pinv = PrepMulModPrecon(g_to_km1, p, pinv);
for (long j = 1; j <= (p-1)/2; j++)
{
// at this point,
// g_to_jm1 holds g^(j-1) mod p
// g_to_km1_to_j holds (g^(k-1))^j mod p
// update g_to_jm1 and compute q = (g*(g^(j-1) mod p) - (g^j mod p)) / p
long q;
g_to_jm1 = MulDivRem(q, g_to_jm1, g, p, g_pinv);
// compute h = -h_g(g^j) = q - (g-1)/2
long h = SubMod(q, half_gm1, p);
// add h_g(g^j) * (g^(k-1))^j to running total
sum = SubMod(sum, MulMod(h, g_to_km1_to_j, p, pinv), p);
// update g_to_km1_to_j
g_to_km1_to_j = MulModPrecon(g_to_km1_to_j, g_to_km1, p, g_to_km1_pinv);
}
return sum;
}
long bernsum_powg_flint2(ulong p, double pinv, ulong k, ulong g)
{
ulong half_gm1 = (g + ((g & 1) ? 0 : p) - 1) / 2; // (g-1)/2 mod p
ulong g_to_jm1 = 1;
ulong g_to_km1 = n_powmod_precomp(g, k-1, p, pinv);
ulong g_to_km1_to_j = g_to_km1;
ulong sum = 0;
ulong g_to_km1_pinv = n_mulmod_precomp_shoup(g_to_km1, p);
for (long j = 1; j <= (p-1)/2; j++)
{
ulong q;
g_to_jm1 = n_divrem2_precomp(&q, g_to_jm1 * g, p, pinv);
ulong h = n_submod(q, half_gm1, p);
sum = n_submod(sum, n_mulmod_precomp(h, g_to_km1_to_j, p, pinv), p);
g_to_km1_to_j = n_mulmod_shoup(g_to_km1, g_to_km1_to_j, g_to_km1_pinv, p);
}
return sum;
}
long bernsum_powg_flint2b(ulong p, ulong pinv, ulong k, ulong g)
{
ulong half_gm1 = (g + ((g & 1) ? 0 : p) - 1) / 2; // (g-1)/2 mod p
ulong g_to_jm1 = 1;
ulong g_to_km1 = n_powmod2_preinv(g, k-1, p, pinv);
ulong g_to_km1_to_j = g_to_km1;
ulong sum = 0;
for (long j = 1; j <= (p-1)/2; j++)
{
ulong q;
g_to_jm1 = n_divrem2_preinv(&q, g_to_jm1 * g, p, pinv);
ulong h = n_submod(q, half_gm1, p);
sum = n_submod(sum, n_mulmod2_preinv(h, g_to_km1_to_j, p, pinv), p);
g_to_km1_to_j = n_mulmod2_preinv(g_to_km1_to_j, g_to_km1, p, pinv);
}
return sum;
}
int main()
{
long s, v;
s = 0;
TIMEIT_START
s |= bernsum_powg(10007, PrepMulMod(10007), 9406, 5);
TIMEIT_STOP
printf("%ld\n", s);
s = 0;
TIMEIT_START
s |= bernsum_powg_flint2(10007, n_precompute_inverse(10007), 9406, 5);
TIMEIT_STOP
printf("%ld\n", s);
s = 0;
TIMEIT_START
s |= bernsum_powg_flint2b(10007, n_preinvert_limb(10007), 9406, 5);
TIMEIT_STOP
printf("%ld\n", s);
}
Timings on my machine:
cpu/wall(s): 4.19e-05 4.19e-05
5444
cpu/wall(s): 6.7e-05 6.7e-05
5444
cpu/wall(s): 0.000131 0.000131
5444
So our n_precomp arithmetic is 1.6x slower than NTL, and our n_preinvert arithmetic is 3x slower. I even cheated here -- NTL has a muldivrem function which we don't, so I put in a plain multiplication which of course will overflow if p is large.
- Something is missing in the n_precomp arithmetic, making it slower than NTL for small-modulus arithmetic.
- We currently use n_preinvert methods almost everywhere in Flint since they work up to 64 bits. We should investigate how to cleanly use n_precomp arithmetic (or something better) whenever the modulus is small.
- n_mulmod_shoup is not very intuitively named, by the way (and takes its arguments in a weird order).
Very surprising. I can't imagine for the life of me what we could have overlooked there. So much care went into making that efficient. Is NTL using doubles?
We could do the following, where they make sense (conversion costs/representation need to be considered):
- use a double mul and fused multiply and add to get a full 106 bit product of two 53 bit values
- vectorisation
- re-optimise our n_mulmod_precomp for modern CPUs
- use 32 bit integers
If n is odd and small enough and we use a balanced representation -n/2...n/2, this does a correctly reduced multiplication:
double dmod_mul(double a, double b, double n, double ninv)
{
double magic = 6755399441055744.0;
double r = a * b;
return r - ((r * ninv + magic) - magic) * n;
}
This should be good for doing lots of multiplications in parallel with SIMD. How quickly can we add in this representation?
- Montgomery reduction when appropriate
Another idea: in multimodular algorithms, we generally use primes of the form 2^n + c or 2^n - c where c is small. For multi-word moduli, this can certainly be exploited, but what about nmods?
I have at least part of the answer: n_mulmod_preinv requires the inputs to be reduced and does something fast, n_mulmod2_preinv does not require the inputs to be reduced and does something slow.
Our nmod_mul is stupidly doing the same thing as n_mulmod2_preinv. We should basically just change it to do an n_mulmod_preinv instead; this is 2x faster on my machine.
We can also replace many other uses of n_mulmod2_preinv with n_mulmod_preinv throughout Flint.
Some shifts can be avoided when the modulus has exactly FLINT_BITS bits; maybe this is worth optimizing for in various places.
Ditto for nmod_addmul / NMOD_ADDMUL.
If n is odd and small enough and we use a balanced representation -n/2...n/2, this does a correctly reduced multiplication
how small is small enough? This is assuming no fmadd/fmsub? What about with fmadd/fmsub?
If n is odd and small enough and we use a balanced representation -n/2...n/2, this does a correctly reduced multiplication
how small is small enough? This is assuming no fmadd/fmsub? What about with fmadd/fmsub?
Up to sqrt(2^53) I guess, but I did not check or prove this. You might want to design an entirely different algorithm around fma.
FWIW, flint does better if one uses nmod_mul and n_mulmod_shoup.
After that, the remaining difference seems to be due to n_divrem2_preinv being much slower than MulDivRem.
Several functions for modular arithmetic like n_mulmod2_preinv, n_mod2_preinv, n_ll_mod_preinv, n_lll_mod_preinv don't take a norm as input and therefore need an flint_clz operation which is redundant in situations where we already have an nmod_t containing this data. There are also probably many places (though not all) where these operations should actually be inlined. Replacing them with nmod_mul, NMOD2_RED2 etc. would be an improvement.
We should think about ways to redesign these interfaces so that they are more obvious.
It would be nice to separate those that need normalization and those who does not. Not sure how to do that user friendly though.
another point of comparison should be https://math.mit.edu/~drew/ffpoly.html I think @andrewvsutherland (the author) did a comparison at some point...