Jay Bosamiya Software Security Researcher

RSA Chained (Dragon CTF Teaser 2019)

In this challenge, we need to recover a message that is encrypted through 4 different RSA keys, while knowing some of the bits of the private keys. In particular, we are given code that generates 4 different RSA keys (of ~2100 bits each), permutes them, encrypts the flag by each of them in succession, and then provides us the encrypted flag. Additionally, we are given the moduli of the keys, as well as the lower half (i.e., least significant 1050 bits) of the private keys. These are given to us in output.txt.

The only other info that is given to us is:

Keys are generated in the standard way with the default setup...

Let's first sketch out how the keys are generated. We name everything using SSA-style (i.e., give new names each time we reassign a variable to a new value) so that we can better keep track of things:

RSA Modulus Use Size (Bits)
n1 p1 q1 700 + 1400
n2 p2 q2 r2 700 + 700 + 700
n3 p3 q3 r2 700 + 700 + 700
n4 p4 q4 q4 700 + 700 + 700

Notice how we have a repeated r2? This will help us soon.

Another thing to note here is that n1, n2, n3, n4 are not in the order we get them in output.txt – the order we get is in sorted order.

Let's call the keys we get as ns[0] .. ns[3], to keep them separate from n1, … n4. Also, let's called the partial private keys as nerf_ds[0] .. nerf_ds[3], since we have only a part of them.

import itertools
import gmpy2

with open('./output.txt') as f:
    data = f.read().strip().split('\n')

keys = map(lambda x: map(int, x.split(' ')[1:]), data[0:4])
enc_flag = int(data[4].split(' ')[2])

ns = map(lambda x: x[0], keys)
nerf_ds = map(lambda x: x[1], keys)

Recall that there was a repeated r2. We can recover this, and also figure out part of the permutations by simply performing a pairwise GCD of the different moduli.

for a, b in itertools.product(ns, ns):
    if a >= b:
        continue
    gcd = gmpy2.gcd(a, b)
    if gcd > 1:
        r2 = gcd
        print str(r2)[:30] + '...'

n2n3 = []
for i, a in enumerate(ns):
    if a % r2 == 0:
        n2n3.append(a)
        print i
326199725504488859529927636344...
0
1

