chore: Format code with ruff

This commit is contained in:
Hank Greenburg 2024-10-01 19:20:29 -07:00
parent adef63443e
commit 76dc74d250
10 changed files with 96 additions and 98 deletions

View file

@ -709,7 +709,6 @@ class PanClient(AsyncClient):
for share in self.get_active_key_requests( for share in self.get_active_key_requests(
message.user_id, message.device_id message.user_id, message.device_id
): ):
continued = True continued = True
if not self.continue_key_share(share): if not self.continue_key_share(share):
@ -811,8 +810,9 @@ class PanClient(AsyncClient):
if not isinstance(event, MegolmEvent): if not isinstance(event, MegolmEvent):
logger.warn( logger.warn(
"Encrypted event is not a megolm event:" "Encrypted event is not a megolm event:" "\n{}".format(
"\n{}".format(pformat(event_dict)) pformat(event_dict)
)
) )
return False return False
@ -836,9 +836,9 @@ class PanClient(AsyncClient):
decrypted_event.source["content"]["url"] = decrypted_event.url decrypted_event.source["content"]["url"] = decrypted_event.url
if decrypted_event.thumbnail_url: if decrypted_event.thumbnail_url:
decrypted_event.source["content"]["info"][ decrypted_event.source["content"]["info"]["thumbnail_url"] = (
"thumbnail_url" decrypted_event.thumbnail_url
] = decrypted_event.thumbnail_url )
event_dict.update(decrypted_event.source) event_dict.update(decrypted_event.source)
event_dict["decrypted"] = True event_dict["decrypted"] = True

View file

@ -186,7 +186,6 @@ class PanConfig:
try: try:
for section_name, section in config.items(): for section_name, section in config.items():
if section_name == "Default": if section_name == "Default":
continue continue

View file

