본문 바로가기

CTF

[*CTF Writeup] Crypto - MyEnc

문제 설명

from Crypto.Util.number import getPrime,bytes_to_long
import time,urandom
from flag import flag
iv=bytes_to_long(urandom(256))
assert len(flag)==15
keystream=bin(int(flag.encode('hex'),16))[2:].rjust(8*len(flag),'0')
p=getPrime(1024)
q=getPrime(1024)
n=p*q
print "n:",n
cnt=0
while True:
    try:
        print 'give me a number:'
        m=int(raw_input())
    except:
        break
    ct=iv
    for i in range(1,8):
        if keystream[cnt]=='1':
            ct+=pow(m^q,i**i**i,n)
            ct%=n
        cnt=(cnt+1)%len(keystream)
    print "done:",ct

n=p*q가 있고, 120bit flag에 stream 형식으로 쿼리를 날릴 수 있다. 구체적으로는 쿼리로 수 m을 주면, 프로그램 시작할때 정해지는 initial vector에 7비트씩 끊어서 i번째 비트가 1인 경우에만 m xor q의 i ^ i ^ i 제곱을 더한다(mod n). 실험해보니 쿼리는 대강 몇백번정도 날릴 수 있었다.

문제 풀이

핵심 아이디어는 i ^ i ^ i가 i가 작은 경우에는 꽤나 작다는 것이다. m=0이라고 하자. 그러면 n이 p * q이고 p와 q가 각각 1024bit prime이기 때문에 i = 1, 2 정도에서만 비트가 1인 경우에는 결과값이 모듈러보다 작아서 계산 결과가 정확히 iv + q^t 일 것이라고 추측해볼 수 있다. 따라서 100번 정도 이를 반복하면 keystream중에 낮은 자리의 비트만 켜져있는 경우의 결과를 수집할 수 있고, 이 경우에는 확률적으로 보았을 때 모듈러를 취한 결과값(2048bit에 비례)보다 모듈러를 취하지 않은 결과값(1024bit + q^t 가량)이 작을 것이므로 수집한 모든 결과 중 최솟값을 min이라고 하면 두 번째 최솟값과 세 번째 최솟값에서 각각 min을 뺀 다음 gcd를 취할 수 있다. 그러면 두 번째 최솟값과 세 번째 최솟값은 둘 다 q^a - q^b 꼴일 것이니 gcd값이 q인 것을 알 수 있다. (사실 p와 q 간에 WLOG가 없기 때문에 p일수도 있으나, n을 알려주기 때문에 n / q로 한 번 더 아래 과정을 시행해보면 된다.) RSA gcd attack의 응용이라고 생각해도 무방하다.

이제 q를 아니, keystream을 복원해야 한다. 쿼리를 날린 횟수를 알면 현재 keystream의 어느 위치를 보고 있는지도 알고 있으니 거기서부터 시작하자. 먼저 iv를 알아내야 하는데, 이는 m=q로 두면 m xor q = 0이 되어 결과값이 반드시 iv가 나오게 된다. 그 후 m=q xor 2로 두면(사실 이는 불필요하며, 0으로 두어도 똑같이 풀린다) 서버는 pow(2, i ** i ** i, n)로 기존 과정을 반복하게 된다. 가능한 모든 7bit 후보에 대해 결과값을 계산해놓은 다음 일치하는 것을 확인하면 7bit씩 끊어서 원래 flag를 복호화할 수 있다.

주의할 점: gcd attack은 확률적으로 실패한다. 예외처리를 해 두자. pow 함수의 시간복잡도가 sublinear하긴 하지만 지수가 테트레이션이라 나이브하게 계산하려면 상상을 초월하게 오래 걸린다. 미리 전처리해두자. (모 한국의 강력한 크립토 해커는 오일러 정리를 사용하라고 했지만, 전처리해도 사실 reasonable한 시간에 해결 가능하다.)

소스 코드

rm = remote('52.163.228.53', 8081)

idx = 0
n = 0


def bypass():
    dump = string.ascii_lowercase + string.ascii_uppercase + string.digits
    r = rm.recvline()
    y = r[12:12 + 16]
    ans = r[33:-1].decode()

    rm.recvline()
    H = SHA256.new()

    query = None
    for x in tqdm(product(dump, repeat=4)):
        sx = ""
        for t in x:
            sx += t
        sx = sx.encode()
        tr = sx + y

        hash = hashlib.sha256(tr).hexdigest()
        if hash == ans:
            print("Passed.")
            break
    rm.sendline(sx)

def query(x):
    global idx
    rm.recvline()
    rm.sendline(x)
    r = rm.recvline()[6:-1]
    idx = (idx + 7) % 120
    return int(r.decode())

def subexploit(q):
    print("subexploit begin with", q)
    global n
    global idx
    iv = query(str(q).encode())
    cx = 0
    l = [-1 for _ in range(120)]

    D = {}
    debug_cnt = 0
    powcal = []
    for i in range(1, 8):
        powcal.append(pow(2, i ** i ** i, n))
    for it in product([0, 1], repeat=7):
        debug_cnt += 1
        # print(debug_cnt)
        v = iv
        for i in range(len(it)):
            if it[i] == 1:
                v += powcal[i]
                v %= n
        D[v] = it

    try:
        while True:
            st = idx
            r = query(str(q ^ 2).encode())
            r = D[r]
            for i in range(7):
                l[(st + i) % 120] = r[i]
                cx += 1
                if cx == 120: break
            if cx == 120: break

        ans_str = ""
        for i in range(120):
            ans_str += str(l[i])
        ans = int(ans_str, 2)
        print(long_to_bytes(ans))
    except:
        print("Exception!!")
        return

def exploit():
    global n
    rm.recvuntil(b"n: ")
    ret = rm.recvline()
    n = int(ret[:-1])

    S = []
    for _ in range(100):
        S.append(query(b'0'))
    S = list(set(S))
    S = sorted(S)

    mi = S[0]
    S = [a - mi for a in S]
    g = gcd(S[1], S[2])

    # checks g is real p or q
    if len(long_to_bytes(g)) < 127 or len(long_to_bytes(g)) > 128:
        print("Something Strange...")
        return
    if len(long_to_bytes(n // g)) < 127 or len(long_to_bytes(n // g)) > 128:
        print("Something is also Strange...")
        return

    subexploit(g)
    subexploit(n // g)


if __name__ == '__main__':
    bypass()
    exploit()