pantalaimon: Format the source tree using black.

This commit is contained in:
Damir Jelić 2019-06-19 12:37:44 +02:00
parent 531d686d8f
commit c9ebfd71ec
11 changed files with 755 additions and 969 deletions

4
.flake8 Normal file
View File

@ -0,0 +1,4 @@
[flake8]
max-line-length = 80
select = C,E,F,W,B,B950
ignore = E501,W503

6
.isort.cfg Normal file
View File

@ -0,0 +1,6 @@
[settings]
multi_line_output=3
include_trailing_comma=True
force_grid_wrap=0
use_parentheses=True
line_length=88

View File

@ -1,5 +1,6 @@
test: test:
python3 -m pytest python3 -m pytest
python3 -m pytest --black pantalaimon
python3 -m pytest --flake8 pantalaimon python3 -m pytest --flake8 pantalaimon
python3 -m pytest --isort python3 -m pytest --isort
@ -14,3 +15,6 @@ run-local:
isort: isort:
isort -y -p pantalaimon isort -y -p pantalaimon
format:
black pantalaimon/

View File

@ -20,21 +20,39 @@ from typing import Any, Dict, Optional
from aiohttp.client_exceptions import ClientConnectionError from aiohttp.client_exceptions import ClientConnectionError
from jsonschema import Draft4Validator, FormatChecker, validators from jsonschema import Draft4Validator, FormatChecker, validators
from nio import (AsyncClient, ClientConfig, EncryptionError, KeysQueryResponse, from nio import (
KeyVerificationEvent, KeyVerificationKey, KeyVerificationMac, AsyncClient,
KeyVerificationStart, LocalProtocolError, MegolmEvent, ClientConfig,
RoomContextError, RoomEncryptedEvent, RoomEncryptedMedia, EncryptionError,
RoomMessageMedia, RoomMessageText, RoomNameEvent, KeysQueryResponse,
RoomTopicEvent, SyncResponse) KeyVerificationEvent,
KeyVerificationKey,
KeyVerificationMac,
KeyVerificationStart,
LocalProtocolError,
MegolmEvent,
RoomContextError,
RoomEncryptedEvent,
RoomEncryptedMedia,
RoomMessageMedia,
RoomMessageText,
RoomNameEvent,
RoomTopicEvent,
SyncResponse,
)
from nio.crypto import Sas from nio.crypto import Sas
from nio.store import SqliteStore from nio.store import SqliteStore
from pantalaimon.index import IndexStore from pantalaimon.index import IndexStore
from pantalaimon.log import logger from pantalaimon.log import logger
from pantalaimon.store import FetchTask from pantalaimon.store import FetchTask
from pantalaimon.thread_messages import (DaemonResponse, InviteSasSignal, from pantalaimon.thread_messages import (
SasDoneSignal, ShowSasSignal, DaemonResponse,
UpdateDevicesMessage) InviteSasSignal,
SasDoneSignal,
ShowSasSignal,
UpdateDevicesMessage,
)
SEARCH_KEYS = ["content.body", "content.name", "content.topic"] SEARCH_KEYS = ["content.body", "content.name", "content.topic"]
@ -51,7 +69,7 @@ SEARCH_TERMS_SCHEMA = {
"keys": { "keys": {
"type": "array", "type": "array",
"items": {"type": "string", "enum": SEARCH_KEYS}, "items": {"type": "string", "enum": SEARCH_KEYS},
"default": SEARCH_KEYS "default": SEARCH_KEYS,
}, },
"order_by": {"type": "string", "default": "rank"}, "order_by": {"type": "string", "default": "rank"},
"include_state": {"type": "boolean", "default": False}, "include_state": {"type": "boolean", "default": False},
@ -59,15 +77,13 @@ SEARCH_TERMS_SCHEMA = {
"event_context": {"type": "object"}, "event_context": {"type": "object"},
"groupings": {"type": "object", "default": {}}, "groupings": {"type": "object", "default": {}},
}, },
"required": ["search_term"] "required": ["search_term"],
},
} }
}, },
"required": ["room_events"]
}, },
"required": [ "required": ["room_events"],
"search_categories", },
], "required": ["search_categories"],
} }
@ -79,9 +95,7 @@ def extend_with_default(validator_class):
if "default" in subschema: if "default" in subschema:
instance.setdefault(prop, subschema["default"]) instance.setdefault(prop, subschema["default"])
for error in validate_properties( for error in validate_properties(validator, properties, instance, schema):
validator, properties, instance, schema
):
yield error yield error
return validators.extend(validator_class, {"properties": set_defaults}) return validators.extend(validator_class, {"properties": set_defaults})
@ -122,11 +136,10 @@ class PanClient(AsyncClient):
store_path="", store_path="",
config=None, config=None,
ssl=None, ssl=None,
proxy=None proxy=None,
): ):
config = config or ClientConfig(store=SqliteStore, store_name="pan.db") config = config or ClientConfig(store=SqliteStore, store_name="pan.db")
super().__init__(homeserver, user_id, device_id, store_path, config, super().__init__(homeserver, user_id, device_id, store_path, config, ssl, proxy)
ssl, proxy)
index_dir = os.path.join(store_path, server_name, user_id) index_dir = os.path.join(store_path, server_name, user_id)
@ -150,31 +163,24 @@ class PanClient(AsyncClient):
self.history_fetcher_task = None self.history_fetcher_task = None
self.history_fetch_queue = asyncio.Queue() self.history_fetch_queue = asyncio.Queue()
self.add_to_device_callback( self.add_to_device_callback(self.key_verification_cb, KeyVerificationEvent)
self.key_verification_cb, self.add_event_callback(self.undecrypted_event_cb, MegolmEvent)
KeyVerificationEvent
)
self.add_event_callback(
self.undecrypted_event_cb,
MegolmEvent
)
self.add_event_callback( self.add_event_callback(
self.store_message_cb, self.store_message_cb,
(RoomMessageText, RoomMessageMedia, RoomEncryptedMedia, (
RoomTopicEvent, RoomNameEvent) RoomMessageText,
RoomMessageMedia,
RoomEncryptedMedia,
RoomTopicEvent,
RoomNameEvent,
),
) )
self.key_verificatins_tasks = [] self.key_verificatins_tasks = []
self.key_request_tasks = [] self.key_request_tasks = []
self.add_response_callback( self.add_response_callback(self.keys_query_cb, KeysQueryResponse)
self.keys_query_cb,
KeysQueryResponse
)
self.add_response_callback( self.add_response_callback(self.sync_tasks, SyncResponse)
self.sync_tasks,
SyncResponse
)
def store_message_cb(self, room, event): def store_message_cb(self, room, event):
display_name = room.user_name(event.sender) display_name = room.user_name(event.sender)
@ -192,9 +198,11 @@ class PanClient(AsyncClient):
"type": "m.room.message", "type": "m.room.message",
"content": { "content": {
"msgtype": "m.text", "msgtype": "m.text",
"body": ("** Unable to decrypt: The sender's device has not " "body": (
"sent us the keys for this message. **") "** Unable to decrypt: The sender's device has not "
} "sent us the keys for this message. **"
),
},
} }
async def send_message(self, message): async def send_message(self, message):
@ -206,17 +214,10 @@ class PanClient(AsyncClient):
await self.queue.put(message) await self.queue.put(message)
def delete_fetcher_task(self, task): def delete_fetcher_task(self, task):
self.pan_store.delete_fetcher_task( self.pan_store.delete_fetcher_task(self.server_name, self.user_id, task)
self.server_name,
self.user_id,
task
)
async def fetcher_loop(self): async def fetcher_loop(self):
for t in self.pan_store.load_fetcher_tasks( for t in self.pan_store.load_fetcher_tasks(self.server_name, self.user_id):
self.server_name,
self.user_id
):
await self.history_fetch_queue.put(t) await self.history_fetch_queue.put(t)
while True: while True:
@ -233,13 +234,13 @@ class PanClient(AsyncClient):
continue continue
try: try:
logger.debug("Fetching room history for {}".format( logger.debug(
room.display_name "Fetching room history for {}".format(room.display_name)
)) )
response = await self.room_messages( response = await self.room_messages(
fetch_task.room_id, fetch_task.room_id,
fetch_task.token, fetch_task.token,
limit=self.pan_conf.indexing_batch_size limit=self.pan_conf.indexing_batch_size,
) )
except ClientConnectionError: except ClientConnectionError:
self.history_fetch_queue.put(fetch_task) self.history_fetch_queue.put(fetch_task)
@ -250,31 +251,31 @@ class PanClient(AsyncClient):
continue continue
for event in response.chunk: for event in response.chunk:
if not isinstance(event, ( if not isinstance(
event,
(
RoomMessageText, RoomMessageText,
RoomMessageMedia, RoomMessageMedia,
RoomEncryptedMedia, RoomEncryptedMedia,
RoomTopicEvent, RoomTopicEvent,
RoomNameEvent RoomNameEvent,
)): ),
):
continue continue
display_name = room.user_name(event.sender) display_name = room.user_name(event.sender)
avatar_url = room.avatar_url(event.sender) avatar_url = room.avatar_url(event.sender)
self.index.add_event(event, room.room_id, display_name, self.index.add_event(event, room.room_id, display_name, avatar_url)
avatar_url)
last_event = response.chunk[-1] last_event = response.chunk[-1]
if not self.index.event_in_store( if not self.index.event_in_store(last_event.event_id, room.room_id):
last_event.event_id,
room.room_id
):
# There may be even more events to fetch, add a new task to # There may be even more events to fetch, add a new task to
# the queue. # the queue.
task = FetchTask(room.room_id, response.end) task = FetchTask(room.room_id, response.end)
self.pan_store.save_fetcher_task(self.server_name, self.pan_store.save_fetcher_task(
self.user_id, task) self.server_name, self.user_id, task
)
await self.history_fetch_queue.put(task) await self.history_fetch_queue.put(task)
await self.index.commit_events() await self.index.commit_events()
@ -294,11 +295,7 @@ class PanClient(AsyncClient):
self.key_request_tasks = [] self.key_request_tasks = []
await self.index.commit_events() await self.index.commit_events()
self.pan_store.save_token( self.pan_store.save_token(self.server_name, self.user_id, self.next_batch)
self.server_name,
self.user_id,
self.next_batch
)
for room_id, room_info in response.rooms.join.items(): for room_id, room_info in response.rooms.join.items():
if room_info.timeline.limited: if room_info.timeline.limited:
@ -307,13 +304,12 @@ class PanClient(AsyncClient):
if not room.encrypted and self.pan_conf.index_encrypted_only: if not room.encrypted and self.pan_conf.index_encrypted_only:
continue continue
logger.info("Room {} had a limited timeline, queueing " logger.info(
"room for history fetching.".format( "Room {} had a limited timeline, queueing "
room.display_name "room for history fetching.".format(room.display_name)
)) )
task = FetchTask(room_id, room_info.timeline.prev_batch) task = FetchTask(room_id, room_info.timeline.prev_batch)
self.pan_store.save_fetcher_task(self.server_name, self.pan_store.save_fetcher_task(self.server_name, self.user_id, task)
self.user_id, task)
await self.history_fetch_queue.put(task) await self.history_fetch_queue.put(task)
@ -323,10 +319,11 @@ class PanClient(AsyncClient):
def undecrypted_event_cb(self, room, event): def undecrypted_event_cb(self, room, event):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
logger.info("Unable to decrypt event from {} via {}.".format( logger.info(
event.sender, "Unable to decrypt event from {} via {}.".format(
event.device_id event.sender, event.device_id
)) )
)
if event.session_id not in self.outgoing_key_requests: if event.session_id not in self.outgoing_key_requests:
logger.info("Requesting room key for undecrypted event.") logger.info("Requesting room key for undecrypted event.")
@ -338,19 +335,16 @@ class PanClient(AsyncClient):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
if isinstance(event, KeyVerificationStart): if isinstance(event, KeyVerificationStart):
logger.info(f"{event.sender} via {event.from_device} has started " logger.info(
f"a key verification process.") f"{event.sender} via {event.from_device} has started "
f"a key verification process."
)
message = InviteSasSignal( message = InviteSasSignal(
self.user_id, self.user_id, event.sender, event.from_device, event.transaction_id
event.sender,
event.from_device,
event.transaction_id
) )
task = loop.create_task( task = loop.create_task(self.queue.put(message))
self.queue.put(message)
)
self.key_verificatins_tasks.append(task) self.key_verificatins_tasks.append(task)
elif isinstance(event, KeyVerificationKey): elif isinstance(event, KeyVerificationKey):
@ -362,16 +356,10 @@ class PanClient(AsyncClient):
emoji = sas.get_emoji() emoji = sas.get_emoji()
message = ShowSasSignal( message = ShowSasSignal(
self.user_id, self.user_id, device.user_id, device.id, sas.transaction_id, emoji
device.user_id,
device.id,
sas.transaction_id,
emoji
) )
task = loop.create_task( task = loop.create_task(self.queue.put(message))
self.queue.put(message)
)
self.key_verificatins_tasks.append(task) self.key_verificatins_tasks.append(task)
elif isinstance(event, KeyVerificationMac): elif isinstance(event, KeyVerificationMac):
@ -381,14 +369,13 @@ class PanClient(AsyncClient):
device = sas.other_olm_device device = sas.other_olm_device
if sas.verified: if sas.verified:
task = loop.create_task(self.send_message( task = loop.create_task(
self.send_message(
SasDoneSignal( SasDoneSignal(
self.user_id, self.user_id, device.user_id, device.id, sas.transaction_id
device.user_id, )
device.id, )
sas.transaction_id
) )
))
self.key_verificatins_tasks.append(task) self.key_verificatins_tasks.append(task)
task = loop.create_task(self.send_update_devcies()) task = loop.create_task(self.send_update_devcies())
self.key_verificatins_tasks.append(task) self.key_verificatins_tasks.append(task)
@ -407,22 +394,13 @@ class PanClient(AsyncClient):
self.history_fetcher_task = loop.create_task(self.fetcher_loop()) self.history_fetcher_task = loop.create_task(self.fetcher_loop())
timeout = 30000 timeout = 30000
sync_filter = { sync_filter = {"room": {"state": {"lazy_load_members": True}}}
"room": {
"state": {"lazy_load_members": True}
}
}
next_batch = self.pan_store.load_token(self.server_name, self.user_id) next_batch = self.pan_store.load_token(self.server_name, self.user_id)
# We don't store any room state so initial sync needs to be with the # We don't store any room state so initial sync needs to be with the
# full_state parameter. Subsequent ones are normal. # full_state parameter. Subsequent ones are normal.
task = loop.create_task( task = loop.create_task(
self.sync_forever( self.sync_forever(timeout, sync_filter, full_state=True, since=next_batch)
timeout,
sync_filter,
full_state=True,
since=next_batch
)
) )
self.task = task self.task = task
@ -436,16 +414,15 @@ class PanClient(AsyncClient):
message.message_id, message.message_id,
self.user_id, self.user_id,
"m.ok", "m.ok",
"Successfully started the key verification request" "Successfully started the key verification request",
)) )
)
except ClientConnectionError as e: except ClientConnectionError as e:
await self.send_message( await self.send_message(
DaemonResponse( DaemonResponse(
message.message_id, message.message_id, self.user_id, "m.connection_error", str(e)
self.user_id, )
"m.connection_error", )
str(e)
))
async def accept_sas(self, message): async def accept_sas(self, message):
user_id = message.user_id user_id = message.user_id
@ -459,9 +436,8 @@ class PanClient(AsyncClient):
message.message_id, message.message_id,
self.user_id, self.user_id,
Sas._txid_error[0], Sas._txid_error[0],
Sas._txid_error[1] Sas._txid_error[1],
) )
) )
return return
@ -472,24 +448,24 @@ class PanClient(AsyncClient):
message.message_id, message.message_id,
self.user_id, self.user_id,
"m.ok", "m.ok",
"Successfully accepted the key verification request" "Successfully accepted the key verification request",
)) )
)
except LocalProtocolError as e: except LocalProtocolError as e:
await self.send_message( await self.send_message(
DaemonResponse( DaemonResponse(
message.message_id, message.message_id,
self.user_id, self.user_id,
Sas._unexpected_message_error[0], Sas._unexpected_message_error[0],
str(e) str(e),
)) )
)
except ClientConnectionError as e: except ClientConnectionError as e:
await self.send_message( await self.send_message(
DaemonResponse( DaemonResponse(
message.message_id, message.message_id, self.user_id, "m.connection_error", str(e)
self.user_id, )
"m.connection_error", )
str(e)
))
async def cancel_sas(self, message): async def cancel_sas(self, message):
user_id = message.user_id user_id = message.user_id
@ -503,9 +479,8 @@ class PanClient(AsyncClient):
message.message_id, message.message_id,
self.user_id, self.user_id,
Sas._txid_error[0], Sas._txid_error[0],
Sas._txid_error[1] Sas._txid_error[1],
) )
) )
return return
@ -516,16 +491,15 @@ class PanClient(AsyncClient):
message.message_id, message.message_id,
self.user_id, self.user_id,
"m.ok", "m.ok",
"Successfully canceled the key verification request" "Successfully canceled the key verification request",
)) )
)
except ClientConnectionError as e: except ClientConnectionError as e:
await self.send_message( await self.send_message(
DaemonResponse( DaemonResponse(
message.message_id, message.message_id, self.user_id, "m.connection_error", str(e)
self.user_id, )
"m.connection_error", )
str(e)
))
async def confirm_sas(self, message): async def confirm_sas(self, message):
user_id = message.user_id user_id = message.user_id
@ -539,9 +513,8 @@ class PanClient(AsyncClient):
message.message_id, message.message_id,
self.user_id, self.user_id,
Sas._txid_error[0], Sas._txid_error[0],
Sas._txid_error[1] Sas._txid_error[1],
) )
) )
return return
@ -550,11 +523,9 @@ class PanClient(AsyncClient):
except ClientConnectionError as e: except ClientConnectionError as e:
await self.send_message( await self.send_message(
DaemonResponse( DaemonResponse(
message.message_id, message.message_id, self.user_id, "m.connection_error", str(e)
self.user_id, )
"m.connection_error", )
str(e)
))
return return
@ -564,10 +535,7 @@ class PanClient(AsyncClient):
await self.send_update_devcies() await self.send_update_devcies()
await self.send_message( await self.send_message(
SasDoneSignal( SasDoneSignal(
self.user_id, self.user_id, device.user_id, device.id, sas.transaction_id
device.user_id,
device.id,
sas.transaction_id
) )
) )
else: else:
@ -576,8 +544,9 @@ class PanClient(AsyncClient):
message.message_id, message.message_id,
self.user_id, self.user_id,
"m.ok", "m.ok",
f"Waiting for {device.user_id} to confirm." f"Waiting for {device.user_id} to confirm.",
)) )
)
async def loop_stop(self): async def loop_stop(self):
"""Stop the client loop.""" """Stop the client loop."""
@ -605,18 +574,15 @@ class PanClient(AsyncClient):
self.history_fetch_queue = asyncio.Queue() self.history_fetch_queue = asyncio.Queue()
def pan_decrypt_event( def pan_decrypt_event(self, event_dict, room_id=None, ignore_failures=True):
self,
event_dict,
room_id=None,
ignore_failures=True
):
# type: (Dict[Any, Any], Optional[str], bool) -> (bool) # type: (Dict[Any, Any], Optional[str], bool) -> (bool)
event = RoomEncryptedEvent.parse_event(event_dict) event = RoomEncryptedEvent.parse_event(event_dict)
if not isinstance(event, MegolmEvent): if not isinstance(event, MegolmEvent):
logger.warn("Encrypted event is not a megolm event:" logger.warn(
"\n{}".format(pformat(event_dict))) "Encrypted event is not a megolm event:"
"\n{}".format(pformat(event_dict))
)
return False return False
if not event.room_id: if not event.room_id:
@ -661,8 +627,7 @@ class PanClient(AsyncClient):
continue continue
if event["type"] != "m.room.encrypted": if event["type"] != "m.room.encrypted":
logger.debug("Event is not encrypted: " logger.debug("Event is not encrypted: " "\n{}".format(pformat(event)))
"\n{}".format(pformat(event)))
continue continue
self.pan_decrypt_event(event, ignore_failures=ignore_failures) self.pan_decrypt_event(event, ignore_failures=ignore_failures)
@ -682,9 +647,11 @@ class PanClient(AsyncClient):
for room_id, room_dict in body["rooms"]["join"].items(): for room_id, room_dict in body["rooms"]["join"].items():
try: try:
if not self.rooms[room_id].encrypted: if not self.rooms[room_id].encrypted:
logger.info("Room {} is not encrypted skipping...".format( logger.info(
"Room {} is not encrypted skipping...".format(
self.rooms[room_id].display_name self.rooms[room_id].display_name
)) )
)
continue continue
except KeyError: except KeyError:
logger.info("Unknown room {} skipping...".format(room_id)) logger.info("Unknown room {} skipping...".format(room_id))
@ -752,8 +719,9 @@ class PanClient(AsyncClient):
after_limit = event_context.get("before_limit", 5) after_limit = event_context.get("before_limit", 5)
if before_limit < 0 or after_limit < 0: if before_limit < 0 or after_limit < 0:
raise InvalidLimit("Invalid context limit, the limit must be a " raise InvalidLimit(
"positive number") "Invalid context limit, the limit must be a " "positive number"
)
response_dict = await self.index.search( response_dict = await self.index.search(
term, term,
@ -762,7 +730,7 @@ class PanClient(AsyncClient):
order_by_recent=order_by_recent, order_by_recent=order_by_recent,
include_profile=include_profile, include_profile=include_profile,
before_limit=before_limit, before_limit=before_limit,
after_limit=after_limit after_limit=after_limit,
) )
if (event_context or include_state) and self.pan_conf.search_requests: if (event_context or include_state) and self.pan_conf.search_requests:
@ -771,14 +739,10 @@ class PanClient(AsyncClient):
event_dict, event_dict,
event_dict["result"]["room_id"], event_dict["result"]["room_id"],
event_dict["result"]["event_id"], event_dict["result"]["event_id"],
include_state include_state,
) )
if include_state: if include_state:
response_dict["state"] = state_cache response_dict["state"] = state_cache
return { return {"search_categories": {"room_events": response_dict}}
"search_categories": {
"room_events": response_dict
}
}

