Added base code for stamp generation and validation

This commit is contained in:
Mark Qvist 2024-09-06 16:49:01 +02:00
parent 4dca031441
commit fe14f8744d
2 changed files with 88 additions and 13 deletions

View File

@ -4,6 +4,8 @@ import RNS.vendor.umsgpack as msgpack
import os import os
import time import time
import base64 import base64
import signal
import multiprocessing
from .LXMF import APP_NAME from .LXMF import APP_NAME
@ -224,15 +226,40 @@ class LXMessage:
def register_failed_callback(self, callback): def register_failed_callback(self, callback):
self.failed_callback = callback self.failed_callback = callback
@staticmethod
def stamp_workblock(message_id):
wb_st = time.time()
expand_rounds = 3000
workblock = b""
for n in range(expand_rounds):
workblock += RNS.Cryptography.hkdf(
length=256,
derive_from=message_id,
salt=RNS.Identity.full_hash(message_id+msgpack.packb(n)),
context=None,
)
wb_time = time.time() - wb_st
RNS.log(f"Stamp workblock size {RNS.prettysize(len(workblock))}, generated in {round(wb_time*1000,2)}ms", RNS.LOG_DEBUG)
return workblock
@staticmethod
def stamp_valid(stamp, target_cost, workblock):
target = 0b1 << 256-target_cost
result = RNS.Identity.full_hash(workblock+stamp)
if int.from_bytes(result, byteorder="big") > target:
return False
else:
return True
def validate_stamp(self, target_cost): def validate_stamp(self, target_cost):
if self.stamp == None: if self.stamp == None:
return False return False
else: else:
target = 0b1 << 256-target_cost if LXMessage.stamp_valid(self.stamp, target_cost, LXMessage.stamp_workblock(self.message_id)):
if int.from_bytes(RNS.Identity.full_hash(self.message_id+self.stamp)) > target:
return False
else:
return True return True
else:
return False
def get_stamp(self, timeout=None): def get_stamp(self, timeout=None):
if self.stamp_cost == None: if self.stamp_cost == None:
@ -244,19 +271,60 @@ class LXMessage:
return self.stamp return self.stamp
else: else:
RNS.log(f"Generating stamp for {self}...", RNS.LOG_DEBUG) RNS.log(f"Generating stamp with cost {self.stamp_cost} for {self}...", RNS.LOG_DEBUG)
workblock = LXMessage.stamp_workblock(self.message_id)
start_time = time.time() start_time = time.time()
stamp = os.urandom(256//8); target = 0b1 << 256-self.stamp_cost; rounds = 1 total_rounds = 0
while int.from_bytes(RNS.Identity.full_hash(self.message_id+stamp)) > target:
if timeout != None and rounds % 10000 == 0: stop_event = multiprocessing.Event()
if time.time() > start_time + timeout: result_queue = multiprocessing.Queue(maxsize=1)
RNS.log(f"Stamp generation for {self} timed out", RNS.LOG_ERROR) rounds_queue = multiprocessing.Queue()
return None def job(stop_event):
terminated = False
rounds = 0
stamp = os.urandom(256//8) stamp = os.urandom(256//8)
rounds += 1 while not LXMessage.stamp_valid(stamp, self.stamp_cost, workblock):
if stop_event.is_set():
break
if timeout != None and rounds % 10000 == 0:
if time.time() > start_time + timeout:
RNS.log(f"Stamp generation for {self} timed out", RNS.LOG_ERROR)
return None
stamp = os.urandom(256//8)
rounds += 1
rounds_queue.put(rounds)
if not stop_event.is_set():
result_queue.put(stamp)
job_procs = []
jobs = multiprocessing.cpu_count()
for _ in range(jobs):
process = multiprocessing.Process(target=job, kwargs={"stop_event": stop_event},)
job_procs.append(process)
process.start()
stamp = result_queue.get()
stop_event.set()
for j in range(jobs):
process = job_procs[j]
process.join()
total_rounds += rounds_queue.get()
duration = time.time() - start_time
rounds = total_rounds
# TODO: Remove stats output
RNS.log(f"Stamp generated in {RNS.prettytime(duration)} / {rounds} rounds", RNS.LOG_DEBUG)
RNS.log(f"Rounds per second {int(rounds/duration)}", RNS.LOG_DEBUG)
RNS.log(f"Stamp: {RNS.hexrep(stamp)}", RNS.LOG_DEBUG)
RNS.log(f"Resulting hash: {RNS.hexrep(RNS.Identity.full_hash(workblock+stamp))}", RNS.LOG_DEBUG)
###########################
RNS.log(f"Stamp generated in {RNS.prettytime(time.time() - start_time)}", RNS.LOG_DEBUG)
return stamp return stamp
def pack(self): def pack(self):

View File

@ -13,6 +13,12 @@ def delivery_callback(message):
if message.unverified_reason == LXMF.LXMessage.SOURCE_UNKNOWN: if message.unverified_reason == LXMF.LXMessage.SOURCE_UNKNOWN:
signature_string = "Cannot verify, source is unknown" signature_string = "Cannot verify, source is unknown"
stamp_cost = 12
if message.validate_stamp(stamp_cost):
stamp_string = "Valid"
else:
stamp_string = "Not valid"
RNS.log("\t+--- LXMF Delivery ---------------------------------------------") RNS.log("\t+--- LXMF Delivery ---------------------------------------------")
RNS.log("\t| Source hash : "+RNS.prettyhexrep(message.source_hash)) RNS.log("\t| Source hash : "+RNS.prettyhexrep(message.source_hash))
RNS.log("\t| Source instance : "+str(message.get_source())) RNS.log("\t| Source instance : "+str(message.get_source()))
@ -24,6 +30,7 @@ def delivery_callback(message):
RNS.log("\t| Content : "+message.content_as_string()) RNS.log("\t| Content : "+message.content_as_string())
RNS.log("\t| Fields : "+str(message.fields)) RNS.log("\t| Fields : "+str(message.fields))
RNS.log("\t| Message signature : "+signature_string) RNS.log("\t| Message signature : "+signature_string)
RNS.log("\t| Stamp : "+stamp_string)
RNS.log("\t+---------------------------------------------------------------") RNS.log("\t+---------------------------------------------------------------")
r = RNS.Reticulum() r = RNS.Reticulum()