pantalaimon: Format the source tree using black.

This commit is contained in:
Damir Jelić 2019-06-19 12:37:44 +02:00
parent 531d686d8f
commit c9ebfd71ec
11 changed files with 755 additions and 969 deletions

4
.flake8 Normal file
View File

@ -0,0 +1,4 @@
[flake8]
max-line-length = 80
select = C,E,F,W,B,B950
ignore = E501,W503

6
.isort.cfg Normal file
View File

@ -0,0 +1,6 @@
[settings]
multi_line_output=3
include_trailing_comma=True
force_grid_wrap=0
use_parentheses=True
line_length=88

View File

@ -1,5 +1,6 @@
test:
python3 -m pytest
python3 -m pytest --black pantalaimon
python3 -m pytest --flake8 pantalaimon
python3 -m pytest --isort
@ -14,3 +15,6 @@ run-local:
isort:
isort -y -p pantalaimon
format:
black pantalaimon/

View File

@ -20,21 +20,39 @@ from typing import Any, Dict, Optional
from aiohttp.client_exceptions import ClientConnectionError
from jsonschema import Draft4Validator, FormatChecker, validators
from nio import (AsyncClient, ClientConfig, EncryptionError, KeysQueryResponse,
KeyVerificationEvent, KeyVerificationKey, KeyVerificationMac,
KeyVerificationStart, LocalProtocolError, MegolmEvent,
RoomContextError, RoomEncryptedEvent, RoomEncryptedMedia,
RoomMessageMedia, RoomMessageText, RoomNameEvent,
RoomTopicEvent, SyncResponse)
from nio import (
AsyncClient,
ClientConfig,
EncryptionError,
KeysQueryResponse,
KeyVerificationEvent,
KeyVerificationKey,
KeyVerificationMac,
KeyVerificationStart,
LocalProtocolError,
MegolmEvent,
RoomContextError,
RoomEncryptedEvent,
RoomEncryptedMedia,
RoomMessageMedia,
RoomMessageText,
RoomNameEvent,
RoomTopicEvent,
SyncResponse,
)
from nio.crypto import Sas
from nio.store import SqliteStore
from pantalaimon.index import IndexStore
from pantalaimon.log import logger
from pantalaimon.store import FetchTask
from pantalaimon.thread_messages import (DaemonResponse, InviteSasSignal,
SasDoneSignal, ShowSasSignal,
UpdateDevicesMessage)
from pantalaimon.thread_messages import (
DaemonResponse,
InviteSasSignal,
SasDoneSignal,
ShowSasSignal,
UpdateDevicesMessage,
)
SEARCH_KEYS = ["content.body", "content.name", "content.topic"]
@ -51,7 +69,7 @@ SEARCH_TERMS_SCHEMA = {
"keys": {
"type": "array",
"items": {"type": "string", "enum": SEARCH_KEYS},
"default": SEARCH_KEYS
"default": SEARCH_KEYS,
},
"order_by": {"type": "string", "default": "rank"},
"include_state": {"type": "boolean", "default": False},
@ -59,15 +77,13 @@ SEARCH_TERMS_SCHEMA = {
"event_context": {"type": "object"},
"groupings": {"type": "object", "default": {}},
},
"required": ["search_term"]
},
}
"required": ["search_term"],
}
},
},
"required": ["room_events"]
"required": ["room_events"],
},
"required": [
"search_categories",
],
"required": ["search_categories"],
}
@ -79,9 +95,7 @@ def extend_with_default(validator_class):
if "default" in subschema:
instance.setdefault(prop, subschema["default"])
for error in validate_properties(
validator, properties, instance, schema
):
for error in validate_properties(validator, properties, instance, schema):
yield error
return validators.extend(validator_class, {"properties": set_defaults})
@ -111,22 +125,21 @@ class PanClient(AsyncClient):
"""A wrapper class around a nio AsyncClient extending its functionality."""
def __init__(
self,
server_name,
pan_store,
pan_conf,
homeserver,
queue=None,
user_id="",
device_id="",
store_path="",
config=None,
ssl=None,
proxy=None
self,
server_name,
pan_store,
pan_conf,
homeserver,
queue=None,
user_id="",
device_id="",
store_path="",
config=None,
ssl=None,
proxy=None,
):
config = config or ClientConfig(store=SqliteStore, store_name="pan.db")
super().__init__(homeserver, user_id, device_id, store_path, config,
ssl, proxy)
super().__init__(homeserver, user_id, device_id, store_path, config, ssl, proxy)
index_dir = os.path.join(store_path, server_name, user_id)
@ -150,31 +163,24 @@ class PanClient(AsyncClient):
self.history_fetcher_task = None
self.history_fetch_queue = asyncio.Queue()
self.add_to_device_callback(
self.key_verification_cb,
KeyVerificationEvent
)
self.add_event_callback(
self.undecrypted_event_cb,
MegolmEvent
)
self.add_to_device_callback(self.key_verification_cb, KeyVerificationEvent)
self.add_event_callback(self.undecrypted_event_cb, MegolmEvent)
self.add_event_callback(
self.store_message_cb,
(RoomMessageText, RoomMessageMedia, RoomEncryptedMedia,
RoomTopicEvent, RoomNameEvent)
(
RoomMessageText,
RoomMessageMedia,
RoomEncryptedMedia,
RoomTopicEvent,
RoomNameEvent,
),
)
self.key_verificatins_tasks = []
self.key_request_tasks = []
self.add_response_callback(
self.keys_query_cb,
KeysQueryResponse
)
self.add_response_callback(self.keys_query_cb, KeysQueryResponse)
self.add_response_callback(
self.sync_tasks,
SyncResponse
)
self.add_response_callback(self.sync_tasks, SyncResponse)
def store_message_cb(self, room, event):
display_name = room.user_name(event.sender)
@ -192,9 +198,11 @@ class PanClient(AsyncClient):
"type": "m.room.message",
"content": {
"msgtype": "m.text",
"body": ("** Unable to decrypt: The sender's device has not "
"sent us the keys for this message. **")
}
"body": (
"** Unable to decrypt: The sender's device has not "
"sent us the keys for this message. **"
),
},
}
async def send_message(self, message):
@ -206,17 +214,10 @@ class PanClient(AsyncClient):
await self.queue.put(message)
def delete_fetcher_task(self, task):
self.pan_store.delete_fetcher_task(
self.server_name,
self.user_id,
task
)
self.pan_store.delete_fetcher_task(self.server_name, self.user_id, task)
async def fetcher_loop(self):
for t in self.pan_store.load_fetcher_tasks(
self.server_name,
self.user_id
):
for t in self.pan_store.load_fetcher_tasks(self.server_name, self.user_id):
await self.history_fetch_queue.put(t)
while True:
@ -233,13 +234,13 @@ class PanClient(AsyncClient):
continue
try:
logger.debug("Fetching room history for {}".format(
room.display_name
))
logger.debug(
"Fetching room history for {}".format(room.display_name)
)
response = await self.room_messages(
fetch_task.room_id,
fetch_task.token,
limit=self.pan_conf.indexing_batch_size
limit=self.pan_conf.indexing_batch_size,
)
except ClientConnectionError:
self.history_fetch_queue.put(fetch_task)
@ -250,31 +251,31 @@ class PanClient(AsyncClient):
continue
for event in response.chunk:
if not isinstance(event, (
if not isinstance(
event,
(
RoomMessageText,
RoomMessageMedia,
RoomEncryptedMedia,
RoomTopicEvent,
RoomNameEvent
)):
RoomNameEvent,
),
):
continue
display_name = room.user_name(event.sender)
avatar_url = room.avatar_url(event.sender)
self.index.add_event(event, room.room_id, display_name,
avatar_url)
self.index.add_event(event, room.room_id, display_name, avatar_url)
last_event = response.chunk[-1]
if not self.index.event_in_store(
last_event.event_id,
room.room_id
):
if not self.index.event_in_store(last_event.event_id, room.room_id):
# There may be even more events to fetch, add a new task to
# the queue.
task = FetchTask(room.room_id, response.end)
self.pan_store.save_fetcher_task(self.server_name,
self.user_id, task)
self.pan_store.save_fetcher_task(
self.server_name, self.user_id, task
)
await self.history_fetch_queue.put(task)
await self.index.commit_events()
@ -294,11 +295,7 @@ class PanClient(AsyncClient):
self.key_request_tasks = []
await self.index.commit_events()
self.pan_store.save_token(
self.server_name,
self.user_id,
self.next_batch
)
self.pan_store.save_token(self.server_name, self.user_id, self.next_batch)
for room_id, room_info in response.rooms.join.items():
if room_info.timeline.limited:
@ -307,13 +304,12 @@ class PanClient(AsyncClient):
if not room.encrypted and self.pan_conf.index_encrypted_only:
continue
logger.info("Room {} had a limited timeline, queueing "
"room for history fetching.".format(
room.display_name
))
logger.info(
"Room {} had a limited timeline, queueing "
"room for history fetching.".format(room.display_name)
)
task = FetchTask(room_id, room_info.timeline.prev_batch)
self.pan_store.save_fetcher_task(self.server_name,
self.user_id, task)
self.pan_store.save_fetcher_task(self.server_name, self.user_id, task)
await self.history_fetch_queue.put(task)
@ -323,10 +319,11 @@ class PanClient(AsyncClient):
def undecrypted_event_cb(self, room, event):
loop = asyncio.get_event_loop()
logger.info("Unable to decrypt event from {} via {}.".format(
event.sender,
event.device_id
))
logger.info(
"Unable to decrypt event from {} via {}.".format(
event.sender, event.device_id
)
)
if event.session_id not in self.outgoing_key_requests:
logger.info("Requesting room key for undecrypted event.")
@ -338,19 +335,16 @@ class PanClient(AsyncClient):
loop = asyncio.get_event_loop()
if isinstance(event, KeyVerificationStart):
logger.info(f"{event.sender} via {event.from_device} has started "
f"a key verification process.")
logger.info(
f"{event.sender} via {event.from_device} has started "
f"a key verification process."
)
message = InviteSasSignal(
self.user_id,
event.sender,
event.from_device,
event.transaction_id
self.user_id, event.sender, event.from_device, event.transaction_id
)
task = loop.create_task(
self.queue.put(message)
)
task = loop.create_task(self.queue.put(message))
self.key_verificatins_tasks.append(task)
elif isinstance(event, KeyVerificationKey):
@ -362,16 +356,10 @@ class PanClient(AsyncClient):
emoji = sas.get_emoji()
message = ShowSasSignal(
self.user_id,
device.user_id,
device.id,
sas.transaction_id,
emoji
self.user_id, device.user_id, device.id, sas.transaction_id, emoji
)
task = loop.create_task(
self.queue.put(message)
)
task = loop.create_task(self.queue.put(message))
self.key_verificatins_tasks.append(task)
elif isinstance(event, KeyVerificationMac):
@ -381,14 +369,13 @@ class PanClient(AsyncClient):
device = sas.other_olm_device
if sas.verified:
task = loop.create_task(self.send_message(
SasDoneSignal(
self.user_id,
device.user_id,
device.id,
sas.transaction_id
task = loop.create_task(
self.send_message(
SasDoneSignal(
self.user_id, device.user_id, device.id, sas.transaction_id
)
)
))
)
self.key_verificatins_tasks.append(task)
task = loop.create_task(self.send_update_devcies())
self.key_verificatins_tasks.append(task)
@ -407,22 +394,13 @@ class PanClient(AsyncClient):
self.history_fetcher_task = loop.create_task(self.fetcher_loop())
timeout = 30000
sync_filter = {
"room": {
"state": {"lazy_load_members": True}
}
}
sync_filter = {"room": {"state": {"lazy_load_members": True}}}
next_batch = self.pan_store.load_token(self.server_name, self.user_id)
# We don't store any room state so initial sync needs to be with the
# full_state parameter. Subsequent ones are normal.
task = loop.create_task(
self.sync_forever(
timeout,
sync_filter,
full_state=True,
since=next_batch
)
self.sync_forever(timeout, sync_filter, full_state=True, since=next_batch)
)
self.task = task
@ -436,16 +414,15 @@ class PanClient(AsyncClient):
message.message_id,
self.user_id,
"m.ok",
"Successfully started the key verification request"
))
"Successfully started the key verification request",
)
)
except ClientConnectionError as e:
await self.send_message(
DaemonResponse(
message.message_id,
self.user_id,
"m.connection_error",
str(e)
))
message.message_id, self.user_id, "m.connection_error", str(e)
)
)
async def accept_sas(self, message):
user_id = message.user_id
@ -459,9 +436,8 @@ class PanClient(AsyncClient):
message.message_id,
self.user_id,
Sas._txid_error[0],
Sas._txid_error[1]
Sas._txid_error[1],
)
)
return
@ -472,24 +448,24 @@ class PanClient(AsyncClient):
message.message_id,
self.user_id,
"m.ok",
"Successfully accepted the key verification request"
))
"Successfully accepted the key verification request",
)
)
except LocalProtocolError as e:
await self.send_message(
DaemonResponse(
message.message_id,
self.user_id,
Sas._unexpected_message_error[0],
str(e)
))
str(e),
)
)
except ClientConnectionError as e:
await self.send_message(
DaemonResponse(
message.message_id,
self.user_id,
"m.connection_error",
str(e)
))
message.message_id, self.user_id, "m.connection_error", str(e)
)
)
async def cancel_sas(self, message):
user_id = message.user_id
@ -503,9 +479,8 @@ class PanClient(AsyncClient):
message.message_id,
self.user_id,
Sas._txid_error[0],
Sas._txid_error[1]
Sas._txid_error[1],
)
)
return
@ -516,16 +491,15 @@ class PanClient(AsyncClient):
message.message_id,
self.user_id,
"m.ok",
"Successfully canceled the key verification request"
))
"Successfully canceled the key verification request",
)
)
except ClientConnectionError as e:
await self.send_message(
DaemonResponse(
message.message_id,
self.user_id,
"m.connection_error",
str(e)
))
message.message_id, self.user_id, "m.connection_error", str(e)
)
)
async def confirm_sas(self, message):
user_id = message.user_id
@ -539,9 +513,8 @@ class PanClient(AsyncClient):
message.message_id,
self.user_id,
Sas._txid_error[0],
Sas._txid_error[1]
Sas._txid_error[1],
)
)
return
@ -550,11 +523,9 @@ class PanClient(AsyncClient):
except ClientConnectionError as e:
await self.send_message(
DaemonResponse(
message.message_id,
self.user_id,
"m.connection_error",
str(e)
))
message.message_id, self.user_id, "m.connection_error", str(e)
)
)
return
@ -564,10 +535,7 @@ class PanClient(AsyncClient):
await self.send_update_devcies()
await self.send_message(
SasDoneSignal(
self.user_id,
device.user_id,
device.id,
sas.transaction_id
self.user_id, device.user_id, device.id, sas.transaction_id
)
)
else:
@ -576,8 +544,9 @@ class PanClient(AsyncClient):
message.message_id,
self.user_id,
"m.ok",
f"Waiting for {device.user_id} to confirm."
))
f"Waiting for {device.user_id} to confirm.",
)
)
async def loop_stop(self):
"""Stop the client loop."""
@ -605,18 +574,15 @@ class PanClient(AsyncClient):
self.history_fetch_queue = asyncio.Queue()
def pan_decrypt_event(
self,
event_dict,
room_id=None,
ignore_failures=True
):
def pan_decrypt_event(self, event_dict, room_id=None, ignore_failures=True):
# type: (Dict[Any, Any], Optional[str], bool) -> (bool)
event = RoomEncryptedEvent.parse_event(event_dict)
if not isinstance(event, MegolmEvent):
logger.warn("Encrypted event is not a megolm event:"
"\n{}".format(pformat(event_dict)))
logger.warn(
"Encrypted event is not a megolm event:"
"\n{}".format(pformat(event_dict))
)
return False
if not event.room_id:
@ -661,8 +627,7 @@ class PanClient(AsyncClient):
continue
if event["type"] != "m.room.encrypted":
logger.debug("Event is not encrypted: "
"\n{}".format(pformat(event)))
logger.debug("Event is not encrypted: " "\n{}".format(pformat(event)))
continue
self.pan_decrypt_event(event, ignore_failures=ignore_failures)
@ -682,9 +647,11 @@ class PanClient(AsyncClient):
for room_id, room_dict in body["rooms"]["join"].items():
try:
if not self.rooms[room_id].encrypted:
logger.info("Room {} is not encrypted skipping...".format(
self.rooms[room_id].display_name
))
logger.info(
"Room {} is not encrypted skipping...".format(
self.rooms[room_id].display_name
)
)
continue
except KeyError:
logger.info("Unknown room {} skipping...".format(room_id))
@ -752,8 +719,9 @@ class PanClient(AsyncClient):
after_limit = event_context.get("before_limit", 5)
if before_limit < 0 or after_limit < 0:
raise InvalidLimit("Invalid context limit, the limit must be a "
"positive number")
raise InvalidLimit(
"Invalid context limit, the limit must be a " "positive number"
)
response_dict = await self.index.search(
term,
@ -762,7 +730,7 @@ class PanClient(AsyncClient):
order_by_recent=order_by_recent,
include_profile=include_profile,
before_limit=before_limit,
after_limit=after_limit
after_limit=after_limit,
)
if (event_context or include_state) and self.pan_conf.search_requests:
@ -771,14 +739,10 @@ class PanClient(AsyncClient):
event_dict,
event_dict["result"]["room_id"],
event_dict["result"]["event_id"],
include_state
include_state,
)
if include_state:
response_dict["state"] = state_cache
return {
"search_categories": {
"room_events": response_dict
}
}
return {"search_categories": {"room_events": response_dict}}