View File

@ -43,7 +43,7 @@ class PanConfigParser(configparser.ConfigParser):
"address": parse_address, "address": parse_address,
"url": parse_url, "url": parse_url,
"loglevel": parse_log_level, "loglevel": parse_log_level,
} },
) )
@ -59,9 +59,10 @@ def parse_url(v):
# type: (str) -> ParseResult # type: (str) -> ParseResult
value = urlparse(v) value = urlparse(v)
if value.scheme not in ('http', 'https'): if value.scheme not in ("http", "https"):
raise ValueError(f"Invalid URL scheme {value.scheme}. " raise ValueError(
f"Only HTTP(s) URLs are allowed") f"Invalid URL scheme {value.scheme}. " f"Only HTTP(s) URLs are allowed"
)
value.port value.port
return value return value
@ -124,8 +125,7 @@ class ServerConfig:
name = attr.ib(type=str) name = attr.ib(type=str)
homeserver = attr.ib(type=ParseResult) homeserver = attr.ib(type=ParseResult)
listen_address = attr.ib( listen_address = attr.ib(
type=Union[IPv4Address, IPv6Address], type=Union[IPv4Address, IPv6Address], default=ip_address("127.0.0.1")
default=ip_address("127.0.0.1")
) )
listen_port = attr.ib(type=int, default=8009) listen_port = attr.ib(type=int, default=8009)
proxy = attr.ib(type=str, default="") proxy = attr.ib(type=str, default="")
@ -183,8 +183,9 @@ class PanConfig:
homeserver = section.geturl("Homeserver") homeserver = section.geturl("Homeserver")
if not homeserver: if not homeserver:
raise PanConfigError(f"Homserver is not set for " raise PanConfigError(
f"section {section_name}") f"Homserver is not set for " f"section {section_name}"
)
listen_address = section.getaddress("ListenAddress") listen_address = section.getaddress("ListenAddress")
listen_port = section.getint("ListenPort") listen_port = section.getint("ListenPort")
@ -198,23 +199,29 @@ class PanConfig:
indexing_batch_size = section.getint("IndexingBatchSize") indexing_batch_size = section.getint("IndexingBatchSize")
if not 1 < indexing_batch_size <= 1000: if not 1 < indexing_batch_size <= 1000:
raise PanConfigError("The indexing batch size needs to be " raise PanConfigError(
"The indexing batch size needs to be "
"a positive integer between 1 and " "a positive integer between 1 and "
"1000") "1000"
)
history_fetch_delay = section.getint("HistoryFetchDelay") history_fetch_delay = section.getint("HistoryFetchDelay")
if not 100 < history_fetch_delay <= 10000: if not 100 < history_fetch_delay <= 10000:
raise PanConfigError("The history fetch delay needs to be " raise PanConfigError(
"The history fetch delay needs to be "
"a positive integer between 100 and " "a positive integer between 100 and "
"10000") "10000"
)
listen_tuple = (listen_address, listen_port) listen_tuple = (listen_address, listen_port)
if listen_tuple in listen_set: if listen_tuple in listen_set:
raise PanConfigError(f"The listen address/port combination" raise PanConfigError(
f"The listen address/port combination"
f" for section {section_name} was " f" for section {section_name} was "
f"already defined before.") f"already defined before."
)
listen_set.add(listen_tuple) listen_set.add(listen_tuple)
server_conf = ServerConfig( server_conf = ServerConfig(
@ -229,7 +236,7 @@ class PanConfig:
search_requests, search_requests,
index_encrypted_only, index_encrypted_only,
indexing_batch_size, indexing_batch_size,
history_fetch_delay / 1000 history_fetch_delay / 1000,
) )
self.servers[section_name] = server_conf self.servers[section_name] = server_conf

View File