@ -227,7 +227,8 @@ class ProxyDaemon:
if ret: if ret:
msg = ( msg = (
f"Device {device.id} of user " f"{device.user_id} successfully verified." f"Device {device.id} of user "
f"{device.user_id} successfully verified."
) )
await client.send_update_device(device) await client.send_update_device(device)
else: else:
@ -309,7 +310,6 @@ class ProxyDaemon:
DeviceUnblacklistMessage, DeviceUnblacklistMessage,
), ),
): ):
device = client.device_store[message.user_id].get(message.device_id, None) device = client.device_store[message.user_id].get(message.device_id, None)
if not device: if not device:
@ -616,7 +616,9 @@ class ProxyDaemon:
await pan_client.close() await pan_client.close()
return return
logger.info(f"Successfully started new background sync client for " f"{user_id}") logger.info(
f"Successfully started new background sync client for " f"{user_id}"
)
await self.send_ui_message( await self.send_ui_message(
UpdateUsersMessage(self.name, user_id, pan_client.device_id) UpdateUsersMessage(self.name, user_id, pan_client.device_id)
@ -733,7 +735,7 @@ class ProxyDaemon:
return decryption_method(body, ignore_failures=False) return decryption_method(body, ignore_failures=False)
except EncryptionError: except EncryptionError:
logger.info("Error decrypting sync, waiting for next pan " "sync") logger.info("Error decrypting sync, waiting for next pan " "sync")
await client.synced.wait(), (await client.synced.wait(),)
logger.info("Pan synced, retrying decryption.") logger.info("Pan synced, retrying decryption.")
try: try:
@ -1273,7 +1275,9 @@ class ProxyDaemon:
client = next(iter(self.pan_clients.values())) client = next(iter(self.pan_clients.values()))
try: try:
response = await client.download(server_name=server_name, media_id=media_id, filename=file_name) response = await client.download(
server_name=server_name, media_id=media_id, filename=file_name
)
except ClientConnectionError as e: except ClientConnectionError as e:
raise e raise e

View file

@ -230,7 +230,6 @@ if False:
) )
for message in query: for message in query:
event = message.event event = message.event
event_dict = { event_dict = {

View file

@ -431,7 +431,6 @@ class PanStore:
device_store = defaultdict(dict) device_store = defaultdict(dict)
for d in account.device_keys: for d in account.device_keys:
if d.deleted: if d.deleted:
continue continue

View file

@ -11,8 +11,7 @@ setup(
url="https://github.com/matrix-org/pantalaimon", url="https://github.com/matrix-org/pantalaimon",
author="The Matrix.org Team", author="The Matrix.org Team",
author_email="poljar@termina.org.uk", author_email="poljar@termina.org.uk",
description=("A Matrix proxy daemon that adds E2E encryption " description=("A Matrix proxy daemon that adds E2E encryption " "capabilities."),
"capabilities."),
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
license="Apache License, Version 2.0", license="Apache License, Version 2.0",
@ -29,7 +28,7 @@ setup(
"cachetools >= 3.0.0", "cachetools >= 3.0.0",
"prompt_toolkit > 2, < 4", "prompt_toolkit > 2, < 4",
"typing;python_version<'3.5'", "typing;python_version<'3.5'",
"matrix-nio[e2e] >= 0.20, < 0.21" "matrix-nio[e2e] >= 0.20, < 0.21",
], ],
extras_require={ extras_require={
"ui": [ "ui": [
@ -40,8 +39,10 @@ setup(
] ]
}, },
entry_points={ entry_points={
"console_scripts": ["pantalaimon=pantalaimon.main:main", "console_scripts": [
"panctl=pantalaimon.panctl:main"], "pantalaimon=pantalaimon.main:main",
"panctl=pantalaimon.panctl:main",
],
}, },
zip_safe=False zip_safe=False,
) )

View file

@ -34,11 +34,9 @@ class Provider(BaseProvider):
def client(self): def client(self):
return ClientInfo(faker.mx_id(), faker.access_token()) return ClientInfo(faker.mx_id(), faker.access_token())
def avatar_url(self): def avatar_url(self):
return "mxc://{}/{}#auto".format( return "mxc://{}/{}#auto".format(
faker.hostname(), faker.hostname(), "".join(choices(ascii_letters) for i in range(24))
"".join(choices(ascii_letters) for i in range(24))
) )
def olm_key_pair(self): def olm_key_pair(self):
@ -56,7 +54,6 @@ class Provider(BaseProvider):
) )
faker.add_provider(Provider) faker.add_provider(Provider)
@ -80,13 +77,7 @@ def tempdir():
@pytest.fixture @pytest.fixture
def panstore(tempdir): def panstore(tempdir):
for _ in range(10): for _ in range(10):
store = SqliteStore( store = SqliteStore(faker.mx_id(), faker.device_id(), tempdir, "", "pan.db")
faker.mx_id(),
faker.device_id(),
tempdir,
"",
"pan.db"
)
account = OlmAccount() account = OlmAccount()
store.save_account(account) store.save_account(account)
@ -130,21 +121,23 @@ async def pan_proxy_server(tempdir, aiohttp_server):
recv_queue=ui_queue.async_q, recv_queue=ui_queue.async_q,
proxy=None, proxy=None,
ssl=False, ssl=False,
client_store_class=SqliteStore client_store_class=SqliteStore,
) )
app.add_routes([ app.add_routes(
[
web.post("/_matrix/client/r0/login", proxy.login), web.post("/_matrix/client/r0/login", proxy.login),
web.get("/_matrix/client/r0/sync", proxy.sync), web.get("/_matrix/client/r0/sync", proxy.sync),
web.get("/_matrix/client/r0/rooms/{room_id}/messages", proxy.messages), web.get("/_matrix/client/r0/rooms/{room_id}/messages", proxy.messages),
web.put( web.put(
r"/_matrix/client/r0/rooms/{room_id}/send/{event_type}/{txnid}", r"/_matrix/client/r0/rooms/{room_id}/send/{event_type}/{txnid}",
proxy.send_message proxy.send_message,
), ),
web.post("/_matrix/client/r0/user/{user_id}/filter", proxy.filter), web.post("/_matrix/client/r0/user/{user_id}/filter", proxy.filter),
web.post("/_matrix/client/r0/search", proxy.search), web.post("/_matrix/client/r0/search", proxy.search),
web.options("/_matrix/client/r0/search", proxy.search_opts), web.options("/_matrix/client/r0/search", proxy.search_opts),
]) ]
)
server = await aiohttp_server(app) server = await aiohttp_server(app)
@ -161,7 +154,7 @@ async def running_proxy(pan_proxy_server, aioresponse, aiohttp_client):
"access_token": "abc123", "access_token": "abc123",
"device_id": "GHTYAJCE", "device_id": "GHTYAJCE",
"home_server": "example.org", "home_server": "example.org",
"user_id": "@example:example.org" "user_id": "@example:example.org",
} }
aioclient = await aiohttp_client(server) aioclient = await aiohttp_client(server)
@ -170,7 +163,7 @@ async def running_proxy(pan_proxy_server, aioresponse, aiohttp_client):
"https://example.org/_matrix/client/r0/login", "https://example.org/_matrix/client/r0/login",
status=200, status=200,
payload=login_response, payload=login_response,
repeat=True repeat=True,
) )
await aioclient.post( await aioclient.post(
@ -179,7 +172,7 @@ async def running_proxy(pan_proxy_server, aioresponse, aiohttp_client):
"type": "m.login.password", "type": "m.login.password",
"user": "example", "user": "example",
"password": "wordpass", "password": "wordpass",
} },
) )
yield server, aioclient, proxy, queues yield server, aioclient, proxy, queues