View File

@ -43,7 +43,7 @@ class PanConfigParser(configparser.ConfigParser):
"address": parse_address,
"url": parse_url,
"loglevel": parse_log_level,
}
},
)
@ -59,9 +59,10 @@ def parse_url(v):
# type: (str) -> ParseResult
value = urlparse(v)
if value.scheme not in ('http', 'https'):
raise ValueError(f"Invalid URL scheme {value.scheme}. "
f"Only HTTP(s) URLs are allowed")
if value.scheme not in ("http", "https"):
raise ValueError(
f"Invalid URL scheme {value.scheme}. " f"Only HTTP(s) URLs are allowed"
)
value.port
return value
@ -124,8 +125,7 @@ class ServerConfig:
name = attr.ib(type=str)
homeserver = attr.ib(type=ParseResult)
listen_address = attr.ib(
type=Union[IPv4Address, IPv6Address],
default=ip_address("127.0.0.1")
type=Union[IPv4Address, IPv6Address], default=ip_address("127.0.0.1")
)
listen_port = attr.ib(type=int, default=8009)
proxy = attr.ib(type=str, default="")
@ -183,8 +183,9 @@ class PanConfig:
homeserver = section.geturl("Homeserver")
if not homeserver:
raise PanConfigError(f"Homserver is not set for "
f"section {section_name}")
raise PanConfigError(
f"Homserver is not set for " f"section {section_name}"
)
listen_address = section.getaddress("ListenAddress")
listen_port = section.getint("ListenPort")
@ -198,23 +199,29 @@ class PanConfig:
indexing_batch_size = section.getint("IndexingBatchSize")
if not 1 < indexing_batch_size <= 1000:
raise PanConfigError("The indexing batch size needs to be "
"a positive integer between 1 and "
"1000")
raise PanConfigError(
"The indexing batch size needs to be "
"a positive integer between 1 and "
"1000"
)
history_fetch_delay = section.getint("HistoryFetchDelay")
if not 100 < history_fetch_delay <= 10000:
raise PanConfigError("The history fetch delay needs to be "
"a positive integer between 100 and "
"10000")
raise PanConfigError(
"The history fetch delay needs to be "
"a positive integer between 100 and "
"10000"
)
listen_tuple = (listen_address, listen_port)
if listen_tuple in listen_set:
raise PanConfigError(f"The listen address/port combination"
f" for section {section_name} was "
f"already defined before.")
raise PanConfigError(
f"The listen address/port combination"
f" for section {section_name} was "
f"already defined before."
)
listen_set.add(listen_tuple)
server_conf = ServerConfig(
@ -229,7 +236,7 @@ class PanConfig:
search_requests,
index_encrypted_only,
indexing_batch_size,
history_fetch_delay / 1000
history_fetch_delay / 1000,
)
self.servers[section_name] = server_conf

View File

