Pagina's

2010/12/19

Binomial Coefficient, n over k, Choose(n,k)

// Another abbreviation

// BNM(5,2) = 5! / 2! / (5-2)! = 10  binomial coefficient

// A sieve of Eratosthenes is used to get the primes,
// GetBnmPrimes(30,15) = {3,3,5,17,19,23,29} , 7 increasing numbers.
// floorLog2(7) = fL2(7) = 2
// 1 << fL2(7) = 4
// BnmSpecialProduct changes the array into
// {17,3*19,3*23,5*29} = {17,57,69,145} , 4 increasing numbers.
// BnmProduct changes it into 
// {17*145,57*69} = {2465,3933}
// BnmProduct finally changes it into
// {2465*3933} = {9694845}
// A small prime, 2 , still has to be handled,
// its power factor is getBnmPower2(30,15) = 4
// 9694845 * 2^4 = 9694845 << 4 = 155117520

// After a few iterations of BnmProduct, 
// the multiplicands are of equal, or almost equal, bitlength,
// a prerequisite to use Karatsuba etc. effectively.
//
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading.Tasks;
using Xint = System.Numerics.BigInteger;
class Program
{
    private static Stopwatch sw = new Stopwatch();
    static void Main()
    {
        sw.Restart();
        BNM(6400000, 2133333);
        sw.Stop();
        Console.WriteLine(sw.ElapsedMilliseconds);       // BNM.exe 3880 ms
        sw.Restart();
        Console.ReadLine();
    }
    public static Xint BNM(int n, int k)
    {
        if (k > n) return 0;
        if ((k == 0) | (n == k)) return 1;
        if ((k == 1) | (n == k + 1)) return n;
        Xint[] P = GetBnmPrimes(n, k);
        int i = P.Length;
        int j = 1 << fL2(i);
        if (i != j)
        {
            P = BnmSpecialProduct(P, i, j);
            i = j;
        }
        while (i > 1)
        {
            P = BnmProduct(P, i);
            i >>= 1;
        }
        return P[0] << getBnmPower2(n, k);
    }
    private static Xint[] GetBnmPrimes(int n, int k)
    {
        int m = ((n & 1) + n) >> 1;                         // Eratosthenes
        BitArray x = new BitArray(m + 1);
        int i = 4, p = 3, q = 4;
        for (; q < m; q += p >> 1 << 2)
        {
            if (!x[q])
            {
                x[q] = true;
                for (i = q + p; i < m; i += p) x[i] = true;
            }
            p += 2;
        }
        List<Xint> Primes = new List<Xint>();
        int j;                        // prime power factorization binomial 
        for (i = 1; i < m; i++) if (!x[i])
            {
                p = i * 2 | 1;
                q = n / p; j = q; while (q >= p) { q /= p; j += q; }
                q = k / p; j -= q; while (q >= p) { q /= p; j -= q; }
                q = (n - k) / p; j -= q; while (q >= p) { q /= p; j -= q; }
                for (q = 0; q < j; q++) Primes.Add(p);
            }
        return Primes.ToArray();
    }
    private static int getBnmPower2(int n, int k)
    {
        int m = n - k;
        int i = (n >>= 1);
        while (n > 1) i += (n >>= 1);
        while (m > 1) i -= (m >>= 1);
        while (k > 1) i -= (k >>= 1);
        return i;
    }
    private static Xint[] BnmSpecialProduct(Xint[] P, int i, int j)
    {
        Xint[] Q = new Xint[j];
        int m = i - j;
        i = j;
        j -= m;
        int n = 0;
        for (; n < j; n++)
        {
            Q[n] = P[m];
            m++;
        }
        j = 0;
        for (; n < i; n++)
        {
            Q[n] = P[m] * P[j];
            j++;
            m++;
        }
        return Q;
    }
    private static Xint[] BnmProduct(Xint[] P, int i)
    {
        int k = i >> 1;
        Xint[] Q = new Xint[k];
        for (int j = 0; j < k; j++)
        {
            i--;
            Q[j] = MTP(P[j], P[i]);
        }
        return Q;
    }

