plaidCTF 2014 - rsa (for450)

For PlaidCTF2014, Eindbazen and fail0verflow joined forces as 0xffa, the Final Fail Alliance.
Don't miss out on other write-ups at Eindbazen's site!
rsa
Forensics (450 pts)
--------------
Our archaeologists recovered a dusty and corrupted old hard drive used by
The Plague in his trips into the past. It contains a private key, but this
has long since been lost to bitrot. Can you recover the full key from the
little information we have recovered?

You can download the recovered information here.

tl;dr

Grab the modulus N from the public key, note that e is the usual, 65537. The corrupted file is openssl text output displaying the private key. Remember eprint 2008510, grab the accompanying code, hack hack hack, get the private key out. The secret is just one plain number you can decode with it.

The setting

The paper by Nadia Heninger and Hovav Shacham describes an algorithm to reconstruct an RSA private key from that key given in the usual redundant form, with erasures. The algorithm is efficient if at least 27% (uniformly at random) of the bits survived. In our case the maimings are per nybble, but that is close enough.

1024-bit RSA with a public exponent of 65537, as we have here, is about 100x cheaper for computing the public operation than for computing the private operation, if done in the naive way. It can be sped up about 4x by storing in addition to the private key (N, d) also p, q, dp, dq, and qp, where N = pq, and dp is d reduced modulo p 1, similar for dq, and qp is the inverse of q modulo p. Instead of computing the private operation xd modulo N directly, we compute xdp modulo p, the analogue for q, and combine these two with the chinese remainder theorem (using qp). Both these smaller exponentiations are about 8x cheaper (half-length multiplies are 4x cheaper and the exponentiation chain is half as long).

This speedup is hard to ignore even if it opens you up to many more side-channel attacks, so it is used by most implementations. In the context of “cold-boot attacks”, where you read out DRAM after it has had power removed for a while, some bits of a stored private key will have degraded; this paper shows that if you can still recover 27% of the bits (and you know which), then you can reconstruct the private key.

Let’s write code! Or not

This problem is worth 450 points, so it is no surprise it is quite a bit of work to implement the algorithm. A much more serious problem is that debugging it is very hard, and it is even harder to make a useful estimate how long that debugging would take. It doesn’t help that the algorithm makes a few not-so-stellar implementation choices (at least from the viewpoint of making it easy to implement); for example, it uses breadth-first search where depth-first search would do fine, and there is this “tau” optimisation. But changing this to be simpler means more unknowns, trading one kind of debugging for another.

Now, although regrettably people do not include actual program code in reports that are (partially) about programs, these days we have the interwebs. Looking around on the author’s (old) webpages reveals a link to code, and that link even isn’t dead. This program (“rsabits”) doesn’t take a degraded key as input; it takes a key and molests it itself. So we add code to do input a mask of known bits (and fix a bug that makes it think some invalid results are valid). And that was all it took.

Let’s write code anyway!

After the contest I had to write this writeup. I decided to try my hand at implementing the (simplified, optimised) algorithm anyway, to see if it was a good estimate that debugging this would take long. It was a good estimate. Eventually I finished; see this.

Let’s write some math instead

We’ll ignore qp, it’s not convenient to work with. For the other quantities we have these relations (from the definition of RSA):

pq = N
de = 1 (mod (p1)(q1))
dpe = 1 (mod p1)
dqe = 1 (mod q1)

It is advantageous to introduce new variables k, kp, kq, giving the integer relations we’ll work with in the following:

pq = N
de = 1 + k(p1)(q1)
dpe = 1 + kp(p1)
dqe = 1 + kq(q1)

Now assume we are given k, kp, kq, and we know the low b bits of all of p, q, d, dp, dq. Then if we guess the next bit of p these relations in sequence give the next bit of the other four. So we’ll try all values for the bit of p, fill in the bit for the other values, and recurse – backtracking as soon as a bit of any of the quantities does not match its desired value.

Actual implementation can be sped up a lot by keeping track of the left and right side of these equations at every step, and by shifting out those known bits. This does make things more complex of course :-P

Dotting the k’s

We still need to find k, kp, kq, or at least fewer candidates for them than the about 248 brute force gives us. Then for every candidate we run the backtracking algo, and presto.

