mirror of
https://github.com/matrix-org/pantalaimon.git
synced 2025-01-08 22:28:01 -05:00
daemon: Overide the filter API endpoint.
This commit is contained in:
parent
90a87460e8
commit
da552973ff
@ -8,6 +8,7 @@ from functools import partial
|
||||
from ipaddress import ip_address
|
||||
from json import JSONDecodeError
|
||||
from urllib.parse import urlparse
|
||||
from typing import Any, Dict
|
||||
|
||||
import aiohttp
|
||||
import attr
|
||||
@ -197,10 +198,37 @@ class ProxyDaemon:
|
||||
|
||||
return access_token
|
||||
|
||||
def sanitize_filter(self, sync_filter):
|
||||
# type: (Dict[Any, Any]) -> Dict[Any, Any]
|
||||
"""Make sure that a filter isn't filtering encrypted messages."""
|
||||
sync_filter = dict(sync_filter)
|
||||
room_filter = sync_filter.get("room", None)
|
||||
|
||||
if room_filter:
|
||||
timeline_filter = room_filter.get("timeline", None)
|
||||
|
||||
if timeline_filter:
|
||||
types_filter = timeline_filter.get("types", None)
|
||||
|
||||
if types_filter:
|
||||
if "m.room.encrypted" not in types_filter:
|
||||
types_filter.append("m.room.encrypted")
|
||||
|
||||
not_types_filter = timeline_filter.get("not_types", None)
|
||||
|
||||
if not_types_filter:
|
||||
try:
|
||||
not_types_filter.remove("m.room.encrypted")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return sync_filter
|
||||
|
||||
async def forward_request(
|
||||
self,
|
||||
request, # type: aiohttp.web.BaseRequest
|
||||
params=None, # type: CIMultiDict
|
||||
data=None, # type: Dict[Any, Any]
|
||||
session=None, # type: aiohttp.ClientSession
|
||||
token=None # type: str
|
||||
):
|
||||
@ -212,6 +240,7 @@ class ProxyDaemon:
|
||||
forwarded.
|
||||
params (CIMultiDict, optional): The query parameters for the
|
||||
request.
|
||||
data (Dict, optional): Data for the request.
|
||||
session (aiohttp.ClientSession, optional): The client session that
|
||||
should be used to forward the request.
|
||||
token (str, optional): The access token that should be used for the
|
||||
@ -238,7 +267,11 @@ class ProxyDaemon:
|
||||
if "access_token" in params:
|
||||
params["access_token"] = token
|
||||
|
||||
data = await request.text()
|
||||
if data:
|
||||
data = data or await request.text()
|
||||
headers.pop("Content-Length", None)
|
||||
else:
|
||||
data = await request.text()
|
||||
|
||||
return await session.request(
|
||||
method,
|
||||
@ -254,6 +287,7 @@ class ProxyDaemon:
|
||||
self,
|
||||
request,
|
||||
params=None,
|
||||
data=None,
|
||||
session=None,
|
||||
token=None
|
||||
):
|
||||
@ -268,6 +302,7 @@ class ProxyDaemon:
|
||||
forwarded.
|
||||
params (CIMultiDict, optional): The query parameters for the
|
||||
request.
|
||||
data (Dict, optional): Data for the request.
|
||||
session (aiohttp.ClientSession, optional): The client session that
|
||||
should be used to forward the request.
|
||||
token (str, optional): The access token that should be used for the
|
||||
@ -276,9 +311,10 @@ class ProxyDaemon:
|
||||
try:
|
||||
response = await self.forward_request(
|
||||
request,
|
||||
params,
|
||||
session,
|
||||
token
|
||||
params=params,
|
||||
data=data,
|
||||
session=session,
|
||||
token=token
|
||||
)
|
||||
return web.Response(
|
||||
status=response.status,
|
||||
@ -440,32 +476,23 @@ class ProxyDaemon:
|
||||
return self._unknown_token
|
||||
|
||||
sync_filter = request.query.get("filter", None)
|
||||
|
||||
try:
|
||||
sync_filter = json.loads(sync_filter)
|
||||
except (JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
if isinstance(sync_filter, int):
|
||||
sync_filter = None
|
||||
|
||||
# TODO edit the sync filter to not filter encrypted messages
|
||||
# TODO do the same with an uploaded filter
|
||||
|
||||
# room_filter = sync_filter.get("room", None)
|
||||
|
||||
# if room_filter:
|
||||
# timeline_filter = room_filter.get("timeline", None)
|
||||
# if timeline_filter:
|
||||
# types_filter = timeline_filter.get("types", None)
|
||||
|
||||
query = CIMultiDict(request.query)
|
||||
query.pop("filter", None)
|
||||
|
||||
if sync_filter:
|
||||
try:
|
||||
sync_filter = json.loads(sync_filter)
|
||||
except (JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
if isinstance(sync_filter, dict):
|
||||
sync_filter = self.sanitize_filter(sync_filter)
|
||||
|
||||
query["filter"] = sync_filter
|
||||
|
||||
try:
|
||||
response = await self.forward_request(
|
||||
request,
|
||||
query,
|
||||
params=query,
|
||||
token=client.access_token
|
||||
)
|
||||
except (ClientProxyConnectionError,
|
||||
@ -568,10 +595,28 @@ class ProxyDaemon:
|
||||
text=await response.transport_response.text()
|
||||
)
|
||||
|
||||
async def filter(self, request):
|
||||
access_token = self.get_access_token(request)
|
||||
|
||||
if not access_token:
|
||||
return self._missing_token
|
||||
|
||||
try:
|
||||
content = await request.json()
|
||||
except (JSONDecodeError, ContentTypeError):
|
||||
return self._not_json
|
||||
|
||||
sanitized_content = self.sanitize_filter(content)
|
||||
|
||||
return await self.forward_to_web(
|
||||
request,
|
||||
data=json.dumps(sanitized_content)
|
||||
)
|
||||
|
||||
async def shutdown(self, app):
|
||||
"""Shut the daemon down closing all the client sessions it has.
|
||||
|
||||
This method is called when we shut the whole app down
|
||||
This method is called when we shut the whole app down.
|
||||
"""
|
||||
|
||||
for client in self.pan_clients.values():
|
||||
@ -612,6 +657,7 @@ async def init(homeserver, http_proxy, ssl, send_queue, recv_queue):
|
||||
r"/_matrix/client/r0/rooms/{room_id}/send/{event_type}/{txnid}",
|
||||
proxy.send_message
|
||||
),
|
||||
web.post("/_matrix/client/r0/user/{user_id}/filter", proxy.filter),
|
||||
])
|
||||
app.router.add_route("*", "/" + "{proxyPath:.*}", proxy.router)
|
||||
app.on_shutdown.append(proxy.shutdown)
|
||||
|
Loading…
Reference in New Issue
Block a user