diff --git a/pantalaimon/daemon.py b/pantalaimon/daemon.py index 03f152e..788dfed 100755 --- a/pantalaimon/daemon.py +++ b/pantalaimon/daemon.py @@ -105,6 +105,7 @@ class ProxyDaemon: client_info = attr.ib(init=False, default=attr.Factory(dict), type=dict) default_session = attr.ib(init=False, default=None) media_info = attr.ib(init=False, default=None) + upload_info = attr.ib(init=False, default=None) database_name = "pan.db" def __attrs_post_init__(self): @@ -115,6 +116,7 @@ class ProxyDaemon: self.store = PanStore(self.data_dir) accounts = self.store.load_users(self.name) self.media_info = self.store.load_media(self.name) + self.upload_info = self.store.load_upload(self.name) for user_id, device_id in accounts: if self.conf.keyring: @@ -830,9 +832,17 @@ class ProxyDaemon: async def _map_media_upload(self, content, request, client): content_uri = content["url"] - upload = self.store.load_upload(content_uri) - if upload is None: - return await self.forward_to_web(request, token=client.access_token) + try: + upload_info = self.upload_info[content_uri] + except KeyError: + upload_info = self.store.load_upload(self.name, content_uri) + if not upload_info: + logger.info(f"No upload info found for {self.name}/{content_uri}") + + return await self.forward_to_web(request, token=client.access_token) + + self.upload_info[content_uri] = upload_info + mxc = urlparse(content_uri) mxc_server = mxc.netloc.strip("/") @@ -841,7 +851,7 @@ class ProxyDaemon: logger.info(f"Adding media info for {mxc_server}/{mxc_path} to the store") - media = MediaInfo(mxc_server, mxc_path, upload.key, upload.iv, upload.hashes) + media = MediaInfo(mxc_server, mxc_path, upload_info.key, upload_info.iv, upload_info.hashes) self.media_info[(mxc_server, mxc_path)] = media self.store.save_media(self.name, media) @@ -1119,7 +1129,7 @@ class ProxyDaemon: body=await response.transport_response.read(), ) - self.store.save_upload(response.content_uri, maybe_keys) + self.store.save_upload(self.name, response.content_uri, maybe_keys) return web.Response( status=response.transport_response.status, @@ -1133,7 +1143,7 @@ class ProxyDaemon: except SendRetryError as e: return web.Response(status=503, text=str(e)) - async def _load_media(self, server_name, media_id, file_name, request): + async def _load_media(self, server_name, media_id, file_name): try: media_info = self.media_info[(server_name, media_id)] except KeyError: diff --git a/pantalaimon/store.py b/pantalaimon/store.py index 123f80f..fbde863 100644 --- a/pantalaimon/store.py +++ b/pantalaimon/store.py @@ -31,6 +31,7 @@ from cachetools import LRUCache MAX_LOADED_MEDIA = 10000 +MAX_LOADED_UPLOAD = 10000 @attr.s @@ -122,13 +123,16 @@ class PanMediaInfo(Model): class PanUploadInfo(Model): + server = ForeignKeyField( + model=Servers, column_name="server_id", backref="upload", on_delete="CASCADE" + ) content_uri = TextField() key = DictField() hashes = DictField() iv = TextField() class Meta: - constraints = [SQL("UNIQUE(content_uri)")] + constraints = [SQL("UNIQUE(server_id, content_uri)")] @attr.s @@ -182,8 +186,11 @@ class PanStore: return None @use_database - def save_upload(self, content_uri, upload): + def save_upload(self, server, content_uri, upload): + server = Servers.get(name=server) + PanUploadInfo.insert( + server=server, content_uri=content_uri, key=upload["key"], iv=upload["iv"], @@ -191,8 +198,23 @@ class PanStore: ).on_conflict_ignore().execute() @use_database - def load_upload(self, content_uri): + def load_upload(self, server, content_uri=None): + server, _ = Servers.get_or_create(name=server) + + if not content_uri: + upload_cache = LRUCache(maxsize=MAX_LOADED_UPLOAD) + + for i, u in enumerate(server.upload): + if i > MAX_LOADED_UPLOAD: + break + + upload = UploadInfo(u.content_uri, u.key, u.iv, u.hashes) + upload_cache[u.content_uri] = upload + + return upload_cache + else: u = PanUploadInfo.get_or_none( + PanUploadInfo.server == server, PanUploadInfo.content_uri == content_uri, ) diff --git a/tests/store_test.py b/tests/store_test.py index ada5737..981170f 100644 --- a/tests/store_test.py +++ b/tests/store_test.py @@ -179,14 +179,21 @@ class TestClass(object): assert media_info == panstore.load_media(server_name, mxc_server, mxc_path) def test_upload_storage(self, panstore): + server_name = "test" + upload_cache = panstore.load_upload(server_name) + assert not upload_cache + event = self.encrypted_media_event - assert not panstore.load_upload(event.url) + assert not panstore.load_upload(server_name, event.url) upload = UploadInfo(event.url, event.key, event.iv, event.hashes) - panstore.save_upload(event.url, {"key": event.key, "iv": event.iv, "hashes": event.hashes}) + panstore.save_upload(server_name, event.url, {"key": event.key, "iv": event.iv, "hashes": event.hashes}) - upload_info = panstore.load_upload(event.url) + upload_cache = panstore.load_upload(server_name) + assert (event.url) in upload_cache + upload_info = upload_cache[event.url] assert upload_info == upload + assert upload_info == panstore.load_upload(server_name, event.url)