Saying b for the number of bits of p and q we have p+q < 2b+1, so (with the first and second relation) we have 0 < kNde = k(p + q1)1 < 2b+1k, so (kN/e)2b+1 < d < (kN/e). This says that all but the low b + 1 bits of d are equal to those of (kN/e), except perhaps when the lowest of those bits of (kN/e) are zeroes in which case carries can happen. We could check for those, or do as rsabits does and simply disregard a whole bunch of low bits.

Try all possible k, see which result in possible d. For all remaining k find kp and kq.

To do that, consider the second relation modulo e. This reads (after some shuffling) p + q = N + 1 + k1 (mod e), so given p mod e we find q mod e. Most of those p, q don’t satisfy pq = N (mod e). With the remaining few derive kp via the third relation (and kq from the fourth); simply try out all possibilities, see if the equation holds. Computing a modular inverse instead is a lot more code that can go wrong and which needs non-trivial testing, and it’s not really faster (64k times essentially nothing is still essentially nothing).

And that’s it. If you want to see all tricky details, look at the code :-)

[Blaggified by jix. Thanks!]

// Copyright 2014  Segher Boessenkool  <segher@kernel.crashing.org>
// This code is licensed to you under the terms of the GNU GPL, version 2;
// see http://www.gnu.org/licenses/old-licenses/gpl-2.0.txt


//
// eprint 2008/510
//


#include <stdlib.h>
#include <stdio.h>


// The number of bits in p and q; this should be a multiple of 64.
#define pbits 512
#define pwords pbits/64

// We'll use 64-bit limbs.  Only fancy machines need apply.
typedef unsigned int u64 __attribute__((mode(DI)));
typedef unsigned int u128 __attribute__((mode(TI)));


// Low word is low, low bit is low.
static u64 n[2*pwords];

static u64 e = 65537;

static struct {
        u64 p[2*pwords];
        u64 q[2*pwords];
        u64 d[2*pwords];
        u64 d_p[2*pwords];
        u64 d_q[2*pwords];
} dat, known, mask;

static u64 k, k_p, k_q;


// We build up p bit by bit.  We then track:
static struct {
        u64 nn[pwords+1];       // p*q div 2**bit
        u64 d1[pwords+1];       // 1+k(p-1)(q-1) div 2**bit
        u64 dp1;                // 1+k_p(p-1) div 2**bit
        u64 dq1;                // 1+k_q(q-1) div 2**bit
        u64 d2;                 // d*e div 2**bit
        u64 dp2;                // d_p*e div 2**bit
        u64 dq2;                // d_q*e div 2**bit
} search;

static u64 bit;


// Some bignum utility things.

static u64 getbit(u64 *x, u64 n)
{
        return (x[n/64] >> (n%64)) & 1;
}

static void setbit(u64 *x, u64 n)
{
        x[n/64] |= (u64)1 << (n%64);
}

static void clrbit(u64 *x, u64 n)
{
        x[n/64] |= (u64)1 << (n%64);
        x[n/64] ^= (u64)1 << (n%64);
}

// d has an extra limb.
static void bn_add(u64 *d, u64 *a, u64 n)
{
        u64 j;

        u128 c = 0;
        for (j = 0; j < n; j++) {
                c = c + d[j] + a[j];
                d[j] = c;
                c >>= 64;
        }

        d[n] += c;
}

static void bn_sub(u64 *d, u64 *a, u64 n)
{
        u64 j;

        u128 c = 1;
        for (j = 0; j < n; j++) {
                c = c + d[j] + ~a[j];
                d[j] = c;
                c >>= 64;
        }

        d[n] += c - 1;
}

static void bn_double(u64 *d, u64 n)
{
        u64 j;

        u128 c = 0;
        for (j = 0; j < n; j++) {
                c = c + d[j] + d[j];
                d[j] = c;
                c >>= 64;
        }
}

static void bn_halve(u64 *d, u64 n)
{
        u64 j;

        u128 c = 0;
        for (j = n - 1; j < n; j--) {
                c = (c << 64) + d[j];
                d[j] = c >> 1;
        }
}

// d += k*(a-1); a is odd, d has an extra limb.
static void bn_update_de(u64 *d, u64 *a, u64 k, u64 n)
{
        u64 j;

        u128 c = -(u128)k;
        for (j = 0; j < n; j++) {
                c = c + d[j] + (u128)k*a[j];
                d[j] = c;
                c >>= 64;
        }

        d[n] += c;
}

static void bn_update_de_m(u64 *d, u64 *a, u64 k, u64 n)
{
        u64 j;

        u128 c = 2*k;
        for (j = 0; j < n; j++) {
                c = c + d[j] + (u128)k*~a[j];
                d[j] = c;
                c >>= 64;
        }

        d[n] += c - k;
}

static void bn_print(u64 *d, u64 n)
{
        u64 j, k;

        for (j = n - 1; j < n; j--) {
                u64 x = d[j];

                for (k = 0; k < 16; k++) {
                        fputc("0123456789abcdef"[x >> 60], stdout);
                        x <<= 4;
                }
        }

        fputc('\n', stdout);
}

// This is where the real work happens.
static void try(void)
{
        // Set the bit of d, d_p, d_q if it is needed for its equation
        // to hold.
        if ((search.d1[0] ^ search.d2) & 1) {
                setbit(dat.d, bit);
                search.d2 += e;
        }

        if ((search.dp1 ^ search.dp2) & 1) {
                setbit(dat.d_p, bit);
                search.dp2 += e;
        }

        if ((search.dq1 ^ search.dq2) & 1) {
                setbit(dat.d_q, bit);
                search.dq2 += e;
        }


        // If any bit is prescribed and we don't match that value, prune.
        if (getbit(mask.p, bit)
        && getbit(dat.p, bit) != getbit(known.p, bit))
                goto out;

        if (getbit(mask.q, bit)
        && getbit(dat.q, bit) != getbit(known.q, bit))
                goto out;

        if (getbit(mask.d, bit)
        && getbit(dat.d, bit) != getbit(known.d, bit))
                goto out;

        if (getbit(mask.d_p, bit)
        && getbit(dat.d_p, bit) != getbit(known.d_p, bit))
                goto out;

        if (getbit(mask.d_q, bit)
        && getbit(dat.d_q, bit) != getbit(known.d_q, bit))
                goto out;


        // Save our low bits, we'll need them back when backtracking.
        u64 nn_low = search.nn[0] & 1;
        u64 d_low = search.d2 & 1;
        u64 dp_low = search.dp2 & 1;
        u64 dq_low = search.dq2 & 1;


        bit++;


        // Shift the low bits out.
        bn_halve(search.nn, pwords+1);
        bn_halve(search.d1, pwords+1);
        search.d2 /= 2;
        search.dp1 /= 2;
        search.dp2 /= 2;
        search.dq1 /= 2;
        search.dq2 /= 2;


        // Oh happy days.
        if (bit == 2*pbits) {
                printf("=== FOUND IT ===\n");

                bn_print(n, 2*pwords);
                bn_print(dat.p, pwords);
                bn_print(dat.q, pwords);
                bn_print(dat.d, 2*pwords);
                bn_print(dat.d_p, pwords);
                bn_print(dat.d_q, pwords);

                goto found;
        }


        // Should the current bits of p and q be equal?  This holds exactly
        // then if the bit of the n we build up matches the real n.
        int pqequal = ((search.nn[0] & 1) == getbit(n, bit));

        // p=0 q=0, recurse.
        if (pqequal)
                try();

        // p=1 q=0, recurse.
        setbit(dat.p, bit);
        bn_add(search.nn, dat.q, pwords);
        bn_update_de(search.d1, dat.q, k, pwords);
        search.dp1 += k_p;

        if (!pqequal)
                try();

        // p=1 q=1, recurse.
        setbit(dat.q, bit);
        bn_add(search.nn, dat.p, pwords);
        bn_update_de(search.d1, dat.p, k, pwords);
        search.dq1 += k_q;

        if (pqequal)
                try();

        // p=0 q=1, recurse.
        clrbit(dat.p, bit);
        bn_sub(search.nn, dat.q, pwords);
        bn_update_de_m(search.d1, dat.q, k, pwords);
        search.dp1 -= k_p;

        if (!pqequal)
                try();

        // And finally restore to pristine state.
        clrbit(dat.q, bit);
        bn_sub(search.nn, dat.p, pwords);
        bn_update_de_m(search.d1, dat.p, k, pwords);
        search.dq1 -= k_q;


found: ;
        // Put the low bits back.
        bn_double(search.nn, pwords+1);
        search.nn[0] += nn_low;
        bn_double(search.d1, pwords+1);
        search.d1[0] += d_low;
        search.d2 = 2*search.d2 + d_low;
        search.dp1 = 2*search.dp1 + dp_low;
        search.dp2 = 2*search.dp2 + dp_low;
        search.dq1 = 2*search.dq1 + dq_low;
        search.dq2 = 2*search.dq2 + dq_low;


        bit--;


out: ;
        // Restore d, d_p, d_q.
        if (getbit(dat.d, bit)) {
                clrbit(dat.d, bit);
                search.d2 -= e;
        }

        if (getbit(dat.d_p, bit)) {
                clrbit(dat.d_p, bit);
                search.dp2 -= e;
        }

        if (getbit(dat.d_q, bit)) {
                clrbit(dat.d_q, bit);
                search.dq2 -= e;
        }

        return;
}


