From 40eb014c917c33927384c59ddd9b66533d839580 Mon Sep 17 00:00:00 2001 From: Mark Qvist Date: Sat, 7 Sep 2024 22:40:13 +0200 Subject: [PATCH] Implemented deferred multiprocessor stamp generation in the background --- LXMF/LXMRouter.py | 101 ++++++++++++++++++++---- LXMF/LXMessage.py | 190 +++++++++++++++++++++++++++------------------- 2 files changed, 200 insertions(+), 91 deletions(-) diff --git a/LXMF/LXMRouter.py b/LXMF/LXMRouter.py index 2ff5d75..9a82a13 100644 --- a/LXMF/LXMRouter.py +++ b/LXMF/LXMRouter.py @@ -95,6 +95,7 @@ class LXMRouter: self.delivery_per_transfer_limit = delivery_limit self.enforce_ratchets = enforce_ratchets self._enforce_stamps = enforce_stamps + self.pending_deferred_stamps = {} self.wants_download_on_path_available_from = None self.wants_download_on_path_available_to = None @@ -110,6 +111,7 @@ class LXMRouter: self.cost_file_lock = threading.Lock() self.ticket_file_lock = threading.Lock() + self.stamp_gen_lock = threading.Lock() if identity == None: identity = RNS.Identity() @@ -237,8 +239,7 @@ class LXMRouter: if display_name != None: def get_app_data(): - return self.get_announce_app_data(delivery_destination) - + return self.get_announce_app_data(delivery_destination.hash) delivery_destination.set_default_app_data(get_app_data) self.delivery_destinations[delivery_destination.hash] = delivery_destination @@ -540,6 +541,7 @@ class LXMRouter: ####################################################### JOB_OUTBOUND_INTERVAL = 1 + JOB_STAMPS_INTERVAL = 1 JOB_LINKS_INTERVAL = 1 JOB_TRANSIENT_INTERVAL = 60 JOB_STORE_INTERVAL = 120 @@ -550,6 +552,9 @@ class LXMRouter: if self.processing_count % LXMRouter.JOB_OUTBOUND_INTERVAL == 0: self.process_outbound() + if self.processing_count % LXMRouter.JOB_STAMPS_INTERVAL == 0: + threading.Thread(target=self.process_deferred_stamps, daemon=True).start() + if self.processing_count % LXMRouter.JOB_LINKS_INTERVAL == 0: self.clean_links() @@ -721,6 +726,14 @@ class LXMRouter: return None + def get_outbound_ticket_expiry(self, destination_hash): + if destination_hash in self.available_tickets["outbound"]: + entry = self.available_tickets["outbound"][destination_hash] + if entry[0] > time.time(): + return entry[0] + + return None + def get_inbound_tickets(self, destination_hash): now = time.time() available_tickets = [] @@ -916,6 +929,29 @@ class LXMRouter: except Exception as e: RNS.log("Could not save available tickets to storage. The contained exception was: "+str(e), RNS.LOG_ERROR) + def reload_available_tickets(self): + RNS.log("Reloading available tickets from storage", RNS.LOG_DEBUG) + try: + with self.ticket_file_lock: + with open(self.storagepath+"/available_tickets", "rb") as available_tickets_file: + data = available_tickets_file.read() + self.available_tickets = msgpack.unpackb(data) + if not type(self.available_tickets) == dict: + RNS.log("Invalid data format for loaded available tickets, recreating...", RNS.LOG_ERROR) + self.available_tickets = {"outbound": {}, "inbound": {}, "last_deliveries": {}} + if not "outbound" in self.available_tickets: + RNS.log("Missing outbound entry in loaded available tickets, recreating...", RNS.LOG_ERROR) + self.available_tickets["outbound"] = {} + if not "inbound" in self.available_tickets: + RNS.log("Missing inbound entry in loaded available tickets, recreating...", RNS.LOG_ERROR) + self.available_tickets["inbound"] = {} + if not "last_deliveries" in self.available_tickets: + RNS.log("Missing local_deliveries entry in loaded available tickets, recreating...", RNS.LOG_ERROR) + self.available_tickets["last_deliveries"] = {} + + except Exception as e: + RNS.log(f"An error occurred while reloading available tickets from storage: {e}", RNS.LOG_ERROR) + def exit_handler(self): if self.propagation_node: try: @@ -1188,19 +1224,24 @@ class LXMRouter: while self.processing_outbound: time.sleep(0.1) - self.pending_outbound.append(lxmessage) - if lxmessage.defer_stamp and lxmessage.stamp_cost == None: RNS.log(f"Deferred stamp generation was requested for {lxmessage}, but no stamp is required, processing immediately", RNS.LOG_DEBUG) lxmessage.defer_stamp = False if not lxmessage.defer_stamp: + self.pending_outbound.append(lxmessage) self.process_outbound() + else: + self.pending_deferred_stamps[lxmessage.message_id] = lxmessage def get_outbound_progress(self, lxm_hash): for lxm in self.pending_outbound: if lxm.hash == lxm_hash: return lxm.progress + + for lxm_id in self.pending_deferred_stamps: + if self.pending_deferred_stamps[lxm_id].hash == lxm_hash: + return self.pending_deferred_stamps[lxm_id].progress return None @@ -1208,6 +1249,10 @@ class LXMRouter: for lxm in self.pending_outbound: if lxm.hash == lxm_hash: return lxm.stamp_cost + + for lxm_id in self.pending_deferred_stamps: + if self.pending_deferred_stamps[lxm_id].hash == lxm_hash: + return self.pending_deferred_stamps[lxm_id].stamp_cost return None @@ -1616,13 +1661,51 @@ class LXMRouter: def fail_message(self, lxmessage): RNS.log(str(lxmessage)+" failed to send", RNS.LOG_DEBUG) - self.pending_outbound.remove(lxmessage) + if lxmessage in self.pending_outbound: + self.pending_outbound.remove(lxmessage) + self.failed_outbound.append(lxmessage) lxmessage.state = LXMessage.FAILED if lxmessage.failed_callback != None and callable(lxmessage.failed_callback): lxmessage.failed_callback(lxmessage) + def process_deferred_stamps(self): + if len(self.pending_deferred_stamps) > 0: + RNS.log(f"Processing deferred stamps...", RNS.LOG_DEBUG) # TODO: Remove + + if self.stamp_gen_lock.locked(): + RNS.log(f"A stamp is already generating, returning...", RNS.LOG_DEBUG) # TODO: Remove + return + + else: + with self.stamp_gen_lock: + selected_lxm = None + selected_message_id = None + for message_id in self.pending_deferred_stamps: + lxmessage = self.pending_deferred_stamps[message_id] + if selected_lxm == None: + selected_lxm = lxmessage + selected_message_id = message_id + + if selected_lxm != None: + RNS.log(f"Starting stamp generation for {selected_lxm}...", RNS.LOG_DEBUG) + generated_stamp = selected_lxm.get_stamp() + if generated_stamp: + selected_lxm.stamp = generated_stamp + selected_lxm.defer_stamp = False + selected_lxm.packed = None + selected_lxm.pack() + self.pending_deferred_stamps.pop(selected_message_id) + self.pending_outbound.append(selected_lxm) + RNS.log(f"Stamp generation completed for {selected_lxm}", RNS.LOG_DEBUG) + else: + RNS.log(f"Deferred stamp generation did not succeed. Failing {selected_lxm}.", RNS.LOG_ERROR) + selected_lxm.stamp_generation_failed = True + self.pending_deferred_stamps.pop(selected_message_id) + self.fail_message(selected_lxm) + + def process_outbound(self, sender = None): if self.processing_outbound: return @@ -1641,14 +1724,6 @@ class LXMRouter: self.pending_outbound.remove(lxmessage) else: RNS.log("Starting outbound processing for "+str(lxmessage)+" to "+RNS.prettyhexrep(lxmessage.get_destination().hash), RNS.LOG_DEBUG) - - # Handle potentially deferred stamp generation - if lxmessage.defer_stamp and lxmessage.stamp == None: - RNS.log(f"Generating deferred stamp for {lxmessage} now", RNS.LOG_DEBUG) - lxmessage.stamp = lxmessage.get_stamp() - lxmessage.defer_stamp = False - lxmessage.packed = None - lxmessage.pack() if lxmessage.progress == None or lxmessage.progress < 0.01: lxmessage.progress = 0.01 diff --git a/LXMF/LXMessage.py b/LXMF/LXMessage.py index 14ff299..86b637d 100644 --- a/LXMF/LXMessage.py +++ b/LXMF/LXMessage.py @@ -45,7 +45,7 @@ class LXMessage: TICKET_EXPIRY = 21*24*60*60 TICKET_GRACE = 5*24*60*60 TICKET_RENEW = 14*24*60*60 - TICKET_INTERVAL = 3*24*60*60 + TICKET_INTERVAL = 1*24*60*60 # LXMF overhead is 111 bytes per message: # 16 bytes for destination hash @@ -131,24 +131,24 @@ class LXMessage: self.set_content_from_string(content) self.set_fields(fields) - self.payload = None - self.timestamp = None - self.signature = None - self.hash = None - self.packed = None - self.state = LXMessage.GENERATING - self.method = LXMessage.UNKNOWN - self.progress = 0.0 - self.rssi = None - self.snr = None - self.q = None + self.payload = None + self.timestamp = None + self.signature = None + self.hash = None + self.packed = None + self.state = LXMessage.GENERATING + self.method = LXMessage.UNKNOWN + self.progress = 0.0 + self.rssi = None + self.snr = None + self.q = None - self.stamp = None - self.stamp_cost = stamp_cost - self.stamp_valid = False - self.defer_stamp = False - self.outbound_ticket = None - self.include_ticket = include_ticket + self.stamp = None + self.stamp_cost = stamp_cost + self.stamp_valid = False + self.defer_stamp = True + self.outbound_ticket = None + self.include_ticket = include_ticket self.propagation_packed = None self.paper_packed = None @@ -166,7 +166,9 @@ class LXMessage: self.resource_representation = None self.__delivery_destination = None self.__delivery_callback = None - self.failed_callback = None + self.failed_callback = None + + self.deferred_stamp_generating = False def set_title_from_string(self, title_string): self.title = title_string.encode("utf-8") @@ -312,50 +314,79 @@ class LXMessage: total_rounds = 0 if not RNS.vendor.platformutils.is_android(): - RNS.log("Preparing IPC semaphores", RNS.LOG_DEBUG) # TODO: Remove + mp_debug = True + + jobs = multiprocessing.cpu_count() stop_event = multiprocessing.Event() - result_queue = multiprocessing.Queue(maxsize=1) + result_queue = multiprocessing.Queue(1) rounds_queue = multiprocessing.Queue() - def job(stop_event): + + def job(stop_event, pn, sc, wb): terminated = False rounds = 0 + pstamp = os.urandom(256//8) - stamp = os.urandom(256//8) - while not LXMessage.stamp_valid(stamp, self.stamp_cost, workblock): - if stop_event.is_set(): - break + def sv(s, c, w): + target = 0b1<<256-c; m = w+s + result = RNS.Identity.full_hash(m) + if int.from_bytes(result, byteorder="big") > target: + return False + else: + return True - if timeout != None and rounds % 10000 == 0: - if time.time() > start_time + timeout: - RNS.log(f"Stamp generation for {self} timed out", RNS.LOG_ERROR) - return None + while not stop_event.is_set() and not sv(pstamp, sc, wb): + pstamp = os.urandom(256//8); rounds += 1 - stamp = os.urandom(256//8) - rounds += 1 - - rounds_queue.put(rounds) if not stop_event.is_set(): - result_queue.put(stamp) - + stop_event.set() + result_queue.put(pstamp) + rounds_queue.put(rounds) + job_procs = [] - jobs = multiprocessing.cpu_count() - RNS.log("Starting workers", RNS.LOG_DEBUG) # TODO: Remove - for _ in range(jobs): - process = multiprocessing.Process(target=job, kwargs={"stop_event": stop_event},) + RNS.log(f"Starting {jobs} workers", RNS.LOG_DEBUG) # TODO: Remove + for jpn in range(jobs): + process = multiprocessing.Process(target=job, kwargs={"stop_event": stop_event, "pn": jpn, "sc": self.stamp_cost, "wb": workblock},) job_procs.append(process) process.start() - RNS.log("Awaiting results on queue", RNS.LOG_DEBUG) # TODO: Remove stamp = result_queue.get() - stop_event.set() - - RNS.log("Joining worker processes", RNS.LOG_DEBUG) # TODO: Remove - for j in range(jobs): - process = job_procs[j] - process.join() - total_rounds += rounds_queue.get() - + RNS.log("Got stamp result from worker", RNS.LOG_DEBUG) # TODO: Remove duration = time.time() - start_time + + spurious_results = 0 + try: + while True: + result_queue.get_nowait() + spurious_results += 1 + except: + pass + + for j in range(jobs): + nrounds = 0 + try: + nrounds = rounds_queue.get(timeout=2) + except Exception as e: + RNS.log(f"Failed to get round stats part {j}: {e}", RNS.LOG_ERROR) # TODO: Remove + total_rounds += nrounds + + all_exited = False + exit_timeout = time.time() + 5 + while time.time() < exit_timeout: + if not any(p.is_alive() for p in job_procs): + all_exited = True + break + time.sleep(0.1) + + if not all_exited: + RNS.log("Stamp generation IPC timeout, possible worker deadlock", RNS.LOG_ERROR) + return None + + else: + for j in range(jobs): + process = job_procs[j] + process.join() + # RNS.log(f"Joined {j} / {process}", RNS.LOG_DEBUG) # TODO: Remove + rounds = total_rounds else: @@ -365,17 +396,21 @@ class LXMessage: # checking in on the progress. use_nacl = False - try: - import nacl.encoding - import nacl.hash - use_nacl = True - except: - pass + rounds_per_worker = 1000 + if RNS.vendor.platformutils.is_android(): + rounds_per_worker = 500 + try: + import nacl.encoding + import nacl.hash + use_nacl = True + except: + pass - def full_hash(m): - if use_nacl: + if use_nacl: + def full_hash(m): return nacl.hash.sha256(m, encoder=nacl.encoding.RawEncoder) - else: + else: + def full_hash(m): return RNS.Identity.full_hash(m) def sv(s, c, w): @@ -391,30 +426,35 @@ class LXMessage: wm = multiprocessing.Manager() jobs = multiprocessing.cpu_count() - # RNS.log(f"Dispatching {jobs} workers for stamp generation...") # TODO: Remove + RNS.log(f"Dispatching {jobs} workers for stamp generation...", RNS.LOG_DEBUG) # TODO: Remove results_dict = wm.dict() while stamp == None: job_procs = [] - def job(procnum=None, results_dict=None, wb=None): - # RNS.log(f"Worker {procnum} starting...") # TODO: Remove + def job(procnum=None, results_dict=None, wb=None, sc=None, jr=None): + RNS.log(f"Worker {procnum} starting for {jr} rounds...") # TODO: Remove rounds = 0 + found_stamp = None + found_time = None - stamp = os.urandom(256//8) - while not sv(stamp, self.stamp_cost, wb): - if rounds >= 500: - stamp = None + while True: + pstamp = os.urandom(256//8) + rounds += 1 + if sv(pstamp, sc, wb): + found_stamp = pstamp + found_time = time.time() + break + + if rounds >= jr: # RNS.log(f"Worker {procnum} found no result in {rounds} rounds") # TODO: Remove break - stamp = os.urandom(256//8) - rounds += 1 - - results_dict[procnum] = [stamp, rounds] + results_dict[procnum] = [found_stamp, rounds, found_time] for pnum in range(jobs): - process = multiprocessing.Process(target=job, kwargs={"procnum":pnum, "results_dict": results_dict, "wb": workblock},) + pargs = {"procnum":pnum, "results_dict": results_dict, "wb": workblock, "sc":self.stamp_cost, "jr":rounds_per_worker} + process = multiprocessing.Process(target=job, kwargs=pargs) job_procs.append(process) process.start() @@ -423,14 +463,13 @@ class LXMessage: for j in results_dict: r = results_dict[j] - # RNS.log(f"Result from {r}: {r[1]} rounds, stamp: {r[0]}") # TODO: Remove total_rounds += r[1] if r[0] != None: stamp = r[0] - # RNS.log(f"Found stamp: {stamp}") # TODO: Remove + found_time = r[2] if stamp == None: - elapsed = time.time() - start_time + elapsed = found_time - start_time speed = total_rounds/elapsed RNS.log(f"Stamp generation for {self} running. {total_rounds} rounds completed so far, {int(speed)} rounds per second", RNS.LOG_DEBUG) @@ -439,12 +478,7 @@ class LXMessage: speed = total_rounds/duration - # TODO: Remove stats output RNS.log(f"Stamp generated in {RNS.prettytime(duration)}, {rounds} rounds, {int(speed)} rounds per second", RNS.LOG_DEBUG) - # RNS.log(f"Rounds per second {int(rounds/duration)}", RNS.LOG_DEBUG) - # RNS.log(f"Stamp: {RNS.hexrep(stamp)}", RNS.LOG_DEBUG) - # RNS.log(f"Resulting hash: {RNS.hexrep(RNS.Identity.full_hash(workblock+stamp))}", RNS.LOG_DEBUG) - ########################### return stamp