daemon: Handle the case where we try to send to unverified devices.

This commit is contained in:
Damir Jelić 2019-05-21 10:25:59 +02:00
parent c40af38b33
commit ff2fc7e448
5 changed files with 106 additions and 15 deletions

View File

@ -1,6 +1,7 @@
import asyncio import asyncio
from pprint import pformat from pprint import pformat
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from collections import defaultdict
from aiohttp.client_exceptions import ClientConnectionError from aiohttp.client_exceptions import ClientConnectionError
from nio import (AsyncClient, ClientConfig, EncryptionError, KeysQueryResponse, from nio import (AsyncClient, ClientConfig, EncryptionError, KeysQueryResponse,
@ -37,6 +38,9 @@ class PanClient(AsyncClient):
self.task = None self.task = None
self.queue = queue self.queue = queue
self.send_semaphores = defaultdict(asyncio.Semaphore)
self.send_decision_queues = dict() # type: asyncio.Queue
self.add_to_device_callback( self.add_to_device_callback(
self.key_verification_cb, self.key_verification_cb,
KeyVerificationEvent KeyVerificationEvent

View File

@ -12,7 +12,7 @@ import keyring
from aiohttp import ClientSession, web from aiohttp import ClientSession, web
from aiohttp.client_exceptions import ClientConnectionError, ContentTypeError from aiohttp.client_exceptions import ClientConnectionError, ContentTypeError
from multidict import CIMultiDict from multidict import CIMultiDict
from nio import EncryptionError, LoginResponse, SendRetryError from nio import EncryptionError, LoginResponse, SendRetryError, OlmTrustError
from pantalaimon.client import PanClient from pantalaimon.client import PanClient
from pantalaimon.log import logger from pantalaimon.log import logger
@ -40,6 +40,7 @@ class ProxyDaemon:
ssl = attr.ib(default=None) ssl = attr.ib(default=None)
decryption_timeout = 10 decryption_timeout = 10
unverified_send_timeout = 10
store = attr.ib(type=PanStore, init=False) store = attr.ib(type=PanStore, init=False)
homeserver_url = attr.ib(init=False, default=attr.Factory(dict)) homeserver_url = attr.ib(init=False, default=attr.Factory(dict))
@ -675,11 +676,16 @@ class ProxyDaemon:
room_id = request.match_info["room_id"] room_id = request.match_info["room_id"]
# The room is not in the joined rooms list, just forward it.
try: try:
encrypt = client.rooms[room_id].encrypted encrypt = client.rooms[room_id].encrypted
except KeyError: except KeyError:
return await self.forward_to_web(request) return await self.forward_to_web(
request,
token=client.access_token
)
# The room isn't encrypted just forward the message.
if not encrypt: if not encrypt:
return await self.forward_to_web( return await self.forward_to_web(
request, request,
@ -694,18 +700,74 @@ class ProxyDaemon:
except (JSONDecodeError, ContentTypeError): except (JSONDecodeError, ContentTypeError):
return self._not_json return self._not_json
async def _send(ignore_unverified=False):
try: try:
response = await client.room_send(room_id, msgtype, content, txnid) response = await client.room_send(room_id, msgtype, content,
except ClientConnectionError as e: txnid, ignore_unverified)
return web.Response(status=500, text=str(e))
except SendRetryError as e:
return web.Response(status=503, text=str(e))
return web.Response( return web.Response(
status=response.transport_response.status, status=response.transport_response.status,
content_type=response.transport_response.content_type, content_type=response.transport_response.content_type,
body=await response.transport_response.read() body=await response.transport_response.read()
) )
except ClientConnectionError as e:
return web.Response(status=500, text=str(e))
except SendRetryError as e:
return web.Response(status=503, text=str(e))
# Aquire a semaphore here so we only send out one
# UnverifiedDevicesSignal
sem = client.send_semaphores[room_id]
async with sem:
try:
return await _send()
except OlmTrustError as e:
# There are unverified/unblocked devices in the room, notify
# the UI thread about this and wait for a response.
queue = asyncio.Queue()
message = UnverifiedDevicesSignal(client.user_id, room_id)
await self.send_queue.put(message)
# TODO allow dbus clients to answer us here.
try:
response = await asyncio.wait_for(
queue.get(),
self.unverified_send_timeout
)
if response == "cancel":
# The send was canceled notify the client that sent the
# request about this.
return web.Response(status=503, text=str(e))
elif response == "send-anyways":
# We are sending and ignoring devices along the way.
ret = await _send(True)
await self.send_update_devcies()
return ret
except asyncio.TimeoutError:
# We didn't get a response to our signal, send out an error
# response.
ret = await _send(True)
await self.send_update_devcies()
return ret
return web.Response(
status=503,
text=(f"Room contains unverified devices and no "
f"action was taken for "
f"{self.unverifiedsend_timeout} seconds, "
f"request timed out")
)
finally:
# Clear up the queue
pass
async def filter(self, request): async def filter(self, request):
access_token = self.get_access_token(request) access_token = self.get_access_token(request)

View File

@ -300,6 +300,11 @@ class PanCtl:
self.devices.VerificationInvite.connect(self.show_sas_invite) self.devices.VerificationInvite.connect(self.show_sas_invite)
self.devices.VerificationString.connect(self.show_sas) self.devices.VerificationString.connect(self.show_sas)
self.devices.VerificationDone.connect(self.sas_done) self.devices.VerificationDone.connect(self.sas_done)
self.devices.UnverifiedDevices.connect(self.unverified_devices)
def unverified_devices(self, pan_user, room_id):
print(f"Error sending message for user {pan_user}, "
f"there are unverified devices in the room {room_id}")
def show_response(self, response_id, pan_user, message): def show_response(self, response_id, pan_user, message):
if response_id not in self.own_message_ids: if response_id not in self.own_message_ids:

View File

@ -6,6 +6,12 @@ class Message:
pass pass
@attr.s
class UnverifiedDevicesSignal(Message):
pan_user = attr.ib()
room_id = attr.ib()
@attr.s @attr.s
class DaemonResponse(Message): class DaemonResponse(Message):
message_id = attr.ib() message_id = attr.ib()

View File

@ -18,7 +18,8 @@ from pantalaimon.thread_messages import (AcceptSasMessage, CancelSasMessage,
InviteSasSignal, SasDoneSignal, InviteSasSignal, SasDoneSignal,
ShowSasSignal, StartSasMessage, ShowSasSignal, StartSasMessage,
UpdateDevicesMessage, UpdateDevicesMessage,
UpdateUsersMessage) UpdateUsersMessage,
UnverifiedDevicesSignal)
class IdCounter: class IdCounter:
@ -60,11 +61,18 @@ class Control:
<arg direction="out" type="s" name="pan_user"/> <arg direction="out" type="s" name="pan_user"/>
<arg direction="out" type="a{ss}" name="message"/> <arg direction="out" type="a{ss}" name="message"/>
</signal> </signal>
<signal name="UnverifiedDevices">
<arg direction="out" type="s" name="pan_user"/>
<arg direction="out" type="s" name="room_id"/>
</signal>
</interface> </interface>
</node> </node>
""" """
Response = signal() Response = signal()
UnverifiedDevices = signal()
def __init__(self, queue, store, server_list, id_counter): def __init__(self, queue, store, server_list, id_counter):
self.server_list = server_list self.server_list = server_list
@ -380,9 +388,15 @@ class GlibT:
if isinstance(message, UpdateDevicesMessage): if isinstance(message, UpdateDevicesMessage):
self.device_if.update_devices() self.device_if.update_devices()
if isinstance(message, UpdateUsersMessage): elif isinstance(message, UpdateUsersMessage):
self.control_if.update_users() self.control_if.update_users()
elif isinstance(message, UnverifiedDevicesSignal):
self.control_if.UnverifiedDevices(
message.pan_user,
message.room_id
)
elif isinstance(message, InviteSasSignal): elif isinstance(message, InviteSasSignal):
self.device_if.VerificationInvite( self.device_if.VerificationInvite(
message.pan_user, message.pan_user,