Sanitize filters

This commit is contained in:
Erik Johnston 2016-01-22 10:41:30 +00:00
parent 297eded261
commit 975903ae17
3 changed files with 40 additions and 34 deletions

View File

@ -28,14 +28,14 @@ class Filtering(object):
return result return result
def add_user_filter(self, user_localpart, user_filter): def add_user_filter(self, user_localpart, user_filter):
self._check_valid_filter(user_filter) self.check_valid_filter(user_filter)
return self.store.add_user_filter(user_localpart, user_filter) return self.store.add_user_filter(user_localpart, user_filter)
# TODO(paul): surely we should probably add a delete_user_filter or # TODO(paul): surely we should probably add a delete_user_filter or
# replace_user_filter at some point? There's no REST API specified for # replace_user_filter at some point? There's no REST API specified for
# them however # them however
def _check_valid_filter(self, user_filter_json): def check_valid_filter(self, user_filter_json):
"""Check if the provided filter is valid. """Check if the provided filter is valid.
This inspects all definitions contained within the filter. This inspects all definitions contained within the filter.
@ -129,52 +129,55 @@ class Filtering(object):
class FilterCollection(object): class FilterCollection(object):
def __init__(self, filter_json): def __init__(self, filter_json):
self.filter_json = filter_json self._filter_json = filter_json
room_filter_json = self.filter_json.get("room", {}) room_filter_json = self._filter_json.get("room", {})
self.room_filter = Filter({ self._room_filter = Filter({
k: v for k, v in room_filter_json.items() k: v for k, v in room_filter_json.items()
if k in ("rooms", "not_rooms") if k in ("rooms", "not_rooms")
}) })
self.room_timeline_filter = Filter(room_filter_json.get("timeline", {})) self._room_timeline_filter = Filter(room_filter_json.get("timeline", {}))
self.room_state_filter = Filter(room_filter_json.get("state", {})) self._room_state_filter = Filter(room_filter_json.get("state", {}))
self.room_ephemeral_filter = Filter(room_filter_json.get("ephemeral", {})) self._room_ephemeral_filter = Filter(room_filter_json.get("ephemeral", {}))
self.room_account_data = Filter(room_filter_json.get("account_data", {})) self._room_account_data = Filter(room_filter_json.get("account_data", {}))
self.presence_filter = Filter(self.filter_json.get("presence", {})) self._presence_filter = Filter(filter_json.get("presence", {}))
self.account_data = Filter(self.filter_json.get("account_data", {})) self._account_data = Filter(filter_json.get("account_data", {}))
self.include_leave = self.filter_json.get("room", {}).get( self.include_leave = filter_json.get("room", {}).get(
"include_leave", False "include_leave", False
) )
def get_filter_json(self):
return self._filter_json
def timeline_limit(self): def timeline_limit(self):
return self.room_timeline_filter.limit() return self._room_timeline_filter.limit()
def presence_limit(self): def presence_limit(self):
return self.presence_filter.limit() return self._presence_filter.limit()
def ephemeral_limit(self): def ephemeral_limit(self):
return self.room_ephemeral_filter.limit() return self._room_ephemeral_filter.limit()
def filter_presence(self, events): def filter_presence(self, events):
return self.presence_filter.filter(events) return self._presence_filter.filter(events)
def filter_account_data(self, events): def filter_account_data(self, events):
return self.account_data.filter(events) return self._account_data.filter(events)
def filter_room_state(self, events): def filter_room_state(self, events):
return self.room_state_filter.filter(self.room_filter.filter(events)) return self._room_state_filter.filter(self._room_filter.filter(events))
def filter_room_timeline(self, events): def filter_room_timeline(self, events):
return self.room_timeline_filter.filter(self.room_filter.filter(events)) return self._room_timeline_filter.filter(self._room_filter.filter(events))
def filter_room_ephemeral(self, events): def filter_room_ephemeral(self, events):
return self.room_ephemeral_filter.filter(self.room_filter.filter(events)) return self._room_ephemeral_filter.filter(self._room_filter.filter(events))
def filter_room_account_data(self, events): def filter_room_account_data(self, events):
return self.room_account_data.filter(self.room_filter.filter(events)) return self._room_account_data.filter(self._room_filter.filter(events))
class Filter(object): class Filter(object):
@ -258,3 +261,6 @@ def _matches_wildcard(actual_value, filter_value):
return actual_value.startswith(type_prefix) return actual_value.startswith(type_prefix)
else: else:
return actual_value == filter_value return actual_value == filter_value
DEFAULT_FILTER_COLLECTION = FilterCollection({})

View File

@ -59,7 +59,7 @@ class GetFilterRestServlet(RestServlet):
filter_id=filter_id, filter_id=filter_id,
) )
defer.returnValue((200, filter.filter_json)) defer.returnValue((200, filter.get_filter_json()))
except KeyError: except KeyError:
raise SynapseError(400, "No such filter") raise SynapseError(400, "No such filter")

View File

@ -24,7 +24,7 @@ from synapse.events import FrozenEvent
from synapse.events.utils import ( from synapse.events.utils import (
serialize_event, format_event_for_client_v2_without_room_id, serialize_event, format_event_for_client_v2_without_room_id,
) )
from synapse.api.filtering import FilterCollection from synapse.api.filtering import FilterCollection, DEFAULT_FILTER_COLLECTION
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from ._base import client_v2_patterns from ._base import client_v2_patterns
@ -113,20 +113,20 @@ class SyncRestServlet(RestServlet):
) )
) )
if filter_id and filter_id.startswith('{'): if filter_id:
try: if filter_id.startswith('{'):
filter_object = json.loads(filter_id) try:
except: filter_object = json.loads(filter_id)
raise SynapseError(400, "Invalid filter JSON") except:
self.filtering._check_valid_filter(filter_object) raise SynapseError(400, "Invalid filter JSON")
filter = FilterCollection(filter_object) self.filtering.check_valid_filter(filter_object)
else: filter = FilterCollection(filter_object)
try: else:
filter = yield self.filtering.get_user_filter( filter = yield self.filtering.get_user_filter(
user.localpart, filter_id user.localpart, filter_id
) )
except: else:
filter = FilterCollection({}) filter = DEFAULT_FILTER_COLLECTION
sync_config = SyncConfig( sync_config = SyncConfig(
user=user, user=user,