Skip to content
rsa.py 6.47 KiB
Newer Older
Leonard Marschke's avatar
Leonard Marschke committed
# pylint: disable=invalid-name,too-many-return-statements
import argparse
Leonard Marschke's avatar
Leonard Marschke committed
import secrets
import random
import time

import gmpy2

INT_BYTES_ARGS = {
    'byteorder': 'little',
    'signed': False,
}
KEY_BYTES = 128


class Config:
    prime1 = None
    prime2 = None
    privateKey = None
    publicKey = None
    fPath = None

    def __init__(self, path):
        self.fPath = path
        try:
            with open(path) as f:
                lines = f.readlines()
        except IOError:
            return

        for key, val in enumerate(lines):
            if key == 0:
                self.prime1 = int(val)
            elif key == 1:
                self.prime2 = int(val)
            elif key == 2:
                self.publicKey = int(val)
            elif key == 3:
                self.privateKey = int(val)

    def valid(self):
        if not isinstance(self.prime1, int):
            return False

        if not isinstance(self.prime2, int):
            return False

        if not isinstance(self.privateKey, int):
            return False

        if not isinstance(self.publicKey, int):
            return False

        return True

    def has_primes(self):
        if not isinstance(self.prime1, int):
            return False

        if not isinstance(self.prime2, int):
            return False

        return True

    def has_pub_key(self):
        return isinstance(self.publicKey, int)

    def get_prime_product(self):
        return self.prime1 * self.prime2

    def store(self):
        with open(self.fPath, 'w') as target:
            target.truncate()
            target.write(str(self.prime1) + "\n")
            target.write(str(self.prime2) + "\n")
            target.write(str(self.publicKey) + "\n")
            target.write(str(self.privateKey))
Leonard Marschke's avatar
Leonard Marschke committed


def is_prime(num, runs=40):
    """
    Implements Miller Rabin (non deterministic)

    try to falsify num's primality 40 times by default, see
    https://stackoverflow.com/questions/6325576/how-many-iterations-of-rabin-miller-should-i-use-for-cryptographic-safe-primes

    :param num: Number to test
    :param runs: How often to run miller rabin
    :return: True if num is (probably) a prime number.
    """

    # Some checks for small (and even) numbers to speed up computation
    if num < 2:
        return False
    if num < 4:
        return True
    if not num % 2:
        return False
    if num < 9:
        return True
    if not num % 3:
        return False

    # calculate n - 1 = d * 2^j
    d = num - 1
    j = 0
    while d % 2 == 0:
        d = d // 2
        j += 1

    for _ in range(runs):
        a = random.randrange(2, num - 1)
        v = pow(a, d, num)  # a ^ d % num

        # first test
        if v in [1, num - 1]:  # if v == 1, it is probably a prime number
            continue

        # second test
        i = 0
        while v != (num - 1):
            if i == j - 1:
                return False
            i += 1
            v = (v ** 2) % num
    return True


def gen_primes():
    random.seed(time.time())
    while True:
        prime = bytearray(secrets.token_bytes(KEY_BYTES))  # 1024 bits of key length
        if prime[-1] < (1 << 7):  # to make sure that our prime numbers are big enough (cryptographically secure)
            continue
        # convert byte array to int
        prime = int.from_bytes(prime, **INT_BYTES_ARGS)
        if is_prime(prime) and is_prime(prime + 2):  # check prime (offset as well)
            return prime, prime + 2  # prime is fine


def gcd(a, b):
    while b != 0:
        a, b = b, a % b
    return a


def modinv(a, m):
    return int(gmpy2.invert(a, m).numerator)  # pylint: disable=c-extension-no-member
Leonard Marschke's avatar
Leonard Marschke committed


def key_gen(config):
    if config.has_primes():
        print('Found primes in config file!')
    else:
        print('Generating new primes...')
        print('This takes some time due to the "slowness" of Python...')
        config.prime1, config.prime2 = gen_primes()

    print('Generating public key...')
    p = (config.prime1 - 1) * (config.prime2 - 1)

    if not config.has_pub_key():
        # Make sure e is not a coprime of our prime numbers
        g = None
        while g != 1:
            e = random.randrange(1, p)
            g = gcd(e, p)
    else:
        e = config.publicKey

    print('Generating private key...')
    config.privateKey = modinv(e, p)
    config.publicKey = e

    config.store()


def encrypt(config, message):
Leonard Marschke's avatar
Leonard Marschke committed
    # convert message to bitarray
    cleartext = message.encode('utf-8')

    ciphertext = []

    # encrypt each message part
    for start in range(0, len(cleartext), KEY_BYTES):
        text_num = int.from_bytes(cleartext[start:start + KEY_BYTES], **INT_BYTES_ARGS)
        ciphertext.append(str(pow(text_num, config.publicKey, config.get_prime_product())))

    print('%'.join(ciphertext))


def decrypt(config, message):
Leonard Marschke's avatar
Leonard Marschke committed
    # decode encrypted message and split it to get every block
    chunks = message.split('%')

    decrypted_bytes = b''

    # read the message into a bitarray
    for chunk in chunks:
        # decrypt each block
        decrypted = pow(int(chunk), config.privateKey, config.get_prime_product())
        decrypted_bytes += int.to_bytes(decrypted, KEY_BYTES, **INT_BYTES_ARGS)

    # remove appended padding 0-bytes
    for key in range(len(decrypted_bytes) - 1, 0, -1):  # we (hopefully) do not have a padding only chunk
        if decrypted_bytes[key] != 0:
            decrypted_bytes = decrypted_bytes[0:key + 1]
            break

    # convert bitarray to string
    print(decrypted_bytes.decode('utf-8'))


def main():
    # Do our argument parsing for different call types
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', help='Config file location', default='./crypto.conf')
    subparsers = parser.add_subparsers(dest='action', help='Action to perform')
    subparsers.required = True  # Throw error if no action is passed
    subparsers.add_parser('key-gen')
    parser_encrypt = subparsers.add_parser('encrypt')
    parser_encrypt.add_argument('plain', help='String to encrypt')
    parser_decrypt = subparsers.add_parser('decrypt')
    parser_decrypt.add_argument('encrypted', help='String to decrypt')

    args = parser.parse_args()

    # load config file
    config = Config(args.config)

    if args.action == 'key-gen':
Leonard Marschke's avatar
Leonard Marschke committed
        key_gen(config)
    elif args.action == 'encrypt':
        encrypt(config, args.plain)
    elif args.action == 'decrypt':
        decrypt(config, args.encrypted)
Leonard Marschke's avatar
Leonard Marschke committed

    # store our possibly modified configuration
    config.store()


if __name__ == '__main__':
    main()