daemon: Overide the filter API endpoint.

This commit is contained in:
Damir Jelić 2019-04-30 13:07:11 +02:00
parent 90a87460e8
commit da552973ff

View File

@ -8,6 +8,7 @@ from functools import partial
from ipaddress import ip_address from ipaddress import ip_address
from json import JSONDecodeError from json import JSONDecodeError
from urllib.parse import urlparse from urllib.parse import urlparse
from typing import Any, Dict
import aiohttp import aiohttp
import attr import attr
@ -197,10 +198,37 @@ class ProxyDaemon:
return access_token return access_token
def sanitize_filter(self, sync_filter):
# type: (Dict[Any, Any]) -> Dict[Any, Any]
"""Make sure that a filter isn't filtering encrypted messages."""
sync_filter = dict(sync_filter)
room_filter = sync_filter.get("room", None)
if room_filter:
timeline_filter = room_filter.get("timeline", None)
if timeline_filter:
types_filter = timeline_filter.get("types", None)
if types_filter:
if "m.room.encrypted" not in types_filter:
types_filter.append("m.room.encrypted")
not_types_filter = timeline_filter.get("not_types", None)
if not_types_filter:
try:
not_types_filter.remove("m.room.encrypted")
except ValueError:
pass
return sync_filter
async def forward_request( async def forward_request(
self, self,
request, # type: aiohttp.web.BaseRequest request, # type: aiohttp.web.BaseRequest
params=None, # type: CIMultiDict params=None, # type: CIMultiDict
data=None, # type: Dict[Any, Any]
session=None, # type: aiohttp.ClientSession session=None, # type: aiohttp.ClientSession
token=None # type: str token=None # type: str
): ):
@ -212,6 +240,7 @@ class ProxyDaemon:
forwarded. forwarded.
params (CIMultiDict, optional): The query parameters for the params (CIMultiDict, optional): The query parameters for the
request. request.
data (Dict, optional): Data for the request.
session (aiohttp.ClientSession, optional): The client session that session (aiohttp.ClientSession, optional): The client session that
should be used to forward the request. should be used to forward the request.
token (str, optional): The access token that should be used for the token (str, optional): The access token that should be used for the
@ -238,7 +267,11 @@ class ProxyDaemon:
if "access_token" in params: if "access_token" in params:
params["access_token"] = token params["access_token"] = token
data = await request.text() if data:
data = data or await request.text()
headers.pop("Content-Length", None)
else:
data = await request.text()
return await session.request( return await session.request(
method, method,
@ -254,6 +287,7 @@ class ProxyDaemon:
self, self,
request, request,
params=None, params=None,
data=None,
session=None, session=None,
token=None token=None
): ):
@ -268,6 +302,7 @@ class ProxyDaemon:
forwarded. forwarded.
params (CIMultiDict, optional): The query parameters for the params (CIMultiDict, optional): The query parameters for the
request. request.
data (Dict, optional): Data for the request.
session (aiohttp.ClientSession, optional): The client session that session (aiohttp.ClientSession, optional): The client session that
should be used to forward the request. should be used to forward the request.
token (str, optional): The access token that should be used for the token (str, optional): The access token that should be used for the
@ -276,9 +311,10 @@ class ProxyDaemon:
try: try:
response = await self.forward_request( response = await self.forward_request(
request, request,
params, params=params,
session, data=data,
token session=session,
token=token
) )
return web.Response( return web.Response(
status=response.status, status=response.status,
@ -440,32 +476,23 @@ class ProxyDaemon:
return self._unknown_token return self._unknown_token
sync_filter = request.query.get("filter", None) sync_filter = request.query.get("filter", None)
try:
sync_filter = json.loads(sync_filter)
except (JSONDecodeError, TypeError):
pass
if isinstance(sync_filter, int):
sync_filter = None
# TODO edit the sync filter to not filter encrypted messages
# TODO do the same with an uploaded filter
# room_filter = sync_filter.get("room", None)
# if room_filter:
# timeline_filter = room_filter.get("timeline", None)
# if timeline_filter:
# types_filter = timeline_filter.get("types", None)
query = CIMultiDict(request.query) query = CIMultiDict(request.query)
query.pop("filter", None)
if sync_filter:
try:
sync_filter = json.loads(sync_filter)
except (JSONDecodeError, TypeError):
pass
if isinstance(sync_filter, dict):
sync_filter = self.sanitize_filter(sync_filter)
query["filter"] = sync_filter
try: try:
response = await self.forward_request( response = await self.forward_request(
request, request,
query, params=query,
token=client.access_token token=client.access_token
) )
except (ClientProxyConnectionError, except (ClientProxyConnectionError,
@ -568,10 +595,28 @@ class ProxyDaemon:
text=await response.transport_response.text() text=await response.transport_response.text()
) )
async def filter(self, request):
access_token = self.get_access_token(request)
if not access_token:
return self._missing_token
try:
content = await request.json()
except (JSONDecodeError, ContentTypeError):
return self._not_json
sanitized_content = self.sanitize_filter(content)
return await self.forward_to_web(
request,
data=json.dumps(sanitized_content)
)
async def shutdown(self, app): async def shutdown(self, app):
"""Shut the daemon down closing all the client sessions it has. """Shut the daemon down closing all the client sessions it has.
This method is called when we shut the whole app down This method is called when we shut the whole app down.
""" """
for client in self.pan_clients.values(): for client in self.pan_clients.values():
@ -612,6 +657,7 @@ async def init(homeserver, http_proxy, ssl, send_queue, recv_queue):
r"/_matrix/client/r0/rooms/{room_id}/send/{event_type}/{txnid}", r"/_matrix/client/r0/rooms/{room_id}/send/{event_type}/{txnid}",
proxy.send_message proxy.send_message
), ),
web.post("/_matrix/client/r0/user/{user_id}/filter", proxy.filter),
]) ])
app.router.add_route("*", "/" + "{proxyPath:.*}", proxy.router) app.router.add_route("*", "/" + "{proxyPath:.*}", proxy.router)
app.on_shutdown.append(proxy.shutdown) app.on_shutdown.append(proxy.shutdown)