Merge branch 'fractal-scrollback'

This commit is contained in:
Damir Jelić 2020-07-20 15:18:35 +02:00
commit 2c0bd21e8c

View File

@ -411,29 +411,34 @@ class ProxyDaemon:
return access_token return access_token
def sanitize_subfilter(self, request_filter: Dict[Any, Any]):
types_filter = request_filter.get("types", None)
if types_filter:
if "m.room.encrypted" not in types_filter:
types_filter.append("m.room.encrypted")
not_types_filter = request_filter.get("not_types", None)
if not_types_filter:
try:
not_types_filter.remove("m.room.encrypted")
except ValueError:
pass
def sanitize_filter(self, sync_filter): def sanitize_filter(self, sync_filter):
# type: (Dict[Any, Any]) -> Dict[Any, Any] # type: (Dict[Any, Any]) -> Dict[Any, Any]
"""Make sure that a filter isn't filtering encrypted messages.""" """Make sure that a filter isn't filtering encrypted messages."""
sync_filter = dict(sync_filter) sync_filter = dict(sync_filter)
room_filter = sync_filter.get("room", None) room_filter = sync_filter.get("room", None)
self.sanitize_subfilter(sync_filter)
if room_filter: if room_filter:
timeline_filter = room_filter.get("timeline", None) timeline_filter = room_filter.get("timeline", None)
if timeline_filter: if timeline_filter:
types_filter = timeline_filter.get("types", None) self.sanitize_subfilter(timeline_filter)
if types_filter:
if "m.room.encrypted" not in types_filter:
types_filter.append("m.room.encrypted")
not_types_filter = timeline_filter.get("not_types", None)
if not_types_filter:
try:
not_types_filter.remove("m.room.encrypted")
except ValueError:
pass
return sync_filter return sync_filter
@ -763,8 +768,22 @@ class ProxyDaemon:
if not client: if not client:
return self._unknown_token return self._unknown_token
request_filter = request.query.get("filter", None)
query = CIMultiDict(request.query)
if request_filter:
try:
request_filter = json.loads(request_filter)
except (JSONDecodeError, TypeError):
pass
if isinstance(request_filter, dict):
request_filter = json.dumps(self.sanitize_filter(request_filter))
query["filter"] = request_filter
try: try:
response = await self.forward_request(request) response = await self.forward_request(request, params=query)
except ClientConnectionError as e: except ClientConnectionError as e:
return web.Response(status=500, text=str(e)) return web.Response(status=500, text=str(e))