Now we know that ns[0] and ns[1] are n2 and n3 (we can, for now, just arbitrarily pick the order – we'll fix it later if necessary).

This also tells us that ns[2] and ns[3] are n1 and n4 although again we don't know which is which.


Now let's get to the nerfed private keys.

Since we have the lower half of the private keys, we should be able to get something, right?

Turns out, there is a nice attack when you know the LSBs of the private key: Coppersmith Attack!

Let's re-derive it from scratch, first for the p*q case (i.e., n1), since we'll need to understand its derivation to be able to use it for n2, n3 and n4.

e*d = 1 (mod phi)
e*d - 1 = k * phi
e*d - 1 = k * ((p-1) * (q-1))
e*d - 1 = k * (p*q - p - q + 1)
e*d - 1 = k * (N - p - q + 1)
e*d*p - q = k * p * (N - p - q + 1)
e*d*p - k*p*(N - p + 1) + k*N = p
e*d*p - k*p*(N - p + 1) + k*N = p (mod M)
(k)*p2 + (d*e - k*N - k - 1)*p + k*N = 0 (mod M)

In the above, M means 2^1050 which refers to the fact that we only know the lower 1050 bits of the private key. Since the final equation is modulo M, we actually know all of the values for the quadratic equation in p, if we simply pick a value for k.

Since we know that 0 < d < phi, we know that 0 < k <= e. Conveniently, e is only 1667, which means we only need to try 1667 such quadratic equations.

One small setback- we need to compute the quadratic equation modulo M. We could use Sage's solve_mod but for some reason I wasn't able to really get it to work well (not entirely sure why) so I decided to implement it myself. What we need to know for this is Hensel's Lemma. More particularly, one if its consequences: if we can compute the root of a single-variable function f modulo some power of a prime, we can easily compute the root modulo the next power. There are indeed more powerful things we can do with this, but this requires more specific conditions that we don't really have, so let's simply use this fact for now. I ended up implementing a fairly simply higher-order function that computes these roots using the most basic way I could think of doing this:

# Compute roots of f modulo p^k
def roots_by_hensel_lift(f, p, k):
    assert k >= 1

    def M(x):
        return x % (p ** k)

    results = set()
    if k == 1:
        for i in xrange(0, p):
            if M(f(i)) == 0:
                results.add(i)
        return results
    assert k > 1
    for r in roots_by_hensel_lift(f, p, k - 1):
        for i in xrange(0, p):
            if M(f(r + ((p ** (k - 1)) * i))) == 0:
                results.add(r + ((p ** (k - 1)) * i))
    return results

Unfortunately, if we use it for anything larger than some depth, we'll hit the default recursion limit for Python, so let's change it:

from sys import setrecursionlimit
setrecursionlimit(10000)

Now we can easily go over each of the keys that we have, and quickly factor them, as long as they are of the form p * q (i.e., n1). Since this must be either ns[2] or ns[3], let's only run it on these two.

def factor_pq(N, partial_d):
    e = 1667
    for k in range(1, e + 1):
        A = k
        B = (partial_d * e) - k * N - k - 1
        C = k * N

        def f(x):
            return (A * x * x) + (B * x) + C

        solns = roots_by_hensel_lift(f, 2, 1050)

        for possible_ans in solns:
            if N % possible_ans == 0:
                return possible_ans

p1 = factor_pq(ns[2], nerf_ds[2])
if p1 is not None:
    print "Found n1 in ns[2]"
    n1 = ns[2]
else:
    p1 = factor_pq(ns[3], nerf_ds[3])
    assert p1 is not None
    print "Found n1 in ns[3]"
    n1 = ns[3]
print "p1 =", str(p1)[:30] + '...'
assert n1 % p1 == 0
assert n1 != p1
q1 = n1 / p1
print "q1 =", str(q1)[:30] + '...'
Found n1 in ns[3]
p1 = 188689169745401648234984799686...
q1 = 224327615705283723685555998413...

It takes a minute or so to find the above answer. But yeah, now we have factorized ns[3] and also just learnt that ns[2] is n4.

Let's tackle that now, shall we?


Since n4 is of the form pqq, we need to re-derive an attack using the same technique we used before (Coppersmith's attack). The only difference here is that we will try to obtain an equation in q simply because it is easier. Additionally, this equation turns out to be a cubic equation, but that's ok for us, since Hensel's lemma works here too.

e*d = 1 (mod phi)
e*d - 1 = k * phi
e*d - 1 = k * ((p-1) * (q-1) * q)
e*d - 1 = k * (p*q*q - q*q - p*q + q)
e*d - 1 = k * (N - q*q - p*q + q)
e*d*q - q = k * q * (N - q*q - p*q + q)
e*d*q - k*q*(N - q*q + q) + k*N = q
e*d*q - k*q*(N - q*q + q) + k*N = q (mod M)
(k)*q3 - (k)*q2 + (d*e - k*N - 1)*q + k*N = 0 (mod M)

Now all we've got to do is dump this equation into the kind of code we used before, using the handy higher-order function roots_by_hensel_lift that we wrote earlier:

def factor_pqq(N, partial_d):
    e = 1667
    for k in range(1, e + 1):
        A = k
        B = -k
        C = (partial_d * e) - k * N - 1
        D = k * N

        def f(x):
            return (A * x * x * x) + (B * x * x) + (C * x) + D

        solns = roots_by_hensel_lift(f, 2, 1050)

        for possible_ans in solns:
            if N % possible_ans == 0:
                return possible_ans

n4 = ns[2]
q4 = factor_pqq(ns[2], nerf_ds[2])
assert q4 is not None
assert n4 % q4 == 0
assert n4 % (q4 * q4) == 0
p4 = n4 / (q4 * q4)
print "p4 =", str(p4)[:30] + '...'
print "q4 =", str(q4)[:30] + '...'
p4 = 220432465124408995064591975926...
q4 = 267307309343866797026967908679...

This too takes a little while to compute (in the order of about a minute), but soon we know the factorization of n4 too.


Now all that's left is for us to do the same for n2 and n3. Recall that these are of the form p * q * r, but we know r due to the shared r that we computed earlier.

Let's just re-derive a similar attack as before. We reduce the final equation down to a quadratic equation in p.

e*d = 1 (mod phi)
e*d - 1 = k * phi
e*d - 1 = k * ((p-1) * (q-1) * (r-1))
e*d - 1 = k * (p*q*r - p*q - p*r - q*r + p + q + r - 1)
e*d - 1 = k * (N - p*q - p*r - q*r + p + q + r - 1)
e*d*p*r - p*r = k*p*r*(N - p*q - p*r - q*r + p + q + r - 1)
(k*r*r - k*r)*p2 + (k*N - r + d*e*r + k*r - k*N*r - k*r*r) + (k*N*r - k*N) = 0
(k*r*r - k*r)*p2 + (k*N - r + d*e*r + k*r - k*N*r - k*r*r) + (k*N*r - k*N) = 0 (mod M)

Since we know r it is ok to keep around. Let's feed this into similar code as before, again reusing our handy roots_by_hensel_lift.

Recall that we don't know the order for which one is n2 and which one is n3. Let's just stick to one order and fix it later if it turns out to be an issue.

def factor_pqr(N, partial_d, r):
    e = 1667
    for k in range(1, e + 1):
        A = k*r*r - k*r
        B = k*N - r + partial_d*e*r + k*r - k*N*r - k*r*r
        C = k*N*r - k*N

        def f(x):
            return (A * x * x) + (B * x) + C

        solns = roots_by_hensel_lift(f, 2, 1050)

        for possible_ans in solns:
            if N % possible_ans == 0:
                return possible_ans


n2 = ns[0]
assert n2 % r2 == 0
p2 = factor_pqr(ns[0], nerf_ds[0], r2)
assert n2 % (p2 * r2) == 0
q2 = n2 / (p2 * r2)

n3 = ns[1]
r3 = r2
assert n3 % r3 == 0
p3 = factor_pqr(ns[1], nerf_ds[1], r3)
assert n3 % (p3 * r3) == 0
q3 = n3 / (p3 * r3)

print "p2 =", str(p2)[:30] + '...'
print "q2 =", str(q2)[:30] + '...'
print "r2 =", str(r2)[:30] + '...'
print "p3 =", str(p3)[:30] + '...'
print "q3 =", str(q3)[:30] + '...'
print "r3 =", str(r3)[:30] + '...'
p2 = 902985578846825776692383207600...
q2 = 291668652611471250039066078554...
r2 = 326199725504488859529927636344...
p3 = 142270506848638924547091203976...
q3 = 282595361018796512312481928903...
r3 = 326199725504488859529927636344...

This takes a tad bit longer, but works out, and then suddenly, now we have all the factorized keys. Let's compute all the private keys.

e = 1667
d1 = gmpy2.powmod(e, -1, (p1-1)*(q1-1))
d2 = gmpy2.powmod(e, -1, (p2-1)*(q2-1)*(r2-1))
d3 = gmpy2.powmod(e, -1, (p3-1)*(q3-1)*(r3-1))
d4 = gmpy2.powmod(e, -1, (p4-1)*(q4-1)*q4)

We also know that the order of the encryptions is n2, n3, n4, n1. This means we should decrypt in the opposite order:

flag = enc_flag
flag = pow(flag, d1, n1)
flag = pow(flag, d4, n4)
flag = pow(flag, d3, n3)
flag = pow(flag, d2, n2)
print hex(flag)[:30] + '...'
0x4472676e537b77335f6669583364...

That looks promising. Let's just dump the flag now.

print hex(flag).replace('0x', '').decode('hex')
DrgnS{w3_fiX3d_that_f0r_y0U}

And there we go!


Overall, this was a super fun challenge to solve. Personally, I think I learnt quite a lot and feel I've got a deeper understanding of both the Coppersmith attack, as well as Hensel's Lemma.

For those wondering, during the CTF, we didn't have such a clean solution. n1 was solved by a custom solver written in Python by me. n2 and n3 were solved by a custom solver written in Sage by @ubuntor (since solve_mod didn't work in Sage for some reason). The only part that matches the above code is the solution for n4 which I wrote last after understanding Hensel's lemma much better. Now that I had the clean version for n4 though, I couldn't resist cleaning the rest of the solution up to reuse the nice roots_by_hensel_lift function, and thus this writeup :)