@ -25,29 +25,39 @@ from aiohttp import ClientSession, web
from aiohttp.client_exceptions import ClientConnectionError, ContentTypeError from aiohttp.client_exceptions import ClientConnectionError, ContentTypeError
from jsonschema import ValidationError from jsonschema import ValidationError
from multidict import CIMultiDict from multidict import CIMultiDict
from nio import (Api, EncryptionError, LoginResponse, OlmTrustError, from nio import Api, EncryptionError, LoginResponse, OlmTrustError, SendRetryError
SendRetryError)
from pantalaimon.client import (SEARCH_TERMS_SCHEMA, InvalidLimit, from pantalaimon.client import (
InvalidOrderByError, PanClient, SEARCH_TERMS_SCHEMA,
UnknownRoomError, validate_json) InvalidLimit,
InvalidOrderByError,
PanClient,
UnknownRoomError,
validate_json,
)
from pantalaimon.index import InvalidQueryError from pantalaimon.index import InvalidQueryError
from pantalaimon.log import logger from pantalaimon.log import logger
from pantalaimon.store import ClientInfo, PanStore from pantalaimon.store import ClientInfo, PanStore
from pantalaimon.thread_messages import (AcceptSasMessage, CancelSasMessage, from pantalaimon.thread_messages import (
AcceptSasMessage,
CancelSasMessage,
CancelSendingMessage, CancelSendingMessage,
ConfirmSasMessage, DaemonResponse, ConfirmSasMessage,
DaemonResponse,
DeviceBlacklistMessage, DeviceBlacklistMessage,
DeviceUnblacklistMessage, DeviceUnblacklistMessage,
DeviceUnverifyMessage, DeviceUnverifyMessage,
DeviceVerifyMessage, DeviceVerifyMessage,
ExportKeysMessage, ImportKeysMessage, ExportKeysMessage,
SasMessage, SendAnywaysMessage, ImportKeysMessage,
SasMessage,
SendAnywaysMessage,
StartSasMessage, StartSasMessage,
UnverifiedDevicesSignal, UnverifiedDevicesSignal,
UnverifiedResponse, UnverifiedResponse,
UpdateDevicesMessage, UpdateDevicesMessage,
UpdateUsersMessage) UpdateUsersMessage,
)
CORS_HEADERS = { CORS_HEADERS = {
"Access-Control-Allow-Headers": ( "Access-Control-Allow-Headers": (
@ -76,11 +86,7 @@ class ProxyDaemon:
homeserver_url = attr.ib(init=False, default=attr.Factory(dict)) homeserver_url = attr.ib(init=False, default=attr.Factory(dict))
hostname = attr.ib(init=False, default=attr.Factory(dict)) hostname = attr.ib(init=False, default=attr.Factory(dict))
pan_clients = attr.ib(init=False, default=attr.Factory(dict)) pan_clients = attr.ib(init=False, default=attr.Factory(dict))
client_info = attr.ib( client_info = attr.ib(init=False, default=attr.Factory(dict), type=dict)
init=False,
default=attr.Factory(dict),
type=dict
)
default_session = attr.ib(init=False, default=None) default_session = attr.ib(init=False, default=None)
database_name = "pan.db" database_name = "pan.db"
@ -93,15 +99,16 @@ class ProxyDaemon:
for user_id, device_id in accounts: for user_id, device_id in accounts:
if self.conf.keyring: if self.conf.keyring:
token = keyring.get_password( token = keyring.get_password(
"pantalaimon", "pantalaimon", f"{user_id}-{device_id}-token"
f"{user_id}-{device_id}-token"
) )
else: else:
token = self.store.load_access_token(user_id, device_id) token = self.store.load_access_token(user_id, device_id)
if not token: if not token:
logger.warn(f"Not restoring client for {user_id} {device_id}, " logger.warn(
f"missing access token.") f"Not restoring client for {user_id} {device_id}, "
f"missing access token."
)
continue continue
logger.info(f"Restoring client for {user_id} {device_id}") logger.info(f"Restoring client for {user_id} {device_id}")
@ -116,7 +123,7 @@ class ProxyDaemon:
device_id, device_id,
store_path=self.data_dir, store_path=self.data_dir,
ssl=self.ssl, ssl=self.ssl,
proxy=self.proxy proxy=self.proxy,
) )
pan_client.user_id = user_id pan_client.user_id = user_id
pan_client.access_token = token pan_client.access_token = token
@ -136,7 +143,7 @@ class ProxyDaemon:
method, method,
self.homeserver_url + path, self.homeserver_url + path,
proxy=self.proxy, proxy=self.proxy,
ssl=self.ssl ssl=self.ssl,
) )
except ClientConnectionError: except ClientConnectionError:
return None return None
@ -155,12 +162,15 @@ class ProxyDaemon:
return None return None
if user_id not in self.pan_clients: if user_id not in self.pan_clients:
logger.warn(f"User {user_id} doesn't have a matching pan " logger.warn(
f"client.") f"User {user_id} doesn't have a matching pan " f"client."
)
return None return None
logger.info(f"Homeserver confirmed valid access token " logger.info(
f"for user {user_id}, caching info.") f"Homeserver confirmed valid access token "
f"for user {user_id}, caching info."
)
client_info = ClientInfo(user_id, access_token) client_info = ClientInfo(user_id, access_token)
self.client_info[access_token] = client_info self.client_info[access_token] = client_info
@ -173,12 +183,12 @@ class ProxyDaemon:
ret = client.verify_device(device) ret = client.verify_device(device)
if ret: if ret:
msg = (f"Device {device.id} of user " msg = (
f"{device.user_id} succesfully verified.") f"Device {device.id} of user " f"{device.user_id} succesfully verified."
)
await self.send_update_devcies() await self.send_update_devcies()
else: else:
msg = (f"Device {device.id} of user " msg = f"Device {device.id} of user " f"{device.user_id} already verified."
f"{device.user_id} already verified.")
logger.info(msg) logger.info(msg)
await self.send_response(message_id, client.user_id, "m.ok", msg) await self.send_response(message_id, client.user_id, "m.ok", msg)
@ -187,12 +197,13 @@ class ProxyDaemon:
ret = client.unverify_device(device) ret = client.unverify_device(device)
if ret: if ret:
msg = (f"Device {device.id} of user " msg = (
f"{device.user_id} succesfully unverified.") f"Device {device.id} of user "
f"{device.user_id} succesfully unverified."
)
await self.send_update_devcies() await self.send_update_devcies()
else: else:
msg = (f"Device {device.id} of user " msg = f"Device {device.id} of user " f"{device.user_id} already unverified."
f"{device.user_id} already unverified.")
logger.info(msg) logger.info(msg)
await self.send_response(message_id, client.user_id, "m.ok", msg) await self.send_response(message_id, client.user_id, "m.ok", msg)
@ -201,12 +212,15 @@ class ProxyDaemon:
ret = client.blacklist_device(device) ret = client.blacklist_device(device)
if ret: if ret:
msg = (f"Device {device.id} of user " msg = (
f"{device.user_id} succesfully blacklisted.") f"Device {device.id} of user "
f"{device.user_id} succesfully blacklisted."
)
await self.send_update_devcies() await self.send_update_devcies()
else: else:
msg = (f"Device {device.id} of user " msg = (
f"{device.user_id} already blacklisted.") f"Device {device.id} of user " f"{device.user_id} already blacklisted."
)
logger.info(msg) logger.info(msg)
await self.send_response(message_id, client.user_id, "m.ok", msg) await self.send_response(message_id, client.user_id, "m.ok", msg)
@ -215,12 +229,16 @@ class ProxyDaemon:
ret = client.unblacklist_device(device) ret = client.unblacklist_device(device)
if ret: if ret:
msg = (f"Device {device.id} of user " msg = (
f"{device.user_id} succesfully unblacklisted.") f"Device {device.id} of user "
f"{device.user_id} succesfully unblacklisted."
)
await self.send_update_devcies() await self.send_update_devcies()
else: else:
msg = (f"Device {device.id} of user " msg = (
f"{device.user_id} already unblacklisted.") f"Device {device.id} of user "
f"{device.user_id} already unblacklisted."
)
logger.info(msg) logger.info(msg)
await self.send_response(message_id, client.user_id, "m.ok", msg) await self.send_response(message_id, client.user_id, "m.ok", msg)
@ -239,23 +257,23 @@ class ProxyDaemon:
if isinstance( if isinstance(
message, message,
(DeviceVerifyMessage, DeviceUnverifyMessage, StartSasMessage, (
DeviceBlacklistMessage, DeviceUnblacklistMessage) DeviceVerifyMessage,
DeviceUnverifyMessage,
StartSasMessage,
DeviceBlacklistMessage,
DeviceUnblacklistMessage,
),
): ):
device = client.device_store[message.user_id].get( device = client.device_store[message.user_id].get(message.device_id, None)
message.device_id,
None
)
if not device: if not device:
msg = (f"No device found for {message.user_id} and " msg = (
f"{message.device_id}") f"No device found for {message.user_id} and " f"{message.device_id}"
)
await self.send_response( await self.send_response(
message.message_id, message.message_id, message.pan_user, "m.unknown_device", msg
message.pan_user,
"m.unknown_device",
msg
) )
logger.info(msg) logger.info(msg)
return return
@ -265,11 +283,9 @@ class ProxyDaemon:
elif isinstance(message, DeviceUnverifyMessage): elif isinstance(message, DeviceUnverifyMessage):
await self._unverify_device(message.message_id, client, device) await self._unverify_device(message.message_id, client, device)
elif isinstance(message, DeviceBlacklistMessage): elif isinstance(message, DeviceBlacklistMessage):
await self._blacklist_device(message.message_id, client, await self._blacklist_device(message.message_id, client, device)
device)
elif isinstance(message, DeviceUnblacklistMessage): elif isinstance(message, DeviceUnblacklistMessage):
await self._unblacklist_device(message.message_id, client, await self._unblacklist_device(message.message_id, client, device)
device)
elif isinstance(message, StartSasMessage): elif isinstance(message, StartSasMessage):
await client.start_sas(message, device) await client.start_sas(message, device)
@ -288,25 +304,21 @@ class ProxyDaemon:
try: try:
await client.export_keys(path, message.passphrase) await client.export_keys(path, message.passphrase)
except OSError as e: except OSError as e:
info_msg = (f"Error exporting keys for {client.user_id} to" info_msg = (
f" {path} {e}") f"Error exporting keys for {client.user_id} to" f" {path} {e}"
)
logger.info(info_msg) logger.info(info_msg)
await self.send_response( await self.send_response(
message.message_id, message.message_id, client.user_id, "m.os_error", str(e)
client.user_id,
"m.os_error",
str(e)
) )
else: else:
info_msg = (f"Succesfully exported keys for {client.user_id} " info_msg = (
f"to {path}") f"Succesfully exported keys for {client.user_id} " f"to {path}"
)
logger.info(info_msg) logger.info(info_msg)
await self.send_response( await self.send_response(
message.message_id, message.message_id, client.user_id, "m.ok", info_msg
client.user_id,
"m.ok",
info_msg
) )
elif isinstance(message, ImportKeysMessage): elif isinstance(message, ImportKeysMessage):
@ -316,37 +328,32 @@ class ProxyDaemon:
try: try:
await client.import_keys(path, message.passphrase) await client.import_keys(path, message.passphrase)
except (OSError, EncryptionError) as e: except (OSError, EncryptionError) as e:
info_msg = (f"Error importing keys for {client.user_id} " info_msg = (
f"from {path} {e}") f"Error importing keys for {client.user_id} " f"from {path} {e}"
)
logger.info(info_msg) logger.info(info_msg)
await self.send_response( await self.send_response(
message.message_id, message.message_id, client.user_id, "m.os_error", str(e)
client.user_id,
"m.os_error",
str(e)
) )
else: else:
info_msg = (f"Succesfully imported keys for {client.user_id} " info_msg = (
f"from {path}") f"Succesfully imported keys for {client.user_id} " f"from {path}"
)
logger.info(info_msg) logger.info(info_msg)
await self.send_response( await self.send_response(
message.message_id, message.message_id, client.user_id, "m.ok", info_msg
client.user_id,
"m.ok",
info_msg
) )
elif isinstance(message, UnverifiedResponse): elif isinstance(message, UnverifiedResponse):
client = self.pan_clients[message.pan_user] client = self.pan_clients[message.pan_user]
if message.room_id not in client.send_decision_queues: if message.room_id not in client.send_decision_queues:
msg = (f"No send request found for user {message.pan_user} " msg = (
f"and room {message.room_id}.") f"No send request found for user {message.pan_user} "
f"and room {message.room_id}."
)
await self.send_response( await self.send_response(
message.message_id, message.message_id, message.pan_user, "m.unknown_request", msg
message.pan_user,
"m.unknown_request",
msg
) )
return return
@ -365,10 +372,7 @@ class ProxyDaemon:
access_token = request.query.get("access_token", "") access_token = request.query.get("access_token", "")
if not access_token: if not access_token:
access_token = request.headers.get( access_token = request.headers.get("Authorization", "").strip("Bearer ")
"Authorization",
""
).strip("Bearer ")
return access_token return access_token
@ -404,7 +408,7 @@ class ProxyDaemon:
params=None, # type: CIMultiDict params=None, # type: CIMultiDict
data=None, # type: bytes data=None, # type: bytes
session=None, # type: aiohttp.ClientSession session=None, # type: aiohttp.ClientSession
token=None # type: str token=None, # type: str
): ):
# type: (...) -> aiohttp.ClientResponse # type: (...) -> aiohttp.ClientResponse
"""Forward the given request to our configured homeserver. """Forward the given request to our configured homeserver.
@ -454,16 +458,11 @@ class ProxyDaemon:
params=params, params=params,
headers=headers, headers=headers,
proxy=self.proxy, proxy=self.proxy,
ssl=self.ssl ssl=self.ssl,
) )
async def forward_to_web( async def forward_to_web(
self, self, request, params=None, data=None, session=None, token=None
request,
params=None,
data=None,
session=None,
token=None
): ):
"""Forward the given request and convert the response to a Response. """Forward the given request and convert the response to a Response.
@ -484,17 +483,13 @@ class ProxyDaemon:
""" """
try: try:
response = await self.forward_request( response = await self.forward_request(
request, request, params=params, data=data, session=session, token=token
params=params,
data=data,
session=session,
token=token
) )
return web.Response( return web.Response(
status=response.status, status=response.status,
content_type=response.content_type, content_type=response.content_type,
headers=CORS_HEADERS, headers=CORS_HEADERS,
body=await response.read() body=await response.read(),
) )
except ClientConnectionError as e: except ClientConnectionError as e:
return web.Response(status=500, text=str(e)) return web.Response(status=500, text=str(e))
@ -522,8 +517,10 @@ class ProxyDaemon:
self.store.save_server_user(self.name, user_id) self.store.save_server_user(self.name, user_id)
if user_id in self.pan_clients: if user_id in self.pan_clients:
logger.info(f"Background sync client already exists for {user_id}," logger.info(
f" not starting new one") f"Background sync client already exists for {user_id},"
f" not starting new one"
)
return return
pan_client = PanClient( pan_client = PanClient(
@ -535,7 +532,7 @@ class ProxyDaemon:
user_id, user_id,
store_path=self.data_dir, store_path=self.data_dir,
ssl=self.ssl, ssl=self.ssl,
proxy=self.proxy proxy=self.proxy,
) )
response = await pan_client.login(password, "pantalaimon") response = await pan_client.login(password, "pantalaimon")
@ -543,8 +540,7 @@ class ProxyDaemon:
await pan_client.close() await pan_client.close()
return return
logger.info(f"Succesfully started new background sync client for " logger.info(f"Succesfully started new background sync client for " f"{user_id}")
f"{user_id}")
await self.send_queue.put(UpdateUsersMessage()) await self.send_queue.put(UpdateUsersMessage())
@ -554,13 +550,11 @@ class ProxyDaemon:
keyring.set_password( keyring.set_password(
"pantalaimon", "pantalaimon",
f"{user_id}-{pan_client.device_id}-token", f"{user_id}-{pan_client.device_id}-token",
pan_client.access_token pan_client.access_token,
) )
else: else:
self.store.save_access_token( self.store.save_access_token(
user_id, user_id, pan_client.device_id, pan_client.access_token
pan_client.device_id,
pan_client.access_token
) )
pan_client.start_loop() pan_client.start_loop()
@ -580,7 +574,7 @@ class ProxyDaemon:
return web.json_response( return web.json_response(
{ {
"errcode": "M_NOT_JSON", "errcode": "M_NOT_JSON",
"error": "Request did not contain valid JSON." "error": "Request did not contain valid JSON.",
}, },
status=500, status=500,
) )
@ -605,25 +599,23 @@ class ProxyDaemon:
access_token = json_response.get("access_token", None) access_token = json_response.get("access_token", None)
if user_id and access_token: if user_id and access_token:
logger.info(f"User: {user} succesfully logged in, starting " logger.info(
f"a background sync client.") f"User: {user} succesfully logged in, starting "
await self.start_pan_client(access_token, user, user_id, f"a background sync client."
password) )
await self.start_pan_client(access_token, user, user_id, password)
return web.Response( return web.Response(
status=response.status, status=response.status,
content_type=response.content_type, content_type=response.content_type,
headers=CORS_HEADERS, headers=CORS_HEADERS,
body=await response.read() body=await response.read(),
) )
@property @property
def _missing_token(self): def _missing_token(self):
return web.json_response( return web.json_response(
{ {"errcode": "M_MISSING_TOKEN", "error": "Missing access token."},
"errcode": "M_MISSING_TOKEN",
"error": "Missing access token."
},
headers=CORS_HEADERS, headers=CORS_HEADERS,
status=401, status=401,
) )
@ -631,10 +623,7 @@ class ProxyDaemon:
@property @property
def _unknown_token(self): def _unknown_token(self):
return web.json_response( return web.json_response(
{ {"errcode": "M_UNKNOWN_TOKEN", "error": "Unrecognised access token."},
"errcode": "M_UNKNOWN_TOKEN",
"error": "Unrecognised access token."
},
headers=CORS_HEADERS, headers=CORS_HEADERS,
status=401, status=401,
) )
@ -642,10 +631,7 @@ class ProxyDaemon:
@property @property
def _not_json(self): def _not_json(self):
return web.json_response( return web.json_response(
{ {"errcode": "M_NOT_JSON", "error": "Request did not contain valid JSON."},
"errcode": "M_NOT_JSON",
"error": "Request did not contain valid JSON."
},
headers=CORS_HEADERS, headers=CORS_HEADERS,
status=400, status=400,
) )
@ -660,23 +646,18 @@ class ProxyDaemon:
while True: while True:
try: try:
logger.info("Trying to decrypt sync") logger.info("Trying to decrypt sync")
return decryption_method( return decryption_method(body, ignore_failures=False)
body,
ignore_failures=False
)
except EncryptionError: except EncryptionError:
logger.info("Error decrypting sync, waiting for next pan " logger.info("Error decrypting sync, waiting for next pan " "sync")
"sync")
await client.synced.wait(), await client.synced.wait(),
logger.info("Pan synced, retrying decryption.") logger.info("Pan synced, retrying decryption.")
try: try:
return await asyncio.wait_for( return await asyncio.wait_for(
decrypt_loop(client, body), decrypt_loop(client, body), timeout=self.decryption_timeout
timeout=self.decryption_timeout) )
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.info("Decryption attempt timed out, decrypting with " logger.info("Decryption attempt timed out, decrypting with " "failures")
"failures")
return decryption_method(body, ignore_failures=True) return decryption_method(body, ignore_failures=True)
async def sync(self, request): async def sync(self, request):
@ -705,9 +686,7 @@ class ProxyDaemon:
try: try:
response = await self.forward_request( response = await self.forward_request(
request, request, params=query, token=client.access_token
params=query,
token=client.access_token
) )
except ClientConnectionError as e: except ClientConnectionError as e:
return web.Response(status=500, text=str(e)) return web.Response(status=500, text=str(e))
@ -718,9 +697,7 @@ class ProxyDaemon:
json_response = await self.decrypt_body(client, json_response) json_response = await self.decrypt_body(client, json_response)
return web.json_response( return web.json_response(
json_response, json_response, headers=CORS_HEADERS, status=response.status
headers=CORS_HEADERS,
status=response.status,
) )
except (JSONDecodeError, ContentTypeError): except (JSONDecodeError, ContentTypeError):
pass pass
@ -729,7 +706,7 @@ class ProxyDaemon:
status=response.status, status=response.status,
content_type=response.content_type, content_type=response.content_type,
headers=CORS_HEADERS, headers=CORS_HEADERS,
body=await response.read() body=await response.read(),
) )
async def messages(self, request): async def messages(self, request):
@ -751,15 +728,11 @@ class ProxyDaemon:
try: try:
json_response = await response.json() json_response = await response.json()
json_response = await self.decrypt_body( json_response = await self.decrypt_body(
client, client, json_response, sync=False
json_response,
sync=False
) )
return web.json_response( return web.json_response(
json_response, json_response, headers=CORS_HEADERS, status=response.status
headers=CORS_HEADERS,
status=response.status,
) )
except (JSONDecodeError, ContentTypeError): except (JSONDecodeError, ContentTypeError):
pass pass
@ -768,7 +741,7 @@ class ProxyDaemon:
status=response.status, status=response.status,
content_type=response.content_type, content_type=response.content_type,
headers=CORS_HEADERS, headers=CORS_HEADERS,
body=await response.read() body=await response.read(),
) )
async def send_message(self, request): async def send_message(self, request):
@ -788,17 +761,11 @@ class ProxyDaemon:
room = client.rooms[room_id] room = client.rooms[room_id]
encrypt = room.encrypted encrypt = room.encrypted
except KeyError: except KeyError:
return await self.forward_to_web( return await self.forward_to_web(request, token=client.access_token)
request,
token=client.access_token
)
# The room isn't encrypted just forward the message. # The room isn't encrypted just forward the message.
if not encrypt: if not encrypt:
return await self.forward_to_web( return await self.forward_to_web(request, token=client.access_token)
request,
token=client.access_token
)
msgtype = request.match_info["event_type"] msgtype = request.match_info["event_type"]
txnid = request.match_info["txnid"] txnid = request.match_info["txnid"]
@ -810,14 +777,15 @@ class ProxyDaemon:
async def _send(ignore_unverified=False): async def _send(ignore_unverified=False):
try: try:
response = await client.room_send(room_id, msgtype, content, response = await client.room_send(
txnid, ignore_unverified) room_id, msgtype, content, txnid, ignore_unverified
)
return web.Response( return web.Response(
status=response.transport_response.status, status=response.transport_response.status,
content_type=response.transport_response.content_type, content_type=response.transport_response.content_type,
headers=CORS_HEADERS, headers=CORS_HEADERS,
body=await response.transport_response.read() body=await response.transport_response.read(),
) )
except ClientConnectionError as e: except ClientConnectionError as e:
return web.Response(status=500, text=str(e)) return web.Response(status=500, text=str(e))
@ -849,45 +817,40 @@ class ProxyDaemon:
client.send_decision_queues[room_id] = queue client.send_decision_queues[room_id] = queue
message = UnverifiedDevicesSignal( message = UnverifiedDevicesSignal(
client.user_id, client.user_id, room_id, room.display_name
room_id,
room.display_name
) )
await self.send_queue.put(message) await self.send_queue.put(message)
try: try:
response = await asyncio.wait_for( response = await asyncio.wait_for(
queue.get(), queue.get(), self.unverified_send_timeout
self.unverified_send_timeout
) )
if isinstance(response, CancelSendingMessage): if isinstance(response, CancelSendingMessage):
# The send was canceled notify the client that sent the # The send was canceled notify the client that sent the
# request about this. # request about this.
info_msg = (f"Canceled message sending for room " info_msg = (
f"{room.display_name} ({room_id}).") f"Canceled message sending for room "
f"{room.display_name} ({room_id})."
)
logger.info(info_msg) logger.info(info_msg)
await self.send_response( await self.send_response(
response.message_id, response.message_id, client.user_id, "m.ok", info_msg
client.user_id,
"m.ok",
info_msg
) )
return web.Response(status=503, text=str(e)) return web.Response(status=503, text=str(e))
elif isinstance(response, SendAnywaysMessage): elif isinstance(response, SendAnywaysMessage):
# We are sending and ignoring devices along the way. # We are sending and ignoring devices along the way.
info_msg = (f"Ignoring unverified devices and sending " info_msg = (
f"Ignoring unverified devices and sending "
f"message to room " f"message to room "
f"{room.display_name} ({room_id}).") f"{room.display_name} ({room_id})."
)
logger.info(info_msg) logger.info(info_msg)
await self.send_response( await self.send_response(
response.message_id, response.message_id, client.user_id, "m.ok", info_msg
client.user_id,
"m.ok",
info_msg
) )
ret = await _send(True) ret = await _send(True)
@ -900,10 +863,12 @@ class ProxyDaemon:
return web.Response( return web.Response(
status=503, status=503,
text=(f"Room contains unverified devices and no " text=(
f"Room contains unverified devices and no "
f"action was taken for " f"action was taken for "
f"{self.unverified_send_timeout} seconds, " f"{self.unverified_send_timeout} seconds, "
f"request timed out") f"request timed out"
),
) )
finally: finally:
@ -922,10 +887,7 @@ class ProxyDaemon:
sanitized_content = self.sanitize_filter(content) sanitized_content = self.sanitize_filter(content)
return await self.forward_to_web( return await self.forward_to_web(request, data=json.dumps(sanitized_content))
request,
data=json.dumps(sanitized_content)
)
async def search_opts(self, request): async def search_opts(self, request):
return web.json_response({}, headers=CORS_HEADERS) return web.json_response({}, headers=CORS_HEADERS)
@ -950,10 +912,7 @@ class ProxyDaemon:
validate_json(content, SEARCH_TERMS_SCHEMA) validate_json(content, SEARCH_TERMS_SCHEMA)
except ValidationError: except ValidationError:
return web.json_response( return web.json_response(
{ {"errcode": "M_BAD_JSON", "error": "Invalid search query"},
"errcode": "M_BAD_JSON",
"error": "Invalid search query"
},
headers=CORS_HEADERS, headers=CORS_HEADERS,
status=400, status=400,
) )
@ -982,21 +941,14 @@ class ProxyDaemon:
result = await client.search(content) result = await client.search(content)
except (InvalidOrderByError, InvalidLimit, InvalidQueryError) as e: except (InvalidOrderByError, InvalidLimit, InvalidQueryError) as e:
return web.json_response( return web.json_response(
{ {"errcode": "M_INVALID_PARAM", "error": str(e)},
"errcode": "M_INVALID_PARAM",
"error": str(e)
},
headers=CORS_HEADERS, headers=CORS_HEADERS,
status=400, status=400,
) )
except UnknownRoomError: except UnknownRoomError:
return await self.forward_to_web(request) return await self.forward_to_web(request)
return web.json_response( return web.json_response(result, headers=CORS_HEADERS, status=200)
result,
headers=CORS_HEADERS,
status=200
)
async def shutdown(self, _): async def shutdown(self, _):
"""Shut the daemon down closing all the client sessions it has. """Shut the daemon down closing all the client sessions it has.