@ -25,36 +25,46 @@ from aiohttp import ClientSession, web
from aiohttp.client_exceptions import ClientConnectionError, ContentTypeError
from jsonschema import ValidationError
from multidict import CIMultiDict
from nio import (Api, EncryptionError, LoginResponse, OlmTrustError,
SendRetryError)
from nio import Api, EncryptionError, LoginResponse, OlmTrustError, SendRetryError
from pantalaimon.client import (SEARCH_TERMS_SCHEMA, InvalidLimit,
InvalidOrderByError, PanClient,
UnknownRoomError, validate_json)
from pantalaimon.client import (
SEARCH_TERMS_SCHEMA,
InvalidLimit,
InvalidOrderByError,
PanClient,
UnknownRoomError,
validate_json,
)
from pantalaimon.index import InvalidQueryError
from pantalaimon.log import logger
from pantalaimon.store import ClientInfo, PanStore
from pantalaimon.thread_messages import (AcceptSasMessage, CancelSasMessage,
CancelSendingMessage,
ConfirmSasMessage, DaemonResponse,
DeviceBlacklistMessage,
DeviceUnblacklistMessage,
DeviceUnverifyMessage,
DeviceVerifyMessage,
ExportKeysMessage, ImportKeysMessage,
SasMessage, SendAnywaysMessage,
StartSasMessage,
UnverifiedDevicesSignal,
UnverifiedResponse,
UpdateDevicesMessage,
UpdateUsersMessage)
from pantalaimon.thread_messages import (
AcceptSasMessage,
CancelSasMessage,
CancelSendingMessage,
ConfirmSasMessage,
DaemonResponse,
DeviceBlacklistMessage,
DeviceUnblacklistMessage,
DeviceUnverifyMessage,
DeviceVerifyMessage,
ExportKeysMessage,
ImportKeysMessage,
SasMessage,
SendAnywaysMessage,
StartSasMessage,
UnverifiedDevicesSignal,
UnverifiedResponse,
UpdateDevicesMessage,
UpdateUsersMessage,
)
CORS_HEADERS = {
"Access-Control-Allow-Headers": (
"Origin, X-Requested-With, Content-Type, Accept, Authorization"
),
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": (
"Origin, X-Requested-With, Content-Type, Accept, Authorization"
),
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
"Access-Control-Allow-Origin": "*",
}
@ -76,11 +86,7 @@ class ProxyDaemon:
homeserver_url = attr.ib(init=False, default=attr.Factory(dict))
hostname = attr.ib(init=False, default=attr.Factory(dict))
pan_clients = attr.ib(init=False, default=attr.Factory(dict))
client_info = attr.ib(
init=False,
default=attr.Factory(dict),
type=dict
)
client_info = attr.ib(init=False, default=attr.Factory(dict), type=dict)
default_session = attr.ib(init=False, default=None)
database_name = "pan.db"
@ -93,15 +99,16 @@ class ProxyDaemon:
for user_id, device_id in accounts:
if self.conf.keyring:
token = keyring.get_password(
"pantalaimon",
f"{user_id}-{device_id}-token"
"pantalaimon", f"{user_id}-{device_id}-token"
)
else:
token = self.store.load_access_token(user_id, device_id)
if not token:
logger.warn(f"Not restoring client for {user_id} {device_id}, "
f"missing access token.")
logger.warn(
f"Not restoring client for {user_id} {device_id}, "
f"missing access token."
)
continue
logger.info(f"Restoring client for {user_id} {device_id}")
@ -116,7 +123,7 @@ class ProxyDaemon:
device_id,
store_path=self.data_dir,
ssl=self.ssl,
proxy=self.proxy
proxy=self.proxy,
)
pan_client.user_id = user_id
pan_client.access_token = token
@ -136,7 +143,7 @@ class ProxyDaemon:
method,
self.homeserver_url + path,
proxy=self.proxy,
ssl=self.ssl
ssl=self.ssl,
)
except ClientConnectionError:
return None
@ -155,12 +162,15 @@ class ProxyDaemon:
return None
if user_id not in self.pan_clients:
logger.warn(f"User {user_id} doesn't have a matching pan "
f"client.")
logger.warn(
f"User {user_id} doesn't have a matching pan " f"client."
)
return None
logger.info(f"Homeserver confirmed valid access token "
f"for user {user_id}, caching info.")
logger.info(
f"Homeserver confirmed valid access token "
f"for user {user_id}, caching info."
)
client_info = ClientInfo(user_id, access_token)
self.client_info[access_token] = client_info
@ -173,12 +183,12 @@ class ProxyDaemon:
ret = client.verify_device(device)
if ret:
msg = (f"Device {device.id} of user "
f"{device.user_id} succesfully verified.")
msg = (
f"Device {device.id} of user " f"{device.user_id} succesfully verified."
)
await self.send_update_devcies()
else:
msg = (f"Device {device.id} of user "
f"{device.user_id} already verified.")
msg = f"Device {device.id} of user " f"{device.user_id} already verified."
logger.info(msg)
await self.send_response(message_id, client.user_id, "m.ok", msg)
@ -187,12 +197,13 @@ class ProxyDaemon:
ret = client.unverify_device(device)
if ret:
msg = (f"Device {device.id} of user "
f"{device.user_id} succesfully unverified.")
msg = (
f"Device {device.id} of user "
f"{device.user_id} succesfully unverified."
)
await self.send_update_devcies()
else:
msg = (f"Device {device.id} of user "
f"{device.user_id} already unverified.")
msg = f"Device {device.id} of user " f"{device.user_id} already unverified."
logger.info(msg)
await self.send_response(message_id, client.user_id, "m.ok", msg)
@ -201,12 +212,15 @@ class ProxyDaemon:
ret = client.blacklist_device(device)
if ret:
msg = (f"Device {device.id} of user "
f"{device.user_id} succesfully blacklisted.")
msg = (
f"Device {device.id} of user "
f"{device.user_id} succesfully blacklisted."
)
await self.send_update_devcies()
else:
msg = (f"Device {device.id} of user "
f"{device.user_id} already blacklisted.")
msg = (
f"Device {device.id} of user " f"{device.user_id} already blacklisted."
)
logger.info(msg)
await self.send_response(message_id, client.user_id, "m.ok", msg)
@ -215,12 +229,16 @@ class ProxyDaemon:
ret = client.unblacklist_device(device)
if ret:
msg = (f"Device {device.id} of user "
f"{device.user_id} succesfully unblacklisted.")
msg = (
f"Device {device.id} of user "
f"{device.user_id} succesfully unblacklisted."
)
await self.send_update_devcies()
else:
msg = (f"Device {device.id} of user "
f"{device.user_id} already unblacklisted.")
msg = (
f"Device {device.id} of user "
f"{device.user_id} already unblacklisted."
)
logger.info(msg)
await self.send_response(message_id, client.user_id, "m.ok", msg)
@ -239,23 +257,23 @@ class ProxyDaemon:
if isinstance(
message,
(DeviceVerifyMessage, DeviceUnverifyMessage, StartSasMessage,
DeviceBlacklistMessage, DeviceUnblacklistMessage)
(
DeviceVerifyMessage,
DeviceUnverifyMessage,
StartSasMessage,
DeviceBlacklistMessage,
DeviceUnblacklistMessage,
),
):
device = client.device_store[message.user_id].get(
message.device_id,
None
)
device = client.device_store[message.user_id].get(message.device_id, None)
if not device:
msg = (f"No device found for {message.user_id} and "
f"{message.device_id}")
msg = (
f"No device found for {message.user_id} and " f"{message.device_id}"
)
await self.send_response(
message.message_id,
message.pan_user,
"m.unknown_device",
msg
message.message_id, message.pan_user, "m.unknown_device", msg
)
logger.info(msg)
return
@ -265,11 +283,9 @@ class ProxyDaemon:
elif isinstance(message, DeviceUnverifyMessage):
await self._unverify_device(message.message_id, client, device)
elif isinstance(message, DeviceBlacklistMessage):
await self._blacklist_device(message.message_id, client,
device)
await self._blacklist_device(message.message_id, client, device)
elif isinstance(message, DeviceUnblacklistMessage):
await self._unblacklist_device(message.message_id, client,
device)
await self._unblacklist_device(message.message_id, client, device)
elif isinstance(message, StartSasMessage):
await client.start_sas(message, device)
@ -288,25 +304,21 @@ class ProxyDaemon:
try:
await client.export_keys(path, message.passphrase)
except OSError as e:
info_msg = (f"Error exporting keys for {client.user_id} to"
f" {path} {e}")
info_msg = (
f"Error exporting keys for {client.user_id} to" f" {path} {e}"
)
logger.info(info_msg)
await self.send_response(
message.message_id,
client.user_id,
"m.os_error",
str(e)
message.message_id, client.user_id, "m.os_error", str(e)
)
else:
info_msg = (f"Succesfully exported keys for {client.user_id} "
f"to {path}")
info_msg = (
f"Succesfully exported keys for {client.user_id} " f"to {path}"
)
logger.info(info_msg)
await self.send_response(
message.message_id,
client.user_id,
"m.ok",
info_msg
message.message_id, client.user_id, "m.ok", info_msg
)
elif isinstance(message, ImportKeysMessage):
@ -316,37 +328,32 @@ class ProxyDaemon:
try:
await client.import_keys(path, message.passphrase)
except (OSError, EncryptionError) as e:
info_msg = (f"Error importing keys for {client.user_id} "
f"from {path} {e}")
info_msg = (
f"Error importing keys for {client.user_id} " f"from {path} {e}"
)
logger.info(info_msg)
await self.send_response(
message.message_id,
client.user_id,
"m.os_error",
str(e)
message.message_id, client.user_id, "m.os_error", str(e)
)
else:
info_msg = (f"Succesfully imported keys for {client.user_id} "
f"from {path}")
info_msg = (
f"Succesfully imported keys for {client.user_id} " f"from {path}"
)
logger.info(info_msg)
await self.send_response(
message.message_id,
client.user_id,
"m.ok",
info_msg
message.message_id, client.user_id, "m.ok", info_msg
)
elif isinstance(message, UnverifiedResponse):
client = self.pan_clients[message.pan_user]
if message.room_id not in client.send_decision_queues:
msg = (f"No send request found for user {message.pan_user} "
f"and room {message.room_id}.")
msg = (
f"No send request found for user {message.pan_user} "
f"and room {message.room_id}."
)
await self.send_response(
message.message_id,
message.pan_user,
"m.unknown_request",
msg
message.message_id, message.pan_user, "m.unknown_request", msg
)
return
@ -365,10 +372,7 @@ class ProxyDaemon:
access_token = request.query.get("access_token", "")
if not access_token:
access_token = request.headers.get(
"Authorization",
""
).strip("Bearer ")
access_token = request.headers.get("Authorization", "").strip("Bearer ")
return access_token
@ -400,11 +404,11 @@ class ProxyDaemon:
async def forward_request(
self,
request, # type: aiohttp.web.BaseRequest
params=None, # type: CIMultiDict
data=None, # type: bytes
request, # type: aiohttp.web.BaseRequest
params=None, # type: CIMultiDict
data=None, # type: bytes
session=None, # type: aiohttp.ClientSession
token=None # type: str
token=None, # type: str
):
# type: (...) -> aiohttp.ClientResponse
"""Forward the given request to our configured homeserver.
@ -454,16 +458,11 @@ class ProxyDaemon:
params=params,
headers=headers,
proxy=self.proxy,
ssl=self.ssl
ssl=self.ssl,
)
async def forward_to_web(
self,
request,
params=None,
data=None,
session=None,
token=None
self, request, params=None, data=None, session=None, token=None
):
"""Forward the given request and convert the response to a Response.
@ -484,17 +483,13 @@ class ProxyDaemon:
"""
try:
response = await self.forward_request(
request,
params=params,
data=data,
session=session,
token=token
request, params=params, data=data, session=session, token=token
)
return web.Response(
status=response.status,
content_type=response.content_type,
headers=CORS_HEADERS,
body=await response.read()
body=await response.read(),
)
except ClientConnectionError as e:
return web.Response(status=500, text=str(e))
@ -522,8 +517,10 @@ class ProxyDaemon:
self.store.save_server_user(self.name, user_id)
if user_id in self.pan_clients:
logger.info(f"Background sync client already exists for {user_id},"
f" not starting new one")
logger.info(
f"Background sync client already exists for {user_id},"
f" not starting new one"
)
return
pan_client = PanClient(
@ -535,7 +532,7 @@ class ProxyDaemon:
user_id,
store_path=self.data_dir,
ssl=self.ssl,
proxy=self.proxy
proxy=self.proxy,
)
response = await pan_client.login(password, "pantalaimon")
@ -543,8 +540,7 @@ class ProxyDaemon:
await pan_client.close()
return
logger.info(f"Succesfully started new background sync client for "
f"{user_id}")
logger.info(f"Succesfully started new background sync client for " f"{user_id}")
await self.send_queue.put(UpdateUsersMessage())
@ -554,13 +550,11 @@ class ProxyDaemon:
keyring.set_password(
"pantalaimon",
f"{user_id}-{pan_client.device_id}-token",
pan_client.access_token
pan_client.access_token,
)
else:
self.store.save_access_token(
user_id,
pan_client.device_id,
pan_client.access_token
user_id, pan_client.device_id, pan_client.access_token
)
pan_client.start_loop()
@ -580,7 +574,7 @@ class ProxyDaemon:
return web.json_response(
{
"errcode": "M_NOT_JSON",
"error": "Request did not contain valid JSON."
"error": "Request did not contain valid JSON.",
},
status=500,
)
@ -605,25 +599,23 @@ class ProxyDaemon:
access_token = json_response.get("access_token", None)
if user_id and access_token:
logger.info(f"User: {user} succesfully logged in, starting "
f"a background sync client.")
await self.start_pan_client(access_token, user, user_id,
password)
logger.info(
f"User: {user} succesfully logged in, starting "
f"a background sync client."
)
await self.start_pan_client(access_token, user, user_id, password)
return web.Response(
status=response.status,
content_type=response.content_type,
headers=CORS_HEADERS,
body=await response.read()
body=await response.read(),
)
@property
def _missing_token(self):
return web.json_response(
{
"errcode": "M_MISSING_TOKEN",
"error": "Missing access token."
},
{"errcode": "M_MISSING_TOKEN", "error": "Missing access token."},
headers=CORS_HEADERS,
status=401,
)
@ -631,10 +623,7 @@ class ProxyDaemon:
@property
def _unknown_token(self):
return web.json_response(
{
"errcode": "M_UNKNOWN_TOKEN",
"error": "Unrecognised access token."
},
{"errcode": "M_UNKNOWN_TOKEN", "error": "Unrecognised access token."},
headers=CORS_HEADERS,
status=401,
)
@ -642,10 +631,7 @@ class ProxyDaemon:
@property
def _not_json(self):
return web.json_response(
{
"errcode": "M_NOT_JSON",
"error": "Request did not contain valid JSON."
},
{"errcode": "M_NOT_JSON", "error": "Request did not contain valid JSON."},
headers=CORS_HEADERS,
status=400,
)
@ -660,23 +646,18 @@ class ProxyDaemon:
while True:
try:
logger.info("Trying to decrypt sync")
return decryption_method(
body,
ignore_failures=False
)
return decryption_method(body, ignore_failures=False)
except EncryptionError:
logger.info("Error decrypting sync, waiting for next pan "
"sync")
logger.info("Error decrypting sync, waiting for next pan " "sync")
await client.synced.wait(),
logger.info("Pan synced, retrying decryption.")
try:
return await asyncio.wait_for(
decrypt_loop(client, body),
timeout=self.decryption_timeout)
decrypt_loop(client, body), timeout=self.decryption_timeout
)
except asyncio.TimeoutError:
logger.info("Decryption attempt timed out, decrypting with "
"failures")
logger.info("Decryption attempt timed out, decrypting with " "failures")
return decryption_method(body, ignore_failures=True)
async def sync(self, request):
@ -705,9 +686,7 @@ class ProxyDaemon:
try:
response = await self.forward_request(
request,
params=query,
token=client.access_token
request, params=query, token=client.access_token
)
except ClientConnectionError as e:
return web.Response(status=500, text=str(e))
@ -718,9 +697,7 @@ class ProxyDaemon:
json_response = await self.decrypt_body(client, json_response)
return web.json_response(
json_response,
headers=CORS_HEADERS,
status=response.status,
json_response, headers=CORS_HEADERS, status=response.status
)
except (JSONDecodeError, ContentTypeError):
pass
@ -729,7 +706,7 @@ class ProxyDaemon:
status=response.status,
content_type=response.content_type,
headers=CORS_HEADERS,
body=await response.read()
body=await response.read(),
)
async def messages(self, request):
@ -751,15 +728,11 @@ class ProxyDaemon:
try:
json_response = await response.json()
json_response = await self.decrypt_body(
client,
json_response,
sync=False
client, json_response, sync=False
)
return web.json_response(
json_response,
headers=CORS_HEADERS,
status=response.status,
json_response, headers=CORS_HEADERS, status=response.status
)
except (JSONDecodeError, ContentTypeError):
pass
@ -768,7 +741,7 @@ class ProxyDaemon:
status=response.status,
content_type=response.content_type,
headers=CORS_HEADERS,
body=await response.read()
body=await response.read(),
)
async def send_message(self, request):
@ -788,17 +761,11 @@ class ProxyDaemon:
room = client.rooms[room_id]
encrypt = room.encrypted
except KeyError:
return await self.forward_to_web(
request,
token=client.access_token
)
return await self.forward_to_web(request, token=client.access_token)
# The room isn't encrypted just forward the message.
if not encrypt:
return await self.forward_to_web(
request,
token=client.access_token
)
return await self.forward_to_web(request, token=client.access_token)
msgtype = request.match_info["event_type"]
txnid = request.match_info["txnid"]
@ -810,14 +777,15 @@ class ProxyDaemon:
async def _send(ignore_unverified=False):
try:
response = await client.room_send(room_id, msgtype, content,
txnid, ignore_unverified)
response = await client.room_send(
room_id, msgtype, content, txnid, ignore_unverified
)
return web.Response(
status=response.transport_response.status,
content_type=response.transport_response.content_type,
headers=CORS_HEADERS,
body=await response.transport_response.read()
body=await response.transport_response.read(),
)
except ClientConnectionError as e:
return web.Response(status=500, text=str(e))
@ -849,45 +817,40 @@ class ProxyDaemon:
client.send_decision_queues[room_id] = queue
message = UnverifiedDevicesSignal(
client.user_id,
room_id,
room.display_name
client.user_id, room_id, room.display_name
)
await self.send_queue.put(message)
try:
response = await asyncio.wait_for(
queue.get(),
self.unverified_send_timeout
queue.get(), self.unverified_send_timeout
)
if isinstance(response, CancelSendingMessage):
# The send was canceled notify the client that sent the
# request about this.
info_msg = (f"Canceled message sending for room "
f"{room.display_name} ({room_id}).")
info_msg = (
f"Canceled message sending for room "
f"{room.display_name} ({room_id})."
)
logger.info(info_msg)
await self.send_response(
response.message_id,
client.user_id,
"m.ok",
info_msg
response.message_id, client.user_id, "m.ok", info_msg
)
return web.Response(status=503, text=str(e))
elif isinstance(response, SendAnywaysMessage):
# We are sending and ignoring devices along the way.
info_msg = (f"Ignoring unverified devices and sending "
f"message to room "
f"{room.display_name} ({room_id}).")
info_msg = (
f"Ignoring unverified devices and sending "
f"message to room "
f"{room.display_name} ({room_id})."
)
logger.info(info_msg)
await self.send_response(
response.message_id,
client.user_id,
"m.ok",
info_msg
response.message_id, client.user_id, "m.ok", info_msg
)
ret = await _send(True)
@ -900,10 +863,12 @@ class ProxyDaemon:
return web.Response(
status=503,
text=(f"Room contains unverified devices and no "
f"action was taken for "
f"{self.unverified_send_timeout} seconds, "
f"request timed out")
text=(
f"Room contains unverified devices and no "
f"action was taken for "
f"{self.unverified_send_timeout} seconds, "
f"request timed out"
),
)
finally:
@ -922,10 +887,7 @@ class ProxyDaemon:
sanitized_content = self.sanitize_filter(content)
return await self.forward_to_web(
request,
data=json.dumps(sanitized_content)
)
return await self.forward_to_web(request, data=json.dumps(sanitized_content))
async def search_opts(self, request):
return web.json_response({}, headers=CORS_HEADERS)
@ -950,10 +912,7 @@ class ProxyDaemon:
validate_json(content, SEARCH_TERMS_SCHEMA)
except ValidationError:
return web.json_response(
{
"errcode": "M_BAD_JSON",
"error": "Invalid search query"
},
{"errcode": "M_BAD_JSON", "error": "Invalid search query"},
headers=CORS_HEADERS,
status=400,
)
@ -982,21 +941,14 @@ class ProxyDaemon:
result = await client.search(content)
except (InvalidOrderByError, InvalidLimit, InvalidQueryError) as e:
return web.json_response(
{
"errcode": "M_INVALID_PARAM",
"error": str(e)
},
{"errcode": "M_INVALID_PARAM", "error": str(e)},
headers=CORS_HEADERS,
status=400,
)
except UnknownRoomError:
return await self.forward_to_web(request)
return web.json_response(
result,
headers=CORS_HEADERS,
status=200
)
return web.json_response(result, headers=CORS_HEADERS, status=200)
async def shutdown(self, _):
"""Shut the daemon down closing all the client sessions it has.

