Stamp validation on incoming propagation node transfers

This commit is contained in:
Mark Qvist 2025-10-31 02:19:24 +01:00
parent ebc8bb33c2
commit 9c646aead7
2 changed files with 48 additions and 52 deletions

View file

@ -2000,10 +2000,13 @@ class LXMRouter:
#######################################
# TODO: Check propagation stamps here #
#######################################
stamps_valid = False
target_cost = max(0, self.propagation_stamp_cost-self.propagation_stamp_cost_flexibility)
validated_messages = LXStamper.validate_pn_stamps(messages, target_cost)
for lxmf_data in messages:
self.lxmf_propagation(lxmf_data)
for validated_entry in validated_messages:
lxmf_data = validated_entry[1]
stamp_value = validated_entry[2]
self.lxmf_propagation(lxmf_data, stamp_value=stamp_value)
self.client_propagation_messages_received += 1
if stamps_valid: packet.prove()
@ -2093,10 +2096,14 @@ class LXMRouter:
#######################################
# TODO: Check propagation stamps here #
#######################################
target_cost = max(0, self.propagation_stamp_cost-self.propagation_stamp_cost_flexibility)
validated_messages = LXStamper.validate_pn_stamps(messages, target_cost)
for lxmf_data in messages:
peer = None
transient_id = RNS.Identity.full_hash(lxmf_data)
for validated_entry in validated_messages:
transient_id = validated_entry[0]
lxmf_data = validated_entry[1]
stamp_value = validated_entry[2]
peer = None
if remote_hash != None and remote_hash in self.peers:
peer = self.peers[remote_hash]
@ -2109,7 +2116,7 @@ class LXMRouter:
else:
self.client_propagation_messages_received += 1
self.lxmf_propagation(lxmf_data, from_peer=peer)
self.lxmf_propagation(lxmf_data, from_peer=peer, stamp_value=stamp_value)
if peer != None: peer.queue_handled_message(transient_id)
else:

View file

