daemon: Add support for key request handling.

This commit is contained in:
Damir Jelić 2019-08-07 13:00:19 +02:00
parent 20bfdce167
commit ddce830d8c
4 changed files with 211 additions and 1 deletions

View File

@ -39,6 +39,8 @@ from nio import (
RoomMessageText, RoomMessageText,
RoomNameEvent, RoomNameEvent,
RoomTopicEvent, RoomTopicEvent,
RoomKeyRequest,
RoomKeyRequestCancellation,
SyncResponse, SyncResponse,
) )
from nio.crypto import Sas from nio.crypto import Sas
@ -53,6 +55,9 @@ from pantalaimon.thread_messages import (
SasDoneSignal, SasDoneSignal,
ShowSasSignal, ShowSasSignal,
UpdateDevicesMessage, UpdateDevicesMessage,
KeyRequestMessage,
ContinueKeyShare,
CancelKeyShare,
) )
SEARCH_KEYS = ["content.body", "content.name", "content.topic"] SEARCH_KEYS = ["content.body", "content.name", "content.topic"]
@ -191,7 +196,14 @@ class PanClient(AsyncClient):
self.history_fetcher_task = None self.history_fetcher_task = None
self.history_fetch_queue = asyncio.Queue() self.history_fetch_queue = asyncio.Queue()
self.add_to_device_callback(self.key_verification_cb, KeyVerificationEvent) self.add_to_device_callback(
self.key_verification_cb,
KeyVerificationEvent
)
self.add_to_device_callback(
self.key_request_cb,
(RoomKeyRequest, RoomKeyRequestCancellation)
)
self.add_event_callback(self.undecrypted_event_cb, MegolmEvent) self.add_event_callback(self.undecrypted_event_cb, MegolmEvent)
if INDEXING_ENABLED: if INDEXING_ENABLED:
@ -392,6 +404,28 @@ class PanClient(AsyncClient):
except ClientConnectionError: except ClientConnectionError:
pass pass
async def key_request_cb(self, event):
if isinstance(event, RoomKeyRequest):
logger.info(
f"{event.sender} via {event.requesting_device_id} has "
f" requested room keys from us."
)
message = KeyRequestMessage(self.user_id, event)
await self.send_message(message)
elif isinstance(event, RoomKeyRequestCancellation):
logger.info(
f"{event.sender} via {event.requesting_device_id} has "
f" canceled its key request."
)
message = KeyRequestMessage(self.user_id, event)
await self.send_message(message)
else:
assert False
async def key_verification_cb(self, event): async def key_verification_cb(self, event):
logger.info("Received key verification event: {}".format(event)) logger.info("Received key verification event: {}".format(event))
if isinstance(event, KeyVerificationStart): if isinstance(event, KeyVerificationStart):
@ -610,6 +644,79 @@ class PanClient(AsyncClient):
) )
) )
async def handle_key_request_message(self, message):
if isinstance(message, ContinueKeyShare):
continued = False
for share in self.get_active_key_requests(
message.user_id,
message.device_id):
continued = True
if not self.continue_key_share(share):
await self.send_message(
DaemonResponse(
message.message_id,
self.user_id,
"m.error",
(f"Unable to continue the key sharing for "
f"{message.user_id} via {message.device_id}: The "
f"device is still not verified.")
)
)
return
if continued:
try:
await self.send_to_device_messages()
except ClientConnectionError:
# We can safely ignore this since this will be retried
# after the next sync in the sync_forever method.
pass
response = (f"Succesfully continued the key requests from "
f"{message.user_id} via {message.device_id}")
ret = "m.ok"
else:
response = (f"No active key requests from {message.user_id} "
f"via {message.device_id} found.")
ret = "m.error"
await self.send_message(
DaemonResponse(
message.message_id,
self.user_id,
ret,
response
)
)
elif isinstance(message, CancelKeyShare):
cancelled = False
for share in self.get_active_key_requests(
message.user_id,
message.device_id):
cancelled = self.cancel_key_share(share)
if cancelled:
response = (f"Succesfully cancelled key requests from "
f"{message.user_id} via {message.device_id}")
ret = "m.ok"
else:
response = (f"No active key requests from {message.user_id} "
f"via {message.device_id} found.")
ret = "m.error"
await self.send_message(
DaemonResponse(
message.message_id,
self.user_id,
ret,
response
)
)
async def loop_stop(self): async def loop_stop(self):
"""Stop the client loop.""" """Stop the client loop."""
logger.info("Stopping the sync loop") logger.info("Stopping the sync loop")

View File

