diff --git a/pantalaimon/daemon.py b/pantalaimon/daemon.py index 5caab57..dcf3935 100755 --- a/pantalaimon/daemon.py +++ b/pantalaimon/daemon.py @@ -8,6 +8,7 @@ from functools import partial from ipaddress import ip_address from json import JSONDecodeError from urllib.parse import urlparse +from typing import Any, Dict import aiohttp import attr @@ -197,10 +198,37 @@ class ProxyDaemon: return access_token + def sanitize_filter(self, sync_filter): + # type: (Dict[Any, Any]) -> Dict[Any, Any] + """Make sure that a filter isn't filtering encrypted messages.""" + sync_filter = dict(sync_filter) + room_filter = sync_filter.get("room", None) + + if room_filter: + timeline_filter = room_filter.get("timeline", None) + + if timeline_filter: + types_filter = timeline_filter.get("types", None) + + if types_filter: + if "m.room.encrypted" not in types_filter: + types_filter.append("m.room.encrypted") + + not_types_filter = timeline_filter.get("not_types", None) + + if not_types_filter: + try: + not_types_filter.remove("m.room.encrypted") + except ValueError: + pass + + return sync_filter + async def forward_request( self, request, # type: aiohttp.web.BaseRequest params=None, # type: CIMultiDict + data=None, # type: Dict[Any, Any] session=None, # type: aiohttp.ClientSession token=None # type: str ): @@ -212,6 +240,7 @@ class ProxyDaemon: forwarded. params (CIMultiDict, optional): The query parameters for the request. + data (Dict, optional): Data for the request. session (aiohttp.ClientSession, optional): The client session that should be used to forward the request. token (str, optional): The access token that should be used for the @@ -238,7 +267,11 @@ class ProxyDaemon: if "access_token" in params: params["access_token"] = token - data = await request.text() + if data: + data = data or await request.text() + headers.pop("Content-Length", None) + else: + data = await request.text() return await session.request( method, @@ -254,6 +287,7 @@ class ProxyDaemon: self, request, params=None, + data=None, session=None, token=None ): @@ -268,6 +302,7 @@ class ProxyDaemon: forwarded. params (CIMultiDict, optional): The query parameters for the request. + data (Dict, optional): Data for the request. session (aiohttp.ClientSession, optional): The client session that should be used to forward the request. token (str, optional): The access token that should be used for the @@ -276,9 +311,10 @@ class ProxyDaemon: try: response = await self.forward_request( request, - params, - session, - token + params=params, + data=data, + session=session, + token=token ) return web.Response( status=response.status, @@ -440,32 +476,23 @@ class ProxyDaemon: return self._unknown_token sync_filter = request.query.get("filter", None) - - try: - sync_filter = json.loads(sync_filter) - except (JSONDecodeError, TypeError): - pass - - if isinstance(sync_filter, int): - sync_filter = None - - # TODO edit the sync filter to not filter encrypted messages - # TODO do the same with an uploaded filter - - # room_filter = sync_filter.get("room", None) - - # if room_filter: - # timeline_filter = room_filter.get("timeline", None) - # if timeline_filter: - # types_filter = timeline_filter.get("types", None) - query = CIMultiDict(request.query) - query.pop("filter", None) + + if sync_filter: + try: + sync_filter = json.loads(sync_filter) + except (JSONDecodeError, TypeError): + pass + + if isinstance(sync_filter, dict): + sync_filter = self.sanitize_filter(sync_filter) + + query["filter"] = sync_filter try: response = await self.forward_request( request, - query, + params=query, token=client.access_token ) except (ClientProxyConnectionError, @@ -568,10 +595,28 @@ class ProxyDaemon: text=await response.transport_response.text() ) + async def filter(self, request): + access_token = self.get_access_token(request) + + if not access_token: + return self._missing_token + + try: + content = await request.json() + except (JSONDecodeError, ContentTypeError): + return self._not_json + + sanitized_content = self.sanitize_filter(content) + + return await self.forward_to_web( + request, + data=json.dumps(sanitized_content) + ) + async def shutdown(self, app): """Shut the daemon down closing all the client sessions it has. - This method is called when we shut the whole app down + This method is called when we shut the whole app down. """ for client in self.pan_clients.values(): @@ -612,6 +657,7 @@ async def init(homeserver, http_proxy, ssl, send_queue, recv_queue): r"/_matrix/client/r0/rooms/{room_id}/send/{event_type}/{txnid}", proxy.send_message ), + web.post("/_matrix/client/r0/user/{user_id}/filter", proxy.filter), ]) app.router.add_route("*", "/" + "{proxyPath:.*}", proxy.router) app.on_shutdown.append(proxy.shutdown)