@ -4,6 +4,7 @@ import RNS.vendor.umsgpack as msgpack
import os
import time
import math
import itertools
import multiprocessing
WORKBLOCK_EXPAND_ROUNDS = 3000
@ -43,46 +44,42 @@ def stamp_valid(stamp, target_cost, workblock):
if int.from_bytes(result, byteorder="big") > target: return False
else: return True
def validate_pn_stamp(transient_id, stamp):
target_cost = 8
workblock = stamp_workblock(transient_id, expand_rounds=WORKBLOCK_EXPAND_ROUNDS_PN)
if stamp_valid(stamp, target_cost, workblock):
RNS.log(f"Stamp on {RNS.prettyhexrep(transient_id)} validated", RNS.LOG_DEBUG)
value = stamp_value(workblock, stamp)
return True
return False
def validate_pn_stamps_job_simple(transient_stamps):
for entry in transient_stamps:
# Get transient ID and stamp for validation
transient_id = transient_stamps[0]
stamp = transient_stamps[1]
def validate_pn_stamp(transient_data, target_cost):
from .LXMessage import LXMessage
if len(transient_data) <= LXMessage.LXMF_OVERHEAD+STAMP_SIZE: return False, None, None
else:
lxm_data = transient_data[:-STAMP_SIZE]
stamp = transient_data[-STAMP_SIZE:]
transient_id = RNS.Identity.full_hash(lxm_data)
workblock = stamp_workblock(transient_id, expand_rounds=WORKBLOCK_EXPAND_ROUNDS_PN)
# Store validation result back into list
transient_stamps[2] = validate_pn_stamp(transient_id, stamp)
if not stamp_valid(stamp, target_cost, workblock): return False, None, None
else:
value = stamp_value(workblock, stamp)
return True, transient_id, value
return transient_stamps
def validate_pn_stamps_job_simple(transient_list, target_cost):
validated_messages = []
for transient_data in transient_list:
stamp_valid, transient_id, value = validate_pn_stamp(transient_data, target_cost)
if stamp_valid: validated_messages.append([transient_id, transient_data, value])
def _validate_single_pn_stamp_entry(entry):
transient_id = entry[0]
stamp = entry[1]
entry[2] = validate_pn_stamp(transient_id, stamp)
return entry
return validated_messages
def validate_pn_stamps_job_multip(transient_stamps):
def validate_pn_stamps_job_multip(transient_list, target_cost):
cores = multiprocessing.cpu_count()
pool_count = min(cores, math.ceil(len(transient_stamps) / PN_VALIDATION_POOL_MIN_SIZE))
pool_count = min(cores, math.ceil(len(transient_list) / PN_VALIDATION_POOL_MIN_SIZE))
RNS.log(f"Validating {len(transient_stamps)} stamps using {pool_count} processes...", RNS.LOG_VERBOSE)
with multiprocessing.Pool(pool_count) as p: validated_entries = p.map(_validate_single_pn_stamp_entry, transient_stamps)
RNS.log(f"Validating {len(transient_list)} stamps using {pool_count} processes...", RNS.LOG_VERBOSE)
with multiprocessing.Pool(pool_count) as p:
validated_entries = p.starmap(validate_pn_stamp, zip(transient_list, itertools.repeat(target_cost)))
return validated_entries
return [e for e in validated_entries if e[0] == True]
def validate_pn_stamps(transient_stamps):
def validate_pn_stamps(transient_list, target_cost):
non_mp_platform = RNS.vendor.platformutils.is_android()
if len(transient_stamps) <= PN_VALIDATION_POOL_MIN_SIZE or non_mp_platform: validate_pn_stamps_job_simple(transient_stamps)
else: validate_pn_stamps_job_multip(transient_stamps)
if len(transient_list) <= PN_VALIDATION_POOL_MIN_SIZE or non_mp_platform: return validate_pn_stamps_job_simple(transient_list, target_cost)
else: return validate_pn_stamps_job_multip(transient_list, target_cost)
def generate_stamp(message_id, stamp_cost):
RNS.log(f"Generating stamp with cost {stamp_cost} for {RNS.prettyhexrep(message_id)}...", RNS.LOG_DEBUG)
@ -93,19 +90,13 @@ def generate_stamp(message_id, stamp_cost):
rounds = 0
value = 0
if RNS.vendor.platformutils.is_windows() or RNS.vendor.platformutils.is_darwin():
stamp, rounds = job_simple(stamp_cost, workblock, message_id)
elif RNS.vendor.platformutils.is_android():
stamp, rounds = job_android(stamp_cost, workblock, message_id)
else:
stamp, rounds = job_linux(stamp_cost, workblock, message_id)
if RNS.vendor.platformutils.is_windows() or RNS.vendor.platformutils.is_darwin(): stamp, rounds = job_simple(stamp_cost, workblock, message_id)
elif RNS.vendor.platformutils.is_android(): stamp, rounds = job_android(stamp_cost, workblock, message_id)
else: stamp, rounds = job_linux(stamp_cost, workblock, message_id)
duration = time.time() - start_time
speed = rounds/duration
if stamp != None:
value = stamp_value(workblock, stamp)
if stamp != None: value = stamp_value(workblock, stamp)
RNS.log(f"Stamp with value {value} generated in {RNS.prettytime(duration)}, {rounds} rounds, {int(speed)} rounds per second", RNS.LOG_DEBUG)
@ -161,10 +152,8 @@ def job_simple(stamp_cost, workblock, message_id):
def sv(s, c, w):
target = 0b1<<256-c; m = w+s
result = RNS.Identity.full_hash(m)
if int.from_bytes(result, byteorder="big") > target:
return False
else:
return True
if int.from_bytes(result, byteorder="big") > target: return False
else: return True
while not sv(pstamp, stamp_cost, workblock) and not active_jobs[message_id]:
pstamp = os.urandom(256//8); rounds += 1