View File

@ -21,10 +21,14 @@ from typing import Any, Dict, List, Optional, Tuple
import attr
import tantivy
from nio import (RoomEncryptedMedia, RoomMessageMedia, RoomMessageText,
RoomNameEvent, RoomTopicEvent)
from peewee import (SQL, DateTimeField, ForeignKeyField, Model, SqliteDatabase,
TextField)
from nio import (
RoomEncryptedMedia,
RoomMessageMedia,
RoomMessageText,
RoomNameEvent,
RoomTopicEvent,
)
from peewee import SQL, DateTimeField, ForeignKeyField, Model, SqliteDatabase, TextField
from pantalaimon.store import use_database
@ -65,22 +69,15 @@ class Event(Model):
source = DictField()
profile = ForeignKeyField(
model=Profile,
column_name="profile_id",
)
profile = ForeignKeyField(model=Profile, column_name="profile_id")
class Meta:
constraints = [SQL("UNIQUE(event_id, room_id, sender, profile_id)")]
class UserMessages(Model):
user = ForeignKeyField(
model=StoreUser,
column_name="user_id")
event = ForeignKeyField(
model=Event,
column_name="event_id")
user = ForeignKeyField(model=StoreUser, column_name="user_id")
event = ForeignKeyField(model=Event, column_name="event_id")
@attr.s
@ -91,17 +88,11 @@ class MessageStore:
database = attr.ib(type=SqliteDatabase, init=False)
database_path = attr.ib(type=str, init=False)
models = [
StoreUser,
Event,
Profile,
UserMessages
]
models = [StoreUser, Event, Profile, UserMessages]
def __attrs_post_init__(self):
self.database_path = os.path.join(
os.path.abspath(self.store_path),
self.database_name
os.path.abspath(self.store_path), self.database_name
)
self.database = self._create_database()
@ -112,21 +103,22 @@ class MessageStore:
def _create_database(self):
return SqliteDatabase(
self.database_path,
pragmas={
"foreign_keys": 1,
"secure_delete": 1,
}
self.database_path, pragmas={"foreign_keys": 1, "secure_delete": 1}
)
@use_database
def event_in_store(self, event_id, room_id):
user, _ = StoreUser.get_or_create(user_id=self.user)
query = Event.select().join(UserMessages).where(
(Event.room_id == room_id) &
(Event.event_id == event_id) &
(UserMessages.user == user)
).execute()
query = (
Event.select()
.join(UserMessages)
.where(
(Event.room_id == room_id)
& (Event.event_id == event_id)
& (UserMessages.user == user)
)
.execute()
)
for _ in query:
return True
@ -137,32 +129,29 @@ class MessageStore:
user, _ = StoreUser.get_or_create(user_id=self.user)
profile_id, _ = Profile.get_or_create(
user_id=event.sender,
display_name=display_name,
avatar_url=avatar_url
user_id=event.sender, display_name=display_name, avatar_url=avatar_url
)
event_source = event.source
event_source["room_id"] = room_id
event_id = Event.insert(
event_id=event.event_id,
sender=event.sender,
date=datetime.datetime.fromtimestamp(
event.server_timestamp / 1000
),
room_id=room_id,
source=event_source,
profile=profile_id
).on_conflict_ignore().execute()
event_id = (
Event.insert(
event_id=event.event_id,
sender=event.sender,
date=datetime.datetime.fromtimestamp(event.server_timestamp / 1000),
room_id=room_id,
source=event_source,
profile=profile_id,
)
.on_conflict_ignore()
.execute()
)
if event_id <= 0:
return None
_, created = UserMessages.get_or_create(
user=user,
event=event_id,
)
_, created = UserMessages.get_or_create(user=user, event=event_id)
if created:
return event_id
@ -173,24 +162,36 @@ class MessageStore:
context = {}
if before > 0:
query = Event.select().join(UserMessages).where(
(Event.date <= event.date) &
(Event.room_id == event.room_id) &
(Event.id != event.id) &
(UserMessages.user == user)
).order_by(Event.date.desc()).limit(before)
query = (
Event.select()
.join(UserMessages)
.where(
(Event.date <= event.date)
& (Event.room_id == event.room_id)
& (Event.id != event.id)
& (UserMessages.user == user)
)
.order_by(Event.date.desc())
.limit(before)
)
context["events_before"] = [e.source for e in query]
else:
context["events_before"] = []
if after > 0:
query = Event.select().join(UserMessages).where(
(Event.date >= event.date) &
(Event.room_id == event.room_id) &
(Event.id != event.id) &
(UserMessages.user == user)
).order_by(Event.date).limit(after)
query = (
Event.select()
.join(UserMessages)
.where(
(Event.date >= event.date)
& (Event.room_id == event.room_id)
& (Event.id != event.id)
& (UserMessages.user == user)
)
.order_by(Event.date)
.limit(after)
)
context["events_after"] = [e.source for e in query]
else:
@ -200,12 +201,12 @@ class MessageStore:
@use_database
def load_events(
self,
search_result, # type: List[Tuple[int, int]]
include_profile=False, # type: bool
order_by_recent=False, # type: bool
before=0, # type: int
after=0 # type: int
self,
search_result, # type: List[Tuple[int, int]]
include_profile=False, # type: bool
order_by_recent=False, # type: bool
before=0, # type: int
after=0, # type: int
):
# type: (...) -> Dict[Any, Any]
user, _ = StoreUser.get_or_create(user_id=self.user)
@ -213,13 +214,13 @@ class MessageStore:
search_dict = {r[1]: r[0] for r in search_result}
columns = list(search_dict.keys())
result_dict = {
"results": []
}
result_dict = {"results": []}
query = UserMessages.select().where(
(UserMessages.user_id == user) & (UserMessages.event.in_(columns))
).execute()
query = (
UserMessages.select()
.where((UserMessages.user_id == user) & (UserMessages.event.in_(columns)))
.execute()
)
for message in query:
@ -228,7 +229,7 @@ class MessageStore:
event_dict = {
"rank": 1 if order_by_recent else search_dict[event.id],
"result": event.source,
"context": {}
"context": {},
}
if include_profile:
@ -256,8 +257,17 @@ def sanitize_room_id(room_id):
class Searcher:
def __init__(self, index, body_field, name_field, topic_field,
column_field, room_field, timestamp_field, searcher):
def __init__(
self,
index,
body_field,
name_field,
topic_field,
column_field,
room_field,
timestamp_field,
searcher,
):
self._index = index
self._searcher = searcher
@ -268,8 +278,7 @@ class Searcher:
self.room_field = room_field
self.timestamp_field = timestamp_field
def search(self, search_term, room=None, max_results=10,
order_by_recent=False):
def search(self, search_term, room=None, max_results=10, order_by_recent=False):
# type (str, str, int, bool) -> List[int, int]
"""Search for events in the index.
@ -277,21 +286,13 @@ class Searcher:
"""
queryparser = tantivy.QueryParser.for_index(
self._index,
[
self.body_field,
self.name_field,
self.topic_field,
self.room_field
]
[self.body_field, self.name_field, self.topic_field, self.room_field],
)
# This currently supports only a single room since the query parser
# doesn't seem to work with multiple room fields here.
if room:
query_string = "{} AND room:{}".format(
search_term,
sanitize_room_id(room)
)
query_string = "{} AND room:{}".format(search_term, sanitize_room_id(room))
else:
query_string = search_term
@ -301,8 +302,9 @@ class Searcher:
raise InvalidQueryError(f"Invalid search term: {search_term}")
if order_by_recent:
collector = tantivy.TopDocs(max_results,
order_by_field=self.timestamp_field)
collector = tantivy.TopDocs(
max_results, order_by_field=self.timestamp_field
)
else:
collector = tantivy.TopDocs(max_results)
@ -329,16 +331,11 @@ class Index:
self.timestamp_field = schema_builder.add_unsigned_field(
"server_timestamp", fast="single"
)
self.date_field = schema_builder.add_date_field(
"message_date"
)
self.date_field = schema_builder.add_date_field("message_date")
self.room_field = schema_builder.add_facet_field("room")
self.column_field = schema_builder.add_unsigned_field(
"database_column",
indexed=True,
stored=True,
fast="single"
"database_column", indexed=True, stored=True, fast="single"
)
schema = schema_builder.build()
@ -359,7 +356,7 @@ class Index:
doc.add_facet(self.room_field, room_facet)
doc.add_date(
self.date_field,
datetime.datetime.fromtimestamp(event.server_timestamp / 1000)
datetime.datetime.fromtimestamp(event.server_timestamp / 1000),
)
doc.add_unsigned(self.timestamp_field, event.server_timestamp)
@ -389,7 +386,7 @@ class Index:
self.column_field,
self.room_field,
self.timestamp_field,
self.reader.searcher()
self.reader.searcher(),
)
@ -430,10 +427,7 @@ class IndexStore:
with store.database.bind_ctx(store.models):
with store.database.atomic():
for item in event_queue:
column_id = store.save_event(
item.event,
item.room_id,
)
column_id = store.save_event(item.event, item.room_id)
if column_id:
index.add_event(column_id, item.event, item.room_id)
@ -451,10 +445,7 @@ class IndexStore:
async with self.write_lock:
write_func = partial(
IndexStore.write_events,
self.store,
self.index,
event_queue
IndexStore.write_events, self.store, self.index, event_queue
)
await loop.run_in_executor(None, write_func)
@ -462,14 +453,14 @@ class IndexStore:
return self.store.event_in_store(event_id, room_id)
async def search(
self,
search_term, # type: str
room=None, # type: Optional[str]
max_results=10, # type: int
order_by_recent=False, # type: bool
include_profile=False, # type: bool
before_limit=0, # type: int
after_limit=0 # type: int
self,
search_term, # type: str
room=None, # type: Optional[str]
max_results=10, # type: int
order_by_recent=False, # type: bool
include_profile=False, # type: bool
before_limit=0, # type: int
after_limit=0, # type: int
):
# type: (...) -> Dict[Any, Any]
"""Search the indexstore for an event."""
@ -480,9 +471,13 @@ class IndexStore:
# the number of CPUs and the semaphore has the same counter value.
async with self.read_semaphore:
searcher = self.index.searcher()
search_func = partial(searcher.search, search_term, room=room,
max_results=max_results,
order_by_recent=order_by_recent)
search_func = partial(
searcher.search,
search_term,
room=room,
max_results=max_results,
order_by_recent=order_by_recent,
)
result = await loop.run_in_executor(None, search_func)
@ -492,7 +487,7 @@ class IndexStore:
include_profile,
order_by_recent,
before_limit,
after_limit
after_limit,
)
search_result = await loop.run_in_executor(None, load_event_func)

