Merge branch 'linkmodes'
Some checks are pending
Build Reticulum / test (push) Waiting to run
Build Reticulum / package (push) Blocked by required conditions
Build Reticulum / release (push) Blocked by required conditions

This commit is contained in:
Mark Qvist 2025-04-16 14:11:14 +02:00
commit 1dbb1a6a35
4 changed files with 134 additions and 69 deletions

View file

@ -36,15 +36,48 @@ if cp.PROVIDER == cp.PROVIDER_INTERNAL:
elif cp.PROVIDER == cp.PROVIDER_PYCA: elif cp.PROVIDER == cp.PROVIDER_PYCA:
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
if pu.cryptography_old_api(): from cryptography.hazmat.backends import default_backend
if pu.cryptography_old_api():
from cryptography.hazmat.backends import default_backend
class AES_128_CBC: class AES_128_CBC:
@staticmethod @staticmethod
def encrypt(plaintext, key, iv): def encrypt(plaintext, key, iv):
if len(key) != 16: raise ValueError(f"Invalid key length {len(key)*8} for {self}")
if cp.PROVIDER == cp.PROVIDER_INTERNAL:
cipher = AES(key)
return cipher.encrypt(plaintext, iv)
elif cp.PROVIDER == cp.PROVIDER_PYCA:
if not pu.cryptography_old_api():
cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
else:
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
encryptor = cipher.encryptor()
ciphertext = encryptor.update(plaintext) + encryptor.finalize()
return ciphertext
@staticmethod
def decrypt(ciphertext, key, iv):
if len(key) != 16: raise ValueError(f"Invalid key length {len(key)*8} for {self}")
if cp.PROVIDER == cp.PROVIDER_INTERNAL:
cipher = AES(key)
return cipher.decrypt(ciphertext, iv)
elif cp.PROVIDER == cp.PROVIDER_PYCA:
if not pu.cryptography_old_api():
cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
else:
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
decryptor = cipher.decryptor()
plaintext = decryptor.update(ciphertext) + decryptor.finalize()
return plaintext
class AES_256_CBC:
@staticmethod
def encrypt(plaintext, key, iv):
if len(key) != 32: raise ValueError(f"Invalid key length {len(key)*8} for {self}")
if cp.PROVIDER == cp.PROVIDER_INTERNAL: if cp.PROVIDER == cp.PROVIDER_INTERNAL:
cipher = AES(key) cipher = AES(key)
return cipher.encrypt(plaintext, iv) return cipher.encrypt(plaintext, iv)
@ -61,6 +94,7 @@ class AES_128_CBC:
@staticmethod @staticmethod
def decrypt(ciphertext, key, iv): def decrypt(ciphertext, key, iv):
if len(key) != 32: raise ValueError(f"Invalid key length {len(key)*8} for {self}")
if cp.PROVIDER == cp.PROVIDER_INTERNAL: if cp.PROVIDER == cp.PROVIDER_INTERNAL:
cipher = AES(key) cipher = AES(key)
return cipher.decrypt(ciphertext, iv) return cipher.decrypt(ciphertext, iv)

View file