    private static Xint MTP(Xint U, Xint V)
    {
        return MTP(U, V, Xint.Max(U.Sign * U, V.Sign * V).ToByteArray().Length << 3);
    }
    private static Xint MTP(Xint U, Xint V, int n)
    {
        if (n <= 3000) return U * V;
        if (n <= 6000) return TC2(U, V, n);
        if (n <= 10000) return TC3(U, V, n);
        if (n <= 40000) return TC4(U, V, n);
        return TC2P(U, V, n);
    }
    private static Xint MTPr(Xint U, Xint V, int n)
    {
        if (n <= 3000) return U * V;
        if (n <= 6000) return TC2(U, V, n);
        if (n <= 10000) return TC3(U, V, n);
        return TC4(U, V, n);
    }
    private static Xint TC2(Xint U1, Xint V1, int n)
    {
        n >>= 1;
        Xint Mask = (Xint.One << n) - 1;
        Xint U0 = U1 & Mask; U1 >>= n;
        Xint V0 = V1 & Mask; V1 >>= n;
        Xint P0 = MTPr(U0, V0, n);
        Xint P2 = MTPr(U1, V1, n);
        return ((P2 << n) + (MTPr(U0 + U1, V0 + V1, n) - (P0 + P2)) << n) + P0;
    }
    private static Xint TC3(Xint U2, Xint V2, int n)
    {
        n = (int)((long)(n) * 0x55555556 >> 32); // n /= 3;
        Xint Mask = (Xint.One << n) - 1;
        Xint U0 = U2 & Mask; U2 >>= n;
        Xint U1 = U2 & Mask; U2 >>= n;
        Xint V0 = V2 & Mask; V2 >>= n;
        Xint V1 = V2 & Mask; V2 >>= n;
        Xint W0 = MTPr(U0, V0, n);
        Xint W4 = MTPr(U2, V2, n);
        Xint P3 = MTPr((((U2 << 1) + U1) << 1) + U0, (((V2 << 1) + V1 << 1)) + V0, n);
        U2 += U0;
        V2 += V0;
        Xint P2 = MTPr(U2 - U1, V2 - V1, n);
        Xint P1 = MTPr(U2 + U1, V2 + V1, n);
        Xint W2 = (P1 + P2 >> 1) - (W0 + W4);
        Xint W3 = W0 - P1;
        W3 = ((W3 + P3 - P2 >> 1) + W3) / 3 - (W4 << 1);
        Xint W1 = P1 - (W4 + W3 + W2 + W0);
        return ((((W4 << n) + W3 << n) + W2 << n) + W1 << n) + W0;
    }
    private static Xint TC4(Xint U3, Xint V3, int n)
    {
        n >>= 2;
        Xint Mask = (Xint.One << n) - 1;
        Xint U0 = U3 & Mask; U3 >>= n;
        Xint U1 = U3 & Mask; U3 >>= n;
        Xint U2 = U3 & Mask; U3 >>= n;
        Xint V0 = V3 & Mask; V3 >>= n;
        Xint V1 = V3 & Mask; V3 >>= n;
        Xint V2 = V3 & Mask; V3 >>= n;

        Xint W0 = MTPr(U0, V0, n);                               //  0
        U0 += U2; U1 += U3;
        V0 += V2; V1 += V3;
        Xint P1 = MTPr(U0 + U1, V0 + V1, n);                     //  1
        Xint P2 = MTPr(U0 - U1, V0 - V1, n);                     // -1
        U0 += 3 * U2; U1 += 3 * U3;
        V0 += 3 * V2; V1 += 3 * V3;
        Xint P3 = MTPr(U0 + (U1 << 1), V0 + (V1 << 1), n);       //  2
        Xint P4 = MTPr(U0 - (U1 << 1), V0 - (V1 << 1), n);       // -2
        Xint P5 = MTPr(U0 + 12 * U2 + ((U1 + 12 * U3) << 2),
                       V0 + 12 * V2 + ((V1 + 12 * V3) << 2), n); //  4
        Xint W6 = MTPr(U3, V3, n);                               //  inf

        Xint W1 = P1 + P2;
        Xint W4 = (((((P3 + P4) >> 1) - (W1 << 1)) / 3 + W0) >> 2) - 5 * W6;
        Xint W2 = (W1 >> 1) - (W6 + W4 + W0);
        P1 = P1 - P2;
        P4 = P4 - P3;
        Xint W5 = ((P1 >> 1) + (5 * P4 + P5 - W0 >> 4) - ((((W6 << 4) + W4) << 4) + W2)) / 45;
        W1 = ((P4 >> 2) + (P1 << 1)) / 3 + (W5 << 2);
        Xint W3 = (P1 >> 1) - (W1 + W5);
        return ((((((W6 << n) + W5 << n) + W4 << n) + W3 << n) + W2 << n) + W1 << n) + W0;
    }
    private static Xint TC2P(Xint A, Xint B, int n)
    {
        n >>= 1;
        Xint Mask = (Xint.One << n) - 1;
        Xint[] U = new Xint[3];
        U[0] = A & Mask; A >>= n; U[2] = A; U[1] = U[0] + A;
        Xint[] V = new Xint[3];
        V[0] = B & Mask; B >>= n; V[2] = B; V[1] = V[0] + B;
        Xint[] P = new Xint[3];
        Parallel.For(0, 3, (int i) => P[i] = MTPr(U[i], V[i], n));
        return ((P[2] << n) + P[1] - (P[0] + P[2]) << n) + P[0];
    }
    private static int fL2(int i)
    {
        return
        i < 1 << 15 ? i < 1 << 07 ? i < 1 << 03 ? i < 1 << 01 ? i < 1 << 00 ? -1 : 00 :
                                                                i < 1 << 02 ? 01 : 02 :
                                                  i < 1 << 05 ? i < 1 << 04 ? 03 : 04 :
                                                                i < 1 << 06 ? 05 : 06 :
                                    i < 1 << 11 ? i < 1 << 09 ? i < 1 << 08 ? 07 : 08 :
                                                                i < 1 << 10 ? 09 : 10 :
                                                  i < 1 << 13 ? i < 1 << 12 ? 11 : 12 :
                                                                i < 1 << 14 ? 13 : 14 :
                      i < 1 << 23 ? i < 1 << 19 ? i < 1 << 17 ? i < 1 << 16 ? 15 : 16 :
                                                                i < 1 << 18 ? 17 : 18 :
                                                  i < 1 << 21 ? i < 1 << 20 ? 19 : 20 :
                                                                i < 1 << 22 ? 21 : 22 :
                                    i < 1 << 27 ? i < 1 << 25 ? i < 1 << 24 ? 23 : 24 :
                                                                i < 1 << 26 ? 25 : 26 :
                                                  i < 1 << 29 ? i < 1 << 28 ? 27 : 28 :
                                                                i < 1 << 30 ? 29 : 30;
    }
    private static int bL(Xint U)
    {
        byte[] bytes = (U.Sign * U).ToByteArray();
        int i = bytes.Length - 1;
        return i << 3 | bitLengthMostSignificantByte(bytes[i]);
    }
    private static int bitLengthMostSignificantByte(byte b)
    {
        return b < 08 ? b < 02 ? b < 01 ? 0 : 1 :
                                 b < 04 ? 2 : 3 :
                        b < 32 ? b < 16 ? 4 : 5 :
                                 b < 64 ? 6 : 7;
    }
}

No comments:

Post a Comment