View File

@ -21,10 +21,14 @@ from typing import Any, Dict, List, Optional, Tuple
import attr import attr
import tantivy import tantivy
from nio import (RoomEncryptedMedia, RoomMessageMedia, RoomMessageText, from nio import (
RoomNameEvent, RoomTopicEvent) RoomEncryptedMedia,
from peewee import (SQL, DateTimeField, ForeignKeyField, Model, SqliteDatabase, RoomMessageMedia,
TextField) RoomMessageText,
RoomNameEvent,
RoomTopicEvent,
)
from peewee import SQL, DateTimeField, ForeignKeyField, Model, SqliteDatabase, TextField
from pantalaimon.store import use_database from pantalaimon.store import use_database
@ -65,22 +69,15 @@ class Event(Model):
source = DictField() source = DictField()
profile = ForeignKeyField( profile = ForeignKeyField(model=Profile, column_name="profile_id")
model=Profile,
column_name="profile_id",
)
class Meta: class Meta:
constraints = [SQL("UNIQUE(event_id, room_id, sender, profile_id)")] constraints = [SQL("UNIQUE(event_id, room_id, sender, profile_id)")]
class UserMessages(Model): class UserMessages(Model):
user = ForeignKeyField( user = ForeignKeyField(model=StoreUser, column_name="user_id")
model=StoreUser, event = ForeignKeyField(model=Event, column_name="event_id")
column_name="user_id")
event = ForeignKeyField(
model=Event,
column_name="event_id")
@attr.s @attr.s
@ -91,17 +88,11 @@ class MessageStore:
database = attr.ib(type=SqliteDatabase, init=False) database = attr.ib(type=SqliteDatabase, init=False)
database_path = attr.ib(type=str, init=False) database_path = attr.ib(type=str, init=False)
models = [ models = [StoreUser, Event, Profile, UserMessages]
StoreUser,
Event,
Profile,
UserMessages
]
def __attrs_post_init__(self): def __attrs_post_init__(self):
self.database_path = os.path.join( self.database_path = os.path.join(
os.path.abspath(self.store_path), os.path.abspath(self.store_path), self.database_name
self.database_name
) )
self.database = self._create_database() self.database = self._create_database()
@ -112,21 +103,22 @@ class MessageStore:
def _create_database(self): def _create_database(self):
return SqliteDatabase( return SqliteDatabase(
self.database_path, self.database_path, pragmas={"foreign_keys": 1, "secure_delete": 1}
pragmas={
"foreign_keys": 1,
"secure_delete": 1,
}
) )
@use_database @use_database
def event_in_store(self, event_id, room_id): def event_in_store(self, event_id, room_id):
user, _ = StoreUser.get_or_create(user_id=self.user) user, _ = StoreUser.get_or_create(user_id=self.user)
query = Event.select().join(UserMessages).where( query = (
(Event.room_id == room_id) & Event.select()
(Event.event_id == event_id) & .join(UserMessages)
(UserMessages.user == user) .where(
).execute() (Event.room_id == room_id)
& (Event.event_id == event_id)
& (UserMessages.user == user)
)
.execute()
)
for _ in query: for _ in query:
return True return True
@ -137,32 +129,29 @@ class MessageStore:
user, _ = StoreUser.get_or_create(user_id=self.user) user, _ = StoreUser.get_or_create(user_id=self.user)
profile_id, _ = Profile.get_or_create( profile_id, _ = Profile.get_or_create(
user_id=event.sender, user_id=event.sender, display_name=display_name, avatar_url=avatar_url
display_name=display_name,
avatar_url=avatar_url
) )
event_source = event.source event_source = event.source
event_source["room_id"] = room_id event_source["room_id"] = room_id
event_id = Event.insert( event_id = (
Event.insert(
event_id=event.event_id, event_id=event.event_id,
sender=event.sender, sender=event.sender,
date=datetime.datetime.fromtimestamp( date=datetime.datetime.fromtimestamp(event.server_timestamp / 1000),
event.server_timestamp / 1000
),
room_id=room_id, room_id=room_id,
source=event_source, source=event_source,
profile=profile_id profile=profile_id,
).on_conflict_ignore().execute() )
.on_conflict_ignore()
.execute()
)
if event_id <= 0: if event_id <= 0:
return None return None
_, created = UserMessages.get_or_create( _, created = UserMessages.get_or_create(user=user, event=event_id)
user=user,
event=event_id,
)
if created: if created:
return event_id return event_id
@ -173,24 +162,36 @@ class MessageStore:
context = {} context = {}
if before > 0: if before > 0:
query = Event.select().join(UserMessages).where( query = (
(Event.date <= event.date) & Event.select()
(Event.room_id == event.room_id) & .join(UserMessages)
(Event.id != event.id) & .where(
(UserMessages.user == user) (Event.date <= event.date)
).order_by(Event.date.desc()).limit(before) & (Event.room_id == event.room_id)
& (Event.id != event.id)
& (UserMessages.user == user)
)
.order_by(Event.date.desc())
.limit(before)
)
context["events_before"] = [e.source for e in query] context["events_before"] = [e.source for e in query]
else: else:
context["events_before"] = [] context["events_before"] = []
if after > 0: if after > 0:
query = Event.select().join(UserMessages).where( query = (
(Event.date >= event.date) & Event.select()
(Event.room_id == event.room_id) & .join(UserMessages)
(Event.id != event.id) & .where(
(UserMessages.user == user) (Event.date >= event.date)
).order_by(Event.date).limit(after) & (Event.room_id == event.room_id)
& (Event.id != event.id)
& (UserMessages.user == user)
)
.order_by(Event.date)
.limit(after)
)
context["events_after"] = [e.source for e in query] context["events_after"] = [e.source for e in query]
else: else:
@ -205,7 +206,7 @@ class MessageStore:
include_profile=False, # type: bool include_profile=False, # type: bool
order_by_recent=False, # type: bool order_by_recent=False, # type: bool
before=0, # type: int before=0, # type: int
after=0 # type: int after=0, # type: int
): ):
# type: (...) -> Dict[Any, Any] # type: (...) -> Dict[Any, Any]
user, _ = StoreUser.get_or_create(user_id=self.user) user, _ = StoreUser.get_or_create(user_id=self.user)
@ -213,13 +214,13 @@ class MessageStore:
search_dict = {r[1]: r[0] for r in search_result} search_dict = {r[1]: r[0] for r in search_result}
columns = list(search_dict.keys()) columns = list(search_dict.keys())
result_dict = { result_dict = {"results": []}
"results": []
}
query = UserMessages.select().where( query = (
(UserMessages.user_id == user) & (UserMessages.event.in_(columns)) UserMessages.select()
).execute() .where((UserMessages.user_id == user) & (UserMessages.event.in_(columns)))
.execute()
)
for message in query: for message in query:
@ -228,7 +229,7 @@ class MessageStore:
event_dict = { event_dict = {
"rank": 1 if order_by_recent else search_dict[event.id], "rank": 1 if order_by_recent else search_dict[event.id],
"result": event.source, "result": event.source,
"context": {} "context": {},
} }
if include_profile: if include_profile:
@ -256,8 +257,17 @@ def sanitize_room_id(room_id):
class Searcher: class Searcher:
def __init__(self, index, body_field, name_field, topic_field, def __init__(
column_field, room_field, timestamp_field, searcher): self,
index,
body_field,
name_field,
topic_field,
column_field,
room_field,
timestamp_field,
searcher,
):
self._index = index self._index = index
self._searcher = searcher self._searcher = searcher
@ -268,8 +278,7 @@ class Searcher:
self.room_field = room_field self.room_field = room_field
self.timestamp_field = timestamp_field self.timestamp_field = timestamp_field
def search(self, search_term, room=None, max_results=10, def search(self, search_term, room=None, max_results=10, order_by_recent=False):
order_by_recent=False):
# type (str, str, int, bool) -> List[int, int] # type (str, str, int, bool) -> List[int, int]
"""Search for events in the index. """Search for events in the index.
@ -277,21 +286,13 @@ class Searcher:
""" """
queryparser = tantivy.QueryParser.for_index( queryparser = tantivy.QueryParser.for_index(
self._index, self._index,
[ [self.body_field, self.name_field, self.topic_field, self.room_field],
self.body_field,
self.name_field,
self.topic_field,
self.room_field
]
) )
# This currently supports only a single room since the query parser # This currently supports only a single room since the query parser
# doesn't seem to work with multiple room fields here. # doesn't seem to work with multiple room fields here.
if room: if room:
query_string = "{} AND room:{}".format( query_string = "{} AND room:{}".format(search_term, sanitize_room_id(room))
search_term,
sanitize_room_id(room)
)
else: else:
query_string = search_term query_string = search_term
@ -301,8 +302,9 @@ class Searcher:
raise InvalidQueryError(f"Invalid search term: {search_term}") raise InvalidQueryError(f"Invalid search term: {search_term}")
if order_by_recent: if order_by_recent:
collector = tantivy.TopDocs(max_results, collector = tantivy.TopDocs(
order_by_field=self.timestamp_field) max_results, order_by_field=self.timestamp_field
)
else: else:
collector = tantivy.TopDocs(max_results) collector = tantivy.TopDocs(max_results)
@ -329,16 +331,11 @@ class Index:
self.timestamp_field = schema_builder.add_unsigned_field( self.timestamp_field = schema_builder.add_unsigned_field(
"server_timestamp", fast="single" "server_timestamp", fast="single"
) )
self.date_field = schema_builder.add_date_field( self.date_field = schema_builder.add_date_field("message_date")
"message_date"
)
self.room_field = schema_builder.add_facet_field("room") self.room_field = schema_builder.add_facet_field("room")
self.column_field = schema_builder.add_unsigned_field( self.column_field = schema_builder.add_unsigned_field(
"database_column", "database_column", indexed=True, stored=True, fast="single"
indexed=True,
stored=True,
fast="single"
) )
schema = schema_builder.build() schema = schema_builder.build()
@ -359,7 +356,7 @@ class Index:
doc.add_facet(self.room_field, room_facet) doc.add_facet(self.room_field, room_facet)
doc.add_date( doc.add_date(
self.date_field, self.date_field,
datetime.datetime.fromtimestamp(event.server_timestamp / 1000) datetime.datetime.fromtimestamp(event.server_timestamp / 1000),
) )
doc.add_unsigned(self.timestamp_field, event.server_timestamp) doc.add_unsigned(self.timestamp_field, event.server_timestamp)
@ -389,7 +386,7 @@ class Index:
self.column_field, self.column_field,
self.room_field, self.room_field,
self.timestamp_field, self.timestamp_field,
self.reader.searcher() self.reader.searcher(),
) )
@ -430,10 +427,7 @@ class IndexStore:
with store.database.bind_ctx(store.models): with store.database.bind_ctx(store.models):
with store.database.atomic(): with store.database.atomic():
for item in event_queue: for item in event_queue:
column_id = store.save_event( column_id = store.save_event(item.event, item.room_id)
item.event,
item.room_id,
)
if column_id: if column_id:
index.add_event(column_id, item.event, item.room_id) index.add_event(column_id, item.event, item.room_id)
@ -451,10 +445,7 @@ class IndexStore:
async with self.write_lock: async with self.write_lock:
write_func = partial( write_func = partial(
IndexStore.write_events, IndexStore.write_events, self.store, self.index, event_queue
self.store,
self.index,
event_queue
) )
await loop.run_in_executor(None, write_func) await loop.run_in_executor(None, write_func)
@ -469,7 +460,7 @@ class IndexStore:
order_by_recent=False, # type: bool order_by_recent=False, # type: bool
include_profile=False, # type: bool include_profile=False, # type: bool
before_limit=0, # type: int before_limit=0, # type: int
after_limit=0 # type: int after_limit=0, # type: int
): ):
# type: (...) -> Dict[Any, Any] # type: (...) -> Dict[Any, Any]
"""Search the indexstore for an event.""" """Search the indexstore for an event."""
@ -480,9 +471,13 @@ class IndexStore:
# the number of CPUs and the semaphore has the same counter value. # the number of CPUs and the semaphore has the same counter value.
async with self.read_semaphore: async with self.read_semaphore:
searcher = self.index.searcher() searcher = self.index.searcher()
search_func = partial(searcher.search, search_term, room=room, search_func = partial(
searcher.search,
search_term,
room=room,
max_results=max_results, max_results=max_results,
order_by_recent=order_by_recent) order_by_recent=order_by_recent,
)
result = await loop.run_in_executor(None, search_func) result = await loop.run_in_executor(None, search_func)
@ -492,7 +487,7 @@ class IndexStore:
include_profile, include_profile,
order_by_recent, order_by_recent,
before_limit, before_limit,
after_limit after_limit,
) )
search_result = await loop.run_in_executor(None, load_event_func) search_result = await loop.run_in_executor(None, load_event_func)

