daemon: Overide the filter API endpoint.

This commit is contained in:
Damir Jelić 2019-04-30 13:07:11 +02:00
parent 90a87460e8
commit da552973ff

View File

@ -8,6 +8,7 @@ from functools import partial
from ipaddress import ip_address
from json import JSONDecodeError
from urllib.parse import urlparse
from typing import Any, Dict
import aiohttp
import attr
@ -197,10 +198,37 @@ class ProxyDaemon:
return access_token
def sanitize_filter(self, sync_filter):
# type: (Dict[Any, Any]) -> Dict[Any, Any]
"""Make sure that a filter isn't filtering encrypted messages."""
sync_filter = dict(sync_filter)
room_filter = sync_filter.get("room", None)
if room_filter:
timeline_filter = room_filter.get("timeline", None)
if timeline_filter:
types_filter = timeline_filter.get("types", None)
if types_filter:
if "m.room.encrypted" not in types_filter:
types_filter.append("m.room.encrypted")
not_types_filter = timeline_filter.get("not_types", None)
if not_types_filter:
try:
not_types_filter.remove("m.room.encrypted")
except ValueError:
pass
return sync_filter
async def forward_request(
self,
request, # type: aiohttp.web.BaseRequest
params=None, # type: CIMultiDict
data=None, # type: Dict[Any, Any]
session=None, # type: aiohttp.ClientSession
token=None # type: str
):
@ -212,6 +240,7 @@ class ProxyDaemon:
forwarded.
params (CIMultiDict, optional): The query parameters for the
request.
data (Dict, optional): Data for the request.
session (aiohttp.ClientSession, optional): The client session that
should be used to forward the request.
token (str, optional): The access token that should be used for the
@ -238,6 +267,10 @@ class ProxyDaemon:
if "access_token" in params:
params["access_token"] = token
if data:
data = data or await request.text()
headers.pop("Content-Length", None)
else:
data = await request.text()
return await session.request(
@ -254,6 +287,7 @@ class ProxyDaemon:
self,
request,
params=None,
data=None,
session=None,
token=None
):
@ -268,6 +302,7 @@ class ProxyDaemon:
forwarded.
params (CIMultiDict, optional): The query parameters for the
request.
data (Dict, optional): Data for the request.
session (aiohttp.ClientSession, optional): The client session that
should be used to forward the request.
token (str, optional): The access token that should be used for the
@ -276,9 +311,10 @@ class ProxyDaemon:
try:
response = await self.forward_request(
request,
params,
session,
token
params=params,
data=data,
session=session,
token=token
)
return web.Response(
status=response.status,
@ -440,32 +476,23 @@ class ProxyDaemon:
return self._unknown_token
sync_filter = request.query.get("filter", None)
query = CIMultiDict(request.query)
if sync_filter:
try:
sync_filter = json.loads(sync_filter)
except (JSONDecodeError, TypeError):
pass
if isinstance(sync_filter, int):
sync_filter = None
if isinstance(sync_filter, dict):
sync_filter = self.sanitize_filter(sync_filter)
# TODO edit the sync filter to not filter encrypted messages
# TODO do the same with an uploaded filter
# room_filter = sync_filter.get("room", None)
# if room_filter:
# timeline_filter = room_filter.get("timeline", None)
# if timeline_filter:
# types_filter = timeline_filter.get("types", None)
query = CIMultiDict(request.query)
query.pop("filter", None)
query["filter"] = sync_filter
try:
response = await self.forward_request(
request,
query,
params=query,
token=client.access_token
)
except (ClientProxyConnectionError,
@ -568,10 +595,28 @@ class ProxyDaemon:
text=await response.transport_response.text()
)
async def filter(self, request):
access_token = self.get_access_token(request)
if not access_token:
return self._missing_token
try:
content = await request.json()
except (JSONDecodeError, ContentTypeError):
return self._not_json
sanitized_content = self.sanitize_filter(content)
return await self.forward_to_web(
request,
data=json.dumps(sanitized_content)
)
async def shutdown(self, app):
"""Shut the daemon down closing all the client sessions it has.
This method is called when we shut the whole app down
This method is called when we shut the whole app down.
"""
for client in self.pan_clients.values():
@ -612,6 +657,7 @@ async def init(homeserver, http_proxy, ssl, send_queue, recv_queue):
r"/_matrix/client/r0/rooms/{room_id}/send/{event_type}/{txnid}",
proxy.send_message
),
web.post("/_matrix/client/r0/user/{user_id}/filter", proxy.filter),
])
app.router.add_route("*", "/" + "{proxyPath:.*}", proxy.router)
app.on_shutdown.append(proxy.shutdown)