View File

@ -53,40 +53,39 @@ async def init(data_dir, server_conf, send_queue, recv_queue):
send_queue=send_queue,
recv_queue=recv_queue,
proxy=server_conf.proxy.geturl() if server_conf.proxy else None,
ssl=None if server_conf.ssl is True else False
ssl=None if server_conf.ssl is True else False,
)
app = web.Application()
app.add_routes([
web.post("/_matrix/client/r0/login", proxy.login),
web.get("/_matrix/client/r0/sync", proxy.sync),
web.get("/_matrix/client/r0/rooms/{room_id}/messages", proxy.messages),
web.put(
r"/_matrix/client/r0/rooms/{room_id}/send/{event_type}/{txnid}",
proxy.send_message
),
web.post("/_matrix/client/r0/user/{user_id}/filter", proxy.filter),
web.post("/_matrix/client/r0/search", proxy.search),
web.options("/_matrix/client/r0/search", proxy.search_opts),
])
app.add_routes(
[
web.post("/_matrix/client/r0/login", proxy.login),
web.get("/_matrix/client/r0/sync", proxy.sync),
web.get("/_matrix/client/r0/rooms/{room_id}/messages", proxy.messages),
web.put(
r"/_matrix/client/r0/rooms/{room_id}/send/{event_type}/{txnid}",
proxy.send_message,
),
web.post("/_matrix/client/r0/user/{user_id}/filter", proxy.filter),
web.post("/_matrix/client/r0/search", proxy.search),
web.options("/_matrix/client/r0/search", proxy.search_opts),
]
)
app.router.add_route("*", "/" + "{proxyPath:.*}", proxy.router)
app.on_shutdown.append(proxy.shutdown)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(
runner,
str(server_conf.listen_address),
server_conf.listen_port
)
site = web.TCPSite(runner, str(server_conf.listen_address), server_conf.listen_port)
return proxy, runner, site
async def message_router(receive_queue, send_queue, proxies):
"""Find the recipient of a message and forward it to the right proxy."""
def find_proxy_by_user(user):
# type: (str) -> Optional[ProxyDaemon]
for proxy in proxies:
@ -109,35 +108,28 @@ async def message_router(receive_queue, send_queue, proxies):
msg = f"No pan client found for {message.pan_user}."
logger.warn(msg)
await send_info(
message.message_id,
message.pan_user,
"m.unknown_client",
msg
message.message_id, message.pan_user, "m.unknown_client", msg
)
await proxy.receive_message(message)
@click.command(
help=("pantalaimon is a reverse proxy for matrix homeservers that "
"transparently encrypts and decrypts messages for clients that "
"connect to pantalaimon.")
help=(
"pantalaimon is a reverse proxy for matrix homeservers that "
"transparently encrypts and decrypts messages for clients that "
"connect to pantalaimon."
)
)
@click.version_option(version="0.1", prog_name="pantalaimon")
@click.option("--log-level", type=click.Choice([
"error",
"warning",
"info",
"debug"
]), default=None)
@click.option(
"--log-level",
type=click.Choice(["error", "warning", "info", "debug"]),
default=None,
)
@click.option("-c", "--config", type=click.Path(exists=True))
@click.pass_context
def main(
context,
log_level,
config
):
def main(context, log_level, config):
loop = asyncio.get_event_loop()
conf_dir = user_config_dir("pantalaimon", "")
@ -172,12 +164,7 @@ def main(
for server_conf in pan_conf.servers.values():
proxy, runner, site = loop.run_until_complete(
init(
data_dir,
server_conf,
pan_queue.async_q,
ui_queue.async_q,
)
init(data_dir, server_conf, pan_queue.async_q, ui_queue.async_q)
)
servers.append((proxy, runner, site))
proxies.append(proxy)
@ -185,14 +172,12 @@ def main(
except keyring.errors.KeyringError as e:
context.fail(f"Error initializing keyring: {e}")
glib_thread = GlibT(pan_queue.sync_q, ui_queue.sync_q, data_dir,
pan_conf.servers.values(), pan_conf)
glib_fut = loop.run_in_executor(
None,
glib_thread.run
glib_thread = GlibT(
pan_queue.sync_q, ui_queue.sync_q, data_dir, pan_conf.servers.values(), pan_conf
)
glib_fut = loop.run_in_executor(None, glib_thread.run)
async def wait_for_glib(glib_thread, fut):
glib_thread.stop()
await fut
@ -211,8 +196,10 @@ def main(
try:
for proxy, _, site in servers:
click.echo(f"======== Starting daemon for homeserver "
f"{proxy.name} on {site.name} ========")
click.echo(
f"======== Starting daemon for homeserver "
f"{proxy.name} on {site.name} ========"
)
loop.run_until_complete(site.start())
click.echo("(Press CTRL+C to quit)")

View File

@ -43,15 +43,12 @@ class PanctlArgParse(argparse.ArgumentParser):
pass
def error(self, message):
message = (
f"Error: {message} "
f"(see help)"
)
message = f"Error: {message} " f"(see help)"
print(message)
raise ParseError
class PanctlParser():
class PanctlParser:
def __init__(self, commands):
self.commands = commands
self.parser = PanctlArgParse()
@ -194,19 +191,13 @@ class PanCompleter(Completer):
return ""
def complete_key_file_cmds(
self,
document,
complete_event,
command,
last_word,
words
self, document, complete_event, command, last_word, words
):
if len(words) == 2:
return self.complete_pan_users(last_word)
elif len(words) == 3:
return self.path_completer.get_completions(
Document(last_word),
complete_event
Document(last_word), complete_event
)
return ""
@ -264,16 +255,9 @@ class PanCompleter(Completer):
]:
return self.complete_verification(command, last_word, words)
elif command in [
"export-keys",
"import-keys",
]:
elif command in ["export-keys", "import-keys"]:
return self.complete_key_file_cmds(
document,
complete_event,
command,
last_word,
words
document, complete_event, command, last_word, words
)
elif command in ["send-anyways", "cancel-sending"]:
@ -300,7 +284,7 @@ def grouper(iterable, n, fillvalue=None):
def partition_key(key):
groups = grouper(key, 4, " ")
return ' '.join(''.join(g) for g in groups)
return " ".join("".join(g) for g in groups)
def get_color(string):
@ -330,37 +314,59 @@ class PanCtl:
command_help = {
"help": "Display help about commands.",
"list-servers": ("List the configured homeservers and pan users on "
"each homeserver."),
"list-devices": ("List the devices of a user that are known to the "
"pan-user."),
"start-verification": ("Start an interactive key verification between "
"the given pan-user and user."),
"accept-verification": ("Accept an interactive key verification that "
"the given user has started with our given "
"pan-user."),
"cancel-verification": ("Cancel an interactive key verification "
"between the given pan-user and user."),
"confirm-verification": ("Confirm that the short authentication "
"string of the interactive key verification "
"with the given pan-user and user is "
"matching."),
"list-servers": (
"List the configured homeservers and pan users on " "each homeserver."
),
"list-devices": (
"List the devices of a user that are known to the " "pan-user."
),
"start-verification": (
"Start an interactive key verification between "
"the given pan-user and user."
),
"accept-verification": (
"Accept an interactive key verification that "
"the given user has started with our given "
"pan-user."
),
"cancel-verification": (
"Cancel an interactive key verification "
"between the given pan-user and user."
),
"confirm-verification": (
"Confirm that the short authentication "
"string of the interactive key verification "
"with the given pan-user and user is "
"matching."
),
"verify-device": ("Manually mark the given device as verified."),
"unverify-device": ("Mark a previously verified device of the given "
"user as unverified."),
"blacklist-device": ("Manually mark the given device of the given "
"user as blacklisted."),
"unblacklist-device": ("Mark a previously blacklisted device of the "
"given user as unblacklisted."),
"send-anyways": ("Send a room message despite having unverified "
"devices in the room and mark the devices as "
"ignored."),
"cancel-sending": ("Cancel the send of a room message in a room that "
"contains unverified devices"),
"import-keys": ("Import end-to-end encryption keys from the given "
"file for the given pan-user."),
"export-keys": ("Export end-to-end encryption keys to the given file "
"for the given pan-user."),
"unverify-device": (
"Mark a previously verified device of the given " "user as unverified."
),
"blacklist-device": (
"Manually mark the given device of the given " "user as blacklisted."
),
"unblacklist-device": (
"Mark a previously blacklisted device of the "
"given user as unblacklisted."
),
"send-anyways": (
"Send a room message despite having unverified "
"devices in the room and mark the devices as "
"ignored."
),
"cancel-sending": (
"Cancel the send of a room message in a room that "
"contains unverified devices"
),
"import-keys": (
"Import end-to-end encryption keys from the given "
"file for the given pan-user."
),
"export-keys": (
"Export end-to-end encryption keys to the given file "
"for the given pan-user."
),
}
commands = [
@ -404,10 +410,12 @@ class PanCtl:
def unverified_devices(self, pan_user, room_id, display_name):
self.completer.rooms[pan_user].add(room_id)
print(f"Error sending message for user {pan_user}, "
f"there are unverified devices in the room {display_name} "
f"({room_id}).\nUse the send-anyways or cancel-sending commands "
f"to ignore the devices or cancel the sending.")
print(
f"Error sending message for user {pan_user}, "
f"there are unverified devices in the room {display_name} "
f"({room_id}).\nUse the send-anyways or cancel-sending commands "
f"to ignore the devices or cancel the sending."
)
def show_response(self, response_id, pan_user, message):
if response_id not in self.own_message_ids:
@ -418,14 +426,18 @@ class PanCtl:
print(message["message"])
def sas_done(self, pan_user, user_id, device_id, _):
print(f"Device {device_id} of user {user_id}"
f" succesfully verified for pan user {pan_user}.")
print(
f"Device {device_id} of user {user_id}"
f" succesfully verified for pan user {pan_user}."
)
def show_sas_invite(self, pan_user, user_id, device_id, _):
print(f"{user_id} has started an interactive device "
f"verification for his device {device_id} with pan user "
f"{pan_user}\n"
f"Accept the invitation with the accept-verification command.")
print(
f"{user_id} has started an interactive device "
f"verification for his device {device_id} with pan user "
f"{pan_user}\n"
f"Accept the invitation with the accept-verification command."
)
# The emoji printing logic was taken from weechat-matrix and was written by
# dkasak.
@ -443,33 +455,26 @@ class PanCtl:
# that they are rendered with coloured glyphs. For these, we
# need to add an extra space after them so that they are
# rendered properly in weechat.
variation_selector_emojis = [
'☁️',
'❤️',
'☂️',
'✏️',
'✂️',
'☎️',
'✈️'
]
variation_selector_emojis = ["☁️", "❤️", "☂️", "✏️", "✂️", "☎️", "✈️"]
if emoji in variation_selector_emojis:
emoji += " "
# This is a trick to account for the fact that emojis are wider
# than other monospace characters.
placeholder = '.' * emoji_width
placeholder = "." * emoji_width
return placeholder.center(width).replace(placeholder, emoji)
emoji_str = u"".join(center_emoji(e, centered_width)
for e in emojis)
desc = u"".join(d.center(centered_width) for d in descriptions)
short_string = u"\n".join([emoji_str, desc])
emoji_str = "".join(center_emoji(e, centered_width) for e in emojis)
desc = "".join(d.center(centered_width) for d in descriptions)
short_string = "\n".join([emoji_str, desc])
print(f"Short authentication string for pan "
f"user {pan_user} from {user_id} via "
f"{device_id}:\n{short_string}")
print(
f"Short authentication string for pan "
f"user {pan_user} from {user_id} via "
f"{device_id}:\n{short_string}"
)
def list_servers(self):
"""List the daemons users."""
@ -480,9 +485,7 @@ class PanCtl:
for server, server_users in servers.items():
server_c = get_color(server)
print_formatted_text(HTML(
f" - Name: <{server_c}>{server}</{server_c}>"
))
print_formatted_text(HTML(f" - Name: <{server_c}>{server}</{server_c}>"))
user_list = []
@ -490,8 +493,10 @@ class PanCtl:
user_c = get_color(user)
device_c = get_color(device)
user_list.append(f" - <{user_c}>{user}</{user_c}> "
f"<{device_c}>{device}</{device_c}>")
user_list.append(
f" - <{user_c}>{user}</{user_c}> "
f"<{device_c}>{device}</{device_c}>"
)
if user_list:
print(f" - Pan users:")
@ -501,9 +506,7 @@ class PanCtl:
def list_devices(self, args):
devices = self.devices.ListUserDevices(args.pan_user, args.user_id)
print_formatted_text(
HTML(f"Devices for user <b>{args.user_id}</b>:")
)
print_formatted_text(HTML(f"Devices for user <b>{args.user_id}</b>:"))
for device in devices:
if device["trust_state"] == "verified":
@ -517,16 +520,18 @@ class PanCtl:
key = partition_key(device["ed25519"])
color = get_color(device["device_id"])
print_formatted_text(HTML(
f" - Display name: "
f"{device['device_display_name']}\n"
f" - Device id: "
f"<{color}>{device['device_id']}</{color}>\n"
f" - Device key: "
f"<ansiyellow>{key}</ansiyellow>\n"
f" - Trust state: "
f"{trust_state}"
))
print_formatted_text(
HTML(
f" - Display name: "
f"{device['device_display_name']}\n"
f" - Device id: "
f"<{color}>{device['device_id']}</{color}>\n"
f" - Device key: "
f"<ansiyellow>{key}</ansiyellow>\n"
f" - Trust state: "
f"{trust_state}"
)
)
async def loop(self):
"""Event loop for panctl."""
@ -559,105 +564,83 @@ class PanCtl:
elif command == "import-keys":
self.own_message_ids.append(
self.ctl.ImportKeys(
args.pan_user,
args.path,
args.passphrase
))
self.ctl.ImportKeys(args.pan_user, args.path, args.passphrase)
)
elif command == "export-keys":
self.own_message_ids.append(
self.ctl.ExportKeys(
args.pan_user,
args.path,
args.passphrase
))
self.ctl.ExportKeys(args.pan_user, args.path, args.passphrase)
)
elif command == "send-anyways":
self.own_message_ids.append(
self.ctl.SendAnyways(
args.pan_user,
args.room_id,
))
self.ctl.SendAnyways(args.pan_user, args.room_id)
)
elif command == "cancel-sending":
self.own_message_ids.append(
self.ctl.CancelSending(
args.pan_user,
args.room_id,
))
self.ctl.CancelSending(args.pan_user, args.room_id)
)
elif command == "list-devices":
self.list_devices(args)
elif command == "verify-device":
self.own_message_ids.append(
self.devices.Verify(
args.pan_user,
args.user_id,
args.device_id
))
self.devices.Verify(args.pan_user, args.user_id, args.device_id)
)
elif command == "unverify-device":
self.own_message_ids.append(
self.devices.Unverify(
args.pan_user,
args.user_id,
args.device_id
))
self.devices.Unverify(args.pan_user, args.user_id, args.device_id)
)
elif command == "blacklist-device":
self.own_message_ids.append(
self.devices.Blacklist(
args.pan_user,
args.user_id,
args.device_id
))
self.devices.Blacklist(args.pan_user, args.user_id, args.device_id)
)
elif command == "unblacklist-device":
self.own_message_ids.append(
self.devices.Unblacklist(
args.pan_user,
args.user_id,
args.device_id
))
args.pan_user, args.user_id, args.device_id
)
)
elif command == "start-verification":
self.own_message_ids.append(
self.devices.StartKeyVerification(
args.pan_user,
args.user_id,
args.device_id
))
args.pan_user, args.user_id, args.device_id
)
)
elif command == "cancel-verification":
self.own_message_ids.append(
self.devices.CancelKeyVerification(
args.pan_user,
args.user_id,
args.device_id
))
args.pan_user, args.user_id, args.device_id
)
)
elif command == "accept-verification":
self.own_message_ids.append(
self.devices.AcceptKeyVerification(
args.pan_user,
args.user_id,
args.device_id
))
args.pan_user, args.user_id, args.device_id
)
)
elif command == "confirm-verification":
self.own_message_ids.append(
self.devices.ConfirmKeyVerification(
args.pan_user,
args.user_id,
args.device_id
))
args.pan_user, args.user_id, args.device_id
)
)
@click.command(
help=("panctl is a small interactive repl to introspect and control"
"the pantalaimon daemon.")
help=(
"panctl is a small interactive repl to introspect and control"
"the pantalaimon daemon."
)
)
@click.version_option(version="0.1", prog_name="panctl")
def main():
@ -670,10 +653,7 @@ def main():
print(f"Error, {e}")
sys.exit(-1)
fut = loop.run_in_executor(
None,
glib_loop.run
)
fut = loop.run_in_executor(None, glib_loop.run)
try:
loop.run_until_complete(panctl.loop())
@ -684,5 +664,5 @@ def main():
loop.run_until_complete(fut)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -18,10 +18,8 @@ from collections import defaultdict
from typing import List, Optional, Tuple
import attr
from nio.store import (Accounts, DeviceKeys, DeviceTrustState, TrustState,
use_database)
from peewee import (SQL, DoesNotExist, ForeignKeyField, Model, SqliteDatabase,
TextField)
from nio.store import Accounts, DeviceKeys, DeviceTrustState, TrustState, use_database
from peewee import SQL, DoesNotExist, ForeignKeyField, Model, SqliteDatabase, TextField
@attr.s
@ -41,10 +39,7 @@ class DictField(TextField):
class AccessTokens(Model):
token = TextField()
account = ForeignKeyField(
model=Accounts,
primary_key=True,
backref="access_token",
on_delete="CASCADE"
model=Accounts, primary_key=True, backref="access_token", on_delete="CASCADE"
)
@ -58,10 +53,7 @@ class Servers(Model):
class ServerUsers(Model):
user_id = TextField()
server = ForeignKeyField(
model=Servers,
column_name="server_id",
backref="users",
on_delete="CASCADE"
model=Servers, column_name="server_id", backref="users", on_delete="CASCADE"
)
class Meta:
@ -70,9 +62,7 @@ class ServerUsers(Model):
class PanSyncTokens(Model):
token = TextField()
user = ForeignKeyField(
model=ServerUsers,
column_name="user_id")
user = ForeignKeyField(model=ServerUsers, column_name="user_id")
class Meta:
constraints = [SQL("UNIQUE(user_id)")]
@ -80,9 +70,8 @@ class PanSyncTokens(Model):
class PanFetcherTasks(Model):
user = ForeignKeyField(
model=ServerUsers,
column_name="user_id",
backref="fetcher_tasks")
model=ServerUsers, column_name="user_id", backref="fetcher_tasks"
)
room_id = TextField()
token = TextField()
@ -110,13 +99,12 @@ class PanStore:
DeviceKeys,
DeviceTrustState,
PanSyncTokens,
PanFetcherTasks
PanFetcherTasks,
]
def __attrs_post_init__(self):
self.database_path = os.path.join(
os.path.abspath(self.store_path),
self.database_name
os.path.abspath(self.store_path), self.database_name
)
self.database = self._create_database()
@ -127,19 +115,14 @@ class PanStore:
def _create_database(self):
return SqliteDatabase(
self.database_path,
pragmas={
"foreign_keys": 1,
"secure_delete": 1,
}
self.database_path, pragmas={"foreign_keys": 1, "secure_delete": 1}
)
@use_database
def _get_account(self, user_id, device_id):
try:
return Accounts.get(
Accounts.user_id == user_id,
Accounts.device_id == device_id,
Accounts.user_id == user_id, Accounts.device_id == device_id
)
except DoesNotExist:
return None
@ -150,9 +133,7 @@ class PanStore:
user = ServerUsers.get(server=server, user_id=pan_user)
PanFetcherTasks.replace(
user=user,
room_id=task.room_id,
token=task.token
user=user, room_id=task.room_id, token=task.token
).execute()
def load_fetcher_tasks(self, server, pan_user):
@ -173,7 +154,7 @@ class PanStore:
PanFetcherTasks.delete().where(
PanFetcherTasks.user == user,
PanFetcherTasks.room_id == task.room_id,
PanFetcherTasks.token == task.token
PanFetcherTasks.token == task.token,
).execute()
@use_database
@ -208,18 +189,14 @@ class PanStore:
server, _ = Servers.get_or_create(name=server_name)
ServerUsers.insert(
user_id=user_id,
server=server
user_id=user_id, server=server
).on_conflict_ignore().execute()
@use_database
def load_all_users(self):
users = []
query = Accounts.select(
Accounts.user_id,
Accounts.device_id,
)
query = Accounts.select(Accounts.user_id, Accounts.device_id)
for account in query:
users.append((account.user_id, account.device_id))
@ -241,10 +218,9 @@ class PanStore:
for u in server.users:
server_users.append(u.user_id)
query = Accounts.select(
Accounts.user_id,
Accounts.device_id,
).where(Accounts.user_id.in_(server_users))
query = Accounts.select(Accounts.user_id, Accounts.device_id).where(
Accounts.user_id.in_(server_users)
)
for account in query:
users.append((account.user_id, account.device_id))
@ -256,10 +232,7 @@ class PanStore:
account = self._get_account(user_id, device_id)
assert account
AccessTokens.replace(
account=account,
token=access_token
).execute()
AccessTokens.replace(account=account, token=access_token).execute()
@use_database
def load_access_token(self, user_id, device_id):
@ -302,7 +275,7 @@ class PanStore:
"ed25519": keys["ed25519"],
"curve25519": keys["curve25519"],
"trust_state": trust_state.name,
"device_display_name": d.display_name
"device_display_name": d.display_name,
}
store[account.user_id] = device_store