View File

@ -53,40 +53,39 @@ async def init(data_dir, server_conf, send_queue, recv_queue):
send_queue=send_queue, send_queue=send_queue,
recv_queue=recv_queue, recv_queue=recv_queue,
proxy=server_conf.proxy.geturl() if server_conf.proxy else None, proxy=server_conf.proxy.geturl() if server_conf.proxy else None,
ssl=None if server_conf.ssl is True else False ssl=None if server_conf.ssl is True else False,
) )
app = web.Application() app = web.Application()
app.add_routes([ app.add_routes(
[
web.post("/_matrix/client/r0/login", proxy.login), web.post("/_matrix/client/r0/login", proxy.login),
web.get("/_matrix/client/r0/sync", proxy.sync), web.get("/_matrix/client/r0/sync", proxy.sync),
web.get("/_matrix/client/r0/rooms/{room_id}/messages", proxy.messages), web.get("/_matrix/client/r0/rooms/{room_id}/messages", proxy.messages),
web.put( web.put(
r"/_matrix/client/r0/rooms/{room_id}/send/{event_type}/{txnid}", r"/_matrix/client/r0/rooms/{room_id}/send/{event_type}/{txnid}",
proxy.send_message proxy.send_message,
), ),
web.post("/_matrix/client/r0/user/{user_id}/filter", proxy.filter), web.post("/_matrix/client/r0/user/{user_id}/filter", proxy.filter),
web.post("/_matrix/client/r0/search", proxy.search), web.post("/_matrix/client/r0/search", proxy.search),
web.options("/_matrix/client/r0/search", proxy.search_opts), web.options("/_matrix/client/r0/search", proxy.search_opts),
]) ]
)
app.router.add_route("*", "/" + "{proxyPath:.*}", proxy.router) app.router.add_route("*", "/" + "{proxyPath:.*}", proxy.router)
app.on_shutdown.append(proxy.shutdown) app.on_shutdown.append(proxy.shutdown)
runner = web.AppRunner(app) runner = web.AppRunner(app)
await runner.setup() await runner.setup()
site = web.TCPSite( site = web.TCPSite(runner, str(server_conf.listen_address), server_conf.listen_port)
runner,
str(server_conf.listen_address),
server_conf.listen_port
)
return proxy, runner, site return proxy, runner, site
async def message_router(receive_queue, send_queue, proxies): async def message_router(receive_queue, send_queue, proxies):
"""Find the recipient of a message and forward it to the right proxy.""" """Find the recipient of a message and forward it to the right proxy."""
def find_proxy_by_user(user): def find_proxy_by_user(user):
# type: (str) -> Optional[ProxyDaemon] # type: (str) -> Optional[ProxyDaemon]
for proxy in proxies: for proxy in proxies:
@ -109,35 +108,28 @@ async def message_router(receive_queue, send_queue, proxies):
msg = f"No pan client found for {message.pan_user}." msg = f"No pan client found for {message.pan_user}."
logger.warn(msg) logger.warn(msg)
await send_info( await send_info(
message.message_id, message.message_id, message.pan_user, "m.unknown_client", msg
message.pan_user,
"m.unknown_client",
msg
) )
await proxy.receive_message(message) await proxy.receive_message(message)
@click.command( @click.command(
help=("pantalaimon is a reverse proxy for matrix homeservers that " help=(
"pantalaimon is a reverse proxy for matrix homeservers that "
"transparently encrypts and decrypts messages for clients that " "transparently encrypts and decrypts messages for clients that "
"connect to pantalaimon.") "connect to pantalaimon."
)
) )
@click.version_option(version="0.1", prog_name="pantalaimon") @click.version_option(version="0.1", prog_name="pantalaimon")
@click.option("--log-level", type=click.Choice([ @click.option(
"error", "--log-level",
"warning", type=click.Choice(["error", "warning", "info", "debug"]),
"info", default=None,
"debug" )
]), default=None)
@click.option("-c", "--config", type=click.Path(exists=True)) @click.option("-c", "--config", type=click.Path(exists=True))
@click.pass_context @click.pass_context
def main( def main(context, log_level, config):
context,
log_level,
config
):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
conf_dir = user_config_dir("pantalaimon", "") conf_dir = user_config_dir("pantalaimon", "")
@ -172,12 +164,7 @@ def main(
for server_conf in pan_conf.servers.values(): for server_conf in pan_conf.servers.values():
proxy, runner, site = loop.run_until_complete( proxy, runner, site = loop.run_until_complete(
init( init(data_dir, server_conf, pan_queue.async_q, ui_queue.async_q)
data_dir,
server_conf,
pan_queue.async_q,
ui_queue.async_q,
)
) )
servers.append((proxy, runner, site)) servers.append((proxy, runner, site))
proxies.append(proxy) proxies.append(proxy)
@ -185,14 +172,12 @@ def main(
except keyring.errors.KeyringError as e: except keyring.errors.KeyringError as e:
context.fail(f"Error initializing keyring: {e}") context.fail(f"Error initializing keyring: {e}")
glib_thread = GlibT(pan_queue.sync_q, ui_queue.sync_q, data_dir, glib_thread = GlibT(
pan_conf.servers.values(), pan_conf) pan_queue.sync_q, ui_queue.sync_q, data_dir, pan_conf.servers.values(), pan_conf
glib_fut = loop.run_in_executor(
None,
glib_thread.run
) )
glib_fut = loop.run_in_executor(None, glib_thread.run)
async def wait_for_glib(glib_thread, fut): async def wait_for_glib(glib_thread, fut):
glib_thread.stop() glib_thread.stop()
await fut await fut
@ -211,8 +196,10 @@ def main(
try: try:
for proxy, _, site in servers: for proxy, _, site in servers:
click.echo(f"======== Starting daemon for homeserver " click.echo(
f"{proxy.name} on {site.name} ========") f"======== Starting daemon for homeserver "
f"{proxy.name} on {site.name} ========"
)
loop.run_until_complete(site.start()) loop.run_until_complete(site.start())
click.echo("(Press CTRL+C to quit)") click.echo("(Press CTRL+C to quit)")

View File

@ -43,15 +43,12 @@ class PanctlArgParse(argparse.ArgumentParser):
pass pass
def error(self, message): def error(self, message):
message = ( message = f"Error: {message} " f"(see help)"
f"Error: {message} "
f"(see help)"
)
print(message) print(message)
raise ParseError raise ParseError
class PanctlParser(): class PanctlParser:
def __init__(self, commands): def __init__(self, commands):
self.commands = commands self.commands = commands
self.parser = PanctlArgParse() self.parser = PanctlArgParse()
@ -194,19 +191,13 @@ class PanCompleter(Completer):
return "" return ""
def complete_key_file_cmds( def complete_key_file_cmds(
self, self, document, complete_event, command, last_word, words
document,
complete_event,
command,
last_word,
words
): ):
if len(words) == 2: if len(words) == 2:
return self.complete_pan_users(last_word) return self.complete_pan_users(last_word)
elif len(words) == 3: elif len(words) == 3:
return self.path_completer.get_completions( return self.path_completer.get_completions(
Document(last_word), Document(last_word), complete_event
complete_event
) )
return "" return ""
@ -264,16 +255,9 @@ class PanCompleter(Completer):
]: ]:
return self.complete_verification(command, last_word, words) return self.complete_verification(command, last_word, words)
elif command in [ elif command in ["export-keys", "import-keys"]:
"export-keys",
"import-keys",
]:
return self.complete_key_file_cmds( return self.complete_key_file_cmds(
document, document, complete_event, command, last_word, words
complete_event,
command,
last_word,
words
) )
elif command in ["send-anyways", "cancel-sending"]: elif command in ["send-anyways", "cancel-sending"]:
@ -300,7 +284,7 @@ def grouper(iterable, n, fillvalue=None):
def partition_key(key): def partition_key(key):
groups = grouper(key, 4, " ") groups = grouper(key, 4, " ")
return ' '.join(''.join(g) for g in groups) return " ".join("".join(g) for g in groups)
def get_color(string): def get_color(string):
@ -330,37 +314,59 @@ class PanCtl:
command_help = { command_help = {
"help": "Display help about commands.", "help": "Display help about commands.",
"list-servers": ("List the configured homeservers and pan users on " "list-servers": (
"each homeserver."), "List the configured homeservers and pan users on " "each homeserver."
"list-devices": ("List the devices of a user that are known to the " ),
"pan-user."), "list-devices": (
"start-verification": ("Start an interactive key verification between " "List the devices of a user that are known to the " "pan-user."
"the given pan-user and user."), ),
"accept-verification": ("Accept an interactive key verification that " "start-verification": (
"Start an interactive key verification between "
"the given pan-user and user."
),
"accept-verification": (
"Accept an interactive key verification that "
"the given user has started with our given " "the given user has started with our given "
"pan-user."), "pan-user."
"cancel-verification": ("Cancel an interactive key verification " ),
"between the given pan-user and user."), "cancel-verification": (
"confirm-verification": ("Confirm that the short authentication " "Cancel an interactive key verification "
"between the given pan-user and user."
),
"confirm-verification": (
"Confirm that the short authentication "
"string of the interactive key verification " "string of the interactive key verification "
"with the given pan-user and user is " "with the given pan-user and user is "
"matching."), "matching."
),
"verify-device": ("Manually mark the given device as verified."), "verify-device": ("Manually mark the given device as verified."),
"unverify-device": ("Mark a previously verified device of the given " "unverify-device": (
"user as unverified."), "Mark a previously verified device of the given " "user as unverified."
"blacklist-device": ("Manually mark the given device of the given " ),
"user as blacklisted."), "blacklist-device": (
"unblacklist-device": ("Mark a previously blacklisted device of the " "Manually mark the given device of the given " "user as blacklisted."
"given user as unblacklisted."), ),
"send-anyways": ("Send a room message despite having unverified " "unblacklist-device": (
"Mark a previously blacklisted device of the "
"given user as unblacklisted."
),
"send-anyways": (
"Send a room message despite having unverified "
"devices in the room and mark the devices as " "devices in the room and mark the devices as "
"ignored."), "ignored."
"cancel-sending": ("Cancel the send of a room message in a room that " ),
"contains unverified devices"), "cancel-sending": (
"import-keys": ("Import end-to-end encryption keys from the given " "Cancel the send of a room message in a room that "
"file for the given pan-user."), "contains unverified devices"
"export-keys": ("Export end-to-end encryption keys to the given file " ),
"for the given pan-user."), "import-keys": (
"Import end-to-end encryption keys from the given "
"file for the given pan-user."
),
"export-keys": (
"Export end-to-end encryption keys to the given file "
"for the given pan-user."
),
} }
commands = [ commands = [
@ -404,10 +410,12 @@ class PanCtl:
def unverified_devices(self, pan_user, room_id, display_name): def unverified_devices(self, pan_user, room_id, display_name):
self.completer.rooms[pan_user].add(room_id) self.completer.rooms[pan_user].add(room_id)
print(f"Error sending message for user {pan_user}, " print(
f"Error sending message for user {pan_user}, "
f"there are unverified devices in the room {display_name} " f"there are unverified devices in the room {display_name} "
f"({room_id}).\nUse the send-anyways or cancel-sending commands " f"({room_id}).\nUse the send-anyways or cancel-sending commands "
f"to ignore the devices or cancel the sending.") f"to ignore the devices or cancel the sending."
)
def show_response(self, response_id, pan_user, message): def show_response(self, response_id, pan_user, message):
if response_id not in self.own_message_ids: if response_id not in self.own_message_ids:
@ -418,14 +426,18 @@ class PanCtl:
print(message["message"]) print(message["message"])
def sas_done(self, pan_user, user_id, device_id, _): def sas_done(self, pan_user, user_id, device_id, _):
print(f"Device {device_id} of user {user_id}" print(
f" succesfully verified for pan user {pan_user}.") f"Device {device_id} of user {user_id}"
f" succesfully verified for pan user {pan_user}."
)
def show_sas_invite(self, pan_user, user_id, device_id, _): def show_sas_invite(self, pan_user, user_id, device_id, _):
print(f"{user_id} has started an interactive device " print(
f"{user_id} has started an interactive device "
f"verification for his device {device_id} with pan user " f"verification for his device {device_id} with pan user "
f"{pan_user}\n" f"{pan_user}\n"
f"Accept the invitation with the accept-verification command.") f"Accept the invitation with the accept-verification command."
)
# The emoji printing logic was taken from weechat-matrix and was written by # The emoji printing logic was taken from weechat-matrix and was written by
# dkasak. # dkasak.
@ -443,33 +455,26 @@ class PanCtl:
# that they are rendered with coloured glyphs. For these, we # that they are rendered with coloured glyphs. For these, we
# need to add an extra space after them so that they are # need to add an extra space after them so that they are
# rendered properly in weechat. # rendered properly in weechat.
variation_selector_emojis = [ variation_selector_emojis = ["☁️", "❤️", "☂️", "✏️", "✂️", "☎️", "✈️"]
'☁️',
'❤️',
'☂️',
'✏️',
'✂️',
'☎️',
'✈️'
]
if emoji in variation_selector_emojis: if emoji in variation_selector_emojis:
emoji += " " emoji += " "
# This is a trick to account for the fact that emojis are wider # This is a trick to account for the fact that emojis are wider
# than other monospace characters. # than other monospace characters.
placeholder = '.' * emoji_width placeholder = "." * emoji_width
return placeholder.center(width).replace(placeholder, emoji) return placeholder.center(width).replace(placeholder, emoji)
emoji_str = u"".join(center_emoji(e, centered_width) emoji_str = "".join(center_emoji(e, centered_width) for e in emojis)
for e in emojis) desc = "".join(d.center(centered_width) for d in descriptions)
desc = u"".join(d.center(centered_width) for d in descriptions) short_string = "\n".join([emoji_str, desc])
short_string = u"\n".join([emoji_str, desc])
print(f"Short authentication string for pan " print(
f"Short authentication string for pan "
f"user {pan_user} from {user_id} via " f"user {pan_user} from {user_id} via "
f"{device_id}:\n{short_string}") f"{device_id}:\n{short_string}"
)
def list_servers(self): def list_servers(self):
"""List the daemons users.""" """List the daemons users."""
@ -480,9 +485,7 @@ class PanCtl:
for server, server_users in servers.items(): for server, server_users in servers.items():
server_c = get_color(server) server_c = get_color(server)
print_formatted_text(HTML( print_formatted_text(HTML(f" - Name: <{server_c}>{server}</{server_c}>"))
f" - Name: <{server_c}>{server}</{server_c}>"
))
user_list = [] user_list = []
@ -490,8 +493,10 @@ class PanCtl:
user_c = get_color(user) user_c = get_color(user)
device_c = get_color(device) device_c = get_color(device)
user_list.append(f" - <{user_c}>{user}</{user_c}> " user_list.append(
f"<{device_c}>{device}</{device_c}>") f" - <{user_c}>{user}</{user_c}> "
f"<{device_c}>{device}</{device_c}>"
)
if user_list: if user_list:
print(f" - Pan users:") print(f" - Pan users:")
@ -501,9 +506,7 @@ class PanCtl:
def list_devices(self, args): def list_devices(self, args):
devices = self.devices.ListUserDevices(args.pan_user, args.user_id) devices = self.devices.ListUserDevices(args.pan_user, args.user_id)
print_formatted_text( print_formatted_text(HTML(f"Devices for user <b>{args.user_id}</b>:"))
HTML(f"Devices for user <b>{args.user_id}</b>:")
)
for device in devices: for device in devices:
if device["trust_state"] == "verified": if device["trust_state"] == "verified":
@ -517,7 +520,8 @@ class PanCtl:
key = partition_key(device["ed25519"]) key = partition_key(device["ed25519"])
color = get_color(device["device_id"]) color = get_color(device["device_id"])
print_formatted_text(HTML( print_formatted_text(
HTML(
f" - Display name: " f" - Display name: "
f"{device['device_display_name']}\n" f"{device['device_display_name']}\n"
f" - Device id: " f" - Device id: "
@ -526,7 +530,8 @@ class PanCtl:
f"<ansiyellow>{key}</ansiyellow>\n" f"<ansiyellow>{key}</ansiyellow>\n"
f" - Trust state: " f" - Trust state: "
f"{trust_state}" f"{trust_state}"
)) )
)
async def loop(self): async def loop(self):
"""Event loop for panctl.""" """Event loop for panctl."""
@ -559,105 +564,83 @@ class PanCtl:
elif command == "import-keys": elif command == "import-keys":
self.own_message_ids.append( self.own_message_ids.append(
self.ctl.ImportKeys( self.ctl.ImportKeys(args.pan_user, args.path, args.passphrase)
args.pan_user, )
args.path,
args.passphrase
))
elif command == "export-keys": elif command == "export-keys":
self.own_message_ids.append( self.own_message_ids.append(
self.ctl.ExportKeys( self.ctl.ExportKeys(args.pan_user, args.path, args.passphrase)
args.pan_user, )
args.path,
args.passphrase
))
elif command == "send-anyways": elif command == "send-anyways":
self.own_message_ids.append( self.own_message_ids.append(
self.ctl.SendAnyways( self.ctl.SendAnyways(args.pan_user, args.room_id)
args.pan_user, )
args.room_id,
))
elif command == "cancel-sending": elif command == "cancel-sending":
self.own_message_ids.append( self.own_message_ids.append(
self.ctl.CancelSending( self.ctl.CancelSending(args.pan_user, args.room_id)
args.pan_user, )
args.room_id,
))
elif command == "list-devices": elif command == "list-devices":
self.list_devices(args) self.list_devices(args)
elif command == "verify-device": elif command == "verify-device":
self.own_message_ids.append( self.own_message_ids.append(
self.devices.Verify( self.devices.Verify(args.pan_user, args.user_id, args.device_id)
args.pan_user, )
args.user_id,
args.device_id
))
elif command == "unverify-device": elif command == "unverify-device":
self.own_message_ids.append( self.own_message_ids.append(
self.devices.Unverify( self.devices.Unverify(args.pan_user, args.user_id, args.device_id)
args.pan_user, )
args.user_id,
args.device_id
))
elif command == "blacklist-device": elif command == "blacklist-device":
self.own_message_ids.append( self.own_message_ids.append(
self.devices.Blacklist( self.devices.Blacklist(args.pan_user, args.user_id, args.device_id)
args.pan_user, )
args.user_id,
args.device_id
))
elif command == "unblacklist-device": elif command == "unblacklist-device":
self.own_message_ids.append( self.own_message_ids.append(
self.devices.Unblacklist( self.devices.Unblacklist(
args.pan_user, args.pan_user, args.user_id, args.device_id
args.user_id, )
args.device_id )
))
elif command == "start-verification": elif command == "start-verification":
self.own_message_ids.append( self.own_message_ids.append(
self.devices.StartKeyVerification( self.devices.StartKeyVerification(
args.pan_user, args.pan_user, args.user_id, args.device_id
args.user_id, )
args.device_id )
))
elif command == "cancel-verification": elif command == "cancel-verification":
self.own_message_ids.append( self.own_message_ids.append(
self.devices.CancelKeyVerification( self.devices.CancelKeyVerification(
args.pan_user, args.pan_user, args.user_id, args.device_id
args.user_id, )
args.device_id )
))
elif command == "accept-verification": elif command == "accept-verification":
self.own_message_ids.append( self.own_message_ids.append(
self.devices.AcceptKeyVerification( self.devices.AcceptKeyVerification(
args.pan_user, args.pan_user, args.user_id, args.device_id
args.user_id, )
args.device_id )
))
elif command == "confirm-verification": elif command == "confirm-verification":
self.own_message_ids.append( self.own_message_ids.append(
self.devices.ConfirmKeyVerification( self.devices.ConfirmKeyVerification(
args.pan_user, args.pan_user, args.user_id, args.device_id
args.user_id, )
args.device_id )
))
@click.command( @click.command(
help=("panctl is a small interactive repl to introspect and control" help=(
"the pantalaimon daemon.") "panctl is a small interactive repl to introspect and control"
"the pantalaimon daemon."
)
) )
@click.version_option(version="0.1", prog_name="panctl") @click.version_option(version="0.1", prog_name="panctl")
def main(): def main():
@ -670,10 +653,7 @@ def main():
print(f"Error, {e}") print(f"Error, {e}")
sys.exit(-1) sys.exit(-1)
fut = loop.run_in_executor( fut = loop.run_in_executor(None, glib_loop.run)
None,
glib_loop.run
)
try: try:
loop.run_until_complete(panctl.loop()) loop.run_until_complete(panctl.loop())
@ -684,5 +664,5 @@ def main():
loop.run_until_complete(fut) loop.run_until_complete(fut)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View File

@ -18,10 +18,8 @@ from collections import defaultdict
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import attr import attr
from nio.store import (Accounts, DeviceKeys, DeviceTrustState, TrustState, from nio.store import Accounts, DeviceKeys, DeviceTrustState, TrustState, use_database
use_database) from peewee import SQL, DoesNotExist, ForeignKeyField, Model, SqliteDatabase, TextField
from peewee import (SQL, DoesNotExist, ForeignKeyField, Model, SqliteDatabase,
TextField)
@attr.s @attr.s
@ -41,10 +39,7 @@ class DictField(TextField):
class AccessTokens(Model): class AccessTokens(Model):
token = TextField() token = TextField()
account = ForeignKeyField( account = ForeignKeyField(
model=Accounts, model=Accounts, primary_key=True, backref="access_token", on_delete="CASCADE"
primary_key=True,
backref="access_token",
on_delete="CASCADE"
) )
@ -58,10 +53,7 @@ class Servers(Model):
class ServerUsers(Model): class ServerUsers(Model):
user_id = TextField() user_id = TextField()
server = ForeignKeyField( server = ForeignKeyField(
model=Servers, model=Servers, column_name="server_id", backref="users", on_delete="CASCADE"
column_name="server_id",
backref="users",
on_delete="CASCADE"
) )
class Meta: class Meta:
@ -70,9 +62,7 @@ class ServerUsers(Model):
class PanSyncTokens(Model): class PanSyncTokens(Model):
token = TextField() token = TextField()
user = ForeignKeyField( user = ForeignKeyField(model=ServerUsers, column_name="user_id")
model=ServerUsers,
column_name="user_id")
class Meta: class Meta:
constraints = [SQL("UNIQUE(user_id)")] constraints = [SQL("UNIQUE(user_id)")]
@ -80,9 +70,8 @@ class PanSyncTokens(Model):
class PanFetcherTasks(Model): class PanFetcherTasks(Model):
user = ForeignKeyField( user = ForeignKeyField(
model=ServerUsers, model=ServerUsers, column_name="user_id", backref="fetcher_tasks"
column_name="user_id", )
backref="fetcher_tasks")
room_id = TextField() room_id = TextField()
token = TextField() token = TextField()
@ -110,13 +99,12 @@ class PanStore:
DeviceKeys, DeviceKeys,
DeviceTrustState, DeviceTrustState,
PanSyncTokens, PanSyncTokens,
PanFetcherTasks PanFetcherTasks,
] ]
def __attrs_post_init__(self): def __attrs_post_init__(self):
self.database_path = os.path.join( self.database_path = os.path.join(
os.path.abspath(self.store_path), os.path.abspath(self.store_path), self.database_name
self.database_name
) )
self.database = self._create_database() self.database = self._create_database()
@ -127,19 +115,14 @@ class PanStore:
def _create_database(self): def _create_database(self):
return SqliteDatabase( return SqliteDatabase(
self.database_path, self.database_path, pragmas={"foreign_keys": 1, "secure_delete": 1}
pragmas={
"foreign_keys": 1,
"secure_delete": 1,
}
) )
@use_database @use_database
def _get_account(self, user_id, device_id): def _get_account(self, user_id, device_id):
try: try:
return Accounts.get( return Accounts.get(
Accounts.user_id == user_id, Accounts.user_id == user_id, Accounts.device_id == device_id
Accounts.device_id == device_id,
) )
except DoesNotExist: except DoesNotExist:
return None return None
@ -150,9 +133,7 @@ class PanStore:
user = ServerUsers.get(server=server, user_id=pan_user) user = ServerUsers.get(server=server, user_id=pan_user)
PanFetcherTasks.replace( PanFetcherTasks.replace(
user=user, user=user, room_id=task.room_id, token=task.token
room_id=task.room_id,
token=task.token
).execute() ).execute()
def load_fetcher_tasks(self, server, pan_user): def load_fetcher_tasks(self, server, pan_user):
@ -173,7 +154,7 @@ class PanStore:
PanFetcherTasks.delete().where( PanFetcherTasks.delete().where(
PanFetcherTasks.user == user, PanFetcherTasks.user == user,
PanFetcherTasks.room_id == task.room_id, PanFetcherTasks.room_id == task.room_id,
PanFetcherTasks.token == task.token PanFetcherTasks.token == task.token,
).execute() ).execute()
@use_database @use_database
@ -208,18 +189,14 @@ class PanStore:
server, _ = Servers.get_or_create(name=server_name) server, _ = Servers.get_or_create(name=server_name)
ServerUsers.insert( ServerUsers.insert(
user_id=user_id, user_id=user_id, server=server
server=server
).on_conflict_ignore().execute() ).on_conflict_ignore().execute()
@use_database @use_database
def load_all_users(self): def load_all_users(self):
users = [] users = []
query = Accounts.select( query = Accounts.select(Accounts.user_id, Accounts.device_id)
Accounts.user_id,
Accounts.device_id,
)
for account in query: for account in query:
users.append((account.user_id, account.device_id)) users.append((account.user_id, account.device_id))
@ -241,10 +218,9 @@ class PanStore:
for u in server.users: for u in server.users:
server_users.append(u.user_id) server_users.append(u.user_id)
query = Accounts.select( query = Accounts.select(Accounts.user_id, Accounts.device_id).where(
Accounts.user_id, Accounts.user_id.in_(server_users)
Accounts.device_id, )
).where(Accounts.user_id.in_(server_users))
for account in query: for account in query:
users.append((account.user_id, account.device_id)) users.append((account.user_id, account.device_id))
@ -256,10 +232,7 @@ class PanStore:
account = self._get_account(user_id, device_id) account = self._get_account(user_id, device_id)
assert account assert account
AccessTokens.replace( AccessTokens.replace(account=account, token=access_token).execute()
account=account,
token=access_token
).execute()
@use_database @use_database
def load_access_token(self, user_id, device_id): def load_access_token(self, user_id, device_id):
@ -302,7 +275,7 @@ class PanStore:
"ed25519": keys["ed25519"], "ed25519": keys["ed25519"],
"curve25519": keys["curve25519"], "curve25519": keys["curve25519"],
"trust_state": trust_state.name, "trust_state": trust_state.name,
"device_display_name": d.display_name "device_display_name": d.display_name,
} }
store[account.user_id] = device_store store[account.user_id] = device_store