View file

@ -380,7 +380,9 @@ class TestClass(object):
) )
aioresponse.get( aioresponse.get(
sync_url, status=200, payload=self.initial_sync_response, sync_url,
status=200,
payload=self.initial_sync_response,
) )
aioresponse.get(sync_url, status=200, payload=self.empty_sync, repeat=True) aioresponse.get(sync_url, status=200, payload=self.empty_sync, repeat=True)
@ -454,7 +456,9 @@ class TestClass(object):
) )
aioresponse.get( aioresponse.get(
sync_url, status=200, payload=self.initial_sync_response, sync_url,
status=200,
payload=self.initial_sync_response,
) )
aioresponse.get(sync_url, status=200, payload=self.empty_sync, repeat=True) aioresponse.get(sync_url, status=200, payload=self.empty_sync, repeat=True)

View file

@ -27,7 +27,7 @@ class TestClass(object):
"access_token": "abc123", "access_token": "abc123",
"device_id": "GHTYAJCE", "device_id": "GHTYAJCE",
"home_server": "example.org", "home_server": "example.org",
"user_id": "@example:example.org" "user_id": "@example:example.org",
} }
@property @property
@ -36,12 +36,7 @@ class TestClass(object):
@property @property
def keys_upload_response(self): def keys_upload_response(self):
return { return {"one_time_key_counts": {"curve25519": 10, "signed_curve25519": 20}}
"one_time_key_counts": {
"curve25519": 10,
"signed_curve25519": 20
}
}
@property @property
def example_devices(self): def example_devices(self):
@ -52,10 +47,7 @@ class TestClass(object):
devices[device.user_id][device.id] = device devices[device.user_id][device.id] = device
bob_device = OlmDevice( bob_device = OlmDevice(
BOB_ID, BOB_ID, BOB_DEVICE, {"ed25519": BOB_ONETIME, "curve25519": BOB_CURVE}
BOB_DEVICE,
{"ed25519": BOB_ONETIME,
"curve25519": BOB_CURVE}
) )
devices[BOB_ID][BOB_DEVICE] = bob_device devices[BOB_ID][BOB_DEVICE] = bob_device
@ -71,7 +63,7 @@ class TestClass(object):
"https://example.org/_matrix/client/r0/login", "https://example.org/_matrix/client/r0/login",
status=200, status=200,
payload=self.login_response, payload=self.login_response,
repeat=True repeat=True,
) )
assert not daemon.pan_clients assert not daemon.pan_clients
@ -82,7 +74,7 @@ class TestClass(object):
"type": "m.login.password", "type": "m.login.password",
"user": "example", "user": "example",
"password": "wordpass", "password": "wordpass",
} },
) )
assert resp.status == 200 assert resp.status == 200
@ -105,11 +97,11 @@ class TestClass(object):
"https://example.org/_matrix/client/r0/login", "https://example.org/_matrix/client/r0/login",
status=200, status=200,
payload=self.login_response, payload=self.login_response,
repeat=True repeat=True,
) )
sync_url = re.compile( sync_url = re.compile(
r'^https://example\.org/_matrix/client/r0/sync\?access_token=.*' r"^https://example\.org/_matrix/client/r0/sync\?access_token=.*"
) )
aioresponse.get( aioresponse.get(
@ -124,14 +116,16 @@ class TestClass(object):
"type": "m.login.password", "type": "m.login.password",
"user": "example", "user": "example",
"password": "wordpass", "password": "wordpass",
} },
) )
# Check that the pan client started to sync after logging in. # Check that the pan client started to sync after logging in.
pan_client = list(daemon.pan_clients.values())[0] pan_client = list(daemon.pan_clients.values())[0]
assert len(pan_client.rooms) == 1 assert len(pan_client.rooms) == 1
async def test_pan_client_keys_upload(self, pan_proxy_server, aiohttp_client, aioresponse): async def test_pan_client_keys_upload(
self, pan_proxy_server, aiohttp_client, aioresponse
):
server, daemon, _ = pan_proxy_server server, daemon, _ = pan_proxy_server
client = await aiohttp_client(server) client = await aiohttp_client(server)
@ -140,11 +134,11 @@ class TestClass(object):
"https://example.org/_matrix/client/r0/login", "https://example.org/_matrix/client/r0/login",
status=200, status=200,
payload=self.login_response, payload=self.login_response,
repeat=True repeat=True,
) )
sync_url = re.compile( sync_url = re.compile(
r'^https://example\.org/_matrix/client/r0/sync\?access_token=.*' r"^https://example\.org/_matrix/client/r0/sync\?access_token=.*"
) )
aioresponse.get( aioresponse.get(
@ -169,7 +163,7 @@ class TestClass(object):
"type": "m.login.password", "type": "m.login.password",
"user": "example", "user": "example",
"password": "wordpass", "password": "wordpass",
} },
) )
pan_client = list(daemon.pan_clients.values())[0] pan_client = list(daemon.pan_clients.values())[0]