// Try everything for a fixed k.
static void try_k(void)
{
        // Compute k/e*n.  Let's use the dumb and simple way, with
        // a big fat division.
        u64 d[2*pwords];
        u64 j;

        u128 c = 0;
        for (j = 0; j < 2*pwords; j++) {
                c += (u128)k*n[j];
                d[j] = c;
                c >>= 64;
        }
        for (j = 2*pwords -  1; j < 2*pwords; j--) {
                c = (c << 64) + d[j];
                d[j] = c / e;
                c %= e;
        }

        // This should agree with d in the top half minus one of the bits,
        // except when carries happen.  Leave a full word of slack, that
        // should be enough for all practical purposes.  We *could* also
        // detect which bits can carry and not compare them.
        for (j = pwords + 1; j < 2*pwords; j++)
                if ((d[j] ^ known.d[j]) & mask.d[j])
                        return;


        // Our k is good.  Now find p, q, k_p, k_q mod e.

        // Find k^{-1} mod e.
        u64 kinv;
        for (kinv = 1; kinv < e; kinv++)
                if (kinv * k % e == 1)
                        break;

        // Find n mod e.  This uses the fact that 2^64 = 1 mod e.
        u64 ne = 0;
        for (j = 0; j < 2*pwords; j++)
                ne += (n[j] % e);
        ne %= e;

        // de = 1 + k(p-1)(q-1)  so  p + q = n + 1 + k^{-1} mod e.
        // Loop over all such pairs p,q mod e.
        u64 pe, qe = (kinv + ne) % e;
        for (pe = 1; pe < e; pe++) {
                // Does that satisfy pq = n?
                if ((pe*qe) % e == ne) {
                        // Find k_p, k_q.
                        // d_p e = 1 + k_p(p-1)  so  k_p p = k_p - 1 mod e.
                        for (k_p = 1; k_p < e; k_p++)
                                if ((k_p*pe) % e == k_p - 1)
                                        break;

                        for (k_q = 1; k_q < e; k_q++)
                                if ((k_q*qe) % e == k_q - 1)
                                        break;

//printf("k=%d k_p=%d k_q=%d\n", (int)k, (int)k_p, (int)k_q);

                        // Both p and q start at 1.
                        dat.p[0] = dat.q[0] = 1;
                        search.nn[0] = search.d1[0] = search.dp1 = search.dq1 = 1;

                        // Work!
                        try();
                }

                qe = (qe + e - 1) % e;
        }
}


// All I/O is ugly.
static void read_number(u64 *num, u64 nwords)
{
        u64 j, k;
        for (k = nwords - 1; k < nwords; k--) {
                for (j = 0; j < 16; j++) {
                        u64 c = fgetc(stdin);
                        if (c >= '0' && c <= '9')
                                num[k] = 16*num[k] + c - '0';
                        else if (c >= 'a' && c <= 'f')
                                num[k] = 16*num[k] + c - 'a' + 10;
                        else
                                goto fail;
                }
        }
        if (fgetc(stdin) != '\n')
fail:
        {
                fprintf(stderr, "bad input, you lose.\n");
                exit(1);
        }
}

int main(void)
{
        read_number(n, 2*pwords);

        read_number(known.p, pwords);
        read_number(known.q, pwords);
        read_number(known.d, 2*pwords);
        read_number(known.d_p, pwords);
        read_number(known.d_q, pwords);

        read_number(mask.p, pwords);
        read_number(mask.q, pwords);
        read_number(mask.d, 2*pwords);
        read_number(mask.d_p, pwords);
        read_number(mask.d_q, pwords);

        u64 j;
        for (j = pwords; j < 2*pwords; j++) {
                mask.p[j]--;
                mask.q[j]--;
                mask.d_p[j]--;
                mask.d_q[j]--;
        }


        for (k = 1; k < e; k++)
                try_k();


        return 0;
}