View File

@ -24,20 +24,27 @@ from pydbus.generic import signal
from pantalaimon.log import logger from pantalaimon.log import logger
from pantalaimon.store import PanStore from pantalaimon.store import PanStore
from pantalaimon.thread_messages import (AcceptSasMessage, CancelSasMessage, from pantalaimon.thread_messages import (
AcceptSasMessage,
CancelSasMessage,
CancelSendingMessage, CancelSendingMessage,
ConfirmSasMessage, DaemonResponse, ConfirmSasMessage,
DaemonResponse,
DeviceBlacklistMessage, DeviceBlacklistMessage,
DeviceUnblacklistMessage, DeviceUnblacklistMessage,
DeviceUnverifyMessage, DeviceUnverifyMessage,
DeviceVerifyMessage, DeviceVerifyMessage,
ExportKeysMessage, ImportKeysMessage, ExportKeysMessage,
InviteSasSignal, SasDoneSignal, ImportKeysMessage,
SendAnywaysMessage, ShowSasSignal, InviteSasSignal,
SasDoneSignal,
SendAnywaysMessage,
ShowSasSignal,
StartSasMessage, StartSasMessage,
UnverifiedDevicesSignal, UnverifiedDevicesSignal,
UpdateDevicesMessage, UpdateDevicesMessage,
UpdateUsersMessage) UpdateUsersMessage,
)
class IdCounter: class IdCounter:
@ -126,40 +133,22 @@ class Control:
return self.users return self.users
def ExportKeys(self, pan_user, filepath, passphrase): def ExportKeys(self, pan_user, filepath, passphrase):
message = ExportKeysMessage( message = ExportKeysMessage(self.message_id, pan_user, filepath, passphrase)
self.message_id,
pan_user,
filepath,
passphrase
)
self.queue.put(message) self.queue.put(message)
return message.message_id return message.message_id
def ImportKeys(self, pan_user, filepath, passphrase): def ImportKeys(self, pan_user, filepath, passphrase):
message = ImportKeysMessage( message = ImportKeysMessage(self.message_id, pan_user, filepath, passphrase)
self.message_id,
pan_user,
filepath,
passphrase
)
self.queue.put(message) self.queue.put(message)
return message.message_id return message.message_id
def SendAnyways(self, pan_user, room_id): def SendAnyways(self, pan_user, room_id):
message = SendAnywaysMessage( message = SendAnywaysMessage(self.message_id, pan_user, room_id)
self.message_id,
pan_user,
room_id
)
self.queue.put(message) self.queue.put(message)
return message.message_id return message.message_id
def CancelSending(self, pan_user, room_id): def CancelSending(self, pan_user, room_id):
message = CancelSendingMessage( message = CancelSendingMessage(self.message_id, pan_user, room_id)
self.message_id,
pan_user,
room_id
)
self.queue.put(message) self.queue.put(message)
return message.message_id return message.message_id
@ -293,8 +282,9 @@ class Devices:
return [] return []
device_list = [ device_list = [
device for device_list in device_store.values() for device in device
device_list.values() for device_list in device_store.values()
for device in device_list.values()
] ]
return device_list return device_list
@ -313,82 +303,44 @@ class Devices:
return device_list.values() return device_list.values()
def Verify(self, pan_user, user_id, device_id): def Verify(self, pan_user, user_id, device_id):
message = DeviceVerifyMessage( message = DeviceVerifyMessage(self.message_id, pan_user, user_id, device_id)
self.message_id,
pan_user,
user_id,
device_id
)
self.queue.put(message) self.queue.put(message)
return message.message_id return message.message_id
def Unverify(self, pan_user, user_id, device_id): def Unverify(self, pan_user, user_id, device_id):
message = DeviceUnverifyMessage( message = DeviceUnverifyMessage(self.message_id, pan_user, user_id, device_id)
self.message_id,
pan_user,
user_id,
device_id
)
self.queue.put(message) self.queue.put(message)
return message.message_id return message.message_id
def Blacklist(self, pan_user, user_id, device_id): def Blacklist(self, pan_user, user_id, device_id):
message = DeviceBlacklistMessage( message = DeviceBlacklistMessage(self.message_id, pan_user, user_id, device_id)
self.message_id,
pan_user,
user_id,
device_id
)
self.queue.put(message) self.queue.put(message)
return message.message_id return message.message_id
def Unblacklist(self, pan_user, user_id, device_id): def Unblacklist(self, pan_user, user_id, device_id):
message = DeviceUnblacklistMessage( message = DeviceUnblacklistMessage(
self.message_id, self.message_id, pan_user, user_id, device_id
pan_user,
user_id,
device_id
) )
self.queue.put(message) self.queue.put(message)
return message.message_id return message.message_id
def StartKeyVerification(self, pan_user, user_id, device_id): def StartKeyVerification(self, pan_user, user_id, device_id):
message = StartSasMessage( message = StartSasMessage(self.message_id, pan_user, user_id, device_id)
self.message_id,
pan_user,
user_id,
device_id
)
self.queue.put(message) self.queue.put(message)
return message.message_id return message.message_id
def CancelKeyVerification(self, pan_user, user_id, device_id): def CancelKeyVerification(self, pan_user, user_id, device_id):
message = CancelSasMessage( message = CancelSasMessage(self.message_id, pan_user, user_id, device_id)
self.message_id,
pan_user,
user_id,
device_id
)
self.queue.put(message) self.queue.put(message)
return message.message_id return message.message_id
def ConfirmKeyVerification(self, pan_user, user_id, device_id): def ConfirmKeyVerification(self, pan_user, user_id, device_id):
message = ConfirmSasMessage( message = ConfirmSasMessage(self.message_id, pan_user, user_id, device_id)
self.message_id,
pan_user,
user_id,
device_id
)
self.queue.put(message) self.queue.put(message)
return message.message_id return message.message_id
def AcceptKeyVerification(self, pan_user, user_id, device_id): def AcceptKeyVerification(self, pan_user, user_id, device_id):
message = AcceptSasMessage( message = AcceptSasMessage(self.message_id, pan_user, user_id, device_id)
self.message_id,
pan_user,
user_id,
device_id
)
self.queue.put(message) self.queue.put(message)
return message.message_id return message.message_id
@ -421,8 +373,9 @@ class GlibT:
id_counter = IdCounter() id_counter = IdCounter()
self.control_if = Control(self.send_queue, self.store, self.control_if = Control(
self.server_list, id_counter) self.send_queue, self.store, self.server_list, id_counter
)
self.device_if = Devices(self.send_queue, self.store, id_counter) self.device_if = Devices(self.send_queue, self.store, id_counter)
self.bus = SessionBus() self.bus = SessionBus()
@ -431,131 +384,95 @@ class GlibT:
def unverified_notification(self, message): def unverified_notification(self, message):
notificaton = notify2.Notification( notificaton = notify2.Notification(
"Unverified devices.", "Unverified devices.",
message=(f"There are unverified devices in the room " message=(
f"{message.room_display_name}.") f"There are unverified devices in the room "
f"{message.room_display_name}."
),
) )
notificaton.set_category("im") notificaton.set_category("im")
def send_cb(notification, action_key, user_data): def send_cb(notification, action_key, user_data):
message = user_data message = user_data
self.control_if.SendAnyways( self.control_if.SendAnyways(message.pan_user, message.room_id)
message.pan_user,
message.room_id
)
def cancel_cb(notification, action_key, user_data): def cancel_cb(notification, action_key, user_data):
message = user_data message = user_data
self.control_if.CancelSending( self.control_if.CancelSending(message.pan_user, message.room_id)
message.pan_user,
message.room_id
)
if "actions" in notify2.get_server_caps(): if "actions" in notify2.get_server_caps():
notificaton.add_action( notificaton.add_action("send", "Send anyways", send_cb, message)
"send", notificaton.add_action("cancel", "Cancel sending", cancel_cb, message)
"Send anyways",
send_cb,
message
)
notificaton.add_action(
"cancel",
"Cancel sending",
cancel_cb,
message
)
notificaton.show() notificaton.show()
def sas_invite_notification(self, message): def sas_invite_notification(self, message):
notificaton = notify2.Notification( notificaton = notify2.Notification(
"Key verification invite", "Key verification invite",
message=(f"{message.user_id} via {message.device_id} has started " message=(
f"a key verification process.") f"{message.user_id} via {message.device_id} has started "
f"a key verification process."
),
) )
notificaton.set_category("im") notificaton.set_category("im")
def accept_cb(notification, action_key, user_data): def accept_cb(notification, action_key, user_data):
message = user_data message = user_data
self.device_if.AcceptKeyVerification( self.device_if.AcceptKeyVerification(
message.pan_user, message.pan_user, message.user_id, message.device_id
message.user_id,
message.device_id
) )
def cancel_cb(notification, action_key, user_data): def cancel_cb(notification, action_key, user_data):
message = user_data message = user_data
self.device_if.CancelKeyVerification( self.device_if.CancelKeyVerification(
message.pan_user, message.pan_user, message.user_id, message.device_id
message.user_id,
message.device_id,
) )
if "actions" in notify2.get_server_caps(): if "actions" in notify2.get_server_caps():
notificaton.add_action( notificaton.add_action("accept", "Accept", accept_cb, message)
"accept", notificaton.add_action("cancel", "Cancel", cancel_cb, message)
"Accept",
accept_cb,
message
)
notificaton.add_action(
"cancel",
"Cancel",
cancel_cb,
message
)
notificaton.show() notificaton.show()
def sas_show_notification(self, message): def sas_show_notification(self, message):
emojis = [x[0] for x in message.emoji] emojis = [x[0] for x in message.emoji]
emoji_str = u" ".join(emojis) emoji_str = " ".join(emojis)
notificaton = notify2.Notification( notificaton = notify2.Notification(
"Short authentication string", "Short authentication string",
message=(f"Short authentication string for the key verification of" message=(
f"Short authentication string for the key verification of"
f" {message.user_id} via {message.device_id}:\n" f" {message.user_id} via {message.device_id}:\n"
f"{emoji_str}") f"{emoji_str}"
),
) )
notificaton.set_category("im") notificaton.set_category("im")
def confirm_cb(notification, action_key, user_data): def confirm_cb(notification, action_key, user_data):
message = user_data message = user_data
self.device_if.ConfirmKeyVerification( self.device_if.ConfirmKeyVerification(
message.pan_user, message.pan_user, message.user_id, message.device_id
message.user_id,
message.device_id
) )
def cancel_cb(notification, action_key, user_data): def cancel_cb(notification, action_key, user_data):
message = user_data message = user_data
self.device_if.CancelKeyVerification( self.device_if.CancelKeyVerification(
message.pan_user, message.pan_user, message.user_id, message.device_id
message.user_id,
message.device_id,
) )
if "actions" in notify2.get_server_caps(): if "actions" in notify2.get_server_caps():
notificaton.add_action( notificaton.add_action("confirm", "Confirm", confirm_cb, message)
"confirm", notificaton.add_action("cancel", "Cancel", cancel_cb, message)
"Confirm",
confirm_cb,
message
)
notificaton.add_action(
"cancel",
"Cancel",
cancel_cb,
message
)
notificaton.show() notificaton.show()
def sas_done_notification(self, message): def sas_done_notification(self, message):
notificaton = notify2.Notification( notificaton = notify2.Notification(
"Device successfully verified.", "Device successfully verified.",
message=(f"Device {message.device_id} of user {message.user_id} " message=(
f"successfully verified.") f"Device {message.device_id} of user {message.user_id} "
f"successfully verified."
),
) )
notificaton.set_category("im") notificaton.set_category("im")
notificaton.show() notificaton.show()
@ -576,9 +493,7 @@ class GlibT:
elif isinstance(message, UnverifiedDevicesSignal): elif isinstance(message, UnverifiedDevicesSignal):
self.control_if.UnverifiedDevices( self.control_if.UnverifiedDevices(
message.pan_user, message.pan_user, message.room_id, message.room_display_name
message.room_id,
message.room_display_name
) )
if self.notifications: if self.notifications:
@ -589,7 +504,7 @@ class GlibT:
message.pan_user, message.pan_user,
message.user_id, message.user_id,
message.device_id, message.device_id,
message.transaction_id message.transaction_id,
) )
if self.notifications: if self.notifications:
@ -622,10 +537,7 @@ class GlibT:
self.control_if.Response( self.control_if.Response(
message.message_id, message.message_id,
message.pan_user, message.pan_user,
{ {"code": message.code, "message": message.message},
"code": message.code,
"message": message.message
}
) )
self.receive_queue.task_done() self.receive_queue.task_done()
@ -639,8 +551,10 @@ class GlibT:
notify2.init("pantalaimon", mainloop=self.loop) notify2.init("pantalaimon", mainloop=self.loop)
self.notifications = True self.notifications = True
except dbus.DBusException: except dbus.DBusException:
logger.error("Notifications are enabled but no notification " logger.error(
"server could be found, disabling notifications.") "Notifications are enabled but no notification "
"server could be found, disabling notifications."
)
self.notifications = False self.notifications = False
GLib.timeout_add(100, self.message_callback) GLib.timeout_add(100, self.message_callback)