View file

@ -27,7 +27,7 @@ class TestClass(object):
"type": "m.room.message", "type": "m.room.message",
"unsigned": {"age": 43289803095}, "unsigned": {"age": 43289803095},
"user_id": "@example2:localhost", "user_id": "@example2:localhost",
"age": 43289803095 "age": 43289803095,
} }
) )
@ -43,13 +43,14 @@ class TestClass(object):
"type": "m.room.message", "type": "m.room.message",
"unsigned": {"age": 43289803095}, "unsigned": {"age": 43289803095},
"user_id": "@example2:localhost", "user_id": "@example2:localhost",
"age": 43289803095 "age": 43289803095,
} }
) )
@property @property
def encrypted_media_event(self): def encrypted_media_event(self):
return RoomEncryptedMedia.from_dict({ return RoomEncryptedMedia.from_dict(
{
"room_id": "!testroom:localhost", "room_id": "!testroom:localhost",
"event_id": "$15163622445EBvZK:localhost", "event_id": "$15163622445EBvZK:localhost",
"origin_server_ts": 1516362244030, "origin_server_ts": 1516362244030,
@ -65,17 +66,18 @@ class TestClass(object):
"ext": True, "ext": True,
"k": "yx0QvkgYlasdWEsdalkejaHBzCkKEBAp3tB7dGtWgrs", "k": "yx0QvkgYlasdWEsdalkejaHBzCkKEBAp3tB7dGtWgrs",
"key_ops": ["encrypt", "decrypt"], "key_ops": ["encrypt", "decrypt"],
"kty": "oct" "kty": "oct",
}, },
"iv": "0pglXX7fspIBBBBAEERLFd", "iv": "0pglXX7fspIBBBBAEERLFd",
"hashes": { "hashes": {
"sha256": "eXRDFvh+aXsQRj8a+5ZVVWUQ9Y6u9DYiz4tq1NvbLu8" "sha256": "eXRDFvh+aXsQRj8a+5ZVVWUQ9Y6u9DYiz4tq1NvbLu8"
}, },
"url": "mxc://localhost/maDtasSiPFjROFMnlwxIhhyW", "url": "mxc://localhost/maDtasSiPFjROFMnlwxIhhyW",
"mimetype": "image/jpeg" "mimetype": "image/jpeg",
},
},
} }
} )
})
def test_account_loading(self, panstore): def test_account_loading(self, panstore):
accounts = panstore.load_all_users() accounts = panstore.load_all_users()
@ -131,6 +133,7 @@ class TestClass(object):
pytest.skip("Indexing needs to be enabled to test this") pytest.skip("Indexing needs to be enabled to test this")
from pantalaimon.index import Index, IndexStore from pantalaimon.index import Index, IndexStore
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
store = IndexStore("example", tempdir) store = IndexStore("example", tempdir)
@ -148,8 +151,10 @@ class TestClass(object):
assert len(result["results"]) == 1 assert len(result["results"]) == 1
assert result["count"] == 1 assert result["count"] == 1
assert result["results"][0]["result"] == self.test_event.source assert result["results"][0]["result"] == self.test_event.source
assert (result["results"][0]["context"]["events_after"][0] assert (
== self.another_event.source) result["results"][0]["context"]["events_after"][0]
== self.another_event.source
)
def test_media_storage(self, panstore): def test_media_storage(self, panstore):
server_name = "test" server_name = "test"