mirror of
https://github.com/matrix-org/pantalaimon.git
synced 2024-10-01 03:35:38 -04:00
pantalaimon: Format the source tree using black.
This commit is contained in:
parent
531d686d8f
commit
c9ebfd71ec
4
.flake8
Normal file
4
.flake8
Normal file
@ -0,0 +1,4 @@
|
||||
[flake8]
|
||||
max-line-length = 80
|
||||
select = C,E,F,W,B,B950
|
||||
ignore = E501,W503
|
6
.isort.cfg
Normal file
6
.isort.cfg
Normal file
@ -0,0 +1,6 @@
|
||||
[settings]
|
||||
multi_line_output=3
|
||||
include_trailing_comma=True
|
||||
force_grid_wrap=0
|
||||
use_parentheses=True
|
||||
line_length=88
|
4
Makefile
4
Makefile
@ -1,5 +1,6 @@
|
||||
test:
|
||||
python3 -m pytest
|
||||
python3 -m pytest --black pantalaimon
|
||||
python3 -m pytest --flake8 pantalaimon
|
||||
python3 -m pytest --isort
|
||||
|
||||
@ -14,3 +15,6 @@ run-local:
|
||||
|
||||
isort:
|
||||
isort -y -p pantalaimon
|
||||
|
||||
format:
|
||||
black pantalaimon/
|
||||
|
@ -20,21 +20,39 @@ from typing import Any, Dict, Optional
|
||||
|
||||
from aiohttp.client_exceptions import ClientConnectionError
|
||||
from jsonschema import Draft4Validator, FormatChecker, validators
|
||||
from nio import (AsyncClient, ClientConfig, EncryptionError, KeysQueryResponse,
|
||||
KeyVerificationEvent, KeyVerificationKey, KeyVerificationMac,
|
||||
KeyVerificationStart, LocalProtocolError, MegolmEvent,
|
||||
RoomContextError, RoomEncryptedEvent, RoomEncryptedMedia,
|
||||
RoomMessageMedia, RoomMessageText, RoomNameEvent,
|
||||
RoomTopicEvent, SyncResponse)
|
||||
from nio import (
|
||||
AsyncClient,
|
||||
ClientConfig,
|
||||
EncryptionError,
|
||||
KeysQueryResponse,
|
||||
KeyVerificationEvent,
|
||||
KeyVerificationKey,
|
||||
KeyVerificationMac,
|
||||
KeyVerificationStart,
|
||||
LocalProtocolError,
|
||||
MegolmEvent,
|
||||
RoomContextError,
|
||||
RoomEncryptedEvent,
|
||||
RoomEncryptedMedia,
|
||||
RoomMessageMedia,
|
||||
RoomMessageText,
|
||||
RoomNameEvent,
|
||||
RoomTopicEvent,
|
||||
SyncResponse,
|
||||
)
|
||||
from nio.crypto import Sas
|
||||
from nio.store import SqliteStore
|
||||
|
||||
from pantalaimon.index import IndexStore
|
||||
from pantalaimon.log import logger
|
||||
from pantalaimon.store import FetchTask
|
||||
from pantalaimon.thread_messages import (DaemonResponse, InviteSasSignal,
|
||||
SasDoneSignal, ShowSasSignal,
|
||||
UpdateDevicesMessage)
|
||||
from pantalaimon.thread_messages import (
|
||||
DaemonResponse,
|
||||
InviteSasSignal,
|
||||
SasDoneSignal,
|
||||
ShowSasSignal,
|
||||
UpdateDevicesMessage,
|
||||
)
|
||||
|
||||
SEARCH_KEYS = ["content.body", "content.name", "content.topic"]
|
||||
|
||||
@ -51,7 +69,7 @@ SEARCH_TERMS_SCHEMA = {
|
||||
"keys": {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "enum": SEARCH_KEYS},
|
||||
"default": SEARCH_KEYS
|
||||
"default": SEARCH_KEYS,
|
||||
},
|
||||
"order_by": {"type": "string", "default": "rank"},
|
||||
"include_state": {"type": "boolean", "default": False},
|
||||
@ -59,15 +77,13 @@ SEARCH_TERMS_SCHEMA = {
|
||||
"event_context": {"type": "object"},
|
||||
"groupings": {"type": "object", "default": {}},
|
||||
},
|
||||
"required": ["search_term"]
|
||||
},
|
||||
"required": ["search_term"],
|
||||
}
|
||||
},
|
||||
"required": ["room_events"]
|
||||
},
|
||||
"required": [
|
||||
"search_categories",
|
||||
],
|
||||
"required": ["room_events"],
|
||||
},
|
||||
"required": ["search_categories"],
|
||||
}
|
||||
|
||||
|
||||
@ -79,9 +95,7 @@ def extend_with_default(validator_class):
|
||||
if "default" in subschema:
|
||||
instance.setdefault(prop, subschema["default"])
|
||||
|
||||
for error in validate_properties(
|
||||
validator, properties, instance, schema
|
||||
):
|
||||
for error in validate_properties(validator, properties, instance, schema):
|
||||
yield error
|
||||
|
||||
return validators.extend(validator_class, {"properties": set_defaults})
|
||||
@ -122,11 +136,10 @@ class PanClient(AsyncClient):
|
||||
store_path="",
|
||||
config=None,
|
||||
ssl=None,
|
||||
proxy=None
|
||||
proxy=None,
|
||||
):
|
||||
config = config or ClientConfig(store=SqliteStore, store_name="pan.db")
|
||||
super().__init__(homeserver, user_id, device_id, store_path, config,
|
||||
ssl, proxy)
|
||||
super().__init__(homeserver, user_id, device_id, store_path, config, ssl, proxy)
|
||||
|
||||
index_dir = os.path.join(store_path, server_name, user_id)
|
||||
|
||||
@ -150,31 +163,24 @@ class PanClient(AsyncClient):
|
||||
self.history_fetcher_task = None
|
||||
self.history_fetch_queue = asyncio.Queue()
|
||||
|
||||
self.add_to_device_callback(
|
||||
self.key_verification_cb,
|
||||
KeyVerificationEvent
|
||||
)
|
||||
self.add_event_callback(
|
||||
self.undecrypted_event_cb,
|
||||
MegolmEvent
|
||||
)
|
||||
self.add_to_device_callback(self.key_verification_cb, KeyVerificationEvent)
|
||||
self.add_event_callback(self.undecrypted_event_cb, MegolmEvent)
|
||||
self.add_event_callback(
|
||||
self.store_message_cb,
|
||||
(RoomMessageText, RoomMessageMedia, RoomEncryptedMedia,
|
||||
RoomTopicEvent, RoomNameEvent)
|
||||
(
|
||||
RoomMessageText,
|
||||
RoomMessageMedia,
|
||||
RoomEncryptedMedia,
|
||||
RoomTopicEvent,
|
||||
RoomNameEvent,
|
||||
),
|
||||
)
|
||||
self.key_verificatins_tasks = []
|
||||
self.key_request_tasks = []
|
||||
|
||||
self.add_response_callback(
|
||||
self.keys_query_cb,
|
||||
KeysQueryResponse
|
||||
)
|
||||
self.add_response_callback(self.keys_query_cb, KeysQueryResponse)
|
||||
|
||||
self.add_response_callback(
|
||||
self.sync_tasks,
|
||||
SyncResponse
|
||||
)
|
||||
self.add_response_callback(self.sync_tasks, SyncResponse)
|
||||
|
||||
def store_message_cb(self, room, event):
|
||||
display_name = room.user_name(event.sender)
|
||||
@ -192,9 +198,11 @@ class PanClient(AsyncClient):
|
||||
"type": "m.room.message",
|
||||
"content": {
|
||||
"msgtype": "m.text",
|
||||
"body": ("** Unable to decrypt: The sender's device has not "
|
||||
"sent us the keys for this message. **")
|
||||
}
|
||||
"body": (
|
||||
"** Unable to decrypt: The sender's device has not "
|
||||
"sent us the keys for this message. **"
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
async def send_message(self, message):
|
||||
@ -206,17 +214,10 @@ class PanClient(AsyncClient):
|
||||
await self.queue.put(message)
|
||||
|
||||
def delete_fetcher_task(self, task):
|
||||
self.pan_store.delete_fetcher_task(
|
||||
self.server_name,
|
||||
self.user_id,
|
||||
task
|
||||
)
|
||||
self.pan_store.delete_fetcher_task(self.server_name, self.user_id, task)
|
||||
|
||||
async def fetcher_loop(self):
|
||||
for t in self.pan_store.load_fetcher_tasks(
|
||||
self.server_name,
|
||||
self.user_id
|
||||
):
|
||||
for t in self.pan_store.load_fetcher_tasks(self.server_name, self.user_id):
|
||||
await self.history_fetch_queue.put(t)
|
||||
|
||||
while True:
|
||||
@ -233,13 +234,13 @@ class PanClient(AsyncClient):
|
||||
continue
|
||||
|
||||
try:
|
||||
logger.debug("Fetching room history for {}".format(
|
||||
room.display_name
|
||||
))
|
||||
logger.debug(
|
||||
"Fetching room history for {}".format(room.display_name)
|
||||
)
|
||||
response = await self.room_messages(
|
||||
fetch_task.room_id,
|
||||
fetch_task.token,
|
||||
limit=self.pan_conf.indexing_batch_size
|
||||
limit=self.pan_conf.indexing_batch_size,
|
||||
)
|
||||
except ClientConnectionError:
|
||||
self.history_fetch_queue.put(fetch_task)
|
||||
@ -250,31 +251,31 @@ class PanClient(AsyncClient):
|
||||
continue
|
||||
|
||||
for event in response.chunk:
|
||||
if not isinstance(event, (
|
||||
if not isinstance(
|
||||
event,
|
||||
(
|
||||
RoomMessageText,
|
||||
RoomMessageMedia,
|
||||
RoomEncryptedMedia,
|
||||
RoomTopicEvent,
|
||||
RoomNameEvent
|
||||
)):
|
||||
RoomNameEvent,
|
||||
),
|
||||
):
|
||||
continue
|
||||
|
||||
display_name = room.user_name(event.sender)
|
||||
avatar_url = room.avatar_url(event.sender)
|
||||
self.index.add_event(event, room.room_id, display_name,
|
||||
avatar_url)
|
||||
self.index.add_event(event, room.room_id, display_name, avatar_url)
|
||||
|
||||
last_event = response.chunk[-1]
|
||||
|
||||
if not self.index.event_in_store(
|
||||
last_event.event_id,
|
||||
room.room_id
|
||||
):
|
||||
if not self.index.event_in_store(last_event.event_id, room.room_id):
|
||||
# There may be even more events to fetch, add a new task to
|
||||
# the queue.
|
||||
task = FetchTask(room.room_id, response.end)
|
||||
self.pan_store.save_fetcher_task(self.server_name,
|
||||
self.user_id, task)
|
||||
self.pan_store.save_fetcher_task(
|
||||
self.server_name, self.user_id, task
|
||||
)
|
||||
await self.history_fetch_queue.put(task)
|
||||
|
||||
await self.index.commit_events()
|
||||
@ -294,11 +295,7 @@ class PanClient(AsyncClient):
|
||||
self.key_request_tasks = []
|
||||
|
||||
await self.index.commit_events()
|
||||
self.pan_store.save_token(
|
||||
self.server_name,
|
||||
self.user_id,
|
||||
self.next_batch
|
||||
)
|
||||
self.pan_store.save_token(self.server_name, self.user_id, self.next_batch)
|
||||
|
||||
for room_id, room_info in response.rooms.join.items():
|
||||
if room_info.timeline.limited:
|
||||
@ -307,13 +304,12 @@ class PanClient(AsyncClient):
|
||||
if not room.encrypted and self.pan_conf.index_encrypted_only:
|
||||
continue
|
||||
|
||||
logger.info("Room {} had a limited timeline, queueing "
|
||||
"room for history fetching.".format(
|
||||
room.display_name
|
||||
))
|
||||
logger.info(
|
||||
"Room {} had a limited timeline, queueing "
|
||||
"room for history fetching.".format(room.display_name)
|
||||
)
|
||||
task = FetchTask(room_id, room_info.timeline.prev_batch)
|
||||
self.pan_store.save_fetcher_task(self.server_name,
|
||||
self.user_id, task)
|
||||
self.pan_store.save_fetcher_task(self.server_name, self.user_id, task)
|
||||
|
||||
await self.history_fetch_queue.put(task)
|
||||
|
||||
@ -323,10 +319,11 @@ class PanClient(AsyncClient):
|
||||
def undecrypted_event_cb(self, room, event):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
logger.info("Unable to decrypt event from {} via {}.".format(
|
||||
event.sender,
|
||||
event.device_id
|
||||
))
|
||||
logger.info(
|
||||
"Unable to decrypt event from {} via {}.".format(
|
||||
event.sender, event.device_id
|
||||
)
|
||||
)
|
||||
|
||||
if event.session_id not in self.outgoing_key_requests:
|
||||
logger.info("Requesting room key for undecrypted event.")
|
||||
@ -338,19 +335,16 @@ class PanClient(AsyncClient):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
if isinstance(event, KeyVerificationStart):
|
||||
logger.info(f"{event.sender} via {event.from_device} has started "
|
||||
f"a key verification process.")
|
||||
logger.info(
|
||||
f"{event.sender} via {event.from_device} has started "
|
||||
f"a key verification process."
|
||||
)
|
||||
|
||||
message = InviteSasSignal(
|
||||
self.user_id,
|
||||
event.sender,
|
||||
event.from_device,
|
||||
event.transaction_id
|
||||
self.user_id, event.sender, event.from_device, event.transaction_id
|
||||
)
|
||||
|
||||
task = loop.create_task(
|
||||
self.queue.put(message)
|
||||
)
|
||||
task = loop.create_task(self.queue.put(message))
|
||||
self.key_verificatins_tasks.append(task)
|
||||
|
||||
elif isinstance(event, KeyVerificationKey):
|
||||
@ -362,16 +356,10 @@ class PanClient(AsyncClient):
|
||||
emoji = sas.get_emoji()
|
||||
|
||||
message = ShowSasSignal(
|
||||
self.user_id,
|
||||
device.user_id,
|
||||
device.id,
|
||||
sas.transaction_id,
|
||||
emoji
|
||||
self.user_id, device.user_id, device.id, sas.transaction_id, emoji
|
||||
)
|
||||
|
||||
task = loop.create_task(
|
||||
self.queue.put(message)
|
||||
)
|
||||
task = loop.create_task(self.queue.put(message))
|
||||
self.key_verificatins_tasks.append(task)
|
||||
|
||||
elif isinstance(event, KeyVerificationMac):
|
||||
@ -381,14 +369,13 @@ class PanClient(AsyncClient):
|
||||
device = sas.other_olm_device
|
||||
|
||||
if sas.verified:
|
||||
task = loop.create_task(self.send_message(
|
||||
task = loop.create_task(
|
||||
self.send_message(
|
||||
SasDoneSignal(
|
||||
self.user_id,
|
||||
device.user_id,
|
||||
device.id,
|
||||
sas.transaction_id
|
||||
self.user_id, device.user_id, device.id, sas.transaction_id
|
||||
)
|
||||
)
|
||||
)
|
||||
))
|
||||
self.key_verificatins_tasks.append(task)
|
||||
task = loop.create_task(self.send_update_devcies())
|
||||
self.key_verificatins_tasks.append(task)
|
||||
@ -407,22 +394,13 @@ class PanClient(AsyncClient):
|
||||
self.history_fetcher_task = loop.create_task(self.fetcher_loop())
|
||||
|
||||
timeout = 30000
|
||||
sync_filter = {
|
||||
"room": {
|
||||
"state": {"lazy_load_members": True}
|
||||
}
|
||||
}
|
||||
sync_filter = {"room": {"state": {"lazy_load_members": True}}}
|
||||
next_batch = self.pan_store.load_token(self.server_name, self.user_id)
|
||||
|
||||
# We don't store any room state so initial sync needs to be with the
|
||||
# full_state parameter. Subsequent ones are normal.
|
||||
task = loop.create_task(
|
||||
self.sync_forever(
|
||||
timeout,
|
||||
sync_filter,
|
||||
full_state=True,
|
||||
since=next_batch
|
||||
)
|
||||
self.sync_forever(timeout, sync_filter, full_state=True, since=next_batch)
|
||||
)
|
||||
self.task = task
|
||||
|
||||
@ -436,16 +414,15 @@ class PanClient(AsyncClient):
|
||||
message.message_id,
|
||||
self.user_id,
|
||||
"m.ok",
|
||||
"Successfully started the key verification request"
|
||||
))
|
||||
"Successfully started the key verification request",
|
||||
)
|
||||
)
|
||||
except ClientConnectionError as e:
|
||||
await self.send_message(
|
||||
DaemonResponse(
|
||||
message.message_id,
|
||||
self.user_id,
|
||||
"m.connection_error",
|
||||
str(e)
|
||||
))
|
||||
message.message_id, self.user_id, "m.connection_error", str(e)
|
||||
)
|
||||
)
|
||||
|
||||
async def accept_sas(self, message):
|
||||
user_id = message.user_id
|
||||
@ -459,9 +436,8 @@ class PanClient(AsyncClient):
|
||||
message.message_id,
|
||||
self.user_id,
|
||||
Sas._txid_error[0],
|
||||
Sas._txid_error[1]
|
||||
Sas._txid_error[1],
|
||||
)
|
||||
|
||||
)
|
||||
return
|
||||
|
||||
@ -472,24 +448,24 @@ class PanClient(AsyncClient):
|
||||
message.message_id,
|
||||
self.user_id,
|
||||
"m.ok",
|
||||
"Successfully accepted the key verification request"
|
||||
))
|
||||
"Successfully accepted the key verification request",
|
||||
)
|
||||
)
|
||||
except LocalProtocolError as e:
|
||||
await self.send_message(
|
||||
DaemonResponse(
|
||||
message.message_id,
|
||||
self.user_id,
|
||||
Sas._unexpected_message_error[0],
|
||||
str(e)
|
||||
))
|
||||
str(e),
|
||||
)
|
||||
)
|
||||
except ClientConnectionError as e:
|
||||
await self.send_message(
|
||||
DaemonResponse(
|
||||
message.message_id,
|
||||
self.user_id,
|
||||
"m.connection_error",
|
||||
str(e)
|
||||
))
|
||||
message.message_id, self.user_id, "m.connection_error", str(e)
|
||||
)
|
||||
)
|
||||
|
||||
async def cancel_sas(self, message):
|
||||
user_id = message.user_id
|
||||
@ -503,9 +479,8 @@ class PanClient(AsyncClient):
|
||||
message.message_id,
|
||||
self.user_id,
|
||||
Sas._txid_error[0],
|
||||
Sas._txid_error[1]
|
||||
Sas._txid_error[1],
|
||||
)
|
||||
|
||||
)
|
||||
return
|
||||
|
||||
@ -516,16 +491,15 @@ class PanClient(AsyncClient):
|
||||
message.message_id,
|
||||
self.user_id,
|
||||
"m.ok",
|
||||
"Successfully canceled the key verification request"
|
||||
))
|
||||
"Successfully canceled the key verification request",
|
||||
)
|
||||
)
|
||||
except ClientConnectionError as e:
|
||||
await self.send_message(
|
||||
DaemonResponse(
|
||||
message.message_id,
|
||||
self.user_id,
|
||||
"m.connection_error",
|
||||
str(e)
|
||||
))
|
||||
message.message_id, self.user_id, "m.connection_error", str(e)
|
||||
)
|
||||
)
|
||||
|
||||
async def confirm_sas(self, message):
|
||||
user_id = message.user_id
|
||||
@ -539,9 +513,8 @@ class PanClient(AsyncClient):
|
||||
message.message_id,
|
||||
self.user_id,
|
||||
Sas._txid_error[0],
|
||||
Sas._txid_error[1]
|
||||
Sas._txid_error[1],
|
||||
)
|
||||
|
||||
)
|
||||
return
|
||||
|
||||
@ -550,11 +523,9 @@ class PanClient(AsyncClient):
|
||||
except ClientConnectionError as e:
|
||||
await self.send_message(
|
||||
DaemonResponse(
|
||||
message.message_id,
|
||||
self.user_id,
|
||||
"m.connection_error",
|
||||
str(e)
|
||||
))
|
||||
message.message_id, self.user_id, "m.connection_error", str(e)
|
||||
)
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
@ -564,10 +535,7 @@ class PanClient(AsyncClient):
|
||||
await self.send_update_devcies()
|
||||
await self.send_message(
|
||||
SasDoneSignal(
|
||||
self.user_id,
|
||||
device.user_id,
|
||||
device.id,
|
||||
sas.transaction_id
|
||||
self.user_id, device.user_id, device.id, sas.transaction_id
|
||||
)
|
||||
)
|
||||
else:
|
||||
@ -576,8 +544,9 @@ class PanClient(AsyncClient):
|
||||
message.message_id,
|
||||
self.user_id,
|
||||
"m.ok",
|
||||
f"Waiting for {device.user_id} to confirm."
|
||||
))
|
||||
f"Waiting for {device.user_id} to confirm.",
|
||||
)
|
||||
)
|
||||
|
||||
async def loop_stop(self):
|
||||
"""Stop the client loop."""
|
||||
@ -605,18 +574,15 @@ class PanClient(AsyncClient):
|
||||
|
||||
self.history_fetch_queue = asyncio.Queue()
|
||||
|
||||
def pan_decrypt_event(
|
||||
self,
|
||||
event_dict,
|
||||
room_id=None,
|
||||
ignore_failures=True
|
||||
):
|
||||
def pan_decrypt_event(self, event_dict, room_id=None, ignore_failures=True):
|
||||
# type: (Dict[Any, Any], Optional[str], bool) -> (bool)
|
||||
event = RoomEncryptedEvent.parse_event(event_dict)
|
||||
|
||||
if not isinstance(event, MegolmEvent):
|
||||
logger.warn("Encrypted event is not a megolm event:"
|
||||
"\n{}".format(pformat(event_dict)))
|
||||
logger.warn(
|
||||
"Encrypted event is not a megolm event:"
|
||||
"\n{}".format(pformat(event_dict))
|
||||
)
|
||||
return False
|
||||
|
||||
if not event.room_id:
|
||||
@ -661,8 +627,7 @@ class PanClient(AsyncClient):
|
||||
continue
|
||||
|
||||
if event["type"] != "m.room.encrypted":
|
||||
logger.debug("Event is not encrypted: "
|
||||
"\n{}".format(pformat(event)))
|
||||
logger.debug("Event is not encrypted: " "\n{}".format(pformat(event)))
|
||||
continue
|
||||
|
||||
self.pan_decrypt_event(event, ignore_failures=ignore_failures)
|
||||
@ -682,9 +647,11 @@ class PanClient(AsyncClient):
|
||||
for room_id, room_dict in body["rooms"]["join"].items():
|
||||
try:
|
||||
if not self.rooms[room_id].encrypted:
|
||||
logger.info("Room {} is not encrypted skipping...".format(
|
||||
logger.info(
|
||||
"Room {} is not encrypted skipping...".format(
|
||||
self.rooms[room_id].display_name
|
||||
))
|
||||
)
|
||||
)
|
||||
continue
|
||||
except KeyError:
|
||||
logger.info("Unknown room {} skipping...".format(room_id))
|
||||
@ -752,8 +719,9 @@ class PanClient(AsyncClient):
|
||||
after_limit = event_context.get("before_limit", 5)
|
||||
|
||||
if before_limit < 0 or after_limit < 0:
|
||||
raise InvalidLimit("Invalid context limit, the limit must be a "
|
||||
"positive number")
|
||||
raise InvalidLimit(
|
||||
"Invalid context limit, the limit must be a " "positive number"
|
||||
)
|
||||
|
||||
response_dict = await self.index.search(
|
||||
term,
|
||||
@ -762,7 +730,7 @@ class PanClient(AsyncClient):
|
||||
order_by_recent=order_by_recent,
|
||||
include_profile=include_profile,
|
||||
before_limit=before_limit,
|
||||
after_limit=after_limit
|
||||
after_limit=after_limit,
|
||||
)
|
||||
|
||||
if (event_context or include_state) and self.pan_conf.search_requests:
|
||||
@ -771,14 +739,10 @@ class PanClient(AsyncClient):
|
||||
event_dict,
|
||||
event_dict["result"]["room_id"],
|
||||
event_dict["result"]["event_id"],
|
||||
include_state
|
||||
include_state,
|
||||
)
|
||||
|
||||
if include_state:
|
||||
response_dict["state"] = state_cache
|
||||
|
||||
return {
|
||||
"search_categories": {
|
||||
"room_events": response_dict
|
||||
}
|
||||
}
|
||||
return {"search_categories": {"room_events": response_dict}}
|
||||
|
@ -43,7 +43,7 @@ class PanConfigParser(configparser.ConfigParser):
|
||||
"address": parse_address,
|
||||
"url": parse_url,
|
||||
"loglevel": parse_log_level,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -59,9 +59,10 @@ def parse_url(v):
|
||||
# type: (str) -> ParseResult
|
||||
value = urlparse(v)
|
||||
|
||||
if value.scheme not in ('http', 'https'):
|
||||
raise ValueError(f"Invalid URL scheme {value.scheme}. "
|
||||
f"Only HTTP(s) URLs are allowed")
|
||||
if value.scheme not in ("http", "https"):
|
||||
raise ValueError(
|
||||
f"Invalid URL scheme {value.scheme}. " f"Only HTTP(s) URLs are allowed"
|
||||
)
|
||||
value.port
|
||||
|
||||
return value
|
||||
@ -124,8 +125,7 @@ class ServerConfig:
|
||||
name = attr.ib(type=str)
|
||||
homeserver = attr.ib(type=ParseResult)
|
||||
listen_address = attr.ib(
|
||||
type=Union[IPv4Address, IPv6Address],
|
||||
default=ip_address("127.0.0.1")
|
||||
type=Union[IPv4Address, IPv6Address], default=ip_address("127.0.0.1")
|
||||
)
|
||||
listen_port = attr.ib(type=int, default=8009)
|
||||
proxy = attr.ib(type=str, default="")
|
||||
@ -183,8 +183,9 @@ class PanConfig:
|
||||
homeserver = section.geturl("Homeserver")
|
||||
|
||||
if not homeserver:
|
||||
raise PanConfigError(f"Homserver is not set for "
|
||||
f"section {section_name}")
|
||||
raise PanConfigError(
|
||||
f"Homserver is not set for " f"section {section_name}"
|
||||
)
|
||||
|
||||
listen_address = section.getaddress("ListenAddress")
|
||||
listen_port = section.getint("ListenPort")
|
||||
@ -198,23 +199,29 @@ class PanConfig:
|
||||
indexing_batch_size = section.getint("IndexingBatchSize")
|
||||
|
||||
if not 1 < indexing_batch_size <= 1000:
|
||||
raise PanConfigError("The indexing batch size needs to be "
|
||||
raise PanConfigError(
|
||||
"The indexing batch size needs to be "
|
||||
"a positive integer between 1 and "
|
||||
"1000")
|
||||
"1000"
|
||||
)
|
||||
|
||||
history_fetch_delay = section.getint("HistoryFetchDelay")
|
||||
|
||||
if not 100 < history_fetch_delay <= 10000:
|
||||
raise PanConfigError("The history fetch delay needs to be "
|
||||
raise PanConfigError(
|
||||
"The history fetch delay needs to be "
|
||||
"a positive integer between 100 and "
|
||||
"10000")
|
||||
"10000"
|
||||
)
|
||||
|
||||
listen_tuple = (listen_address, listen_port)
|
||||
|
||||
if listen_tuple in listen_set:
|
||||
raise PanConfigError(f"The listen address/port combination"
|
||||
raise PanConfigError(
|
||||
f"The listen address/port combination"
|
||||
f" for section {section_name} was "
|
||||
f"already defined before.")
|
||||
f"already defined before."
|
||||
)
|
||||
listen_set.add(listen_tuple)
|
||||
|
||||
server_conf = ServerConfig(
|
||||
@ -229,7 +236,7 @@ class PanConfig:
|
||||
search_requests,
|
||||
index_encrypted_only,
|
||||
indexing_batch_size,
|
||||
history_fetch_delay / 1000
|
||||
history_fetch_delay / 1000,
|
||||
)
|
||||
|
||||
self.servers[section_name] = server_conf
|
||||
|
@ -25,29 +25,39 @@ from aiohttp import ClientSession, web
|
||||
from aiohttp.client_exceptions import ClientConnectionError, ContentTypeError
|
||||
from jsonschema import ValidationError
|
||||
from multidict import CIMultiDict
|
||||
from nio import (Api, EncryptionError, LoginResponse, OlmTrustError,
|
||||
SendRetryError)
|
||||
from nio import Api, EncryptionError, LoginResponse, OlmTrustError, SendRetryError
|
||||
|
||||
from pantalaimon.client import (SEARCH_TERMS_SCHEMA, InvalidLimit,
|
||||
InvalidOrderByError, PanClient,
|
||||
UnknownRoomError, validate_json)
|
||||
from pantalaimon.client import (
|
||||
SEARCH_TERMS_SCHEMA,
|
||||
InvalidLimit,
|
||||
InvalidOrderByError,
|
||||
PanClient,
|
||||
UnknownRoomError,
|
||||
validate_json,
|
||||
)
|
||||
from pantalaimon.index import InvalidQueryError
|
||||
from pantalaimon.log import logger
|
||||
from pantalaimon.store import ClientInfo, PanStore
|
||||
from pantalaimon.thread_messages import (AcceptSasMessage, CancelSasMessage,
|
||||
from pantalaimon.thread_messages import (
|
||||
AcceptSasMessage,
|
||||
CancelSasMessage,
|
||||
CancelSendingMessage,
|
||||
ConfirmSasMessage, DaemonResponse,
|
||||
ConfirmSasMessage,
|
||||
DaemonResponse,
|
||||
DeviceBlacklistMessage,
|
||||
DeviceUnblacklistMessage,
|
||||
DeviceUnverifyMessage,
|
||||
DeviceVerifyMessage,
|
||||
ExportKeysMessage, ImportKeysMessage,
|
||||
SasMessage, SendAnywaysMessage,
|
||||
ExportKeysMessage,
|
||||
ImportKeysMessage,
|
||||
SasMessage,
|
||||
SendAnywaysMessage,
|
||||
StartSasMessage,
|
||||
UnverifiedDevicesSignal,
|
||||
UnverifiedResponse,
|
||||
UpdateDevicesMessage,
|
||||
UpdateUsersMessage)
|
||||
UpdateUsersMessage,
|
||||
)
|
||||
|
||||
CORS_HEADERS = {
|
||||
"Access-Control-Allow-Headers": (
|
||||
@ -76,11 +86,7 @@ class ProxyDaemon:
|
||||
homeserver_url = attr.ib(init=False, default=attr.Factory(dict))
|
||||
hostname = attr.ib(init=False, default=attr.Factory(dict))
|
||||
pan_clients = attr.ib(init=False, default=attr.Factory(dict))
|
||||
client_info = attr.ib(
|
||||
init=False,
|
||||
default=attr.Factory(dict),
|
||||
type=dict
|
||||
)
|
||||
client_info = attr.ib(init=False, default=attr.Factory(dict), type=dict)
|
||||
default_session = attr.ib(init=False, default=None)
|
||||
database_name = "pan.db"
|
||||
|
||||
@ -93,15 +99,16 @@ class ProxyDaemon:
|
||||
for user_id, device_id in accounts:
|
||||
if self.conf.keyring:
|
||||
token = keyring.get_password(
|
||||
"pantalaimon",
|
||||
f"{user_id}-{device_id}-token"
|
||||
"pantalaimon", f"{user_id}-{device_id}-token"
|
||||
)
|
||||
else:
|
||||
token = self.store.load_access_token(user_id, device_id)
|
||||
|
||||
if not token:
|
||||
logger.warn(f"Not restoring client for {user_id} {device_id}, "
|
||||
f"missing access token.")
|
||||
logger.warn(
|
||||
f"Not restoring client for {user_id} {device_id}, "
|
||||
f"missing access token."
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(f"Restoring client for {user_id} {device_id}")
|
||||
@ -116,7 +123,7 @@ class ProxyDaemon:
|
||||
device_id,
|
||||
store_path=self.data_dir,
|
||||
ssl=self.ssl,
|
||||
proxy=self.proxy
|
||||
proxy=self.proxy,
|
||||
)
|
||||
pan_client.user_id = user_id
|
||||
pan_client.access_token = token
|
||||
@ -136,7 +143,7 @@ class ProxyDaemon:
|
||||
method,
|
||||
self.homeserver_url + path,
|
||||
proxy=self.proxy,
|
||||
ssl=self.ssl
|
||||
ssl=self.ssl,
|
||||
)
|
||||
except ClientConnectionError:
|
||||
return None
|
||||
@ -155,12 +162,15 @@ class ProxyDaemon:
|
||||
return None
|
||||
|
||||
if user_id not in self.pan_clients:
|
||||
logger.warn(f"User {user_id} doesn't have a matching pan "
|
||||
f"client.")
|
||||
logger.warn(
|
||||
f"User {user_id} doesn't have a matching pan " f"client."
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(f"Homeserver confirmed valid access token "
|
||||
f"for user {user_id}, caching info.")
|
||||
logger.info(
|
||||
f"Homeserver confirmed valid access token "
|
||||
f"for user {user_id}, caching info."
|
||||
)
|
||||
|
||||
client_info = ClientInfo(user_id, access_token)
|
||||
self.client_info[access_token] = client_info
|
||||
@ -173,12 +183,12 @@ class ProxyDaemon:
|
||||
ret = client.verify_device(device)
|
||||
|
||||
if ret:
|
||||
msg = (f"Device {device.id} of user "
|
||||
f"{device.user_id} succesfully verified.")
|
||||
msg = (
|
||||
f"Device {device.id} of user " f"{device.user_id} succesfully verified."
|
||||
)
|
||||
await self.send_update_devcies()
|
||||
else:
|
||||
msg = (f"Device {device.id} of user "
|
||||
f"{device.user_id} already verified.")
|
||||
msg = f"Device {device.id} of user " f"{device.user_id} already verified."
|
||||
|
||||
logger.info(msg)
|
||||
await self.send_response(message_id, client.user_id, "m.ok", msg)
|
||||
@ -187,12 +197,13 @@ class ProxyDaemon:
|
||||
ret = client.unverify_device(device)
|
||||
|
||||
if ret:
|
||||
msg = (f"Device {device.id} of user "
|
||||
f"{device.user_id} succesfully unverified.")
|
||||
msg = (
|
||||
f"Device {device.id} of user "
|
||||
f"{device.user_id} succesfully unverified."
|
||||
)
|
||||
await self.send_update_devcies()
|
||||
else:
|
||||
msg = (f"Device {device.id} of user "
|
||||
f"{device.user_id} already unverified.")
|
||||
msg = f"Device {device.id} of user " f"{device.user_id} already unverified."
|
||||
|
||||
logger.info(msg)
|
||||
await self.send_response(message_id, client.user_id, "m.ok", msg)
|
||||
@ -201,12 +212,15 @@ class ProxyDaemon:
|
||||
ret = client.blacklist_device(device)
|
||||
|
||||
if ret:
|
||||
msg = (f"Device {device.id} of user "
|
||||
f"{device.user_id} succesfully blacklisted.")
|
||||
msg = (
|
||||
f"Device {device.id} of user "
|
||||
f"{device.user_id} succesfully blacklisted."
|
||||
)
|
||||
await self.send_update_devcies()
|
||||
else:
|
||||
msg = (f"Device {device.id} of user "
|
||||
f"{device.user_id} already blacklisted.")
|
||||
msg = (
|
||||
f"Device {device.id} of user " f"{device.user_id} already blacklisted."
|
||||
)
|
||||
|
||||
logger.info(msg)
|
||||
await self.send_response(message_id, client.user_id, "m.ok", msg)
|
||||
@ -215,12 +229,16 @@ class ProxyDaemon:
|
||||
ret = client.unblacklist_device(device)
|
||||
|
||||
if ret:
|
||||
msg = (f"Device {device.id} of user "
|
||||
f"{device.user_id} succesfully unblacklisted.")
|
||||
msg = (
|
||||
f"Device {device.id} of user "
|
||||
f"{device.user_id} succesfully unblacklisted."
|
||||
)
|
||||
await self.send_update_devcies()
|
||||
else:
|
||||
msg = (f"Device {device.id} of user "
|
||||
f"{device.user_id} already unblacklisted.")
|
||||
msg = (
|
||||
f"Device {device.id} of user "
|
||||
f"{device.user_id} already unblacklisted."
|
||||
)
|
||||
|
||||
logger.info(msg)
|
||||
await self.send_response(message_id, client.user_id, "m.ok", msg)
|
||||
@ -239,23 +257,23 @@ class ProxyDaemon:
|
||||
|
||||
if isinstance(
|
||||
message,
|
||||
(DeviceVerifyMessage, DeviceUnverifyMessage, StartSasMessage,
|
||||
DeviceBlacklistMessage, DeviceUnblacklistMessage)
|
||||
(
|
||||
DeviceVerifyMessage,
|
||||
DeviceUnverifyMessage,
|
||||
StartSasMessage,
|
||||
DeviceBlacklistMessage,
|
||||
DeviceUnblacklistMessage,
|
||||
),
|
||||
):
|
||||
|
||||
device = client.device_store[message.user_id].get(
|
||||
message.device_id,
|
||||
None
|
||||
)
|
||||
device = client.device_store[message.user_id].get(message.device_id, None)
|
||||
|
||||
if not device:
|
||||
msg = (f"No device found for {message.user_id} and "
|
||||
f"{message.device_id}")
|
||||
msg = (
|
||||
f"No device found for {message.user_id} and " f"{message.device_id}"
|
||||
)
|
||||
await self.send_response(
|
||||
message.message_id,
|
||||
message.pan_user,
|
||||
"m.unknown_device",
|
||||
msg
|
||||
message.message_id, message.pan_user, "m.unknown_device", msg
|
||||
)
|
||||
logger.info(msg)
|
||||
return
|
||||
@ -265,11 +283,9 @@ class ProxyDaemon:
|
||||
elif isinstance(message, DeviceUnverifyMessage):
|
||||
await self._unverify_device(message.message_id, client, device)
|
||||
elif isinstance(message, DeviceBlacklistMessage):
|
||||
await self._blacklist_device(message.message_id, client,
|
||||
device)
|
||||
await self._blacklist_device(message.message_id, client, device)
|
||||
elif isinstance(message, DeviceUnblacklistMessage):
|
||||
await self._unblacklist_device(message.message_id, client,
|
||||
device)
|
||||
await self._unblacklist_device(message.message_id, client, device)
|
||||
elif isinstance(message, StartSasMessage):
|
||||
await client.start_sas(message, device)
|
||||
|
||||
@ -288,25 +304,21 @@ class ProxyDaemon:
|
||||
try:
|
||||
await client.export_keys(path, message.passphrase)
|
||||
except OSError as e:
|
||||
info_msg = (f"Error exporting keys for {client.user_id} to"
|
||||
f" {path} {e}")
|
||||
info_msg = (
|
||||
f"Error exporting keys for {client.user_id} to" f" {path} {e}"
|
||||
)
|
||||
logger.info(info_msg)
|
||||
await self.send_response(
|
||||
message.message_id,
|
||||
client.user_id,
|
||||
"m.os_error",
|
||||
str(e)
|
||||
message.message_id, client.user_id, "m.os_error", str(e)
|
||||
)
|
||||
|
||||
else:
|
||||
info_msg = (f"Succesfully exported keys for {client.user_id} "
|
||||
f"to {path}")
|
||||
info_msg = (
|
||||
f"Succesfully exported keys for {client.user_id} " f"to {path}"
|
||||
)
|
||||
logger.info(info_msg)
|
||||
await self.send_response(
|
||||
message.message_id,
|
||||
client.user_id,
|
||||
"m.ok",
|
||||
info_msg
|
||||
message.message_id, client.user_id, "m.ok", info_msg
|
||||
)
|
||||
|
||||
elif isinstance(message, ImportKeysMessage):
|
||||
@ -316,37 +328,32 @@ class ProxyDaemon:
|
||||
try:
|
||||
await client.import_keys(path, message.passphrase)
|
||||
except (OSError, EncryptionError) as e:
|
||||
info_msg = (f"Error importing keys for {client.user_id} "
|
||||
f"from {path} {e}")
|
||||
info_msg = (
|
||||
f"Error importing keys for {client.user_id} " f"from {path} {e}"
|
||||
)
|
||||
logger.info(info_msg)
|
||||
await self.send_response(
|
||||
message.message_id,
|
||||
client.user_id,
|
||||
"m.os_error",
|
||||
str(e)
|
||||
message.message_id, client.user_id, "m.os_error", str(e)
|
||||
)
|
||||
else:
|
||||
info_msg = (f"Succesfully imported keys for {client.user_id} "
|
||||
f"from {path}")
|
||||
info_msg = (
|
||||
f"Succesfully imported keys for {client.user_id} " f"from {path}"
|
||||
)
|
||||
logger.info(info_msg)
|
||||
await self.send_response(
|
||||
message.message_id,
|
||||
client.user_id,
|
||||
"m.ok",
|
||||
info_msg
|
||||
message.message_id, client.user_id, "m.ok", info_msg
|
||||
)
|
||||
|
||||
elif isinstance(message, UnverifiedResponse):
|
||||
client = self.pan_clients[message.pan_user]
|
||||
|
||||
if message.room_id not in client.send_decision_queues:
|
||||
msg = (f"No send request found for user {message.pan_user} "
|
||||
f"and room {message.room_id}.")
|
||||
msg = (
|
||||
f"No send request found for user {message.pan_user} "
|
||||
f"and room {message.room_id}."
|
||||
)
|
||||
await self.send_response(
|
||||
message.message_id,
|
||||
message.pan_user,
|
||||
"m.unknown_request",
|
||||
msg
|
||||
message.message_id, message.pan_user, "m.unknown_request", msg
|
||||
)
|
||||
return
|
||||
|
||||
@ -365,10 +372,7 @@ class ProxyDaemon:
|
||||
access_token = request.query.get("access_token", "")
|
||||
|
||||
if not access_token:
|
||||
access_token = request.headers.get(
|
||||
"Authorization",
|
||||
""
|
||||
).strip("Bearer ")
|
||||
access_token = request.headers.get("Authorization", "").strip("Bearer ")
|
||||
|
||||
return access_token
|
||||
|
||||
@ -404,7 +408,7 @@ class ProxyDaemon:
|
||||
params=None, # type: CIMultiDict
|
||||
data=None, # type: bytes
|
||||
session=None, # type: aiohttp.ClientSession
|
||||
token=None # type: str
|
||||
token=None, # type: str
|
||||
):
|
||||
# type: (...) -> aiohttp.ClientResponse
|
||||
"""Forward the given request to our configured homeserver.
|
||||
@ -454,16 +458,11 @@ class ProxyDaemon:
|
||||
params=params,
|
||||
headers=headers,
|
||||
proxy=self.proxy,
|
||||
ssl=self.ssl
|
||||
ssl=self.ssl,
|
||||
)
|
||||
|
||||
async def forward_to_web(
|
||||
self,
|
||||
request,
|
||||
params=None,
|
||||
data=None,
|
||||
session=None,
|
||||
token=None
|
||||
self, request, params=None, data=None, session=None, token=None
|
||||
):
|
||||
"""Forward the given request and convert the response to a Response.
|
||||
|
||||
@ -484,17 +483,13 @@ class ProxyDaemon:
|
||||
"""
|
||||
try:
|
||||
response = await self.forward_request(
|
||||
request,
|
||||
params=params,
|
||||
data=data,
|
||||
session=session,
|
||||
token=token
|
||||
request, params=params, data=data, session=session, token=token
|
||||
)
|
||||
return web.Response(
|
||||
status=response.status,
|
||||
content_type=response.content_type,
|
||||
headers=CORS_HEADERS,
|
||||
body=await response.read()
|
||||
body=await response.read(),
|
||||
)
|
||||
except ClientConnectionError as e:
|
||||
return web.Response(status=500, text=str(e))
|
||||
@ -522,8 +517,10 @@ class ProxyDaemon:
|
||||
self.store.save_server_user(self.name, user_id)
|
||||
|
||||
if user_id in self.pan_clients:
|
||||
logger.info(f"Background sync client already exists for {user_id},"
|
||||
f" not starting new one")
|
||||
logger.info(
|
||||
f"Background sync client already exists for {user_id},"
|
||||
f" not starting new one"
|
||||
)
|
||||
return
|
||||
|
||||
pan_client = PanClient(
|
||||
@ -535,7 +532,7 @@ class ProxyDaemon:
|
||||
user_id,
|
||||
store_path=self.data_dir,
|
||||
ssl=self.ssl,
|
||||
proxy=self.proxy
|
||||
proxy=self.proxy,
|
||||
)
|
||||
response = await pan_client.login(password, "pantalaimon")
|
||||
|
||||
@ -543,8 +540,7 @@ class ProxyDaemon:
|
||||
await pan_client.close()
|
||||
return
|
||||
|
||||
logger.info(f"Succesfully started new background sync client for "
|
||||
f"{user_id}")
|
||||
logger.info(f"Succesfully started new background sync client for " f"{user_id}")
|
||||
|
||||
await self.send_queue.put(UpdateUsersMessage())
|
||||
|
||||
@ -554,13 +550,11 @@ class ProxyDaemon:
|
||||
keyring.set_password(
|
||||
"pantalaimon",
|
||||
f"{user_id}-{pan_client.device_id}-token",
|
||||
pan_client.access_token
|
||||
pan_client.access_token,
|
||||
)
|
||||
else:
|
||||
self.store.save_access_token(
|
||||
user_id,
|
||||
pan_client.device_id,
|
||||
pan_client.access_token
|
||||
user_id, pan_client.device_id, pan_client.access_token
|
||||
)
|
||||
|
||||
pan_client.start_loop()
|
||||
@ -580,7 +574,7 @@ class ProxyDaemon:
|
||||
return web.json_response(
|
||||
{
|
||||
"errcode": "M_NOT_JSON",
|
||||
"error": "Request did not contain valid JSON."
|
||||
"error": "Request did not contain valid JSON.",
|
||||
},
|
||||
status=500,
|
||||
)
|
||||
@ -605,25 +599,23 @@ class ProxyDaemon:
|
||||
access_token = json_response.get("access_token", None)
|
||||
|
||||
if user_id and access_token:
|
||||
logger.info(f"User: {user} succesfully logged in, starting "
|
||||
f"a background sync client.")
|
||||
await self.start_pan_client(access_token, user, user_id,
|
||||
password)
|
||||
logger.info(
|
||||
f"User: {user} succesfully logged in, starting "
|
||||
f"a background sync client."
|
||||
)
|
||||
await self.start_pan_client(access_token, user, user_id, password)
|
||||
|
||||
return web.Response(
|
||||
status=response.status,
|
||||
content_type=response.content_type,
|
||||
headers=CORS_HEADERS,
|
||||
body=await response.read()
|
||||
body=await response.read(),
|
||||
)
|
||||
|
||||
@property
|
||||
def _missing_token(self):
|
||||
return web.json_response(
|
||||
{
|
||||
"errcode": "M_MISSING_TOKEN",
|
||||
"error": "Missing access token."
|
||||
},
|
||||
{"errcode": "M_MISSING_TOKEN", "error": "Missing access token."},
|
||||
headers=CORS_HEADERS,
|
||||
status=401,
|
||||
)
|
||||
@ -631,10 +623,7 @@ class ProxyDaemon:
|
||||
@property
|
||||
def _unknown_token(self):
|
||||
return web.json_response(
|
||||
{
|
||||
"errcode": "M_UNKNOWN_TOKEN",
|
||||
"error": "Unrecognised access token."
|
||||
},
|
||||
{"errcode": "M_UNKNOWN_TOKEN", "error": "Unrecognised access token."},
|
||||
headers=CORS_HEADERS,
|
||||
status=401,
|
||||
)
|
||||
@ -642,10 +631,7 @@ class ProxyDaemon:
|
||||
@property
|
||||
def _not_json(self):
|
||||
return web.json_response(
|
||||
{
|
||||
"errcode": "M_NOT_JSON",
|
||||
"error": "Request did not contain valid JSON."
|
||||
},
|
||||
{"errcode": "M_NOT_JSON", "error": "Request did not contain valid JSON."},
|
||||
headers=CORS_HEADERS,
|
||||
status=400,
|
||||
)
|
||||
@ -660,23 +646,18 @@ class ProxyDaemon:
|
||||
while True:
|
||||
try:
|
||||
logger.info("Trying to decrypt sync")
|
||||
return decryption_method(
|
||||
body,
|
||||
ignore_failures=False
|
||||
)
|
||||
return decryption_method(body, ignore_failures=False)
|
||||
except EncryptionError:
|
||||
logger.info("Error decrypting sync, waiting for next pan "
|
||||
"sync")
|
||||
logger.info("Error decrypting sync, waiting for next pan " "sync")
|
||||
await client.synced.wait(),
|
||||
logger.info("Pan synced, retrying decryption.")
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
decrypt_loop(client, body),
|
||||
timeout=self.decryption_timeout)
|
||||
decrypt_loop(client, body), timeout=self.decryption_timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.info("Decryption attempt timed out, decrypting with "
|
||||
"failures")
|
||||
logger.info("Decryption attempt timed out, decrypting with " "failures")
|
||||
return decryption_method(body, ignore_failures=True)
|
||||
|
||||
async def sync(self, request):
|
||||
@ -705,9 +686,7 @@ class ProxyDaemon:
|
||||
|
||||
try:
|
||||
response = await self.forward_request(
|
||||
request,
|
||||
params=query,
|
||||
token=client.access_token
|
||||
request, params=query, token=client.access_token
|
||||
)
|
||||
except ClientConnectionError as e:
|
||||
return web.Response(status=500, text=str(e))
|
||||
@ -718,9 +697,7 @@ class ProxyDaemon:
|
||||
json_response = await self.decrypt_body(client, json_response)
|
||||
|
||||
return web.json_response(
|
||||
json_response,
|
||||
headers=CORS_HEADERS,
|
||||
status=response.status,
|
||||
json_response, headers=CORS_HEADERS, status=response.status
|
||||
)
|
||||
except (JSONDecodeError, ContentTypeError):
|
||||
pass
|
||||
@ -729,7 +706,7 @@ class ProxyDaemon:
|
||||
status=response.status,
|
||||
content_type=response.content_type,
|
||||
headers=CORS_HEADERS,
|
||||
body=await response.read()
|
||||
body=await response.read(),
|
||||
)
|
||||
|
||||
async def messages(self, request):
|
||||
@ -751,15 +728,11 @@ class ProxyDaemon:
|
||||
try:
|
||||
json_response = await response.json()
|
||||
json_response = await self.decrypt_body(
|
||||
client,
|
||||
json_response,
|
||||
sync=False
|
||||
client, json_response, sync=False
|
||||
)
|
||||
|
||||
return web.json_response(
|
||||
json_response,
|
||||
headers=CORS_HEADERS,
|
||||
status=response.status,
|
||||
json_response, headers=CORS_HEADERS, status=response.status
|
||||
)
|
||||
except (JSONDecodeError, ContentTypeError):
|
||||
pass
|
||||
@ -768,7 +741,7 @@ class ProxyDaemon:
|
||||
status=response.status,
|
||||
content_type=response.content_type,
|
||||
headers=CORS_HEADERS,
|
||||
body=await response.read()
|
||||
body=await response.read(),
|
||||
)
|
||||
|
||||
async def send_message(self, request):
|
||||
@ -788,17 +761,11 @@ class ProxyDaemon:
|
||||
room = client.rooms[room_id]
|
||||
encrypt = room.encrypted
|
||||
except KeyError:
|
||||
return await self.forward_to_web(
|
||||
request,
|
||||
token=client.access_token
|
||||
)
|
||||
return await self.forward_to_web(request, token=client.access_token)
|
||||
|
||||
# The room isn't encrypted just forward the message.
|
||||
if not encrypt:
|
||||
return await self.forward_to_web(
|
||||
request,
|
||||
token=client.access_token
|
||||
)
|
||||
return await self.forward_to_web(request, token=client.access_token)
|
||||
|
||||
msgtype = request.match_info["event_type"]
|
||||
txnid = request.match_info["txnid"]
|
||||
@ -810,14 +777,15 @@ class ProxyDaemon:
|
||||
|
||||
async def _send(ignore_unverified=False):
|
||||
try:
|
||||
response = await client.room_send(room_id, msgtype, content,
|
||||
txnid, ignore_unverified)
|
||||
response = await client.room_send(
|
||||
room_id, msgtype, content, txnid, ignore_unverified
|
||||
)
|
||||
|
||||
return web.Response(
|
||||
status=response.transport_response.status,
|
||||
content_type=response.transport_response.content_type,
|
||||
headers=CORS_HEADERS,
|
||||
body=await response.transport_response.read()
|
||||
body=await response.transport_response.read(),
|
||||
)
|
||||
except ClientConnectionError as e:
|
||||
return web.Response(status=500, text=str(e))
|
||||
@ -849,45 +817,40 @@ class ProxyDaemon:
|
||||
client.send_decision_queues[room_id] = queue
|
||||
|
||||
message = UnverifiedDevicesSignal(
|
||||
client.user_id,
|
||||
room_id,
|
||||
room.display_name
|
||||
client.user_id, room_id, room.display_name
|
||||
)
|
||||
|
||||
await self.send_queue.put(message)
|
||||
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
queue.get(),
|
||||
self.unverified_send_timeout
|
||||
queue.get(), self.unverified_send_timeout
|
||||
)
|
||||
|
||||
if isinstance(response, CancelSendingMessage):
|
||||
# The send was canceled notify the client that sent the
|
||||
# request about this.
|
||||
info_msg = (f"Canceled message sending for room "
|
||||
f"{room.display_name} ({room_id}).")
|
||||
info_msg = (
|
||||
f"Canceled message sending for room "
|
||||
f"{room.display_name} ({room_id})."
|
||||
)
|
||||
logger.info(info_msg)
|
||||
await self.send_response(
|
||||
response.message_id,
|
||||
client.user_id,
|
||||
"m.ok",
|
||||
info_msg
|
||||
response.message_id, client.user_id, "m.ok", info_msg
|
||||
)
|
||||
|
||||
return web.Response(status=503, text=str(e))
|
||||
|
||||
elif isinstance(response, SendAnywaysMessage):
|
||||
# We are sending and ignoring devices along the way.
|
||||
info_msg = (f"Ignoring unverified devices and sending "
|
||||
info_msg = (
|
||||
f"Ignoring unverified devices and sending "
|
||||
f"message to room "
|
||||
f"{room.display_name} ({room_id}).")
|
||||
f"{room.display_name} ({room_id})."
|
||||
)
|
||||
logger.info(info_msg)
|
||||
await self.send_response(
|
||||
response.message_id,
|
||||
client.user_id,
|
||||
"m.ok",
|
||||
info_msg
|
||||
response.message_id, client.user_id, "m.ok", info_msg
|
||||
)
|
||||
|
||||
ret = await _send(True)
|
||||
@ -900,10 +863,12 @@ class ProxyDaemon:
|
||||
|
||||
return web.Response(
|
||||
status=503,
|
||||
text=(f"Room contains unverified devices and no "
|
||||
text=(
|
||||
f"Room contains unverified devices and no "
|
||||
f"action was taken for "
|
||||
f"{self.unverified_send_timeout} seconds, "
|
||||
f"request timed out")
|
||||
f"request timed out"
|
||||
),
|
||||
)
|
||||
|
||||
finally:
|
||||
@ -922,10 +887,7 @@ class ProxyDaemon:
|
||||
|
||||
sanitized_content = self.sanitize_filter(content)
|
||||
|
||||
return await self.forward_to_web(
|
||||
request,
|
||||
data=json.dumps(sanitized_content)
|
||||
)
|
||||
return await self.forward_to_web(request, data=json.dumps(sanitized_content))
|
||||
|
||||
async def search_opts(self, request):
|
||||
return web.json_response({}, headers=CORS_HEADERS)
|
||||
@ -950,10 +912,7 @@ class ProxyDaemon:
|
||||
validate_json(content, SEARCH_TERMS_SCHEMA)
|
||||
except ValidationError:
|
||||
return web.json_response(
|
||||
{
|
||||
"errcode": "M_BAD_JSON",
|
||||
"error": "Invalid search query"
|
||||
},
|
||||
{"errcode": "M_BAD_JSON", "error": "Invalid search query"},
|
||||
headers=CORS_HEADERS,
|
||||
status=400,
|
||||
)
|
||||
@ -982,21 +941,14 @@ class ProxyDaemon:
|
||||
result = await client.search(content)
|
||||
except (InvalidOrderByError, InvalidLimit, InvalidQueryError) as e:
|
||||
return web.json_response(
|
||||
{
|
||||
"errcode": "M_INVALID_PARAM",
|
||||
"error": str(e)
|
||||
},
|
||||
{"errcode": "M_INVALID_PARAM", "error": str(e)},
|
||||
headers=CORS_HEADERS,
|
||||
status=400,
|
||||
)
|
||||
except UnknownRoomError:
|
||||
return await self.forward_to_web(request)
|
||||
|
||||
return web.json_response(
|
||||
result,
|
||||
headers=CORS_HEADERS,
|
||||
status=200
|
||||
)
|
||||
return web.json_response(result, headers=CORS_HEADERS, status=200)
|
||||
|
||||
async def shutdown(self, _):
|
||||
"""Shut the daemon down closing all the client sessions it has.
|
||||
|
@ -21,10 +21,14 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
import tantivy
|
||||
from nio import (RoomEncryptedMedia, RoomMessageMedia, RoomMessageText,
|
||||
RoomNameEvent, RoomTopicEvent)
|
||||
from peewee import (SQL, DateTimeField, ForeignKeyField, Model, SqliteDatabase,
|
||||
TextField)
|
||||
from nio import (
|
||||
RoomEncryptedMedia,
|
||||
RoomMessageMedia,
|
||||
RoomMessageText,
|
||||
RoomNameEvent,
|
||||
RoomTopicEvent,
|
||||
)
|
||||
from peewee import SQL, DateTimeField, ForeignKeyField, Model, SqliteDatabase, TextField
|
||||
|
||||
from pantalaimon.store import use_database
|
||||
|
||||
@ -65,22 +69,15 @@ class Event(Model):
|
||||
|
||||
source = DictField()
|
||||
|
||||
profile = ForeignKeyField(
|
||||
model=Profile,
|
||||
column_name="profile_id",
|
||||
)
|
||||
profile = ForeignKeyField(model=Profile, column_name="profile_id")
|
||||
|
||||
class Meta:
|
||||
constraints = [SQL("UNIQUE(event_id, room_id, sender, profile_id)")]
|
||||
|
||||
|
||||
class UserMessages(Model):
|
||||
user = ForeignKeyField(
|
||||
model=StoreUser,
|
||||
column_name="user_id")
|
||||
event = ForeignKeyField(
|
||||
model=Event,
|
||||
column_name="event_id")
|
||||
user = ForeignKeyField(model=StoreUser, column_name="user_id")
|
||||
event = ForeignKeyField(model=Event, column_name="event_id")
|
||||
|
||||
|
||||
@attr.s
|
||||
@ -91,17 +88,11 @@ class MessageStore:
|
||||
database = attr.ib(type=SqliteDatabase, init=False)
|
||||
database_path = attr.ib(type=str, init=False)
|
||||
|
||||
models = [
|
||||
StoreUser,
|
||||
Event,
|
||||
Profile,
|
||||
UserMessages
|
||||
]
|
||||
models = [StoreUser, Event, Profile, UserMessages]
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self.database_path = os.path.join(
|
||||
os.path.abspath(self.store_path),
|
||||
self.database_name
|
||||
os.path.abspath(self.store_path), self.database_name
|
||||
)
|
||||
|
||||
self.database = self._create_database()
|
||||
@ -112,21 +103,22 @@ class MessageStore:
|
||||
|
||||
def _create_database(self):
|
||||
return SqliteDatabase(
|
||||
self.database_path,
|
||||
pragmas={
|
||||
"foreign_keys": 1,
|
||||
"secure_delete": 1,
|
||||
}
|
||||
self.database_path, pragmas={"foreign_keys": 1, "secure_delete": 1}
|
||||
)
|
||||
|
||||
@use_database
|
||||
def event_in_store(self, event_id, room_id):
|
||||
user, _ = StoreUser.get_or_create(user_id=self.user)
|
||||
query = Event.select().join(UserMessages).where(
|
||||
(Event.room_id == room_id) &
|
||||
(Event.event_id == event_id) &
|
||||
(UserMessages.user == user)
|
||||
).execute()
|
||||
query = (
|
||||
Event.select()
|
||||
.join(UserMessages)
|
||||
.where(
|
||||
(Event.room_id == room_id)
|
||||
& (Event.event_id == event_id)
|
||||
& (UserMessages.user == user)
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
for _ in query:
|
||||
return True
|
||||
@ -137,32 +129,29 @@ class MessageStore:
|
||||
user, _ = StoreUser.get_or_create(user_id=self.user)
|
||||
|
||||
profile_id, _ = Profile.get_or_create(
|
||||
user_id=event.sender,
|
||||
display_name=display_name,
|
||||
avatar_url=avatar_url
|
||||
user_id=event.sender, display_name=display_name, avatar_url=avatar_url
|
||||
)
|
||||
|
||||
event_source = event.source
|
||||
event_source["room_id"] = room_id
|
||||
|
||||
event_id = Event.insert(
|
||||
event_id = (
|
||||
Event.insert(
|
||||
event_id=event.event_id,
|
||||
sender=event.sender,
|
||||
date=datetime.datetime.fromtimestamp(
|
||||
event.server_timestamp / 1000
|
||||
),
|
||||
date=datetime.datetime.fromtimestamp(event.server_timestamp / 1000),
|
||||
room_id=room_id,
|
||||
source=event_source,
|
||||
profile=profile_id
|
||||
).on_conflict_ignore().execute()
|
||||
profile=profile_id,
|
||||
)
|
||||
.on_conflict_ignore()
|
||||
.execute()
|
||||
)
|
||||
|
||||
if event_id <= 0:
|
||||
return None
|
||||
|
||||
_, created = UserMessages.get_or_create(
|
||||
user=user,
|
||||
event=event_id,
|
||||
)
|
||||
_, created = UserMessages.get_or_create(user=user, event=event_id)
|
||||
|
||||
if created:
|
||||
return event_id
|
||||
@ -173,24 +162,36 @@ class MessageStore:
|
||||
context = {}
|
||||
|
||||
if before > 0:
|
||||
query = Event.select().join(UserMessages).where(
|
||||
(Event.date <= event.date) &
|
||||
(Event.room_id == event.room_id) &
|
||||
(Event.id != event.id) &
|
||||
(UserMessages.user == user)
|
||||
).order_by(Event.date.desc()).limit(before)
|
||||
query = (
|
||||
Event.select()
|
||||
.join(UserMessages)
|
||||
.where(
|
||||
(Event.date <= event.date)
|
||||
& (Event.room_id == event.room_id)
|
||||
& (Event.id != event.id)
|
||||
& (UserMessages.user == user)
|
||||
)
|
||||
.order_by(Event.date.desc())
|
||||
.limit(before)
|
||||
)
|
||||
|
||||
context["events_before"] = [e.source for e in query]
|
||||
else:
|
||||
context["events_before"] = []
|
||||
|
||||
if after > 0:
|
||||
query = Event.select().join(UserMessages).where(
|
||||
(Event.date >= event.date) &
|
||||
(Event.room_id == event.room_id) &
|
||||
(Event.id != event.id) &
|
||||
(UserMessages.user == user)
|
||||
).order_by(Event.date).limit(after)
|
||||
query = (
|
||||
Event.select()
|
||||
.join(UserMessages)
|
||||
.where(
|
||||
(Event.date >= event.date)
|
||||
& (Event.room_id == event.room_id)
|
||||
& (Event.id != event.id)
|
||||
& (UserMessages.user == user)
|
||||
)
|
||||
.order_by(Event.date)
|
||||
.limit(after)
|
||||
)
|
||||
|
||||
context["events_after"] = [e.source for e in query]
|
||||
else:
|
||||
@ -205,7 +206,7 @@ class MessageStore:
|
||||
include_profile=False, # type: bool
|
||||
order_by_recent=False, # type: bool
|
||||
before=0, # type: int
|
||||
after=0 # type: int
|
||||
after=0, # type: int
|
||||
):
|
||||
# type: (...) -> Dict[Any, Any]
|
||||
user, _ = StoreUser.get_or_create(user_id=self.user)
|
||||
@ -213,13 +214,13 @@ class MessageStore:
|
||||
search_dict = {r[1]: r[0] for r in search_result}
|
||||
columns = list(search_dict.keys())
|
||||
|
||||
result_dict = {
|
||||
"results": []
|
||||
}
|
||||
result_dict = {"results": []}
|
||||
|
||||
query = UserMessages.select().where(
|
||||
(UserMessages.user_id == user) & (UserMessages.event.in_(columns))
|
||||
).execute()
|
||||
query = (
|
||||
UserMessages.select()
|
||||
.where((UserMessages.user_id == user) & (UserMessages.event.in_(columns)))
|
||||
.execute()
|
||||
)
|
||||
|
||||
for message in query:
|
||||
|
||||
@ -228,7 +229,7 @@ class MessageStore:
|
||||
event_dict = {
|
||||
"rank": 1 if order_by_recent else search_dict[event.id],
|
||||
"result": event.source,
|
||||
"context": {}
|
||||
"context": {},
|
||||
}
|
||||
|
||||
if include_profile:
|
||||
@ -256,8 +257,17 @@ def sanitize_room_id(room_id):
|
||||
|
||||
|
||||
class Searcher:
|
||||
def __init__(self, index, body_field, name_field, topic_field,
|
||||
column_field, room_field, timestamp_field, searcher):
|
||||
def __init__(
|
||||
self,
|
||||
index,
|
||||
body_field,
|
||||
name_field,
|
||||
topic_field,
|
||||
column_field,
|
||||
room_field,
|
||||
timestamp_field,
|
||||
searcher,
|
||||
):
|
||||
self._index = index
|
||||
self._searcher = searcher
|
||||
|
||||
@ -268,8 +278,7 @@ class Searcher:
|
||||
self.room_field = room_field
|
||||
self.timestamp_field = timestamp_field
|
||||
|
||||
def search(self, search_term, room=None, max_results=10,
|
||||
order_by_recent=False):
|
||||
def search(self, search_term, room=None, max_results=10, order_by_recent=False):
|
||||
# type (str, str, int, bool) -> List[int, int]
|
||||
"""Search for events in the index.
|
||||
|
||||
@ -277,21 +286,13 @@ class Searcher:
|
||||
"""
|
||||
queryparser = tantivy.QueryParser.for_index(
|
||||
self._index,
|
||||
[
|
||||
self.body_field,
|
||||
self.name_field,
|
||||
self.topic_field,
|
||||
self.room_field
|
||||
]
|
||||
[self.body_field, self.name_field, self.topic_field, self.room_field],
|
||||
)
|
||||
|
||||
# This currently supports only a single room since the query parser
|
||||
# doesn't seem to work with multiple room fields here.
|
||||
if room:
|
||||
query_string = "{} AND room:{}".format(
|
||||
search_term,
|
||||
sanitize_room_id(room)
|
||||
)
|
||||
query_string = "{} AND room:{}".format(search_term, sanitize_room_id(room))
|
||||
else:
|
||||
query_string = search_term
|
||||
|
||||
@ -301,8 +302,9 @@ class Searcher:
|
||||
raise InvalidQueryError(f"Invalid search term: {search_term}")
|
||||
|
||||
if order_by_recent:
|
||||
collector = tantivy.TopDocs(max_results,
|
||||
order_by_field=self.timestamp_field)
|
||||
collector = tantivy.TopDocs(
|
||||
max_results, order_by_field=self.timestamp_field
|
||||
)
|
||||
else:
|
||||
collector = tantivy.TopDocs(max_results)
|
||||
|
||||
@ -329,16 +331,11 @@ class Index:
|
||||
self.timestamp_field = schema_builder.add_unsigned_field(
|
||||
"server_timestamp", fast="single"
|
||||
)
|
||||
self.date_field = schema_builder.add_date_field(
|
||||
"message_date"
|
||||
)
|
||||
self.date_field = schema_builder.add_date_field("message_date")
|
||||
self.room_field = schema_builder.add_facet_field("room")
|
||||
|
||||
self.column_field = schema_builder.add_unsigned_field(
|
||||
"database_column",
|
||||
indexed=True,
|
||||
stored=True,
|
||||
fast="single"
|
||||
"database_column", indexed=True, stored=True, fast="single"
|
||||
)
|
||||
|
||||
schema = schema_builder.build()
|
||||
@ -359,7 +356,7 @@ class Index:
|
||||
doc.add_facet(self.room_field, room_facet)
|
||||
doc.add_date(
|
||||
self.date_field,
|
||||
datetime.datetime.fromtimestamp(event.server_timestamp / 1000)
|
||||
datetime.datetime.fromtimestamp(event.server_timestamp / 1000),
|
||||
)
|
||||
doc.add_unsigned(self.timestamp_field, event.server_timestamp)
|
||||
|
||||
@ -389,7 +386,7 @@ class Index:
|
||||
self.column_field,
|
||||
self.room_field,
|
||||
self.timestamp_field,
|
||||
self.reader.searcher()
|
||||
self.reader.searcher(),
|
||||
)
|
||||
|
||||
|
||||
@ -430,10 +427,7 @@ class IndexStore:
|
||||
with store.database.bind_ctx(store.models):
|
||||
with store.database.atomic():
|
||||
for item in event_queue:
|
||||
column_id = store.save_event(
|
||||
item.event,
|
||||
item.room_id,
|
||||
)
|
||||
column_id = store.save_event(item.event, item.room_id)
|
||||
|
||||
if column_id:
|
||||
index.add_event(column_id, item.event, item.room_id)
|
||||
@ -451,10 +445,7 @@ class IndexStore:
|
||||
|
||||
async with self.write_lock:
|
||||
write_func = partial(
|
||||
IndexStore.write_events,
|
||||
self.store,
|
||||
self.index,
|
||||
event_queue
|
||||
IndexStore.write_events, self.store, self.index, event_queue
|
||||
)
|
||||
await loop.run_in_executor(None, write_func)
|
||||
|
||||
@ -469,7 +460,7 @@ class IndexStore:
|
||||
order_by_recent=False, # type: bool
|
||||
include_profile=False, # type: bool
|
||||
before_limit=0, # type: int
|
||||
after_limit=0 # type: int
|
||||
after_limit=0, # type: int
|
||||
):
|
||||
# type: (...) -> Dict[Any, Any]
|
||||
"""Search the indexstore for an event."""
|
||||
@ -480,9 +471,13 @@ class IndexStore:
|
||||
# the number of CPUs and the semaphore has the same counter value.
|
||||
async with self.read_semaphore:
|
||||
searcher = self.index.searcher()
|
||||
search_func = partial(searcher.search, search_term, room=room,
|
||||
search_func = partial(
|
||||
searcher.search,
|
||||
search_term,
|
||||
room=room,
|
||||
max_results=max_results,
|
||||
order_by_recent=order_by_recent)
|
||||
order_by_recent=order_by_recent,
|
||||
)
|
||||
|
||||
result = await loop.run_in_executor(None, search_func)
|
||||
|
||||
@ -492,7 +487,7 @@ class IndexStore:
|
||||
include_profile,
|
||||
order_by_recent,
|
||||
before_limit,
|
||||
after_limit
|
||||
after_limit,
|
||||
)
|
||||
|
||||
search_result = await loop.run_in_executor(None, load_event_func)
|
||||
|
@ -53,40 +53,39 @@ async def init(data_dir, server_conf, send_queue, recv_queue):
|
||||
send_queue=send_queue,
|
||||
recv_queue=recv_queue,
|
||||
proxy=server_conf.proxy.geturl() if server_conf.proxy else None,
|
||||
ssl=None if server_conf.ssl is True else False
|
||||
ssl=None if server_conf.ssl is True else False,
|
||||
)
|
||||
|
||||
app = web.Application()
|
||||
|
||||
app.add_routes([
|
||||
app.add_routes(
|
||||
[
|
||||
web.post("/_matrix/client/r0/login", proxy.login),
|
||||
web.get("/_matrix/client/r0/sync", proxy.sync),
|
||||
web.get("/_matrix/client/r0/rooms/{room_id}/messages", proxy.messages),
|
||||
web.put(
|
||||
r"/_matrix/client/r0/rooms/{room_id}/send/{event_type}/{txnid}",
|
||||
proxy.send_message
|
||||
proxy.send_message,
|
||||
),
|
||||
web.post("/_matrix/client/r0/user/{user_id}/filter", proxy.filter),
|
||||
web.post("/_matrix/client/r0/search", proxy.search),
|
||||
web.options("/_matrix/client/r0/search", proxy.search_opts),
|
||||
])
|
||||
]
|
||||
)
|
||||
app.router.add_route("*", "/" + "{proxyPath:.*}", proxy.router)
|
||||
app.on_shutdown.append(proxy.shutdown)
|
||||
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
|
||||
site = web.TCPSite(
|
||||
runner,
|
||||
str(server_conf.listen_address),
|
||||
server_conf.listen_port
|
||||
)
|
||||
site = web.TCPSite(runner, str(server_conf.listen_address), server_conf.listen_port)
|
||||
|
||||
return proxy, runner, site
|
||||
|
||||
|
||||
async def message_router(receive_queue, send_queue, proxies):
|
||||
"""Find the recipient of a message and forward it to the right proxy."""
|
||||
|
||||
def find_proxy_by_user(user):
|
||||
# type: (str) -> Optional[ProxyDaemon]
|
||||
for proxy in proxies:
|
||||
@ -109,35 +108,28 @@ async def message_router(receive_queue, send_queue, proxies):
|
||||
msg = f"No pan client found for {message.pan_user}."
|
||||
logger.warn(msg)
|
||||
await send_info(
|
||||
message.message_id,
|
||||
message.pan_user,
|
||||
"m.unknown_client",
|
||||
msg
|
||||
message.message_id, message.pan_user, "m.unknown_client", msg
|
||||
)
|
||||
|
||||
await proxy.receive_message(message)
|
||||
|
||||
|
||||
@click.command(
|
||||
help=("pantalaimon is a reverse proxy for matrix homeservers that "
|
||||
help=(
|
||||
"pantalaimon is a reverse proxy for matrix homeservers that "
|
||||
"transparently encrypts and decrypts messages for clients that "
|
||||
"connect to pantalaimon.")
|
||||
|
||||
"connect to pantalaimon."
|
||||
)
|
||||
)
|
||||
@click.version_option(version="0.1", prog_name="pantalaimon")
|
||||
@click.option("--log-level", type=click.Choice([
|
||||
"error",
|
||||
"warning",
|
||||
"info",
|
||||
"debug"
|
||||
]), default=None)
|
||||
@click.option(
|
||||
"--log-level",
|
||||
type=click.Choice(["error", "warning", "info", "debug"]),
|
||||
default=None,
|
||||
)
|
||||
@click.option("-c", "--config", type=click.Path(exists=True))
|
||||
@click.pass_context
|
||||
def main(
|
||||
context,
|
||||
log_level,
|
||||
config
|
||||
):
|
||||
def main(context, log_level, config):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
conf_dir = user_config_dir("pantalaimon", "")
|
||||
@ -172,12 +164,7 @@ def main(
|
||||
|
||||
for server_conf in pan_conf.servers.values():
|
||||
proxy, runner, site = loop.run_until_complete(
|
||||
init(
|
||||
data_dir,
|
||||
server_conf,
|
||||
pan_queue.async_q,
|
||||
ui_queue.async_q,
|
||||
)
|
||||
init(data_dir, server_conf, pan_queue.async_q, ui_queue.async_q)
|
||||
)
|
||||
servers.append((proxy, runner, site))
|
||||
proxies.append(proxy)
|
||||
@ -185,14 +172,12 @@ def main(
|
||||
except keyring.errors.KeyringError as e:
|
||||
context.fail(f"Error initializing keyring: {e}")
|
||||
|
||||
glib_thread = GlibT(pan_queue.sync_q, ui_queue.sync_q, data_dir,
|
||||
pan_conf.servers.values(), pan_conf)
|
||||
|
||||
glib_fut = loop.run_in_executor(
|
||||
None,
|
||||
glib_thread.run
|
||||
glib_thread = GlibT(
|
||||
pan_queue.sync_q, ui_queue.sync_q, data_dir, pan_conf.servers.values(), pan_conf
|
||||
)
|
||||
|
||||
glib_fut = loop.run_in_executor(None, glib_thread.run)
|
||||
|
||||
async def wait_for_glib(glib_thread, fut):
|
||||
glib_thread.stop()
|
||||
await fut
|
||||
@ -211,8 +196,10 @@ def main(
|
||||
|
||||
try:
|
||||
for proxy, _, site in servers:
|
||||
click.echo(f"======== Starting daemon for homeserver "
|
||||
f"{proxy.name} on {site.name} ========")
|
||||
click.echo(
|
||||
f"======== Starting daemon for homeserver "
|
||||
f"{proxy.name} on {site.name} ========"
|
||||
)
|
||||
loop.run_until_complete(site.start())
|
||||
|
||||
click.echo("(Press CTRL+C to quit)")
|
||||
|
@ -43,15 +43,12 @@ class PanctlArgParse(argparse.ArgumentParser):
|
||||
pass
|
||||
|
||||
def error(self, message):
|
||||
message = (
|
||||
f"Error: {message} "
|
||||
f"(see help)"
|
||||
)
|
||||
message = f"Error: {message} " f"(see help)"
|
||||
print(message)
|
||||
raise ParseError
|
||||
|
||||
|
||||
class PanctlParser():
|
||||
class PanctlParser:
|
||||
def __init__(self, commands):
|
||||
self.commands = commands
|
||||
self.parser = PanctlArgParse()
|
||||
@ -194,19 +191,13 @@ class PanCompleter(Completer):
|
||||
return ""
|
||||
|
||||
def complete_key_file_cmds(
|
||||
self,
|
||||
document,
|
||||
complete_event,
|
||||
command,
|
||||
last_word,
|
||||
words
|
||||
self, document, complete_event, command, last_word, words
|
||||
):
|
||||
if len(words) == 2:
|
||||
return self.complete_pan_users(last_word)
|
||||
elif len(words) == 3:
|
||||
return self.path_completer.get_completions(
|
||||
Document(last_word),
|
||||
complete_event
|
||||
Document(last_word), complete_event
|
||||
)
|
||||
|
||||
return ""
|
||||
@ -264,16 +255,9 @@ class PanCompleter(Completer):
|
||||
]:
|
||||
return self.complete_verification(command, last_word, words)
|
||||
|
||||
elif command in [
|
||||
"export-keys",
|
||||
"import-keys",
|
||||
]:
|
||||
elif command in ["export-keys", "import-keys"]:
|
||||
return self.complete_key_file_cmds(
|
||||
document,
|
||||
complete_event,
|
||||
command,
|
||||
last_word,
|
||||
words
|
||||
document, complete_event, command, last_word, words
|
||||
)
|
||||
|
||||
elif command in ["send-anyways", "cancel-sending"]:
|
||||
@ -300,7 +284,7 @@ def grouper(iterable, n, fillvalue=None):
|
||||
|
||||
def partition_key(key):
|
||||
groups = grouper(key, 4, " ")
|
||||
return ' '.join(''.join(g) for g in groups)
|
||||
return " ".join("".join(g) for g in groups)
|
||||
|
||||
|
||||
def get_color(string):
|
||||
@ -330,37 +314,59 @@ class PanCtl:
|
||||
|
||||
command_help = {
|
||||
"help": "Display help about commands.",
|
||||
"list-servers": ("List the configured homeservers and pan users on "
|
||||
"each homeserver."),
|
||||
"list-devices": ("List the devices of a user that are known to the "
|
||||
"pan-user."),
|
||||
"start-verification": ("Start an interactive key verification between "
|
||||
"the given pan-user and user."),
|
||||
"accept-verification": ("Accept an interactive key verification that "
|
||||
"list-servers": (
|
||||
"List the configured homeservers and pan users on " "each homeserver."
|
||||
),
|
||||
"list-devices": (
|
||||
"List the devices of a user that are known to the " "pan-user."
|
||||
),
|
||||
"start-verification": (
|
||||
"Start an interactive key verification between "
|
||||
"the given pan-user and user."
|
||||
),
|
||||
"accept-verification": (
|
||||
"Accept an interactive key verification that "
|
||||
"the given user has started with our given "
|
||||
"pan-user."),
|
||||
"cancel-verification": ("Cancel an interactive key verification "
|
||||
"between the given pan-user and user."),
|
||||
"confirm-verification": ("Confirm that the short authentication "
|
||||
"pan-user."
|
||||
),
|
||||
"cancel-verification": (
|
||||
"Cancel an interactive key verification "
|
||||
"between the given pan-user and user."
|
||||
),
|
||||
"confirm-verification": (
|
||||
"Confirm that the short authentication "
|
||||
"string of the interactive key verification "
|
||||
"with the given pan-user and user is "
|
||||
"matching."),
|
||||
"matching."
|
||||
),
|
||||
"verify-device": ("Manually mark the given device as verified."),
|
||||
"unverify-device": ("Mark a previously verified device of the given "
|
||||
"user as unverified."),
|
||||
"blacklist-device": ("Manually mark the given device of the given "
|
||||
"user as blacklisted."),
|
||||
"unblacklist-device": ("Mark a previously blacklisted device of the "
|
||||
"given user as unblacklisted."),
|
||||
"send-anyways": ("Send a room message despite having unverified "
|
||||
"unverify-device": (
|
||||
"Mark a previously verified device of the given " "user as unverified."
|
||||
),
|
||||
"blacklist-device": (
|
||||
"Manually mark the given device of the given " "user as blacklisted."
|
||||
),
|
||||
"unblacklist-device": (
|
||||
"Mark a previously blacklisted device of the "
|
||||
"given user as unblacklisted."
|
||||
),
|
||||
"send-anyways": (
|
||||
"Send a room message despite having unverified "
|
||||
"devices in the room and mark the devices as "
|
||||
"ignored."),
|
||||
"cancel-sending": ("Cancel the send of a room message in a room that "
|
||||
"contains unverified devices"),
|
||||
"import-keys": ("Import end-to-end encryption keys from the given "
|
||||
"file for the given pan-user."),
|
||||
"export-keys": ("Export end-to-end encryption keys to the given file "
|
||||
"for the given pan-user."),
|
||||
"ignored."
|
||||
),
|
||||
"cancel-sending": (
|
||||
"Cancel the send of a room message in a room that "
|
||||
"contains unverified devices"
|
||||
),
|
||||
"import-keys": (
|
||||
"Import end-to-end encryption keys from the given "
|
||||
"file for the given pan-user."
|
||||
),
|
||||
"export-keys": (
|
||||
"Export end-to-end encryption keys to the given file "
|
||||
"for the given pan-user."
|
||||
),
|
||||
}
|
||||
|
||||
commands = [
|
||||
@ -404,10 +410,12 @@ class PanCtl:
|
||||
|
||||
def unverified_devices(self, pan_user, room_id, display_name):
|
||||
self.completer.rooms[pan_user].add(room_id)
|
||||
print(f"Error sending message for user {pan_user}, "
|
||||
print(
|
||||
f"Error sending message for user {pan_user}, "
|
||||
f"there are unverified devices in the room {display_name} "
|
||||
f"({room_id}).\nUse the send-anyways or cancel-sending commands "
|
||||
f"to ignore the devices or cancel the sending.")
|
||||
f"to ignore the devices or cancel the sending."
|
||||
)
|
||||
|
||||
def show_response(self, response_id, pan_user, message):
|
||||
if response_id not in self.own_message_ids:
|
||||
@ -418,14 +426,18 @@ class PanCtl:
|
||||
print(message["message"])
|
||||
|
||||
def sas_done(self, pan_user, user_id, device_id, _):
|
||||
print(f"Device {device_id} of user {user_id}"
|
||||
f" succesfully verified for pan user {pan_user}.")
|
||||
print(
|
||||
f"Device {device_id} of user {user_id}"
|
||||
f" succesfully verified for pan user {pan_user}."
|
||||
)
|
||||
|
||||
def show_sas_invite(self, pan_user, user_id, device_id, _):
|
||||
print(f"{user_id} has started an interactive device "
|
||||
print(
|
||||
f"{user_id} has started an interactive device "
|
||||
f"verification for his device {device_id} with pan user "
|
||||
f"{pan_user}\n"
|
||||
f"Accept the invitation with the accept-verification command.")
|
||||
f"Accept the invitation with the accept-verification command."
|
||||
)
|
||||
|
||||
# The emoji printing logic was taken from weechat-matrix and was written by
|
||||
# dkasak.
|
||||
@ -443,33 +455,26 @@ class PanCtl:
|
||||
# that they are rendered with coloured glyphs. For these, we
|
||||
# need to add an extra space after them so that they are
|
||||
# rendered properly in weechat.
|
||||
variation_selector_emojis = [
|
||||
'☁️',
|
||||
'❤️',
|
||||
'☂️',
|
||||
'✏️',
|
||||
'✂️',
|
||||
'☎️',
|
||||
'✈️'
|
||||
]
|
||||
variation_selector_emojis = ["☁️", "❤️", "☂️", "✏️", "✂️", "☎️", "✈️"]
|
||||
|
||||
if emoji in variation_selector_emojis:
|
||||
emoji += " "
|
||||
|
||||
# This is a trick to account for the fact that emojis are wider
|
||||
# than other monospace characters.
|
||||
placeholder = '.' * emoji_width
|
||||
placeholder = "." * emoji_width
|
||||
|
||||
return placeholder.center(width).replace(placeholder, emoji)
|
||||
|
||||
emoji_str = u"".join(center_emoji(e, centered_width)
|
||||
for e in emojis)
|
||||
desc = u"".join(d.center(centered_width) for d in descriptions)
|
||||
short_string = u"\n".join([emoji_str, desc])
|
||||
emoji_str = "".join(center_emoji(e, centered_width) for e in emojis)
|
||||
desc = "".join(d.center(centered_width) for d in descriptions)
|
||||
short_string = "\n".join([emoji_str, desc])
|
||||
|
||||
print(f"Short authentication string for pan "
|
||||
print(
|
||||
f"Short authentication string for pan "
|
||||
f"user {pan_user} from {user_id} via "
|
||||
f"{device_id}:\n{short_string}")
|
||||
f"{device_id}:\n{short_string}"
|
||||
)
|
||||
|
||||
def list_servers(self):
|
||||
"""List the daemons users."""
|
||||
@ -480,9 +485,7 @@ class PanCtl:
|
||||
for server, server_users in servers.items():
|
||||
server_c = get_color(server)
|
||||
|
||||
print_formatted_text(HTML(
|
||||
f" - Name: <{server_c}>{server}</{server_c}>"
|
||||
))
|
||||
print_formatted_text(HTML(f" - Name: <{server_c}>{server}</{server_c}>"))
|
||||
|
||||
user_list = []
|
||||
|
||||
@ -490,8 +493,10 @@ class PanCtl:
|
||||
user_c = get_color(user)
|
||||
device_c = get_color(device)
|
||||
|
||||
user_list.append(f" - <{user_c}>{user}</{user_c}> "
|
||||
f"<{device_c}>{device}</{device_c}>")
|
||||
user_list.append(
|
||||
f" - <{user_c}>{user}</{user_c}> "
|
||||
f"<{device_c}>{device}</{device_c}>"
|
||||
)
|
||||
|
||||
if user_list:
|
||||
print(f" - Pan users:")
|
||||
@ -501,9 +506,7 @@ class PanCtl:
|
||||
def list_devices(self, args):
|
||||
devices = self.devices.ListUserDevices(args.pan_user, args.user_id)
|
||||
|
||||
print_formatted_text(
|
||||
HTML(f"Devices for user <b>{args.user_id}</b>:")
|
||||
)
|
||||
print_formatted_text(HTML(f"Devices for user <b>{args.user_id}</b>:"))
|
||||
|
||||
for device in devices:
|
||||
if device["trust_state"] == "verified":
|
||||
@ -517,7 +520,8 @@ class PanCtl:
|
||||
|
||||
key = partition_key(device["ed25519"])
|
||||
color = get_color(device["device_id"])
|
||||
print_formatted_text(HTML(
|
||||
print_formatted_text(
|
||||
HTML(
|
||||
f" - Display name: "
|
||||
f"{device['device_display_name']}\n"
|
||||
f" - Device id: "
|
||||
@ -526,7 +530,8 @@ class PanCtl:
|
||||
f"<ansiyellow>{key}</ansiyellow>\n"
|
||||
f" - Trust state: "
|
||||
f"{trust_state}"
|
||||
))
|
||||
)
|
||||
)
|
||||
|
||||
async def loop(self):
|
||||
"""Event loop for panctl."""
|
||||
@ -559,105 +564,83 @@ class PanCtl:
|
||||
|
||||
elif command == "import-keys":
|
||||
self.own_message_ids.append(
|
||||
self.ctl.ImportKeys(
|
||||
args.pan_user,
|
||||
args.path,
|
||||
args.passphrase
|
||||
))
|
||||
self.ctl.ImportKeys(args.pan_user, args.path, args.passphrase)
|
||||
)
|
||||
|
||||
elif command == "export-keys":
|
||||
self.own_message_ids.append(
|
||||
self.ctl.ExportKeys(
|
||||
args.pan_user,
|
||||
args.path,
|
||||
args.passphrase
|
||||
))
|
||||
self.ctl.ExportKeys(args.pan_user, args.path, args.passphrase)
|
||||
)
|
||||
|
||||
elif command == "send-anyways":
|
||||
self.own_message_ids.append(
|
||||
self.ctl.SendAnyways(
|
||||
args.pan_user,
|
||||
args.room_id,
|
||||
))
|
||||
self.ctl.SendAnyways(args.pan_user, args.room_id)
|
||||
)
|
||||
|
||||
elif command == "cancel-sending":
|
||||
self.own_message_ids.append(
|
||||
self.ctl.CancelSending(
|
||||
args.pan_user,
|
||||
args.room_id,
|
||||
))
|
||||
self.ctl.CancelSending(args.pan_user, args.room_id)
|
||||
)
|
||||
|
||||
elif command == "list-devices":
|
||||
self.list_devices(args)
|
||||
|
||||
elif command == "verify-device":
|
||||
self.own_message_ids.append(
|
||||
self.devices.Verify(
|
||||
args.pan_user,
|
||||
args.user_id,
|
||||
args.device_id
|
||||
))
|
||||
self.devices.Verify(args.pan_user, args.user_id, args.device_id)
|
||||
)
|
||||
|
||||
elif command == "unverify-device":
|
||||
self.own_message_ids.append(
|
||||
self.devices.Unverify(
|
||||
args.pan_user,
|
||||
args.user_id,
|
||||
args.device_id
|
||||
))
|
||||
self.devices.Unverify(args.pan_user, args.user_id, args.device_id)
|
||||
)
|
||||
|
||||
elif command == "blacklist-device":
|
||||
self.own_message_ids.append(
|
||||
self.devices.Blacklist(
|
||||
args.pan_user,
|
||||
args.user_id,
|
||||
args.device_id
|
||||
))
|
||||
self.devices.Blacklist(args.pan_user, args.user_id, args.device_id)
|
||||
)
|
||||
|
||||
elif command == "unblacklist-device":
|
||||
self.own_message_ids.append(
|
||||
self.devices.Unblacklist(
|
||||
args.pan_user,
|
||||
args.user_id,
|
||||
args.device_id
|
||||
))
|
||||
args.pan_user, args.user_id, args.device_id
|
||||
)
|
||||
)
|
||||
|
||||
elif command == "start-verification":
|
||||
self.own_message_ids.append(
|
||||
self.devices.StartKeyVerification(
|
||||
args.pan_user,
|
||||
args.user_id,
|
||||
args.device_id
|
||||
))
|
||||
args.pan_user, args.user_id, args.device_id
|
||||
)
|
||||
)
|
||||
|
||||
elif command == "cancel-verification":
|
||||
self.own_message_ids.append(
|
||||
self.devices.CancelKeyVerification(
|
||||
args.pan_user,
|
||||
args.user_id,
|
||||
args.device_id
|
||||
))
|
||||
args.pan_user, args.user_id, args.device_id
|
||||
)
|
||||
)
|
||||
|
||||
elif command == "accept-verification":
|
||||
self.own_message_ids.append(
|
||||
self.devices.AcceptKeyVerification(
|
||||
args.pan_user,
|
||||
args.user_id,
|
||||
args.device_id
|
||||
))
|
||||
args.pan_user, args.user_id, args.device_id
|
||||
)
|
||||
)
|
||||
|
||||
elif command == "confirm-verification":
|
||||
self.own_message_ids.append(
|
||||
self.devices.ConfirmKeyVerification(
|
||||
args.pan_user,
|
||||
args.user_id,
|
||||
args.device_id
|
||||
))
|
||||
args.pan_user, args.user_id, args.device_id
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@click.command(
|
||||
help=("panctl is a small interactive repl to introspect and control"
|
||||
"the pantalaimon daemon.")
|
||||
help=(
|
||||
"panctl is a small interactive repl to introspect and control"
|
||||
"the pantalaimon daemon."
|
||||
)
|
||||
)
|
||||
@click.version_option(version="0.1", prog_name="panctl")
|
||||
def main():
|
||||
@ -670,10 +653,7 @@ def main():
|
||||
print(f"Error, {e}")
|
||||
sys.exit(-1)
|
||||
|
||||
fut = loop.run_in_executor(
|
||||
None,
|
||||
glib_loop.run
|
||||
)
|
||||
fut = loop.run_in_executor(None, glib_loop.run)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(panctl.loop())
|
||||
@ -684,5 +664,5 @@ def main():
|
||||
loop.run_until_complete(fut)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -18,10 +18,8 @@ from collections import defaultdict
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
from nio.store import (Accounts, DeviceKeys, DeviceTrustState, TrustState,
|
||||
use_database)
|
||||
from peewee import (SQL, DoesNotExist, ForeignKeyField, Model, SqliteDatabase,
|
||||
TextField)
|
||||
from nio.store import Accounts, DeviceKeys, DeviceTrustState, TrustState, use_database
|
||||
from peewee import SQL, DoesNotExist, ForeignKeyField, Model, SqliteDatabase, TextField
|
||||
|
||||
|
||||
@attr.s
|
||||
@ -41,10 +39,7 @@ class DictField(TextField):
|
||||
class AccessTokens(Model):
|
||||
token = TextField()
|
||||
account = ForeignKeyField(
|
||||
model=Accounts,
|
||||
primary_key=True,
|
||||
backref="access_token",
|
||||
on_delete="CASCADE"
|
||||
model=Accounts, primary_key=True, backref="access_token", on_delete="CASCADE"
|
||||
)
|
||||
|
||||
|
||||
@ -58,10 +53,7 @@ class Servers(Model):
|
||||
class ServerUsers(Model):
|
||||
user_id = TextField()
|
||||
server = ForeignKeyField(
|
||||
model=Servers,
|
||||
column_name="server_id",
|
||||
backref="users",
|
||||
on_delete="CASCADE"
|
||||
model=Servers, column_name="server_id", backref="users", on_delete="CASCADE"
|
||||
)
|
||||
|
||||
class Meta:
|
||||
@ -70,9 +62,7 @@ class ServerUsers(Model):
|
||||
|
||||
class PanSyncTokens(Model):
|
||||
token = TextField()
|
||||
user = ForeignKeyField(
|
||||
model=ServerUsers,
|
||||
column_name="user_id")
|
||||
user = ForeignKeyField(model=ServerUsers, column_name="user_id")
|
||||
|
||||
class Meta:
|
||||
constraints = [SQL("UNIQUE(user_id)")]
|
||||
@ -80,9 +70,8 @@ class PanSyncTokens(Model):
|
||||
|
||||
class PanFetcherTasks(Model):
|
||||
user = ForeignKeyField(
|
||||
model=ServerUsers,
|
||||
column_name="user_id",
|
||||
backref="fetcher_tasks")
|
||||
model=ServerUsers, column_name="user_id", backref="fetcher_tasks"
|
||||
)
|
||||
room_id = TextField()
|
||||
token = TextField()
|
||||
|
||||
@ -110,13 +99,12 @@ class PanStore:
|
||||
DeviceKeys,
|
||||
DeviceTrustState,
|
||||
PanSyncTokens,
|
||||
PanFetcherTasks
|
||||
PanFetcherTasks,
|
||||
]
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self.database_path = os.path.join(
|
||||
os.path.abspath(self.store_path),
|
||||
self.database_name
|
||||
os.path.abspath(self.store_path), self.database_name
|
||||
)
|
||||
|
||||
self.database = self._create_database()
|
||||
@ -127,19 +115,14 @@ class PanStore:
|
||||
|
||||
def _create_database(self):
|
||||
return SqliteDatabase(
|
||||
self.database_path,
|
||||
pragmas={
|
||||
"foreign_keys": 1,
|
||||
"secure_delete": 1,
|
||||
}
|
||||
self.database_path, pragmas={"foreign_keys": 1, "secure_delete": 1}
|
||||
)
|
||||
|
||||
@use_database
|
||||
def _get_account(self, user_id, device_id):
|
||||
try:
|
||||
return Accounts.get(
|
||||
Accounts.user_id == user_id,
|
||||
Accounts.device_id == device_id,
|
||||
Accounts.user_id == user_id, Accounts.device_id == device_id
|
||||
)
|
||||
except DoesNotExist:
|
||||
return None
|
||||
@ -150,9 +133,7 @@ class PanStore:
|
||||
user = ServerUsers.get(server=server, user_id=pan_user)
|
||||
|
||||
PanFetcherTasks.replace(
|
||||
user=user,
|
||||
room_id=task.room_id,
|
||||
token=task.token
|
||||
user=user, room_id=task.room_id, token=task.token
|
||||
).execute()
|
||||
|
||||
def load_fetcher_tasks(self, server, pan_user):
|
||||
@ -173,7 +154,7 @@ class PanStore:
|
||||
PanFetcherTasks.delete().where(
|
||||
PanFetcherTasks.user == user,
|
||||
PanFetcherTasks.room_id == task.room_id,
|
||||
PanFetcherTasks.token == task.token
|
||||
PanFetcherTasks.token == task.token,
|
||||
).execute()
|
||||
|
||||
@use_database
|
||||
@ -208,18 +189,14 @@ class PanStore:
|
||||
server, _ = Servers.get_or_create(name=server_name)
|
||||
|
||||
ServerUsers.insert(
|
||||
user_id=user_id,
|
||||
server=server
|
||||
user_id=user_id, server=server
|
||||
).on_conflict_ignore().execute()
|
||||
|
||||
@use_database
|
||||
def load_all_users(self):
|
||||
users = []
|
||||
|
||||
query = Accounts.select(
|
||||
Accounts.user_id,
|
||||
Accounts.device_id,
|
||||
)
|
||||
query = Accounts.select(Accounts.user_id, Accounts.device_id)
|
||||
|
||||
for account in query:
|
||||
users.append((account.user_id, account.device_id))
|
||||
@ -241,10 +218,9 @@ class PanStore:
|
||||
for u in server.users:
|
||||
server_users.append(u.user_id)
|
||||
|
||||
query = Accounts.select(
|
||||
Accounts.user_id,
|
||||
Accounts.device_id,
|
||||
).where(Accounts.user_id.in_(server_users))
|
||||
query = Accounts.select(Accounts.user_id, Accounts.device_id).where(
|
||||
Accounts.user_id.in_(server_users)
|
||||
)
|
||||
|
||||
for account in query:
|
||||
users.append((account.user_id, account.device_id))
|
||||
@ -256,10 +232,7 @@ class PanStore:
|
||||
account = self._get_account(user_id, device_id)
|
||||
assert account
|
||||
|
||||
AccessTokens.replace(
|
||||
account=account,
|
||||
token=access_token
|
||||
).execute()
|
||||
AccessTokens.replace(account=account, token=access_token).execute()
|
||||
|
||||
@use_database
|
||||
def load_access_token(self, user_id, device_id):
|
||||
@ -302,7 +275,7 @@ class PanStore:
|
||||
"ed25519": keys["ed25519"],
|
||||
"curve25519": keys["curve25519"],
|
||||
"trust_state": trust_state.name,
|
||||
"device_display_name": d.display_name
|
||||
"device_display_name": d.display_name,
|
||||
}
|
||||
|
||||
store[account.user_id] = device_store
|
||||
|
@ -24,20 +24,27 @@ from pydbus.generic import signal
|
||||
|
||||
from pantalaimon.log import logger
|
||||
from pantalaimon.store import PanStore
|
||||
from pantalaimon.thread_messages import (AcceptSasMessage, CancelSasMessage,
|
||||
from pantalaimon.thread_messages import (
|
||||
AcceptSasMessage,
|
||||
CancelSasMessage,
|
||||
CancelSendingMessage,
|
||||
ConfirmSasMessage, DaemonResponse,
|
||||
ConfirmSasMessage,
|
||||
DaemonResponse,
|
||||
DeviceBlacklistMessage,
|
||||
DeviceUnblacklistMessage,
|
||||
DeviceUnverifyMessage,
|
||||
DeviceVerifyMessage,
|
||||
ExportKeysMessage, ImportKeysMessage,
|
||||
InviteSasSignal, SasDoneSignal,
|
||||
SendAnywaysMessage, ShowSasSignal,
|
||||
ExportKeysMessage,
|
||||
ImportKeysMessage,
|
||||
InviteSasSignal,
|
||||
SasDoneSignal,
|
||||
SendAnywaysMessage,
|
||||
ShowSasSignal,
|
||||
StartSasMessage,
|
||||
UnverifiedDevicesSignal,
|
||||
UpdateDevicesMessage,
|
||||
UpdateUsersMessage)
|
||||
UpdateUsersMessage,
|
||||
)
|
||||
|
||||
|
||||
class IdCounter:
|
||||
@ -126,40 +133,22 @@ class Control:
|
||||
return self.users
|
||||
|
||||
def ExportKeys(self, pan_user, filepath, passphrase):
|
||||
message = ExportKeysMessage(
|
||||
self.message_id,
|
||||
pan_user,
|
||||
filepath,
|
||||
passphrase
|
||||
)
|
||||
message = ExportKeysMessage(self.message_id, pan_user, filepath, passphrase)
|
||||
self.queue.put(message)
|
||||
return message.message_id
|
||||
|
||||
def ImportKeys(self, pan_user, filepath, passphrase):
|
||||
message = ImportKeysMessage(
|
||||
self.message_id,
|
||||
pan_user,
|
||||
filepath,
|
||||
passphrase
|
||||
)
|
||||
message = ImportKeysMessage(self.message_id, pan_user, filepath, passphrase)
|
||||
self.queue.put(message)
|
||||
return message.message_id
|
||||
|
||||
def SendAnyways(self, pan_user, room_id):
|
||||
message = SendAnywaysMessage(
|
||||
self.message_id,
|
||||
pan_user,
|
||||
room_id
|
||||
)
|
||||
message = SendAnywaysMessage(self.message_id, pan_user, room_id)
|
||||
self.queue.put(message)
|
||||
return message.message_id
|
||||
|
||||
def CancelSending(self, pan_user, room_id):
|
||||
message = CancelSendingMessage(
|
||||
self.message_id,
|
||||
pan_user,
|
||||
room_id
|
||||
)
|
||||
message = CancelSendingMessage(self.message_id, pan_user, room_id)
|
||||
self.queue.put(message)
|
||||
return message.message_id
|
||||
|
||||
@ -293,8 +282,9 @@ class Devices:
|
||||
return []
|
||||
|
||||
device_list = [
|
||||
device for device_list in device_store.values() for device in
|
||||
device_list.values()
|
||||
device
|
||||
for device_list in device_store.values()
|
||||
for device in device_list.values()
|
||||
]
|
||||
|
||||
return device_list
|
||||
@ -313,82 +303,44 @@ class Devices:
|
||||
return device_list.values()
|
||||
|
||||
def Verify(self, pan_user, user_id, device_id):
|
||||
message = DeviceVerifyMessage(
|
||||
self.message_id,
|
||||
pan_user,
|
||||
user_id,
|
||||
device_id
|
||||
)
|
||||
message = DeviceVerifyMessage(self.message_id, pan_user, user_id, device_id)
|
||||
self.queue.put(message)
|
||||
return message.message_id
|
||||
|
||||
def Unverify(self, pan_user, user_id, device_id):
|
||||
message = DeviceUnverifyMessage(
|
||||
self.message_id,
|
||||
pan_user,
|
||||
user_id,
|
||||
device_id
|
||||
)
|
||||
message = DeviceUnverifyMessage(self.message_id, pan_user, user_id, device_id)
|
||||
self.queue.put(message)
|
||||
return message.message_id
|
||||
|
||||
def Blacklist(self, pan_user, user_id, device_id):
|
||||
message = DeviceBlacklistMessage(
|
||||
self.message_id,
|
||||
pan_user,
|
||||
user_id,
|
||||
device_id
|
||||
)
|
||||
message = DeviceBlacklistMessage(self.message_id, pan_user, user_id, device_id)
|
||||
self.queue.put(message)
|
||||
return message.message_id
|
||||
|
||||
def Unblacklist(self, pan_user, user_id, device_id):
|
||||
message = DeviceUnblacklistMessage(
|
||||
self.message_id,
|
||||
pan_user,
|
||||
user_id,
|
||||
device_id
|
||||
self.message_id, pan_user, user_id, device_id
|
||||
)
|
||||
self.queue.put(message)
|
||||
return message.message_id
|
||||
|
||||
def StartKeyVerification(self, pan_user, user_id, device_id):
|
||||
message = StartSasMessage(
|
||||
self.message_id,
|
||||
pan_user,
|
||||
user_id,
|
||||
device_id
|
||||
)
|
||||
message = StartSasMessage(self.message_id, pan_user, user_id, device_id)
|
||||
self.queue.put(message)
|
||||
return message.message_id
|
||||
|
||||
def CancelKeyVerification(self, pan_user, user_id, device_id):
|
||||
message = CancelSasMessage(
|
||||
self.message_id,
|
||||
pan_user,
|
||||
user_id,
|
||||
device_id
|
||||
)
|
||||
message = CancelSasMessage(self.message_id, pan_user, user_id, device_id)
|
||||
self.queue.put(message)
|
||||
return message.message_id
|
||||
|
||||
def ConfirmKeyVerification(self, pan_user, user_id, device_id):
|
||||
message = ConfirmSasMessage(
|
||||
self.message_id,
|
||||
pan_user,
|
||||
user_id,
|
||||
device_id
|
||||
)
|
||||
message = ConfirmSasMessage(self.message_id, pan_user, user_id, device_id)
|
||||
self.queue.put(message)
|
||||
return message.message_id
|
||||
|
||||
def AcceptKeyVerification(self, pan_user, user_id, device_id):
|
||||
message = AcceptSasMessage(
|
||||
self.message_id,
|
||||
pan_user,
|
||||
user_id,
|
||||
device_id
|
||||
)
|
||||
message = AcceptSasMessage(self.message_id, pan_user, user_id, device_id)
|
||||
self.queue.put(message)
|
||||
return message.message_id
|
||||
|
||||
@ -421,8 +373,9 @@ class GlibT:
|
||||
|
||||
id_counter = IdCounter()
|
||||
|
||||
self.control_if = Control(self.send_queue, self.store,
|
||||
self.server_list, id_counter)
|
||||
self.control_if = Control(
|
||||
self.send_queue, self.store, self.server_list, id_counter
|
||||
)
|
||||
self.device_if = Devices(self.send_queue, self.store, id_counter)
|
||||
|
||||
self.bus = SessionBus()
|
||||
@ -431,131 +384,95 @@ class GlibT:
|
||||
def unverified_notification(self, message):
|
||||
notificaton = notify2.Notification(
|
||||
"Unverified devices.",
|
||||
message=(f"There are unverified devices in the room "
|
||||
f"{message.room_display_name}.")
|
||||
message=(
|
||||
f"There are unverified devices in the room "
|
||||
f"{message.room_display_name}."
|
||||
),
|
||||
)
|
||||
notificaton.set_category("im")
|
||||
|
||||
def send_cb(notification, action_key, user_data):
|
||||
message = user_data
|
||||
self.control_if.SendAnyways(
|
||||
message.pan_user,
|
||||
message.room_id
|
||||
)
|
||||
self.control_if.SendAnyways(message.pan_user, message.room_id)
|
||||
|
||||
def cancel_cb(notification, action_key, user_data):
|
||||
message = user_data
|
||||
self.control_if.CancelSending(
|
||||
message.pan_user,
|
||||
message.room_id
|
||||
)
|
||||
self.control_if.CancelSending(message.pan_user, message.room_id)
|
||||
|
||||
if "actions" in notify2.get_server_caps():
|
||||
notificaton.add_action(
|
||||
"send",
|
||||
"Send anyways",
|
||||
send_cb,
|
||||
message
|
||||
)
|
||||
notificaton.add_action(
|
||||
"cancel",
|
||||
"Cancel sending",
|
||||
cancel_cb,
|
||||
message
|
||||
)
|
||||
notificaton.add_action("send", "Send anyways", send_cb, message)
|
||||
notificaton.add_action("cancel", "Cancel sending", cancel_cb, message)
|
||||
|
||||
notificaton.show()
|
||||
|
||||
def sas_invite_notification(self, message):
|
||||
notificaton = notify2.Notification(
|
||||
"Key verification invite",
|
||||
message=(f"{message.user_id} via {message.device_id} has started "
|
||||
f"a key verification process.")
|
||||
message=(
|
||||
f"{message.user_id} via {message.device_id} has started "
|
||||
f"a key verification process."
|
||||
),
|
||||
)
|
||||
notificaton.set_category("im")
|
||||
|
||||
def accept_cb(notification, action_key, user_data):
|
||||
message = user_data
|
||||
self.device_if.AcceptKeyVerification(
|
||||
message.pan_user,
|
||||
message.user_id,
|
||||
message.device_id
|
||||
message.pan_user, message.user_id, message.device_id
|
||||
)
|
||||
|
||||
def cancel_cb(notification, action_key, user_data):
|
||||
message = user_data
|
||||
self.device_if.CancelKeyVerification(
|
||||
message.pan_user,
|
||||
message.user_id,
|
||||
message.device_id,
|
||||
message.pan_user, message.user_id, message.device_id
|
||||
)
|
||||
|
||||
if "actions" in notify2.get_server_caps():
|
||||
notificaton.add_action(
|
||||
"accept",
|
||||
"Accept",
|
||||
accept_cb,
|
||||
message
|
||||
)
|
||||
notificaton.add_action(
|
||||
"cancel",
|
||||
"Cancel",
|
||||
cancel_cb,
|
||||
message
|
||||
)
|
||||
notificaton.add_action("accept", "Accept", accept_cb, message)
|
||||
notificaton.add_action("cancel", "Cancel", cancel_cb, message)
|
||||
|
||||
notificaton.show()
|
||||
|
||||
def sas_show_notification(self, message):
|
||||
emojis = [x[0] for x in message.emoji]
|
||||
|
||||
emoji_str = u" ".join(emojis)
|
||||
emoji_str = " ".join(emojis)
|
||||
|
||||
notificaton = notify2.Notification(
|
||||
"Short authentication string",
|
||||
message=(f"Short authentication string for the key verification of"
|
||||
message=(
|
||||
f"Short authentication string for the key verification of"
|
||||
f" {message.user_id} via {message.device_id}:\n"
|
||||
f"{emoji_str}")
|
||||
f"{emoji_str}"
|
||||
),
|
||||
)
|
||||
notificaton.set_category("im")
|
||||
|
||||
def confirm_cb(notification, action_key, user_data):
|
||||
message = user_data
|
||||
self.device_if.ConfirmKeyVerification(
|
||||
message.pan_user,
|
||||
message.user_id,
|
||||
message.device_id
|
||||
message.pan_user, message.user_id, message.device_id
|
||||
)
|
||||
|
||||
def cancel_cb(notification, action_key, user_data):
|
||||
message = user_data
|
||||
self.device_if.CancelKeyVerification(
|
||||
message.pan_user,
|
||||
message.user_id,
|
||||
message.device_id,
|
||||
message.pan_user, message.user_id, message.device_id
|
||||
)
|
||||
|
||||
if "actions" in notify2.get_server_caps():
|
||||
notificaton.add_action(
|
||||
"confirm",
|
||||
"Confirm",
|
||||
confirm_cb,
|
||||
message
|
||||
)
|
||||
notificaton.add_action(
|
||||
"cancel",
|
||||
"Cancel",
|
||||
cancel_cb,
|
||||
message
|
||||
)
|
||||
notificaton.add_action("confirm", "Confirm", confirm_cb, message)
|
||||
notificaton.add_action("cancel", "Cancel", cancel_cb, message)
|
||||
|
||||
notificaton.show()
|
||||
|
||||
def sas_done_notification(self, message):
|
||||
notificaton = notify2.Notification(
|
||||
"Device successfully verified.",
|
||||
message=(f"Device {message.device_id} of user {message.user_id} "
|
||||
f"successfully verified.")
|
||||
message=(
|
||||
f"Device {message.device_id} of user {message.user_id} "
|
||||
f"successfully verified."
|
||||
),
|
||||
)
|
||||
notificaton.set_category("im")
|
||||
notificaton.show()
|
||||
@ -576,9 +493,7 @@ class GlibT:
|
||||
|
||||
elif isinstance(message, UnverifiedDevicesSignal):
|
||||
self.control_if.UnverifiedDevices(
|
||||
message.pan_user,
|
||||
message.room_id,
|
||||
message.room_display_name
|
||||
message.pan_user, message.room_id, message.room_display_name
|
||||
)
|
||||
|
||||
if self.notifications:
|
||||
@ -589,7 +504,7 @@ class GlibT:
|
||||
message.pan_user,
|
||||
message.user_id,
|
||||
message.device_id,
|
||||
message.transaction_id
|
||||
message.transaction_id,
|
||||
)
|
||||
|
||||
if self.notifications:
|
||||
@ -622,10 +537,7 @@ class GlibT:
|
||||
self.control_if.Response(
|
||||
message.message_id,
|
||||
message.pan_user,
|
||||
{
|
||||
"code": message.code,
|
||||
"message": message.message
|
||||
}
|
||||
{"code": message.code, "message": message.message},
|
||||
)
|
||||
|
||||
self.receive_queue.task_done()
|
||||
@ -639,8 +551,10 @@ class GlibT:
|
||||
notify2.init("pantalaimon", mainloop=self.loop)
|
||||
self.notifications = True
|
||||
except dbus.DBusException:
|
||||
logger.error("Notifications are enabled but no notification "
|
||||
"server could be found, disabling notifications.")
|
||||
logger.error(
|
||||
"Notifications are enabled but no notification "
|
||||
"server could be found, disabling notifications."
|
||||
)
|
||||
self.notifications = False
|
||||
|
||||
GLib.timeout_add(100, self.message_callback)
|
||||
|
Loading…
Reference in New Issue
Block a user