You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

530 lines
15 KiB

13 years ago
"""RSA module
Module for calculating large primes, and RSA encryption, decryption,
signing and verification. Includes generating public and private keys.
WARNING: this implementation does not use random padding, compression of the
cleartext input to prevent repetitions, or other common security improvements.
Use with care.
"""
__author__ = "Sybren Stuvel, Marloes de Boer, Ivo Tamboer, and Barry Mead"
__date__ = "2010-02-08"
__version__ = '2.0'
import math
import os
import random
import sys
import types
from rsa._compat import byte
# Display a warning that this insecure version is imported.
import warnings
warnings.warn('Insecure version of the RSA module is imported as %s' % __name__)
def bit_size(number):
"""Returns the number of bits required to hold a specific long number"""
return int(math.ceil(math.log(number,2)))
def gcd(p, q):
"""Returns the greatest common divisor of p and q
>>> gcd(48, 180)
12
"""
# Iterateive Version is faster and uses much less stack space
while q != 0:
if p < q: (p,q) = (q,p)
(p,q) = (q, p % q)
return p
def bytes2int(bytes):
"""Converts a list of bytes or a string to an integer
>>> (((128 * 256) + 64) * 256) + 15
8405007
>>> l = [128, 64, 15]
>>> bytes2int(l) #same as bytes2int('\x80@\x0f')
8405007
"""
if not (type(bytes) is types.ListType or type(bytes) is types.StringType):
raise TypeError("You must pass a string or a list")
# Convert byte stream to integer
integer = 0
for byte in bytes:
integer *= 256
if type(byte) is types.StringType: byte = ord(byte)
integer += byte
return integer
def int2bytes(number):
"""
Converts a number to a string of bytes
"""
if not (type(number) is types.LongType or type(number) is types.IntType):
raise TypeError("You must pass a long or an int")
string = ""
while number > 0:
string = "%s%s" % (byte(number & 0xFF), string)
number /= 256
return string
def to64(number):
"""Converts a number in the range of 0 to 63 into base 64 digit
character in the range of '0'-'9', 'A'-'Z', 'a'-'z','-','_'.
>>> to64(10)
'A'
"""
if not (type(number) is types.LongType or type(number) is types.IntType):
raise TypeError("You must pass a long or an int")
if 0 <= number <= 9: #00-09 translates to '0' - '9'
return byte(number + 48)
if 10 <= number <= 35:
return byte(number + 55) #10-35 translates to 'A' - 'Z'
if 36 <= number <= 61:
return byte(number + 61) #36-61 translates to 'a' - 'z'
if number == 62: # 62 translates to '-' (minus)
return byte(45)
if number == 63: # 63 translates to '_' (underscore)
return byte(95)
raise ValueError('Invalid Base64 value: %i' % number)
def from64(number):
"""Converts an ordinal character value in the range of
0-9,A-Z,a-z,-,_ to a number in the range of 0-63.
>>> from64(49)
1
"""
if not (type(number) is types.LongType or type(number) is types.IntType):
raise TypeError("You must pass a long or an int")
if 48 <= number <= 57: #ord('0') - ord('9') translates to 0-9
return(number - 48)
if 65 <= number <= 90: #ord('A') - ord('Z') translates to 10-35
return(number - 55)
if 97 <= number <= 122: #ord('a') - ord('z') translates to 36-61
return(number - 61)
if number == 45: #ord('-') translates to 62
return(62)
if number == 95: #ord('_') translates to 63
return(63)
raise ValueError('Invalid Base64 value: %i' % number)
def int2str64(number):
"""Converts a number to a string of base64 encoded characters in
the range of '0'-'9','A'-'Z,'a'-'z','-','_'.
>>> int2str64(123456789)
'7MyqL'
"""
if not (type(number) is types.LongType or type(number) is types.IntType):
raise TypeError("You must pass a long or an int")
string = ""
while number > 0:
string = "%s%s" % (to64(number & 0x3F), string)
number /= 64
return string
def str642int(string):
"""Converts a base64 encoded string into an integer.
The chars of this string in in the range '0'-'9','A'-'Z','a'-'z','-','_'
>>> str642int('7MyqL')
123456789
"""
if not (type(string) is types.ListType or type(string) is types.StringType):
raise TypeError("You must pass a string or a list")
integer = 0
for byte in string:
integer *= 64
if type(byte) is types.StringType: byte = ord(byte)
integer += from64(byte)
return integer
def read_random_int(nbits):
"""Reads a random integer of approximately nbits bits rounded up
to whole bytes"""
nbytes = int(math.ceil(nbits/8.))
randomdata = os.urandom(nbytes)
return bytes2int(randomdata)
def randint(minvalue, maxvalue):
"""Returns a random integer x with minvalue <= x <= maxvalue"""
# Safety - get a lot of random data even if the range is fairly
# small
min_nbits = 32
# The range of the random numbers we need to generate
range = (maxvalue - minvalue) + 1
# Which is this number of bytes
rangebytes = ((bit_size(range) + 7) / 8)
# Convert to bits, but make sure it's always at least min_nbits*2
rangebits = max(rangebytes * 8, min_nbits * 2)
# Take a random number of bits between min_nbits and rangebits
nbits = random.randint(min_nbits, rangebits)
return (read_random_int(nbits) % range) + minvalue
def jacobi(a, b):
"""Calculates the value of the Jacobi symbol (a/b)
where both a and b are positive integers, and b is odd
"""
if a == 0: return 0
result = 1
while a > 1:
if a & 1:
if ((a-1)*(b-1) >> 2) & 1:
result = -result
a, b = b % a, a
else:
if (((b * b) - 1) >> 3) & 1:
result = -result
a >>= 1
if a == 0: return 0
return result
def jacobi_witness(x, n):
"""Returns False if n is an Euler pseudo-prime with base x, and
True otherwise.
"""
j = jacobi(x, n) % n
f = pow(x, (n-1)/2, n)
if j == f: return False
return True
def randomized_primality_testing(n, k):
"""Calculates whether n is composite (which is always correct) or
prime (which is incorrect with error probability 2**-k)
Returns False if the number is composite, and True if it's
probably prime.
"""
# 50% of Jacobi-witnesses can report compositness of non-prime numbers
for i in range(k):
x = randint(1, n-1)
if jacobi_witness(x, n): return False
return True
def is_prime(number):
"""Returns True if the number is prime, and False otherwise.
>>> is_prime(42)
0
>>> is_prime(41)
1
"""
if randomized_primality_testing(number, 6):
# Prime, according to Jacobi
return True
# Not prime
return False
def getprime(nbits):
"""Returns a prime number of max. 'math.ceil(nbits/8)*8' bits. In
other words: nbits is rounded up to whole bytes.
>>> p = getprime(8)
>>> is_prime(p-1)
0
>>> is_prime(p)
1
>>> is_prime(p+1)
0
"""
while True:
integer = read_random_int(nbits)
# Make sure it's odd
integer |= 1
# Test for primeness
if is_prime(integer): break
# Retry if not prime
return integer
def are_relatively_prime(a, b):
"""Returns True if a and b are relatively prime, and False if they
are not.
>>> are_relatively_prime(2, 3)
1
>>> are_relatively_prime(2, 4)
0
"""
d = gcd(a, b)
return (d == 1)
def find_p_q(nbits):
"""Returns a tuple of two different primes of nbits bits"""
pbits = nbits + (nbits/16) #Make sure that p and q aren't too close
qbits = nbits - (nbits/16) #or the factoring programs can factor n
p = getprime(pbits)
while True:
q = getprime(qbits)
#Make sure p and q are different.
if not q == p: break
return (p, q)
def extended_gcd(a, b):
"""Returns a tuple (r, i, j) such that r = gcd(a, b) = ia + jb
"""
# r = gcd(a,b) i = multiplicitive inverse of a mod b
# or j = multiplicitive inverse of b mod a
# Neg return values for i or j are made positive mod b or a respectively
# Iterateive Version is faster and uses much less stack space
x = 0
y = 1
lx = 1
ly = 0
oa = a #Remember original a/b to remove
ob = b #negative values from return results
while b != 0:
q = long(a/b)
(a, b) = (b, a % b)
(x, lx) = ((lx - (q * x)),x)
(y, ly) = ((ly - (q * y)),y)
if (lx < 0): lx += ob #If neg wrap modulo orignal b
if (ly < 0): ly += oa #If neg wrap modulo orignal a
return (a, lx, ly) #Return only positive values
# Main function: calculate encryption and decryption keys
def calculate_keys(p, q, nbits):
"""Calculates an encryption and a decryption key for p and q, and
returns them as a tuple (e, d)"""
n = p * q
phi_n = (p-1) * (q-1)
while True:
# Make sure e has enough bits so we ensure "wrapping" through
# modulo n
e = max(65537,getprime(nbits/4))
if are_relatively_prime(e, n) and are_relatively_prime(e, phi_n): break
(d, i, j) = extended_gcd(e, phi_n)
if not d == 1:
raise Exception("e (%d) and phi_n (%d) are not relatively prime" % (e, phi_n))
if (i < 0):
raise Exception("New extended_gcd shouldn't return negative values")
if not (e * i) % phi_n == 1:
raise Exception("e (%d) and i (%d) are not mult. inv. modulo phi_n (%d)" % (e, i, phi_n))
return (e, i)
def gen_keys(nbits):
"""Generate RSA keys of nbits bits. Returns (p, q, e, d).
Note: this can take a long time, depending on the key size.
"""
(p, q) = find_p_q(nbits)
(e, d) = calculate_keys(p, q, nbits)
return (p, q, e, d)
def newkeys(nbits):
"""Generates public and private keys, and returns them as (pub,
priv).
The public key consists of a dict {e: ..., , n: ....). The private
key consists of a dict {d: ...., p: ...., q: ....).
"""
nbits = max(9,nbits) # Don't let nbits go below 9 bits
(p, q, e, d) = gen_keys(nbits)
return ( {'e': e, 'n': p*q}, {'d': d, 'p': p, 'q': q} )
def encrypt_int(message, ekey, n):
"""Encrypts a message using encryption key 'ekey', working modulo n"""
if type(message) is types.IntType:
message = long(message)
if not type(message) is types.LongType:
raise TypeError("You must pass a long or int")
if message < 0 or message > n:
raise OverflowError("The message is too long")
#Note: Bit exponents start at zero (bit counts start at 1) this is correct
safebit = bit_size(n) - 2 #compute safe bit (MSB - 1)
message += (1 << safebit) #add safebit to ensure folding
return pow(message, ekey, n)
def decrypt_int(cyphertext, dkey, n):
"""Decrypts a cypher text using the decryption key 'dkey', working
modulo n"""
message = pow(cyphertext, dkey, n)
safebit = bit_size(n) - 2 #compute safe bit (MSB - 1)
message -= (1 << safebit) #remove safebit before decode
return message
def encode64chops(chops):
"""base64encodes chops and combines them into a ',' delimited string"""
chips = [] #chips are character chops
for value in chops:
chips.append(int2str64(value))
#delimit chops with comma
encoded = ','.join(chips)
return encoded
def decode64chops(string):
"""base64decodes and makes a ',' delimited string into chops"""
chips = string.split(',') #split chops at commas
chops = []
for string in chips: #make char chops (chips) into chops
chops.append(str642int(string))
return chops
def chopstring(message, key, n, funcref):
"""Chops the 'message' into integers that fit into n,
leaving room for a safebit to be added to ensure that all
messages fold during exponentiation. The MSB of the number n
is not independant modulo n (setting it could cause overflow), so
use the next lower bit for the safebit. Therefore reserve 2-bits
in the number n for non-data bits. Calls specified encryption
function for each chop.
Used by 'encrypt' and 'sign'.
"""
msglen = len(message)
mbits = msglen * 8
#Set aside 2-bits so setting of safebit won't overflow modulo n.
nbits = bit_size(n) - 2 # leave room for safebit
nbytes = nbits / 8
blocks = msglen / nbytes
if msglen % nbytes > 0:
blocks += 1
cypher = []
for bindex in range(blocks):
offset = bindex * nbytes
block = message[offset:offset+nbytes]
value = bytes2int(block)
cypher.append(funcref(value, key, n))
return encode64chops(cypher) #Encode encrypted ints to base64 strings
def gluechops(string, key, n, funcref):
"""Glues chops back together into a string. calls
funcref(integer, key, n) for each chop.
Used by 'decrypt' and 'verify'.
"""
message = ""
chops = decode64chops(string) #Decode base64 strings into integer chops
for cpart in chops:
mpart = funcref(cpart, key, n) #Decrypt each chop
message += int2bytes(mpart) #Combine decrypted strings into a msg
return message
def encrypt(message, key):
"""Encrypts a string 'message' with the public key 'key'"""
if 'n' not in key:
raise Exception("You must use the public key with encrypt")
return chopstring(message, key['e'], key['n'], encrypt_int)
def sign(message, key):
"""Signs a string 'message' with the private key 'key'"""
if 'p' not in key:
raise Exception("You must use the private key with sign")
return chopstring(message, key['d'], key['p']*key['q'], encrypt_int)
def decrypt(cypher, key):
"""Decrypts a string 'cypher' with the private key 'key'"""
if 'p' not in key:
raise Exception("You must use the private key with decrypt")
return gluechops(cypher, key['d'], key['p']*key['q'], decrypt_int)
def verify(cypher, key):
"""Verifies a string 'cypher' with the public key 'key'"""
if 'n' not in key:
raise Exception("You must use the public key with verify")
return gluechops(cypher, key['e'], key['n'], decrypt_int)
# Do doctest if we're not imported
if __name__ == "__main__":
import doctest
doctest.testmod()
__all__ = ["newkeys", "encrypt", "decrypt", "sign", "verify"]