Initial trivial implementation of an actual 'Filtering' object; move storage of user filters into there

This commit is contained in:
Paul "LeoNerd" Evans 2015-01-27 14:28:56 +00:00
parent f9958f3404
commit 05c7cba73a
3 changed files with 58 additions and 13 deletions

41
synapse/api/filtering.py Normal file
View File

@ -0,0 +1,41 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(paul)
_filters_for_user = {}
class Filtering(object):
def __init__(self, hs):
super(Filtering, self).__init__()
self.hs = hs
def get_user_filter(self, user_localpart, filter_id):
filters = _filters_for_user.get(user_localpart, None)
if not filters or filter_id >= len(filters):
raise KeyError()
return filters[filter_id]
def add_user_filter(self, user_localpart, definition):
filters = _filters_for_user.setdefault(user_localpart, [])
filter_id = len(filters)
filters.append(definition)
return filter_id

View File

@ -28,10 +28,6 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# TODO(paul)
_filters_for_user = {}
class GetFilterRestServlet(RestServlet): class GetFilterRestServlet(RestServlet):
PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)") PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
@ -39,6 +35,7 @@ class GetFilterRestServlet(RestServlet):
super(GetFilterRestServlet, self).__init__() super(GetFilterRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id, filter_id): def on_GET(self, request, user_id, filter_id):
@ -56,13 +53,14 @@ class GetFilterRestServlet(RestServlet):
except: except:
raise SynapseError(400, "Invalid filter_id") raise SynapseError(400, "Invalid filter_id")
filters = _filters_for_user.get(target_user.localpart, None) try:
defer.returnValue((200, self.filtering.get_user_filter(
if not filters or filter_id >= len(filters): user_localpart=target_user.localpart,
filter_id=filter_id,
)))
except KeyError:
raise SynapseError(400, "No such filter") raise SynapseError(400, "No such filter")
defer.returnValue((200, filters[filter_id]))
class CreateFilterRestServlet(RestServlet): class CreateFilterRestServlet(RestServlet):
PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter") PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter")
@ -71,6 +69,7 @@ class CreateFilterRestServlet(RestServlet):
super(CreateFilterRestServlet, self).__init__() super(CreateFilterRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id): def on_POST(self, request, user_id):
@ -90,10 +89,10 @@ class CreateFilterRestServlet(RestServlet):
except: except:
raise SynapseError(400, "Invalid filter definition") raise SynapseError(400, "Invalid filter definition")
filters = _filters_for_user.setdefault(target_user.localpart, []) filter_id = self.filtering.add_user_filter(
user_localpart=target_user.localpart,
filter_id = len(filters) definition=content,
filters.append(content) )
defer.returnValue((200, {"filter_id": str(filter_id)})) defer.returnValue((200, {"filter_id": str(filter_id)}))

View File

@ -32,6 +32,7 @@ from synapse.streams.events import EventSources
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.crypto.keyring import Keyring from synapse.crypto.keyring import Keyring
from synapse.events.builder import EventBuilderFactory from synapse.events.builder import EventBuilderFactory
from synapse.api.filtering import Filtering
class BaseHomeServer(object): class BaseHomeServer(object):
@ -79,6 +80,7 @@ class BaseHomeServer(object):
'ratelimiter', 'ratelimiter',
'keyring', 'keyring',
'event_builder_factory', 'event_builder_factory',
'filtering',
] ]
def __init__(self, hostname, **kwargs): def __init__(self, hostname, **kwargs):
@ -197,3 +199,6 @@ class HomeServer(BaseHomeServer):
clock=self.get_clock(), clock=self.get_clock(),
hostname=self.hostname, hostname=self.hostname,
) )
def build_filtering(self):
return Filtering(self)