diff --git a/LXMF/LXMessage.py b/LXMF/LXMessage.py index c6eeb3f..a79f600 100644 --- a/LXMF/LXMessage.py +++ b/LXMF/LXMessage.py @@ -4,6 +4,8 @@ import RNS.vendor.umsgpack as msgpack import os import time import base64 +import signal +import multiprocessing from .LXMF import APP_NAME @@ -224,15 +226,40 @@ class LXMessage: def register_failed_callback(self, callback): self.failed_callback = callback + @staticmethod + def stamp_workblock(message_id): + wb_st = time.time() + expand_rounds = 3000 + workblock = b"" + for n in range(expand_rounds): + workblock += RNS.Cryptography.hkdf( + length=256, + derive_from=message_id, + salt=RNS.Identity.full_hash(message_id+msgpack.packb(n)), + context=None, + ) + wb_time = time.time() - wb_st + RNS.log(f"Stamp workblock size {RNS.prettysize(len(workblock))}, generated in {round(wb_time*1000,2)}ms", RNS.LOG_DEBUG) + + return workblock + + @staticmethod + def stamp_valid(stamp, target_cost, workblock): + target = 0b1 << 256-target_cost + result = RNS.Identity.full_hash(workblock+stamp) + if int.from_bytes(result, byteorder="big") > target: + return False + else: + return True + def validate_stamp(self, target_cost): if self.stamp == None: return False else: - target = 0b1 << 256-target_cost - if int.from_bytes(RNS.Identity.full_hash(self.message_id+self.stamp)) > target: - return False - else: + if LXMessage.stamp_valid(self.stamp, target_cost, LXMessage.stamp_workblock(self.message_id)): return True + else: + return False def get_stamp(self, timeout=None): if self.stamp_cost == None: @@ -244,19 +271,60 @@ class LXMessage: return self.stamp else: - RNS.log(f"Generating stamp for {self}...", RNS.LOG_DEBUG) + RNS.log(f"Generating stamp with cost {self.stamp_cost} for {self}...", RNS.LOG_DEBUG) + workblock = LXMessage.stamp_workblock(self.message_id) start_time = time.time() - stamp = os.urandom(256//8); target = 0b1 << 256-self.stamp_cost; rounds = 1 - while int.from_bytes(RNS.Identity.full_hash(self.message_id+stamp)) > target: - if timeout != None and rounds % 10000 == 0: - if time.time() > start_time + timeout: - RNS.log(f"Stamp generation for {self} timed out", RNS.LOG_ERROR) - return None + total_rounds = 0 + + stop_event = multiprocessing.Event() + result_queue = multiprocessing.Queue(maxsize=1) + rounds_queue = multiprocessing.Queue() + def job(stop_event): + terminated = False + rounds = 0 stamp = os.urandom(256//8) - rounds += 1 + while not LXMessage.stamp_valid(stamp, self.stamp_cost, workblock): + if stop_event.is_set(): + break + + if timeout != None and rounds % 10000 == 0: + if time.time() > start_time + timeout: + RNS.log(f"Stamp generation for {self} timed out", RNS.LOG_ERROR) + return None + + stamp = os.urandom(256//8) + rounds += 1 + + rounds_queue.put(rounds) + if not stop_event.is_set(): + result_queue.put(stamp) + + job_procs = [] + jobs = multiprocessing.cpu_count() + for _ in range(jobs): + process = multiprocessing.Process(target=job, kwargs={"stop_event": stop_event},) + job_procs.append(process) + process.start() + + stamp = result_queue.get() + stop_event.set() + + for j in range(jobs): + process = job_procs[j] + process.join() + total_rounds += rounds_queue.get() + + duration = time.time() - start_time + rounds = total_rounds + + # TODO: Remove stats output + RNS.log(f"Stamp generated in {RNS.prettytime(duration)} / {rounds} rounds", RNS.LOG_DEBUG) + RNS.log(f"Rounds per second {int(rounds/duration)}", RNS.LOG_DEBUG) + RNS.log(f"Stamp: {RNS.hexrep(stamp)}", RNS.LOG_DEBUG) + RNS.log(f"Resulting hash: {RNS.hexrep(RNS.Identity.full_hash(workblock+stamp))}", RNS.LOG_DEBUG) + ########################### - RNS.log(f"Stamp generated in {RNS.prettytime(time.time() - start_time)}", RNS.LOG_DEBUG) return stamp def pack(self): diff --git a/docs/example_receiver.py b/docs/example_receiver.py index 75c628e..e6ea117 100644 --- a/docs/example_receiver.py +++ b/docs/example_receiver.py @@ -13,6 +13,12 @@ def delivery_callback(message): if message.unverified_reason == LXMF.LXMessage.SOURCE_UNKNOWN: signature_string = "Cannot verify, source is unknown" + stamp_cost = 12 + if message.validate_stamp(stamp_cost): + stamp_string = "Valid" + else: + stamp_string = "Not valid" + RNS.log("\t+--- LXMF Delivery ---------------------------------------------") RNS.log("\t| Source hash : "+RNS.prettyhexrep(message.source_hash)) RNS.log("\t| Source instance : "+str(message.get_source())) @@ -24,6 +30,7 @@ def delivery_callback(message): RNS.log("\t| Content : "+message.content_as_string()) RNS.log("\t| Fields : "+str(message.fields)) RNS.log("\t| Message signature : "+signature_string) + RNS.log("\t| Stamp : "+stamp_string) RNS.log("\t+---------------------------------------------------------------") r = RNS.Reticulum()