@ -33,7 +33,9 @@ import time
from RNS.Cryptography import HMAC from RNS.Cryptography import HMAC
from RNS.Cryptography import PKCS7 from RNS.Cryptography import PKCS7
from RNS.Cryptography import AES
from RNS.Cryptography.AES import AES_128_CBC from RNS.Cryptography.AES import AES_128_CBC
from RNS.Cryptography.AES import AES_256_CBC
class Token(): class Token():
""" """
@ -48,45 +50,50 @@ class Token():
TOKEN_OVERHEAD = 48 # Bytes TOKEN_OVERHEAD = 48 # Bytes
@staticmethod @staticmethod
def generate_key(): def generate_key(mode=AES_128_CBC):
return os.urandom(32) if mode == AES_128_CBC: return os.urandom(32)
elif mode == AES_256_CBC: return os.urandom(64)
else: raise TypeError(f"Invalid token mode: {mode}")
def __init__(self, key = None): def __init__(self, key=None, mode=AES):
if key == None: if key == None: raise ValueError("Token key cannot be None")
raise ValueError("Token key cannot be None")
if len(key) != 32: if mode == AES:
raise ValueError("Token key must be 32 bytes, not "+str(len(key))) if len(key) == 32:
self.mode = AES_128_CBC
self._signing_key = key[:16] self._signing_key = key[:16]
self._encryption_key = key[16:] self._encryption_key = key[16:]
elif len(key) == 64:
self.mode = AES_256_CBC
self._signing_key = key[:32]
self._encryption_key = key[32:]
else: raise ValueError("Token key must be 128 or 256 bits, not "+str(len(key)*8))
else: raise TypeError(f"Invalid token mode: {mode}")
def verify_hmac(self, token): def verify_hmac(self, token):
if len(token) <= 32: if len(token) <= 32: raise ValueError("Cannot verify HMAC on token of only "+str(len(token))+" bytes")
raise ValueError("Cannot verify HMAC on token of only "+str(len(token))+" bytes")
else: else:
received_hmac = token[-32:] received_hmac = token[-32:]
expected_hmac = HMAC.new(self._signing_key, token[:-32]).digest() expected_hmac = HMAC.new(self._signing_key, token[:-32]).digest()
if received_hmac == expected_hmac: if received_hmac == expected_hmac: return True
return True else: return False
else:
return False
def encrypt(self, data = None): def encrypt(self, data = None):
iv = os.urandom(16) iv = os.urandom(16)
current_time = int(time.time()) current_time = int(time.time())
if not isinstance(data, bytes): if not isinstance(data, bytes): raise TypeError("Token plaintext input must be bytes")
raise TypeError("Token plaintext input must be bytes")
ciphertext = AES_128_CBC.encrypt( ciphertext = self.mode.encrypt(
plaintext = PKCS7.pad(data), plaintext = PKCS7.pad(data),
key = self._encryption_key, key = self._encryption_key,
iv = iv, iv = iv)
)
signed_parts = iv+ciphertext signed_parts = iv+ciphertext
@ -94,25 +101,19 @@ class Token():
def decrypt(self, token = None): def decrypt(self, token = None):
if not isinstance(token, bytes): if not isinstance(token, bytes): raise TypeError("Token must be bytes")
raise TypeError("Token must be bytes") if not self.verify_hmac(token): raise ValueError("Token HMAC was invalid")
if not self.verify_hmac(token):
raise ValueError("Token HMAC was invalid")
iv = token[:16] iv = token[:16]
ciphertext = token[16:-32] ciphertext = token[16:-32]
try: try:
plaintext = PKCS7.unpad( plaintext = PKCS7.unpad(
AES_128_CBC.decrypt( self.mode.decrypt(
ciphertext, ciphertext = ciphertext,
self._encryption_key, key = self._encryption_key,
iv, iv = iv))
)
)
return plaintext return plaintext
except Exception as e: except Exception as e: raise ValueError("Could not decrypt token")
raise ValueError("Could not decrypt token")

View file

@ -22,6 +22,7 @@
from .utils import * from .utils import *
# TODO: Add AES-256 support to pure-python implementation
class AES: class AES:
# AES-128 block size # AES-128 block size

View file

