मुख्य कंटेंट तक स्किप करें

Least Common Genominator - GoogleCTF 2023

Someone used this program to send me an encrypted message but I can't read it! It uses something called an LCG, do you know what it is? I dumped the first six consecutive values generated from it but what do I do with it?!

challenge.py
from secret import config
from Crypto.PublicKey import RSA
from Crypto.Util.number import bytes_to_long, isPrime

class LCG:
lcg_m = config.m
lcg_c = config.c
lcg_n = config.n

def __init__(self, lcg_s):
self.state = lcg_s

def next(self):
self.state = (self.state * self.lcg_m + self.lcg_c) % self.lcg_n
return self.state

if __name__ == '__main__':

assert 4096 % config.it == 0
assert config.it == 8
assert 4096 % config.bits == 0
assert config.bits == 512

# Find prime value of specified bits a specified amount of times
seed = 211286818345627549183608678726370412218029639873054513839005340650674982169404937862395980568550063504804783328450267566224937880641772833325018028629959635
lcg = LCG(seed)
primes_arr = []

dump = True
items = 0
dump_file = open("dump.txt", "w")

primes_n = 1
while True:
for i in range(config.it):
while True:
prime_candidate = lcg.next()
if dump:
dump_file.write(str(prime_candidate) + '\n')
items += 1
if items == 6:
dump = False
dump_file.close()
if not isPrime(prime_candidate):
continue
elif prime_candidate.bit_length() != config.bits:
continue
else:
primes_n *= prime_candidate
primes_arr.append(prime_candidate)
break

# Check bit length
if primes_n.bit_length() > 4096:
print("bit length", primes_n.bit_length())
primes_arr.clear()
primes_n = 1
continue
else:
break

# Create public key 'n'
n = 1
for j in primes_arr:
n *= j
print("[+] Public Key: ", n)
print("[+] size: ", n.bit_length(), "bits")

# Calculate totient 'Phi(n)'
phi = 1
for k in primes_arr:
phi *= (k - 1)

# Calculate private key 'd'
d = pow(config.e, -1, phi)

# Generate Flag
assert config.flag.startswith(b"CTF{")
assert config.flag.endswith(b"}")
enc_flag = bytes_to_long(config.flag)
assert enc_flag < n

# Encrypt Flag
_enc = pow(enc_flag, config.e, n)

with open ("flag.txt", "wb") as flag_file:
flag_file.write(_enc.to_bytes(n.bit_length(), "little"))

# Export RSA Key
rsa = RSA.construct((n, config.e))
with open ("public.pem", "w") as pub_file:
pub_file.write(rsa.exportKey().decode())

Solution

solve.py
from Crypto.PublicKey import RSA
from Crypto.Util.number import long_to_bytes,isPrime,inverse,bytes_to_long
from math import gcd

lcg_c,lcg_m,lcg_n = 0,0,0

class LCG:
def __init__(self, lcg_s):
self.state = lcg_s

def next(self):
self.state = (self.state * lcg_m + lcg_c) % lcg_n
return self.state


with open("public.pem", "r") as pub_file:
key = RSA.importKey(pub_file.read())

with open("flag.txt", "rb") as flag_file:
enc_flag = flag_file.read()

with open("dump.txt", "r") as dump_file:
primes = dump_file.read().split('\n')

# print key parameters
n = key.n
e = key.e

seed = 211286818345627549183608678726370412218029639873054513839005340650674982169404937862395980568550063504804783328450267566224937880641772833325018028629959635
primes = [seed]+[2166771675595184069339107365908377157701164485820981409993925279512199123418374034275465590004848135946671454084220731645099286746251308323653144363063385,6729272950467625456298454678219613090467254824679318993052294587570153424935267364971827277137521929202783621553421958533761123653824135472378133765236115,2230396903302352921484704122705539403201050490164649102182798059926343096511158288867301614648471516723052092761312105117735046752506523136197227936190287,4578847787736143756850823407168519112175260092601476810539830792656568747136604250146858111418705054138266193348169239751046779010474924367072989895377792,7578332979479086546637469036948482551151240099803812235949997147892871097982293017256475189504447955147399405791875395450814297264039908361472603256921612,2550420443270381003007873520763042837493244197616666667768397146110589301602119884836605418664463550865399026934848289084292975494312467018767881691302197]

Ts = [primes[i+1] - primes[i] for i in range(len(primes)-1)]
Us = [abs(Ts[i+2] * Ts[i] - Ts[i+1] * Ts[i+1]) for i in range(len(Ts)-2)]

lcg_n = gcd(*Us)
lcg_m = (Ts[1] * inverse(Ts[0], lcg_n)) % lcg_n
lcg_c = (primes[1]-primes[0]*lcg_m) % lcg_n


lcg = LCG(seed)
primes_arr = []
items=0
primes_n = 1
dump = True
while True:
for i in range(8):
while True:
prime_candidate = lcg.next()
if dump:
print(prime_candidate)
items += 1
if items == 6:
dump = False
dump_file.close()
if not isPrime(prime_candidate):
continue
elif prime_candidate.bit_length() != 512:
continue
else:
primes_n *= prime_candidate
primes_arr.append(prime_candidate)
break

# Check bit length
if primes_n.bit_length() > 4096:
print("bit length", primes_n.bit_length())
primes_arr.clear()
primes_n = 1
continue
else:
break


phi = 1
for k in primes_arr:
phi *= (k - 1)
# Calculate private key 'd'
d = pow(e, -1, phi)

flag = pow(int.from_bytes(enc_flag, "little"), d, n)
print(long_to_bytes(flag))