import hashlib
import hmac
import json
from .compat import constant_time_compare, string_types
from .exceptions import InvalidKeyError
from .utils import (
base64url_decode, base64url_encode, der_to_raw_signature,
force_bytes, force_unicode, from_base64url_uint, raw_to_der_signature,
to_base64url_uint
)
try:
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.serialization import (
load_pem_private_key, load_pem_public_key, load_ssh_public_key
)
from cryptography.hazmat.primitives.asymmetric.rsa import (
RSAPrivateKey, RSAPublicKey, RSAPrivateNumbers, RSAPublicNumbers,
rsa_recover_prime_factors, rsa_crt_dmp1, rsa_crt_dmq1, rsa_crt_iqmp
)
from cryptography.hazmat.primitives.asymmetric.ec import (
EllipticCurvePrivateKey, EllipticCurvePublicKey
)
from cryptography.hazmat.primitives.asymmetric import ec, padding
from cryptography.hazmat.backends import default_backend
from cryptography.exceptions import InvalidSignature
has_crypto = True
except ImportError:
has_crypto = False
requires_cryptography = set(['RS256', 'RS384', 'RS512', 'ES256', 'ES384',
'ES521', 'ES512', 'PS256', 'PS384', 'PS512'])
def get_default_algorithms():
"""
Returns the algorithms that are implemented by the library.
"""
default_algorithms = {
'none': NoneAlgorithm(),
'HS256': HMACAlgorithm(HMACAlgorithm.SHA256),
'HS384': HMACAlgorithm(HMACAlgorithm.SHA384),
'HS512': HMACAlgorithm(HMACAlgorithm.SHA512)
}
if has_crypto:
default_algorithms.update({
'RS256': RSAAlgorithm(RSAAlgorithm.SHA256),
'RS384': RSAAlgorithm(RSAAlgorithm.SHA384),
'RS512': RSAAlgorithm(RSAAlgorithm.SHA512),
'ES256': ECAlgorithm(ECAlgorithm.SHA256),
'ES384': ECAlgorithm(ECAlgorithm.SHA384),
'ES521': ECAlgorithm(ECAlgorithm.SHA512),
'ES512': ECAlgorithm(ECAlgorithm.SHA512), # Backward compat for #219 fix
'PS256': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256),
'PS384': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384),
'PS512': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512)
})
return default_algorithms
class Algorithm(object):
"""
The interface for an algorithm used to sign and verify tokens.
"""
def prepare_key(self, key):
"""
Performs necessary validation and conversions on the key and returns
the key value in the proper format for sign() and verify().
"""
raise NotImplementedError
def sign(self, msg, key):
"""
Returns a digital signature for the specified message
using the specified key value.
"""
raise NotImplementedError
def verify(self, msg, key, sig):
"""
Verifies that the specified digital signature is valid
for the specified message and key values.
"""
raise NotImplementedError
@staticmethod
def to_jwk(key_obj):
"""
Serializes a given RSA key into a JWK
"""
raise NotImplementedError
@staticmethod
def from_jwk(jwk):
"""
Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object
"""
raise NotImplementedError
class NoneAlgorithm(Algorithm):
"""
Placeholder for use when no signing or verification
operations are required.
"""
def prepare_key(self, key):
if key == '':
key = None
if key is not None:
raise InvalidKeyError('When alg = "none", key value must be None.')
return key
def sign(self, msg, key):
return b''
def verify(self, msg, key, sig):
return False
class HMACAlgorithm(Algorithm):
"""
Performs signing and verification operations using HMAC
and the specified hash function.
"""
SHA256 = hashlib.sha256
SHA384 = hashlib.sha384
SHA512 = hashlib.sha512
def __init__(self, hash_alg):
self.hash_alg = hash_alg
def prepare_key(self, key):
key = force_bytes(key)
invalid_strings = [
b'-----BEGIN PUBLIC KEY-----',
b'-----BEGIN CERTIFICATE-----',
b'-----BEGIN RSA PUBLIC KEY-----',
b'ssh-rsa'
]
if any([string_value in key for string_value in invalid_strings]):
raise InvalidKeyError(
'The specified key is an asymmetric key or x509 certificate and'
' should not be used as an HMAC secret.')
return key
@staticmethod
def to_jwk(key_obj):
return json.dumps({
'k': force_unicode(base64url_encode(force_bytes(key_obj))),
'kty': 'oct'
})
@staticmethod
def from_jwk(jwk):
obj = json.loads(jwk)
if obj.get('kty') != 'oct':
raise InvalidKeyError('Not an HMAC key')
return base64url_decode(obj['k'])
def sign(self, msg, key):
return hmac.new(key, msg, self.hash_alg).digest()
def verify(self, msg, key, sig):
return constant_time_compare(sig, self.sign(msg, key))
if has_crypto:
class RSAAlgorithm(Algorithm):
"""
Performs signing and verification operations using
RSASSA-PKCS-v1_5 and the specified hash function.
"""
SHA256 = hashes.SHA256
SHA384 = hashes.SHA384
SHA512 = hashes.SHA512
def __init__(self, hash_alg):
self.hash_alg = hash_alg
def prepare_key(self, key):
if isinstance(key, RSAPrivateKey) or \
isinstance(key, RSAPublicKey):
return key
if isinstance(key, string_types):
key = force_bytes(key)
try:
if key.startswith(b'ssh-rsa'):
key = load_ssh_public_key(key, backend=default_backend())
else:
key = load_pem_private_key(key, password=None, backend=default_backend())
except ValueError:
key = load_pem_public_key(key, backend=default_backend())
else:
raise TypeError('Expecting a PEM-formatted key.')
return key
@staticmethod
def to_jwk(key_obj):
obj = None
if getattr(key_obj, 'private_numbers', None):
# Private key
numbers = key_obj.private_numbers()
obj = {
'kty': 'RSA',
'key_ops': ['sign'],
'n': force_unicode(to_base64url_uint(numbers.public_numbers.n)),
'e': force_unicode(to_base64url_uint(numbers.public_numbers.e)),
'd': force_unicode(to_base64url_uint(numbers.d)),
'p': force_unicode(to_base64url_uint(numbers.p)),
'q': force_unicode(to_base64url_uint(numbers.q)),
'dp': force_unicode(to_base64url_uint(numbers.dmp1)),
'dq': force_unicode(to_base64url_uint(numbers.dmq1)),
'qi': force_unicode(to_base64url_uint(numbers.iqmp))
}
elif getattr(key_obj, 'verify', None):
# Public key
numbers = key_obj.public_numbers()
obj = {
'kty': 'RSA',
'key_ops': ['verify'],
'n': force_unicode(to_base64url_uint(numbers.n)),
'e': force_unicode(to_base64url_uint(numbers.e))
}
else:
raise InvalidKeyError('Not a public or private key')
return json.dumps(obj)
@staticmethod
def from_jwk(jwk):
try:
obj = json.loads(jwk)
except ValueError:
raise InvalidKeyError('Key is not valid JSON')
if obj.get('kty') != 'RSA':
raise InvalidKeyError('Not an RSA key')
if 'd' in obj and 'e' in obj and 'n' in obj:
# Private key
if 'oth' in obj:
raise InvalidKeyError('Unsupported RSA private key: > 2 primes not supported')
other_props = ['p', 'q', 'dp', 'dq', 'qi']
props_found = [prop in obj for prop in other_props]
any_props_found = any(props_found)
if any_props_found and not all(props_found):
raise InvalidKeyError('RSA key must include all parameters if any are present besides d')
public_numbers = RSAPublicNumbers(
from_base64url_uint(obj['e']), from_base64url_uint(obj['n'])
)
if any_props_found:
numbers = RSAPrivateNumbers(
d=from_base64url_uint(obj['d']),
p=from_base64url_uint(obj['p']),
q=from_base64url_uint(obj['q']),
dmp1=from_base64url_uint(obj['dp']),
dmq1=from_base64url_uint(obj['dq']),
iqmp=from_base64url_uint(obj['qi']),
public_numbers=public_numbers
)
else:
d = from_base64url_uint(obj['d'])
p, q = rsa_recover_prime_factors(
public_numbers.n, d, public_numbers.e
)
numbers = RSAPrivateNumbers(
d=d,
p=p,
q=q,
dmp1=rsa_crt_dmp1(d, p),
dmq1=rsa_crt_dmq1(d, q),
iqmp=rsa_crt_iqmp(p, q),
public_numbers=public_numbers
)
return numbers.private_key(default_backend())
elif 'n' in obj and 'e' in obj:
# Public key
numbers = RSAPublicNumbers(
from_base64url_uint(obj['e']), from_base64url_uint(obj['n'])
)
return numbers.public_key(default_backend())
else:
raise InvalidKeyError('Not a public or private key')
def sign(self, msg, key):
return key.sign(msg, padding.PKCS1v15(), self.hash_alg())
def verify(self, msg, key, sig):
try:
key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg())
return True
except InvalidSignature:
return False
class ECAlgorithm(Algorithm):
"""
Performs signing and verification operations using
ECDSA and the specified hash function
"""
SHA256 = hashes.SHA256
SHA384 = hashes.SHA384
SHA512 = hashes.SHA512
def __init__(self, hash_alg):
self.hash_alg = hash_alg
def prepare_key(self, key):
if isinstance(key, EllipticCurvePrivateKey) or \
isinstance(key, EllipticCurvePublicKey):
return key
if isinstance(key, string_types):
key = force_bytes(key)
# Attempt to load key. We don't know if it's
# a Signing Key or a Verifying Key, so we try
# the Verifying Key first.
try:
if key.startswith(b'ecdsa-sha2-'):
key = load_ssh_public_key(key, backend=default_backend())
else:
key = load_pem_public_key(key, backend=default_backend())
except ValueError:
key = load_pem_private_key(key, password=None, backend=default_backend())
else:
raise TypeError('Expecting a PEM-formatted key.')
return key
def sign(self, msg, key):
der_sig = key.sign(msg, ec.ECDSA(self.hash_alg()))
return der_to_raw_signature(der_sig, key.curve)
def verify(self, msg, key, sig):
try:
der_sig = raw_to_der_signature(sig, key.curve)
except ValueError:
return False
try:
key.verify(der_sig, msg, ec.ECDSA(self.hash_alg()))
return True
except InvalidSignature:
return False
class RSAPSSAlgorithm(RSAAlgorithm):
"""
Performs a signature using RSASSA-PSS with MGF1
"""
def sign(self, msg, key):
return key.sign(
msg,
padding.PSS(
mgf=padding.MGF1(self.hash_alg()),
salt_length=self.hash_alg.digest_size
),
self.hash_alg()
)
def verify(self, msg, key, sig):
try:
key.verify(
sig,
msg,
padding.PSS(
mgf=padding.MGF1(self.hash_alg()),
salt_length=self.hash_alg.digest_size
),
self.hash_alg()
)
return True
except InvalidSignature:
return False