@ -56,6 +56,8 @@ from pantalaimon.thread_messages import (
UnverifiedDevicesSignal, UnverifiedDevicesSignal,
UnverifiedResponse, UnverifiedResponse,
UpdateUsersMessage, UpdateUsersMessage,
ContinueKeyShare,
CancelKeyShare,
) )
CORS_HEADERS = { CORS_HEADERS = {
@ -375,6 +377,10 @@ class ProxyDaemon:
queue = client.send_decision_queues[message.room_id] queue = client.send_decision_queues[message.room_id]
await queue.put(message) await queue.put(message)
elif isinstance(message, (ContinueKeyShare, CancelKeyShare)):
client = self.pan_clients[message.pan_user]
await client.handle_key_request_message(message)
def get_access_token(self, request): def get_access_token(self, request):
# type: (aiohttp.web.BaseRequest) -> str # type: (aiohttp.web.BaseRequest) -> str
"""Extract the access token from the request. """Extract the access token from the request.

View File

@ -44,6 +44,30 @@ class CancelSendingMessage(UnverifiedResponse):
pass pass
@attr.s
class KeyRequestMessage(Message):
pan_user = attr.ib(type=str)
event = attr.ib()
@attr.s
class _KeyShare(Message):
message_id = attr.ib()
pan_user = attr.ib()
user_id = attr.ib()
device_id = attr.ib()
@attr.s
class ContinueKeyShare(_KeyShare):
pass
@attr.s
class CancelKeyShare(_KeyShare):
pass
@attr.s @attr.s
class DaemonResponse(Message): class DaemonResponse(Message):
message_id = attr.ib() message_id = attr.ib()

View File

@ -29,6 +29,8 @@ if UI_ENABLED:
from pydbus import SessionBus from pydbus import SessionBus
from pydbus.generic import signal from pydbus.generic import signal
from nio import RoomKeyRequest, RoomKeyRequestCancellation
from pantalaimon.log import logger from pantalaimon.log import logger
from pantalaimon.thread_messages import ( from pantalaimon.thread_messages import (
AcceptSasMessage, AcceptSasMessage,
@ -50,6 +52,9 @@ if UI_ENABLED:
UnverifiedDevicesSignal, UnverifiedDevicesSignal,
UpdateDevicesMessage, UpdateDevicesMessage,
UpdateUsersMessage, UpdateUsersMessage,
KeyRequestMessage,
ContinueKeyShare,
CancelKeyShare,
) )
UI_ENABLED = True UI_ENABLED = True
@ -257,6 +262,34 @@ if UI_ENABLED:
<arg direction="out" type="s" name="transaction_id"/> <arg direction="out" type="s" name="transaction_id"/>
</signal> </signal>
<method name='ContinueKeyShare'>
<arg type='s' name='pan_user' direction='in'/>
<arg type='s' name='user_id' direction='in'/>
<arg type='s' name='device_id' direction='in'/>
<arg type='u' name='id' direction='out'/>
</method>
<method name='CancelKeyShare'>
<arg type='s' name='pan_user' direction='in'/>
<arg type='s' name='user_id' direction='in'/>
<arg type='s' name='device_id' direction='in'/>
<arg type='u' name='id' direction='out'/>
</method>
<signal name="KeyRequest">
<arg direction="out" type="s" name="pan_user"/>
<arg direction="out" type="s" name="user_id"/>
<arg direction="out" type="s" name="device_id"/>
<arg direction="out" type="s" name="request_id"/>
</signal>
<signal name="KeyRequestCancel">
<arg direction="out" type="s" name="pan_user"/>
<arg direction="out" type="s" name="user_id"/>
<arg direction="out" type="s" name="device_id"/>
<arg direction="out" type="s" name="request_id"/>
</signal>
</interface> </interface>
</node> </node>
""" """
@ -266,11 +299,16 @@ if UI_ENABLED:
VerificationString = signal() VerificationString = signal()
VerificationDone = signal() VerificationDone = signal()
KeyRequest = signal()
KeyRequestCancel = signal()
def __init__(self, queue, id_counter): def __init__(self, queue, id_counter):
self.device_list = dict() self.device_list = dict()
self.queue = queue self.queue = queue
self.id_counter = id_counter self.id_counter = id_counter
self.key_requests = dict()
@property @property
def message_id(self): def message_id(self):
return self.id_counter.message_id return self.id_counter.message_id
@ -348,6 +386,16 @@ if UI_ENABLED:
self.queue.put(message) self.queue.put(message)
return message.message_id return message.message_id
def ContinueKeyShare(self, pan_user, user_id, device_id):
message = ContinueKeyShare(self.message_id, pan_user, user_id, device_id)
self.queue.put(message)
return message.message_id
def CancelKeyShare(self, pan_user, user_id, device_id):
message = CancelKeyShare(self.message_id, pan_user, user_id, device_id)
self.queue.put(message)
return message.message_id
def update_devices(self, message): def update_devices(self, message):
if message.pan_user not in self.device_list: if message.pan_user not in self.device_list:
self.device_list[message.pan_user] = defaultdict(dict) self.device_list[message.pan_user] = defaultdict(dict)
@ -366,6 +414,28 @@ if UI_ENABLED:
device.pop("deleted") device.pop("deleted")
device_list[device["user_id"]][device["device_id"]] = device device_list[device["user_id"]][device["device_id"]] = device
def update_key_requests(self, message):
# type: (KeyRequestMessage) -> None
event = message.event
if isinstance(event, RoomKeyRequest):
self.key_requests[event.request_id] = event
self.KeyRequest(
message.pan_user,
event.sender,
event.requesting_device_id,
event.request_id,
)
elif isinstance(event, RoomKeyRequestCancellation):
self.key_requests.pop(event.request_id, None)
self.KeyRequestCancel(
message.pan_user,
event.sender,
event.requesting_device_id,
event.request_id,
)
@attr.s @attr.s
class GlibT: class GlibT:
receive_queue = attr.ib() receive_queue = attr.ib()
@ -555,6 +625,9 @@ if UI_ENABLED:
{"code": message.code, "message": message.message}, {"code": message.code, "message": message.message},
) )
elif isinstance(message, KeyRequestMessage):
self.device_if.update_key_requests(message)
self.receive_queue.task_done() self.receive_queue.task_done()
return True return True