View File

@ -24,20 +24,27 @@ from pydbus.generic import signal
from pantalaimon.log import logger
from pantalaimon.store import PanStore
from pantalaimon.thread_messages import (AcceptSasMessage, CancelSasMessage,
CancelSendingMessage,
ConfirmSasMessage, DaemonResponse,
DeviceBlacklistMessage,
DeviceUnblacklistMessage,
DeviceUnverifyMessage,
DeviceVerifyMessage,
ExportKeysMessage, ImportKeysMessage,
InviteSasSignal, SasDoneSignal,
SendAnywaysMessage, ShowSasSignal,
StartSasMessage,
UnverifiedDevicesSignal,
UpdateDevicesMessage,
UpdateUsersMessage)
from pantalaimon.thread_messages import (
AcceptSasMessage,
CancelSasMessage,
CancelSendingMessage,
ConfirmSasMessage,
DaemonResponse,
DeviceBlacklistMessage,
DeviceUnblacklistMessage,
DeviceUnverifyMessage,
DeviceVerifyMessage,
ExportKeysMessage,
ImportKeysMessage,
InviteSasSignal,
SasDoneSignal,
SendAnywaysMessage,
ShowSasSignal,
StartSasMessage,
UnverifiedDevicesSignal,
UpdateDevicesMessage,
UpdateUsersMessage,
)
class IdCounter:
@ -126,40 +133,22 @@ class Control:
return self.users
def ExportKeys(self, pan_user, filepath, passphrase):
message = ExportKeysMessage(
self.message_id,
pan_user,
filepath,
passphrase
)
message = ExportKeysMessage(self.message_id, pan_user, filepath, passphrase)
self.queue.put(message)
return message.message_id
def ImportKeys(self, pan_user, filepath, passphrase):
message = ImportKeysMessage(
self.message_id,
pan_user,
filepath,
passphrase
)
message = ImportKeysMessage(self.message_id, pan_user, filepath, passphrase)
self.queue.put(message)
return message.message_id
def SendAnyways(self, pan_user, room_id):
message = SendAnywaysMessage(
self.message_id,
pan_user,
room_id
)
message = SendAnywaysMessage(self.message_id, pan_user, room_id)
self.queue.put(message)
return message.message_id
def CancelSending(self, pan_user, room_id):
message = CancelSendingMessage(
self.message_id,
pan_user,
room_id
)
message = CancelSendingMessage(self.message_id, pan_user, room_id)
self.queue.put(message)
return message.message_id
@ -293,8 +282,9 @@ class Devices:
return []
device_list = [
device for device_list in device_store.values() for device in
device_list.values()
device
for device_list in device_store.values()
for device in device_list.values()
]
return device_list
@ -313,82 +303,44 @@ class Devices:
return device_list.values()
def Verify(self, pan_user, user_id, device_id):
message = DeviceVerifyMessage(
self.message_id,
pan_user,
user_id,
device_id
)
message = DeviceVerifyMessage(self.message_id, pan_user, user_id, device_id)
self.queue.put(message)
return message.message_id
def Unverify(self, pan_user, user_id, device_id):
message = DeviceUnverifyMessage(
self.message_id,
pan_user,
user_id,
device_id
)
message = DeviceUnverifyMessage(self.message_id, pan_user, user_id, device_id)
self.queue.put(message)
return message.message_id
def Blacklist(self, pan_user, user_id, device_id):
message = DeviceBlacklistMessage(
self.message_id,
pan_user,
user_id,
device_id
)
message = DeviceBlacklistMessage(self.message_id, pan_user, user_id, device_id)
self.queue.put(message)
return message.message_id
def Unblacklist(self, pan_user, user_id, device_id):
message = DeviceUnblacklistMessage(
self.message_id,
pan_user,
user_id,
device_id
self.message_id, pan_user, user_id, device_id
)
self.queue.put(message)
return message.message_id
def StartKeyVerification(self, pan_user, user_id, device_id):
message = StartSasMessage(
self.message_id,
pan_user,
user_id,
device_id
)
message = StartSasMessage(self.message_id, pan_user, user_id, device_id)
self.queue.put(message)
return message.message_id
def CancelKeyVerification(self, pan_user, user_id, device_id):
message = CancelSasMessage(
self.message_id,
pan_user,
user_id,
device_id
)
message = CancelSasMessage(self.message_id, pan_user, user_id, device_id)
self.queue.put(message)
return message.message_id
def ConfirmKeyVerification(self, pan_user, user_id, device_id):
message = ConfirmSasMessage(
self.message_id,
pan_user,
user_id,
device_id
)
message = ConfirmSasMessage(self.message_id, pan_user, user_id, device_id)
self.queue.put(message)
return message.message_id
def AcceptKeyVerification(self, pan_user, user_id, device_id):
message = AcceptSasMessage(
self.message_id,
pan_user,
user_id,
device_id
)
message = AcceptSasMessage(self.message_id, pan_user, user_id, device_id)
self.queue.put(message)
return message.message_id
@ -421,8 +373,9 @@ class GlibT:
id_counter = IdCounter()
self.control_if = Control(self.send_queue, self.store,
self.server_list, id_counter)
self.control_if = Control(
self.send_queue, self.store, self.server_list, id_counter
)
self.device_if = Devices(self.send_queue, self.store, id_counter)
self.bus = SessionBus()
@ -431,131 +384,95 @@ class GlibT:
def unverified_notification(self, message):
notificaton = notify2.Notification(
"Unverified devices.",
message=(f"There are unverified devices in the room "
f"{message.room_display_name}.")
message=(
f"There are unverified devices in the room "
f"{message.room_display_name}."
),
)
notificaton.set_category("im")
def send_cb(notification, action_key, user_data):
message = user_data
self.control_if.SendAnyways(
message.pan_user,
message.room_id
)
self.control_if.SendAnyways(message.pan_user, message.room_id)
def cancel_cb(notification, action_key, user_data):
message = user_data
self.control_if.CancelSending(
message.pan_user,
message.room_id
)
self.control_if.CancelSending(message.pan_user, message.room_id)
if "actions" in notify2.get_server_caps():
notificaton.add_action(
"send",
"Send anyways",
send_cb,
message
)
notificaton.add_action(
"cancel",
"Cancel sending",
cancel_cb,
message
)
notificaton.add_action("send", "Send anyways", send_cb, message)
notificaton.add_action("cancel", "Cancel sending", cancel_cb, message)
notificaton.show()
def sas_invite_notification(self, message):
notificaton = notify2.Notification(
"Key verification invite",
message=(f"{message.user_id} via {message.device_id} has started "
f"a key verification process.")
message=(
f"{message.user_id} via {message.device_id} has started "
f"a key verification process."
),
)
notificaton.set_category("im")
def accept_cb(notification, action_key, user_data):
message = user_data
self.device_if.AcceptKeyVerification(
message.pan_user,
message.user_id,
message.device_id
message.pan_user, message.user_id, message.device_id
)
def cancel_cb(notification, action_key, user_data):
message = user_data
self.device_if.CancelKeyVerification(
message.pan_user,
message.user_id,
message.device_id,
message.pan_user, message.user_id, message.device_id
)
if "actions" in notify2.get_server_caps():
notificaton.add_action(
"accept",
"Accept",
accept_cb,
message
)
notificaton.add_action(
"cancel",
"Cancel",
cancel_cb,
message
)
notificaton.add_action("accept", "Accept", accept_cb, message)
notificaton.add_action("cancel", "Cancel", cancel_cb, message)
notificaton.show()
def sas_show_notification(self, message):
emojis = [x[0] for x in message.emoji]
emoji_str = u" ".join(emojis)
emoji_str = " ".join(emojis)
notificaton = notify2.Notification(
"Short authentication string",
message=(f"Short authentication string for the key verification of"
f" {message.user_id} via {message.device_id}:\n"
f"{emoji_str}")
message=(
f"Short authentication string for the key verification of"
f" {message.user_id} via {message.device_id}:\n"
f"{emoji_str}"
),
)
notificaton.set_category("im")
def confirm_cb(notification, action_key, user_data):
message = user_data
self.device_if.ConfirmKeyVerification(
message.pan_user,
message.user_id,
message.device_id
message.pan_user, message.user_id, message.device_id
)
def cancel_cb(notification, action_key, user_data):
message = user_data
self.device_if.CancelKeyVerification(
message.pan_user,
message.user_id,
message.device_id,
message.pan_user, message.user_id, message.device_id
)
if "actions" in notify2.get_server_caps():
notificaton.add_action(
"confirm",
"Confirm",
confirm_cb,
message
)
notificaton.add_action(
"cancel",
"Cancel",
cancel_cb,
message
)
notificaton.add_action("confirm", "Confirm", confirm_cb, message)
notificaton.add_action("cancel", "Cancel", cancel_cb, message)
notificaton.show()
def sas_done_notification(self, message):
notificaton = notify2.Notification(
"Device successfully verified.",
message=(f"Device {message.device_id} of user {message.user_id} "
f"successfully verified.")
message=(
f"Device {message.device_id} of user {message.user_id} "
f"successfully verified."
),
)
notificaton.set_category("im")
notificaton.show()
@ -576,9 +493,7 @@ class GlibT:
elif isinstance(message, UnverifiedDevicesSignal):
self.control_if.UnverifiedDevices(
message.pan_user,
message.room_id,
message.room_display_name
message.pan_user, message.room_id, message.room_display_name
)
if self.notifications:
@ -589,7 +504,7 @@ class GlibT:
message.pan_user,
message.user_id,
message.device_id,
message.transaction_id
message.transaction_id,
)
if self.notifications:
@ -622,10 +537,7 @@ class GlibT:
self.control_if.Response(
message.message_id,
message.pan_user,
{
"code": message.code,
"message": message.message
}
{"code": message.code, "message": message.message},
)
self.receive_queue.task_done()
@ -639,8 +551,10 @@ class GlibT:
notify2.init("pantalaimon", mainloop=self.loop)
self.notifications = True
except dbus.DBusException:
logger.error("Notifications are enabled but no notification "
"server could be found, disabling notifications.")
logger.error(
"Notifications are enabled but no notification "
"server could be found, disabling notifications."
)
self.notifications = False
GLib.timeout_add(100, self.message_callback)