mirror of
https://github.com/matrix-org/pantalaimon.git
synced 2025-01-09 14:39:34 -05:00
daemon: Overide the filter API endpoint.
This commit is contained in:
parent
90a87460e8
commit
da552973ff
@ -8,6 +8,7 @@ from functools import partial
|
|||||||
from ipaddress import ip_address
|
from ipaddress import ip_address
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import attr
|
import attr
|
||||||
@ -197,10 +198,37 @@ class ProxyDaemon:
|
|||||||
|
|
||||||
return access_token
|
return access_token
|
||||||
|
|
||||||
|
def sanitize_filter(self, sync_filter):
|
||||||
|
# type: (Dict[Any, Any]) -> Dict[Any, Any]
|
||||||
|
"""Make sure that a filter isn't filtering encrypted messages."""
|
||||||
|
sync_filter = dict(sync_filter)
|
||||||
|
room_filter = sync_filter.get("room", None)
|
||||||
|
|
||||||
|
if room_filter:
|
||||||
|
timeline_filter = room_filter.get("timeline", None)
|
||||||
|
|
||||||
|
if timeline_filter:
|
||||||
|
types_filter = timeline_filter.get("types", None)
|
||||||
|
|
||||||
|
if types_filter:
|
||||||
|
if "m.room.encrypted" not in types_filter:
|
||||||
|
types_filter.append("m.room.encrypted")
|
||||||
|
|
||||||
|
not_types_filter = timeline_filter.get("not_types", None)
|
||||||
|
|
||||||
|
if not_types_filter:
|
||||||
|
try:
|
||||||
|
not_types_filter.remove("m.room.encrypted")
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return sync_filter
|
||||||
|
|
||||||
async def forward_request(
|
async def forward_request(
|
||||||
self,
|
self,
|
||||||
request, # type: aiohttp.web.BaseRequest
|
request, # type: aiohttp.web.BaseRequest
|
||||||
params=None, # type: CIMultiDict
|
params=None, # type: CIMultiDict
|
||||||
|
data=None, # type: Dict[Any, Any]
|
||||||
session=None, # type: aiohttp.ClientSession
|
session=None, # type: aiohttp.ClientSession
|
||||||
token=None # type: str
|
token=None # type: str
|
||||||
):
|
):
|
||||||
@ -212,6 +240,7 @@ class ProxyDaemon:
|
|||||||
forwarded.
|
forwarded.
|
||||||
params (CIMultiDict, optional): The query parameters for the
|
params (CIMultiDict, optional): The query parameters for the
|
||||||
request.
|
request.
|
||||||
|
data (Dict, optional): Data for the request.
|
||||||
session (aiohttp.ClientSession, optional): The client session that
|
session (aiohttp.ClientSession, optional): The client session that
|
||||||
should be used to forward the request.
|
should be used to forward the request.
|
||||||
token (str, optional): The access token that should be used for the
|
token (str, optional): The access token that should be used for the
|
||||||
@ -238,6 +267,10 @@ class ProxyDaemon:
|
|||||||
if "access_token" in params:
|
if "access_token" in params:
|
||||||
params["access_token"] = token
|
params["access_token"] = token
|
||||||
|
|
||||||
|
if data:
|
||||||
|
data = data or await request.text()
|
||||||
|
headers.pop("Content-Length", None)
|
||||||
|
else:
|
||||||
data = await request.text()
|
data = await request.text()
|
||||||
|
|
||||||
return await session.request(
|
return await session.request(
|
||||||
@ -254,6 +287,7 @@ class ProxyDaemon:
|
|||||||
self,
|
self,
|
||||||
request,
|
request,
|
||||||
params=None,
|
params=None,
|
||||||
|
data=None,
|
||||||
session=None,
|
session=None,
|
||||||
token=None
|
token=None
|
||||||
):
|
):
|
||||||
@ -268,6 +302,7 @@ class ProxyDaemon:
|
|||||||
forwarded.
|
forwarded.
|
||||||
params (CIMultiDict, optional): The query parameters for the
|
params (CIMultiDict, optional): The query parameters for the
|
||||||
request.
|
request.
|
||||||
|
data (Dict, optional): Data for the request.
|
||||||
session (aiohttp.ClientSession, optional): The client session that
|
session (aiohttp.ClientSession, optional): The client session that
|
||||||
should be used to forward the request.
|
should be used to forward the request.
|
||||||
token (str, optional): The access token that should be used for the
|
token (str, optional): The access token that should be used for the
|
||||||
@ -276,9 +311,10 @@ class ProxyDaemon:
|
|||||||
try:
|
try:
|
||||||
response = await self.forward_request(
|
response = await self.forward_request(
|
||||||
request,
|
request,
|
||||||
params,
|
params=params,
|
||||||
session,
|
data=data,
|
||||||
token
|
session=session,
|
||||||
|
token=token
|
||||||
)
|
)
|
||||||
return web.Response(
|
return web.Response(
|
||||||
status=response.status,
|
status=response.status,
|
||||||
@ -440,32 +476,23 @@ class ProxyDaemon:
|
|||||||
return self._unknown_token
|
return self._unknown_token
|
||||||
|
|
||||||
sync_filter = request.query.get("filter", None)
|
sync_filter = request.query.get("filter", None)
|
||||||
|
query = CIMultiDict(request.query)
|
||||||
|
|
||||||
|
if sync_filter:
|
||||||
try:
|
try:
|
||||||
sync_filter = json.loads(sync_filter)
|
sync_filter = json.loads(sync_filter)
|
||||||
except (JSONDecodeError, TypeError):
|
except (JSONDecodeError, TypeError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if isinstance(sync_filter, int):
|
if isinstance(sync_filter, dict):
|
||||||
sync_filter = None
|
sync_filter = self.sanitize_filter(sync_filter)
|
||||||
|
|
||||||
# TODO edit the sync filter to not filter encrypted messages
|
query["filter"] = sync_filter
|
||||||
# TODO do the same with an uploaded filter
|
|
||||||
|
|
||||||
# room_filter = sync_filter.get("room", None)
|
|
||||||
|
|
||||||
# if room_filter:
|
|
||||||
# timeline_filter = room_filter.get("timeline", None)
|
|
||||||
# if timeline_filter:
|
|
||||||
# types_filter = timeline_filter.get("types", None)
|
|
||||||
|
|
||||||
query = CIMultiDict(request.query)
|
|
||||||
query.pop("filter", None)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self.forward_request(
|
response = await self.forward_request(
|
||||||
request,
|
request,
|
||||||
query,
|
params=query,
|
||||||
token=client.access_token
|
token=client.access_token
|
||||||
)
|
)
|
||||||
except (ClientProxyConnectionError,
|
except (ClientProxyConnectionError,
|
||||||
@ -568,10 +595,28 @@ class ProxyDaemon:
|
|||||||
text=await response.transport_response.text()
|
text=await response.transport_response.text()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def filter(self, request):
|
||||||
|
access_token = self.get_access_token(request)
|
||||||
|
|
||||||
|
if not access_token:
|
||||||
|
return self._missing_token
|
||||||
|
|
||||||
|
try:
|
||||||
|
content = await request.json()
|
||||||
|
except (JSONDecodeError, ContentTypeError):
|
||||||
|
return self._not_json
|
||||||
|
|
||||||
|
sanitized_content = self.sanitize_filter(content)
|
||||||
|
|
||||||
|
return await self.forward_to_web(
|
||||||
|
request,
|
||||||
|
data=json.dumps(sanitized_content)
|
||||||
|
)
|
||||||
|
|
||||||
async def shutdown(self, app):
|
async def shutdown(self, app):
|
||||||
"""Shut the daemon down closing all the client sessions it has.
|
"""Shut the daemon down closing all the client sessions it has.
|
||||||
|
|
||||||
This method is called when we shut the whole app down
|
This method is called when we shut the whole app down.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for client in self.pan_clients.values():
|
for client in self.pan_clients.values():
|
||||||
@ -612,6 +657,7 @@ async def init(homeserver, http_proxy, ssl, send_queue, recv_queue):
|
|||||||
r"/_matrix/client/r0/rooms/{room_id}/send/{event_type}/{txnid}",
|
r"/_matrix/client/r0/rooms/{room_id}/send/{event_type}/{txnid}",
|
||||||
proxy.send_message
|
proxy.send_message
|
||||||
),
|
),
|
||||||
|
web.post("/_matrix/client/r0/user/{user_id}/filter", proxy.filter),
|
||||||
])
|
])
|
||||||
app.router.add_route("*", "/" + "{proxyPath:.*}", proxy.router)
|
app.router.add_route("*", "/" + "{proxyPath:.*}", proxy.router)
|
||||||
app.on_shutdown.append(proxy.shutdown)
|
app.on_shutdown.append(proxy.shutdown)
|
||||||
|
Loading…
Reference in New Issue
Block a user