@ -100,39 +100,68 @@ class Link:
and will be torn down. and will be torn down.
""" """
PENDING = 0x00 PENDING = 0x00
HANDSHAKE = 0x01 HANDSHAKE = 0x01
ACTIVE = 0x02 ACTIVE = 0x02
STALE = 0x03 STALE = 0x03
CLOSED = 0x04 CLOSED = 0x04
TIMEOUT = 0x01 TIMEOUT = 0x01
INITIATOR_CLOSED = 0x02 INITIATOR_CLOSED = 0x02
DESTINATION_CLOSED = 0x03 DESTINATION_CLOSED = 0x03
ACCEPT_NONE = 0x00 ACCEPT_NONE = 0x00
ACCEPT_APP = 0x01 ACCEPT_APP = 0x01
ACCEPT_ALL = 0x02 ACCEPT_ALL = 0x02
resource_strategies = [ACCEPT_NONE, ACCEPT_APP, ACCEPT_ALL] resource_strategies = [ACCEPT_NONE, ACCEPT_APP, ACCEPT_ALL]
MODE_AES128_CBC = 0x00
MODE_AES256_CBC = 0x01
MODE_AES256_GCM = 0x02
MODE_OTP_RESERVED = 0x03
MODE_PQ_RESERVED_1 = 0x04
MODE_PQ_RESERVED_2 = 0x05
MODE_PQ_RESERVED_3 = 0x06
MODE_PQ_RESERVED_4 = 0x07
enabled_modes = [MODE_AES128_CBC]
MTU_BYTEMASK = 0x1FFFFF
MODE_BYTEMASK = 0xE0
@staticmethod @staticmethod
def mtu_bytes(mtu): def mtu_bytes(mtu):
return struct.pack(">I", mtu & 0xFFFFFF)[1:] return struct.pack(">I", mtu & Link.MTU_BYTEMASK)[1:]
@staticmethod @staticmethod
def mtu_from_lr_packet(packet): def mtu_from_lr_packet(packet):
if len(packet.data) == Link.ECPUBSIZE+Link.LINK_MTU_SIZE: if len(packet.data) == Link.ECPUBSIZE+Link.LINK_MTU_SIZE:
return (packet.data[Link.ECPUBSIZE] << 16) + (packet.data[Link.ECPUBSIZE+1] << 8) + (packet.data[Link.ECPUBSIZE+2]) return (packet.data[Link.ECPUBSIZE] << 16) + (packet.data[Link.ECPUBSIZE+1] << 8) + (packet.data[Link.ECPUBSIZE+2]) & Link.MTU_BYTEMASK
else: else: return None
return None
@staticmethod @staticmethod
def mtu_from_lp_packet(packet): def mtu_from_lp_packet(packet):
if len(packet.data) == RNS.Identity.SIGLENGTH//8+Link.ECPUBSIZE//2+Link.LINK_MTU_SIZE: if len(packet.data) == RNS.Identity.SIGLENGTH//8+Link.ECPUBSIZE//2+Link.LINK_MTU_SIZE:
mtu_bytes = packet.data[RNS.Identity.SIGLENGTH//8+Link.ECPUBSIZE//2:RNS.Identity.SIGLENGTH//8+Link.ECPUBSIZE//2+Link.LINK_MTU_SIZE] mtu_bytes = packet.data[RNS.Identity.SIGLENGTH//8+Link.ECPUBSIZE//2:RNS.Identity.SIGLENGTH//8+Link.ECPUBSIZE//2+Link.LINK_MTU_SIZE]
return (mtu_bytes[0] << 16) + (mtu_bytes[1] << 8) + (mtu_bytes[2]) return (mtu_bytes[0] << 16) + (mtu_bytes[1] << 8) + (mtu_bytes[2]) & Link.MTU_BYTEMASK
else: else: return None
return None
@staticmethod
def mode_byte(mode):
if mode in Link.enabled_modes: return (mode << 5) & Link.MODE_BYTEMASK
else: raise TypeError(f"Requested link mode {mode} not enabled")
@staticmethod
def mode_from_lr_packet(packet):
if len(packet.data) > Link.ECPUBSIZE:
return (packet.data[Link.ECPUBSIZE] << 16) + (packet.data[Link.ECPUBSIZE+1] << 8) + (packet.data[Link.ECPUBSIZE+2]) & Link.MODE_BYTEMASK
else: return None
@staticmethod
def mode_from_lp_packet(packet):
if len(packet.data) > RNS.Identity.SIGLENGTH//8+Link.ECPUBSIZE//2:
mode_byte = packet.data[RNS.Identity.SIGLENGTH//8+Link.ECPUBSIZE//2:RNS.Identity.SIGLENGTH//8+Link.ECPUBSIZE//2+1]
return mode_byte & Link.MODE_BYTEMASK
else: return None
@staticmethod @staticmethod
def validate_request(owner, data, packet): def validate_request(owner, data, packet):
@ -177,9 +206,9 @@ class Link:
return None return None
def __init__(self, destination=None, established_callback = None, closed_callback = None, owner=None, peer_pub_bytes = None, peer_sig_pub_bytes = None): def __init__(self, destination=None, established_callback=None, closed_callback=None, owner=None, peer_pub_bytes=None, peer_sig_pub_bytes=None, mode=MODE_AES128_CBC):
if destination != None and destination.type != RNS.Destination.SINGLE: if destination != None and destination.type != RNS.Destination.SINGLE: raise TypeError("Links can only be established to the \"single\" destination type")
raise TypeError("Links can only be established to the \"single\" destination type") self.mode = mode
self.rtt = None self.rtt = None
self.mtu = RNS.Reticulum.MTU self.mtu = RNS.Reticulum.MTU
self.establishment_cost = 0 self.establishment_cost = 0
@ -299,14 +328,17 @@ class Link:
self.status = Link.HANDSHAKE self.status = Link.HANDSHAKE
self.shared_key = self.prv.exchange(self.peer_pub) self.shared_key = self.prv.exchange(self.peer_pub)
if self.mode == Link.MODE_AES128_CBC: derived_key_length = 32
elif self.mode == Link.MODE_AES256_CBC: derived_key_length = 64
else: raise TypeError(f"Invalid link mode {self.mode} on {self}")
self.derived_key = RNS.Cryptography.hkdf( self.derived_key = RNS.Cryptography.hkdf(
length=32, length=derived_key_length,
derive_from=self.shared_key, derive_from=self.shared_key,
salt=self.get_salt(), salt=self.get_salt(),
context=self.get_context(), context=self.get_context())
)
else: else: RNS.log("Handshake attempt on "+str(self)+" with invalid state "+str(self.status), RNS.LOG_ERROR)
RNS.log("Handshake attempt on "+str(self)+" with invalid state "+str(self.status), RNS.LOG_ERROR)
def prove(self): def prove(self):
@ -1093,8 +1125,7 @@ class Link:
def encrypt(self, plaintext): def encrypt(self, plaintext):
try: try:
if not self.token: if not self.token:
try: try: self.token = Token(self.derived_key)
self.token = Token(self.derived_key)
except Exception as e: except Exception as e:
RNS.log("Could not instantiate token while performing encryption on link "+str(self)+". The contained exception was: "+str(e), RNS.LOG_ERROR) RNS.log("Could not instantiate token while performing encryption on link "+str(self)+". The contained exception was: "+str(e), RNS.LOG_ERROR)
raise e raise e
@ -1108,9 +1139,7 @@ class Link:
def decrypt(self, ciphertext): def decrypt(self, ciphertext):
try: try:
if not self.token: if not self.token: self.token = Token(self.derived_key)
self.token = Token(self.derived_key)
return self.token.decrypt(ciphertext) return self.token.decrypt(ciphertext)
except Exception as e: except Exception as e: