diff --git a/RNS/Cryptography/AES.py b/RNS/Cryptography/AES.py index 2689509..f3130b4 100644 --- a/RNS/Cryptography/AES.py +++ b/RNS/Cryptography/AES.py @@ -36,15 +36,48 @@ if cp.PROVIDER == cp.PROVIDER_INTERNAL: elif cp.PROVIDER == cp.PROVIDER_PYCA: from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes - - if pu.cryptography_old_api(): - from cryptography.hazmat.backends import default_backend + if pu.cryptography_old_api(): from cryptography.hazmat.backends import default_backend class AES_128_CBC: - @staticmethod def encrypt(plaintext, key, iv): + if len(key) != 16: raise ValueError(f"Invalid key length {len(key)*8} for {self}") + if cp.PROVIDER == cp.PROVIDER_INTERNAL: + cipher = AES(key) + return cipher.encrypt(plaintext, iv) + + elif cp.PROVIDER == cp.PROVIDER_PYCA: + if not pu.cryptography_old_api(): + cipher = Cipher(algorithms.AES(key), modes.CBC(iv)) + else: + cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend()) + + encryptor = cipher.encryptor() + ciphertext = encryptor.update(plaintext) + encryptor.finalize() + return ciphertext + + @staticmethod + def decrypt(ciphertext, key, iv): + if len(key) != 16: raise ValueError(f"Invalid key length {len(key)*8} for {self}") + if cp.PROVIDER == cp.PROVIDER_INTERNAL: + cipher = AES(key) + return cipher.decrypt(ciphertext, iv) + + elif cp.PROVIDER == cp.PROVIDER_PYCA: + if not pu.cryptography_old_api(): + cipher = Cipher(algorithms.AES(key), modes.CBC(iv)) + else: + cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend()) + + decryptor = cipher.decryptor() + plaintext = decryptor.update(ciphertext) + decryptor.finalize() + return plaintext + +class AES_256_CBC: + @staticmethod + def encrypt(plaintext, key, iv): + if len(key) != 32: raise ValueError(f"Invalid key length {len(key)*8} for {self}") if cp.PROVIDER == cp.PROVIDER_INTERNAL: cipher = AES(key) return cipher.encrypt(plaintext, iv) @@ -61,6 +94,7 @@ class AES_128_CBC: @staticmethod def decrypt(ciphertext, key, iv): + if len(key) != 32: raise ValueError(f"Invalid key length {len(key)*8} for {self}") if cp.PROVIDER == cp.PROVIDER_INTERNAL: cipher = AES(key) return cipher.decrypt(ciphertext, iv) diff --git a/RNS/Cryptography/Token.py b/RNS/Cryptography/Token.py index 6a84a3c..9434de5 100644 --- a/RNS/Cryptography/Token.py +++ b/RNS/Cryptography/Token.py @@ -33,7 +33,9 @@ import time from RNS.Cryptography import HMAC from RNS.Cryptography import PKCS7 +from RNS.Cryptography import AES from RNS.Cryptography.AES import AES_128_CBC +from RNS.Cryptography.AES import AES_256_CBC class Token(): """ @@ -48,45 +50,50 @@ class Token(): TOKEN_OVERHEAD = 48 # Bytes @staticmethod - def generate_key(): - return os.urandom(32) + def generate_key(mode=AES_128_CBC): + if mode == AES_128_CBC: return os.urandom(32) + elif mode == AES_256_CBC: return os.urandom(64) + else: raise TypeError(f"Invalid token mode: {mode}") - def __init__(self, key = None): - if key == None: - raise ValueError("Token key cannot be None") + def __init__(self, key=None, mode=AES): + if key == None: raise ValueError("Token key cannot be None") - if len(key) != 32: - raise ValueError("Token key must be 32 bytes, not "+str(len(key))) - - self._signing_key = key[:16] - self._encryption_key = key[16:] + if mode == AES: + if len(key) == 32: + self.mode = AES_128_CBC + self._signing_key = key[:16] + self._encryption_key = key[16:] + + elif len(key) == 64: + self.mode = AES_256_CBC + self._signing_key = key[:32] + self._encryption_key = key[32:] + + else: raise ValueError("Token key must be 128 or 256 bits, not "+str(len(key)*8)) + + else: raise TypeError(f"Invalid token mode: {mode}") def verify_hmac(self, token): - if len(token) <= 32: - raise ValueError("Cannot verify HMAC on token of only "+str(len(token))+" bytes") + if len(token) <= 32: raise ValueError("Cannot verify HMAC on token of only "+str(len(token))+" bytes") else: received_hmac = token[-32:] expected_hmac = HMAC.new(self._signing_key, token[:-32]).digest() - if received_hmac == expected_hmac: - return True - else: - return False + if received_hmac == expected_hmac: return True + else: return False def encrypt(self, data = None): iv = os.urandom(16) current_time = int(time.time()) - if not isinstance(data, bytes): - raise TypeError("Token plaintext input must be bytes") + if not isinstance(data, bytes): raise TypeError("Token plaintext input must be bytes") - ciphertext = AES_128_CBC.encrypt( + ciphertext = self.mode.encrypt( plaintext = PKCS7.pad(data), key = self._encryption_key, - iv = iv, - ) + iv = iv) signed_parts = iv+ciphertext @@ -94,25 +101,19 @@ class Token(): def decrypt(self, token = None): - if not isinstance(token, bytes): - raise TypeError("Token must be bytes") - - if not self.verify_hmac(token): - raise ValueError("Token HMAC was invalid") + if not isinstance(token, bytes): raise TypeError("Token must be bytes") + if not self.verify_hmac(token): raise ValueError("Token HMAC was invalid") iv = token[:16] ciphertext = token[16:-32] try: plaintext = PKCS7.unpad( - AES_128_CBC.decrypt( - ciphertext, - self._encryption_key, - iv, - ) - ) + self.mode.decrypt( + ciphertext = ciphertext, + key = self._encryption_key, + iv = iv)) return plaintext - except Exception as e: - raise ValueError("Could not decrypt token") \ No newline at end of file + except Exception as e: raise ValueError("Could not decrypt token") \ No newline at end of file diff --git a/RNS/Cryptography/aes/aes.py b/RNS/Cryptography/aes/aes.py index eabb20b..5244b20 100644 --- a/RNS/Cryptography/aes/aes.py +++ b/RNS/Cryptography/aes/aes.py @@ -22,6 +22,7 @@ from .utils import * +# TODO: Add AES-256 support to pure-python implementation class AES: # AES-128 block size diff --git a/RNS/Link.py b/RNS/Link.py index 864840e..69f7d52 100644 --- a/RNS/Link.py +++ b/RNS/Link.py @@ -100,39 +100,68 @@ class Link: and will be torn down. """ - PENDING = 0x00 - HANDSHAKE = 0x01 - ACTIVE = 0x02 - STALE = 0x03 - CLOSED = 0x04 + PENDING = 0x00 + HANDSHAKE = 0x01 + ACTIVE = 0x02 + STALE = 0x03 + CLOSED = 0x04 - TIMEOUT = 0x01 - INITIATOR_CLOSED = 0x02 - DESTINATION_CLOSED = 0x03 + TIMEOUT = 0x01 + INITIATOR_CLOSED = 0x02 + DESTINATION_CLOSED = 0x03 - ACCEPT_NONE = 0x00 - ACCEPT_APP = 0x01 - ACCEPT_ALL = 0x02 + ACCEPT_NONE = 0x00 + ACCEPT_APP = 0x01 + ACCEPT_ALL = 0x02 resource_strategies = [ACCEPT_NONE, ACCEPT_APP, ACCEPT_ALL] + MODE_AES128_CBC = 0x00 + MODE_AES256_CBC = 0x01 + MODE_AES256_GCM = 0x02 + MODE_OTP_RESERVED = 0x03 + MODE_PQ_RESERVED_1 = 0x04 + MODE_PQ_RESERVED_2 = 0x05 + MODE_PQ_RESERVED_3 = 0x06 + MODE_PQ_RESERVED_4 = 0x07 + enabled_modes = [MODE_AES128_CBC] + + MTU_BYTEMASK = 0x1FFFFF + MODE_BYTEMASK = 0xE0 + @staticmethod def mtu_bytes(mtu): - return struct.pack(">I", mtu & 0xFFFFFF)[1:] + return struct.pack(">I", mtu & Link.MTU_BYTEMASK)[1:] @staticmethod def mtu_from_lr_packet(packet): if len(packet.data) == Link.ECPUBSIZE+Link.LINK_MTU_SIZE: - return (packet.data[Link.ECPUBSIZE] << 16) + (packet.data[Link.ECPUBSIZE+1] << 8) + (packet.data[Link.ECPUBSIZE+2]) - else: - return None + return (packet.data[Link.ECPUBSIZE] << 16) + (packet.data[Link.ECPUBSIZE+1] << 8) + (packet.data[Link.ECPUBSIZE+2]) & Link.MTU_BYTEMASK + else: return None @staticmethod def mtu_from_lp_packet(packet): if len(packet.data) == RNS.Identity.SIGLENGTH//8+Link.ECPUBSIZE//2+Link.LINK_MTU_SIZE: mtu_bytes = packet.data[RNS.Identity.SIGLENGTH//8+Link.ECPUBSIZE//2:RNS.Identity.SIGLENGTH//8+Link.ECPUBSIZE//2+Link.LINK_MTU_SIZE] - return (mtu_bytes[0] << 16) + (mtu_bytes[1] << 8) + (mtu_bytes[2]) - else: - return None + return (mtu_bytes[0] << 16) + (mtu_bytes[1] << 8) + (mtu_bytes[2]) & Link.MTU_BYTEMASK + else: return None + + @staticmethod + def mode_byte(mode): + if mode in Link.enabled_modes: return (mode << 5) & Link.MODE_BYTEMASK + else: raise TypeError(f"Requested link mode {mode} not enabled") + + @staticmethod + def mode_from_lr_packet(packet): + if len(packet.data) > Link.ECPUBSIZE: + return (packet.data[Link.ECPUBSIZE] << 16) + (packet.data[Link.ECPUBSIZE+1] << 8) + (packet.data[Link.ECPUBSIZE+2]) & Link.MODE_BYTEMASK + else: return None + + @staticmethod + def mode_from_lp_packet(packet): + if len(packet.data) > RNS.Identity.SIGLENGTH//8+Link.ECPUBSIZE//2: + mode_byte = packet.data[RNS.Identity.SIGLENGTH//8+Link.ECPUBSIZE//2:RNS.Identity.SIGLENGTH//8+Link.ECPUBSIZE//2+1] + return mode_byte & Link.MODE_BYTEMASK + else: return None @staticmethod def validate_request(owner, data, packet): @@ -177,9 +206,9 @@ class Link: return None - def __init__(self, destination=None, established_callback = None, closed_callback = None, owner=None, peer_pub_bytes = None, peer_sig_pub_bytes = None): - if destination != None and destination.type != RNS.Destination.SINGLE: - raise TypeError("Links can only be established to the \"single\" destination type") + def __init__(self, destination=None, established_callback=None, closed_callback=None, owner=None, peer_pub_bytes=None, peer_sig_pub_bytes=None, mode=MODE_AES128_CBC): + if destination != None and destination.type != RNS.Destination.SINGLE: raise TypeError("Links can only be established to the \"single\" destination type") + self.mode = mode self.rtt = None self.mtu = RNS.Reticulum.MTU self.establishment_cost = 0 @@ -299,14 +328,17 @@ class Link: self.status = Link.HANDSHAKE self.shared_key = self.prv.exchange(self.peer_pub) + if self.mode == Link.MODE_AES128_CBC: derived_key_length = 32 + elif self.mode == Link.MODE_AES256_CBC: derived_key_length = 64 + else: raise TypeError(f"Invalid link mode {self.mode} on {self}") + self.derived_key = RNS.Cryptography.hkdf( - length=32, + length=derived_key_length, derive_from=self.shared_key, salt=self.get_salt(), - context=self.get_context(), - ) - else: - RNS.log("Handshake attempt on "+str(self)+" with invalid state "+str(self.status), RNS.LOG_ERROR) + context=self.get_context()) + + else: RNS.log("Handshake attempt on "+str(self)+" with invalid state "+str(self.status), RNS.LOG_ERROR) def prove(self): @@ -1093,8 +1125,7 @@ class Link: def encrypt(self, plaintext): try: if not self.token: - try: - self.token = Token(self.derived_key) + try: self.token = Token(self.derived_key) except Exception as e: RNS.log("Could not instantiate token while performing encryption on link "+str(self)+". The contained exception was: "+str(e), RNS.LOG_ERROR) raise e @@ -1108,9 +1139,7 @@ class Link: def decrypt(self, ciphertext): try: - if not self.token: - self.token = Token(self.derived_key) - + if not self.token: self.token = Token(self.derived_key) return self.token.decrypt(ciphertext) except Exception as e: