Refactored Link method names

This commit is contained in:
Mark Qvist 2021-05-16 16:37:12 +02:00
parent d03b7d7a52
commit cd2f49272d
3 changed files with 16 additions and 16 deletions

View File

@ -180,14 +180,14 @@ class Destination:
plaintext = self.decrypt(packet.data) plaintext = self.decrypt(packet.data)
if plaintext != None: if plaintext != None:
if packet.packet_type == RNS.Packet.LINKREQUEST: if packet.packet_type == RNS.Packet.LINKREQUEST:
self.incomingLinkRequest(plaintext, packet) self.incoming_link_request(plaintext, packet)
if packet.packet_type == RNS.Packet.DATA: if packet.packet_type == RNS.Packet.DATA:
if self.callbacks.packet != None: if self.callbacks.packet != None:
self.callbacks.packet(plaintext, packet) self.callbacks.packet(plaintext, packet)
def incomingLinkRequest(self, data, packet): def incoming_link_request(self, data, packet):
link = RNS.Link.validateRequest(self, data, packet) link = RNS.Link.validate_request(self, data, packet)
if link != None: if link != None:
self.links.append(link) self.links.append(link)

View File

@ -53,11 +53,11 @@ class Link:
resource_strategies = [ACCEPT_NONE, ACCEPT_APP, ACCEPT_ALL] resource_strategies = [ACCEPT_NONE, ACCEPT_APP, ACCEPT_ALL]
@staticmethod @staticmethod
def validateRequest(owner, data, packet): def validate_request(owner, data, packet):
if len(data) == (Link.ECPUBSIZE): if len(data) == (Link.ECPUBSIZE):
try: try:
link = Link(owner = owner, peer_pub_bytes = data[:Link.ECPUBSIZE]) link = Link(owner = owner, peer_pub_bytes = data[:Link.ECPUBSIZE])
link.setLinkID(packet) link.set_link_id(packet)
link.destination = packet.destination link.destination = packet.destination
RNS.log("Validating link request "+RNS.prettyhexrep(link.link_id), RNS.LOG_VERBOSE) RNS.log("Validating link request "+RNS.prettyhexrep(link.link_id), RNS.LOG_VERBOSE)
link.handshake() link.handshake()
@ -127,13 +127,13 @@ class Link:
self.peer_pub = None self.peer_pub = None
self.peer_pub_bytes = None self.peer_pub_bytes = None
else: else:
self.loadPeer(peer_pub_bytes) self.load_peer(peer_pub_bytes)
if (self.initiator): if (self.initiator):
self.request_data = self.pub_bytes self.request_data = self.pub_bytes
self.packet = RNS.Packet(destination, self.request_data, packet_type=RNS.Packet.LINKREQUEST) self.packet = RNS.Packet(destination, self.request_data, packet_type=RNS.Packet.LINKREQUEST)
self.packet.pack() self.packet.pack()
self.setLinkID(self.packet) self.set_link_id(self.packet)
RNS.Transport.registerLink(self) RNS.Transport.registerLink(self)
self.request_time = time.time() self.request_time = time.time()
self.start_watchdog() self.start_watchdog()
@ -142,13 +142,13 @@ class Link:
RNS.log("Link request "+RNS.prettyhexrep(self.link_id)+" sent to "+str(self.destination), RNS.LOG_VERBOSE) RNS.log("Link request "+RNS.prettyhexrep(self.link_id)+" sent to "+str(self.destination), RNS.LOG_VERBOSE)
def loadPeer(self, peer_pub_bytes): def load_peer(self, peer_pub_bytes):
self.peer_pub_bytes = peer_pub_bytes self.peer_pub_bytes = peer_pub_bytes
self.peer_pub = serialization.load_der_public_key(peer_pub_bytes, backend=default_backend()) self.peer_pub = serialization.load_der_public_key(peer_pub_bytes, backend=default_backend())
if not hasattr(self.peer_pub, "curve"): if not hasattr(self.peer_pub, "curve"):
self.peer_pub.curve = Link.CURVE self.peer_pub.curve = Link.CURVE
def setLinkID(self, packet): def set_link_id(self, packet):
self.link_id = packet.getTruncatedHash() self.link_id = packet.getTruncatedHash()
self.hash = self.link_id self.hash = self.link_id
@ -158,8 +158,8 @@ class Link:
self.derived_key = HKDF( self.derived_key = HKDF(
algorithm=hashes.SHA256(), algorithm=hashes.SHA256(),
length=32, length=32,
salt=self.getSalt(), salt=self.get_salt(),
info=self.getContext(), info=self.get_context(),
backend=default_backend() backend=default_backend()
).derive(self.shared_key) ).derive(self.shared_key)
@ -185,14 +185,14 @@ class Link:
proof.send() proof.send()
self.had_outbound() self.had_outbound()
def validateProof(self, packet): def validate_proof(self, packet):
if self.initiator: if self.initiator:
peer_pub_bytes = packet.data[:Link.ECPUBSIZE] peer_pub_bytes = packet.data[:Link.ECPUBSIZE]
signed_data = self.link_id+peer_pub_bytes signed_data = self.link_id+peer_pub_bytes
signature = packet.data[Link.ECPUBSIZE:RNS.Identity.KEYSIZE//8+Link.ECPUBSIZE] signature = packet.data[Link.ECPUBSIZE:RNS.Identity.KEYSIZE//8+Link.ECPUBSIZE]
if self.destination.identity.validate(signature, signed_data): if self.destination.identity.validate(signature, signed_data):
self.loadPeer(peer_pub_bytes) self.load_peer(peer_pub_bytes)
self.handshake() self.handshake()
self.rtt = time.time() - self.request_time self.rtt = time.time() - self.request_time
self.attached_interface = packet.receiving_interface self.attached_interface = packet.receiving_interface
@ -236,10 +236,10 @@ class Link:
traceback.print_exc() traceback.print_exc()
self.teardown() self.teardown()
def getSalt(self): def get_salt(self):
return self.link_id return self.link_id
def getContext(self): def get_context(self):
return None return None
def no_inbound_for(self): def no_inbound_for(self):

View File

@ -828,7 +828,7 @@ class Transport:
# pending link # pending link
for link in Transport.pending_links: for link in Transport.pending_links:
if link.link_id == packet.destination_hash: if link.link_id == packet.destination_hash:
link.validateProof(packet) link.validate_proof(packet)
elif packet.context == RNS.Packet.RESOURCE_PRF: elif packet.context == RNS.Packet.RESOURCE_PRF:
for link in Transport.active_links: for link in Transport.active_links: