Implemented ratchets

This commit is contained in:
Mark Qvist 2024-09-04 17:37:18 +02:00
parent c32086c6f1
commit a11f14e75f
6 changed files with 170 additions and 48 deletions

View file

@ -1,6 +1,6 @@
# MIT License
#
# Copyright (c) 2016-2023 Mark Qvist / unsigned.io and contributors.
# Copyright (c) 2016-2024 Mark Qvist / unsigned.io and contributors.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@ -50,7 +50,10 @@ class Identity:
KEYSIZE = 256*2
"""
X25519 key size in bits. A complete key is the concatenation of a 256 bit encryption key, and a 256 bit signing key.
"""
"""
RATCHETSIZE = 256
RATCHET_EXPIRY = 60*60*24*30
# Non-configurable constants
FERNET_OVERHEAD = RNS.Cryptography.Fernet.FERNET_OVERHEAD
@ -67,6 +70,7 @@ class Identity:
# Storage
known_destinations = {}
known_ratchets = {}
@staticmethod
def remember(packet_hash, destination_hash, public_key, app_data = None):
@ -222,29 +226,102 @@ class Identity:
return Identity.truncated_hash(os.urandom(Identity.TRUNCATED_HASHLENGTH//8))
@staticmethod
def generate_ratchet():
def _ratchet_public_bytes(ratchet):
return X25519PrivateKey.from_private_bytes(ratchet).public_key().public_bytes()
@staticmethod
def _generate_ratchet():
ratchet_prv = X25519PrivateKey.generate()
ratchet_pub = ratchet_prv.public_key()
return ratchet_prv.private_bytes()
@staticmethod
def ratchet_public_bytes(ratchet):
return X25519PrivateKey.from_private_bytes(ratchet).public_key().public_bytes()
def _remember_ratchet(destination_hash, ratchet):
RNS.log(f"Remembering ratchet {RNS.hexrep(ratchet)} for {RNS.prettyhexrep(destination_hash)}", RNS.LOG_DEBUG) # TODO: Remove
try:
Identity.known_ratchets[destination_hash] = ratchet
hexhash = RNS.hexrep(destination_hash, delimit=False)
ratchet_data = {"ratchet": ratchet, "received": time.time()}
ratchetdir = RNS.Reticulum.storagepath+"/ratchets"
if not os.path.isdir(ratchetdir):
os.makedirs(ratchetdir)
outpath = f"{ratchetdir}/{hexhash}.out"
finalpath = f"{ratchetdir}/{hexhash}"
ratchet_file = open(outpath, "wb")
ratchet_file.write(umsgpack.packb(ratchet_data))
ratchet_file.close()
os.rename(outpath, finalpath)
except Exception as e:
RNS.log(f"Could not persist ratchet for {RNS.prettyhexrep(destination_hash)} to storage.", RNS.LOG_ERROR)
RNS.log(f"The contained exception was: {e}")
RNS.trace_exception(e)
@staticmethod
def get_ratchet(destination_hash):
if not destination_hash in Identity.known_ratchets:
RNS.log(f"Trying to load ratchet for {RNS.prettyhexrep(destination_hash)} from storage") # TODO: Remove
ratchetdir = RNS.Reticulum.storagepath+"/ratchets"
hexhash = RNS.hexrep(destination_hash, delimit=False)
ratchet_path = f"{ratchetdir}/hexhash"
if os.path.isfile(ratchet_path):
try:
ratchet_file = open(ratchet_path, "rb")
ratchet_data = umsgpack.unpackb(ratchets_file.read())
if time.time() < ratchet_data["received"]+Identity.RATCHET_EXPIRY and len(ratchet_data["ratchet"]) == Identity.RATCHETSIZE//8:
Identity.known_ratchets[destination_hash] = ratchet_data["ratchet"]
else:
return None
except Exception as e:
RNS.log(f"An error occurred while loading ratchet data for {RNS.prettyhexrep(destination_hash)} from storage.", RNS.LOG_ERROR)
RNS.log(f"The contained exception was: {e}", RNS.LOG_ERROR)
return None
if destination_hash in Identity.known_ratchets:
return Identity.known_ratchets[destination_hash]
else:
return None
@staticmethod
def validate_announce(packet, only_validate_signature=False):
try:
if packet.packet_type == RNS.Packet.ANNOUNCE:
keysize = Identity.KEYSIZE//8
ratchetsize = Identity.RATCHETSIZE//8
name_hash_len = Identity.NAME_HASH_LENGTH//8
sig_len = Identity.SIGLENGTH//8
destination_hash = packet.destination_hash
public_key = packet.data[:Identity.KEYSIZE//8]
name_hash = packet.data[Identity.KEYSIZE//8:Identity.KEYSIZE//8+Identity.NAME_HASH_LENGTH//8]
random_hash = packet.data[Identity.KEYSIZE//8+Identity.NAME_HASH_LENGTH//8:Identity.KEYSIZE//8+Identity.NAME_HASH_LENGTH//8+10]
signature = packet.data[Identity.KEYSIZE//8+Identity.NAME_HASH_LENGTH//8+10:Identity.KEYSIZE//8+Identity.NAME_HASH_LENGTH//8+10+Identity.SIGLENGTH//8]
app_data = b""
if len(packet.data) > Identity.KEYSIZE//8+Identity.NAME_HASH_LENGTH//8+10+Identity.SIGLENGTH//8:
app_data = packet.data[Identity.KEYSIZE//8+Identity.NAME_HASH_LENGTH//8+10+Identity.SIGLENGTH//8:]
signed_data = destination_hash+public_key+name_hash+random_hash+app_data
# Get public key bytes from announce
public_key = packet.data[:keysize]
# If the packet context flag is set,
# this announce contains a new ratchet
if packet.context_flag == RNS.Packet.FLAG_SET:
name_hash = packet.data[keysize:keysize+name_hash_len ]
random_hash = packet.data[keysize+name_hash_len:keysize+name_hash_len+10]
ratchet = packet.data[keysize+name_hash_len+10:keysize+name_hash_len+10+ratchetsize]
signature = packet.data[keysize+name_hash_len+10+ratchetsize:keysize+name_hash_len+10+ratchetsize+sig_len]
app_data = b""
if len(packet.data) > keysize+name_hash_len+10+sig_len+ratchetsize:
app_data = packet.data[keysize+name_hash_len+10+sig_len+ratchetsize:]
# If the packet context flag is not set,
# this announce does not contain a ratchet
else:
ratchet = b""
name_hash = packet.data[keysize:keysize+name_hash_len]
random_hash = packet.data[keysize+name_hash_len:keysize+name_hash_len+10]
signature = packet.data[keysize+name_hash_len+10:keysize+name_hash_len+10+sig_len]
app_data = b""
if len(packet.data) > keysize+name_hash_len+10+sig_len:
app_data = packet.data[keysize+name_hash_len+10+sig_len:]
signed_data = destination_hash+public_key+name_hash+random_hash+ratchet+app_data
if not len(packet.data) > Identity.KEYSIZE//8+Identity.NAME_HASH_LENGTH//8+10+Identity.SIGLENGTH//8:
app_data = None
@ -291,6 +368,9 @@ class Identity:
else:
RNS.log("Valid announce for "+RNS.prettyhexrep(destination_hash)+" "+str(packet.hops)+" hops away, received on "+str(packet.receiving_interface)+signal_str, RNS.LOG_EXTREME)
if ratchet:
Identity._remember_ratchet(destination_hash, ratchet)
return True
else:
@ -479,7 +559,7 @@ class Identity:
def get_context(self):
return None
def encrypt(self, plaintext):
def encrypt(self, plaintext, ratchet=None):
"""
Encrypts information for the identity.
@ -491,7 +571,13 @@ class Identity:
ephemeral_key = X25519PrivateKey.generate()
ephemeral_pub_bytes = ephemeral_key.public_key().public_bytes()
shared_key = ephemeral_key.exchange(self.pub)
if ratchet != None:
RNS.log(f"Encrypting with ratchet {RNS.hexrep(ratchet)}", RNS.LOG_DEBUG) # TODO: Remove
target_public_key = X25519PublicKey.from_public_bytes(ratchet)
else:
target_public_key = self.pub
shared_key = ephemeral_key.exchange(target_public_key)
derived_key = RNS.Cryptography.hkdf(
length=32,
@ -509,7 +595,7 @@ class Identity:
raise KeyError("Encryption failed because identity does not hold a public key")
def decrypt(self, ciphertext_token):
def decrypt(self, ciphertext_token, ratchets=None):
"""
Decrypts information for the identity.
@ -523,19 +609,40 @@ class Identity:
try:
peer_pub_bytes = ciphertext_token[:Identity.KEYSIZE//8//2]
peer_pub = X25519PublicKey.from_public_bytes(peer_pub_bytes)
shared_key = self.prv.exchange(peer_pub)
derived_key = RNS.Cryptography.hkdf(
length=32,
derive_from=shared_key,
salt=self.get_salt(),
context=self.get_context(),
)
fernet = Fernet(derived_key)
ciphertext = ciphertext_token[Identity.KEYSIZE//8//2:]
plaintext = fernet.decrypt(ciphertext)
if ratchets:
for ratchet in ratchets:
try:
ratchet_prv = X25519PrivateKey.from_private_bytes(ratchet)
shared_key = ratchet_prv.exchange(peer_pub)
derived_key = RNS.Cryptography.hkdf(
length=32,
derive_from=shared_key,
salt=self.get_salt(),
context=self.get_context(),
)
fernet = Fernet(derived_key)
plaintext = fernet.decrypt(ciphertext)
RNS.log(f"Decrypted with ratchet {RNS.hexrep(ratchet_prv.public_key().public_bytes())}", RNS.LOG_DEBUG) # TODO: Remove
break
except Exception as e:
pass
# RNS.log("Decryption using this ratchet failed", RNS.LOG_DEBUG) # TODO: Remove
if plaintext == None:
shared_key = self.prv.exchange(peer_pub)
derived_key = RNS.Cryptography.hkdf(
length=32,
derive_from=shared_key,
salt=self.get_salt(),
context=self.get_context(),
)
fernet = Fernet(derived_key)
plaintext = fernet.decrypt(ciphertext)
except Exception as e:
RNS.log("Decryption by "+RNS.prettyhexrep(self.hash)+" failed: "+str(e), RNS.LOG_DEBUG)