Move storage of user filters into real datastore layer; now have to mock it out in the REST-level tests

This commit is contained in:
Paul "LeoNerd" Evans 2015-01-27 17:48:13 +00:00
parent 059651efa1
commit 54e513b4e6
5 changed files with 79 additions and 27 deletions

View File

@ -16,37 +16,18 @@
from twisted.internet import defer from twisted.internet import defer
# TODO(paul)
_filters_for_user = {}
class Filtering(object): class Filtering(object):
def __init__(self, hs): def __init__(self, hs):
super(Filtering, self).__init__() super(Filtering, self).__init__()
self.hs = hs self.store = hs.get_datastore()
@defer.inlineCallbacks
def get_user_filter(self, user_localpart, filter_id): def get_user_filter(self, user_localpart, filter_id):
filters = _filters_for_user.get(user_localpart, None) return self.store.get_user_filter(user_localpart, filter_id)
if not filters or filter_id >= len(filters):
raise KeyError()
# trivial yield to make it a generator so d.iC works
yield
defer.returnValue(filters[filter_id])
@defer.inlineCallbacks
def add_user_filter(self, user_localpart, definition): def add_user_filter(self, user_localpart, definition):
filters = _filters_for_user.setdefault(user_localpart, []) # TODO(paul): implement sanity checking of the definition
return self.store.add_user_filter(user_localpart, definition)
filter_id = len(filters)
filters.append(definition)
# trivial yield, see above
yield
defer.returnValue(filter_id)
# TODO(paul): surely we should probably add a delete_user_filter or # TODO(paul): surely we should probably add a delete_user_filter or
# replace_user_filter at some point? There's no REST API specified for # replace_user_filter at some point? There's no REST API specified for

View File

@ -30,9 +30,9 @@ from .transactions import TransactionStore
from .keys import KeyStore from .keys import KeyStore
from .event_federation import EventFederationStore from .event_federation import EventFederationStore
from .media_repository import MediaRepositoryStore from .media_repository import MediaRepositoryStore
from .state import StateStore from .state import StateStore
from .signatures import SignatureStore from .signatures import SignatureStore
from .filtering import FilteringStore
from syutil.base64util import decode_base64 from syutil.base64util import decode_base64
from syutil.jsonutil import encode_canonical_json from syutil.jsonutil import encode_canonical_json
@ -82,6 +82,7 @@ class DataStore(RoomMemberStore, RoomStore,
DirectoryStore, KeyStore, StateStore, SignatureStore, DirectoryStore, KeyStore, StateStore, SignatureStore,
EventFederationStore, EventFederationStore,
MediaRepositoryStore, MediaRepositoryStore,
FilteringStore,
): ):
def __init__(self, hs): def __init__(self, hs):

View File

@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import SQLBaseStore
# TODO(paul)
_filters_for_user = {}
class FilteringStore(SQLBaseStore):
@defer.inlineCallbacks
def get_user_filter(self, user_localpart, filter_id):
filters = _filters_for_user.get(user_localpart, None)
if not filters or filter_id >= len(filters):
raise KeyError()
# trivial yield to make it a generator so d.iC works
yield
defer.returnValue(filters[filter_id])
@defer.inlineCallbacks
def add_user_filter(self, user_localpart, definition):
filters = _filters_for_user.setdefault(user_localpart, [])
filter_id = len(filters)
filters.append(definition)
# trivial yield, see above
yield
defer.returnValue(filter_id)

View File

@ -39,9 +39,7 @@ class V2AlphaRestTestCase(unittest.TestCase):
hs = HomeServer("test", hs = HomeServer("test",
db_pool=None, db_pool=None,
datastore=Mock(spec=[ datastore=self.make_datastore_mock(),
"insert_client_ip",
]),
http_client=None, http_client=None,
resource_for_client=self.mock_resource, resource_for_client=self.mock_resource,
resource_for_federation=self.mock_resource, resource_for_federation=self.mock_resource,
@ -58,3 +56,8 @@ class V2AlphaRestTestCase(unittest.TestCase):
for r in self.TO_REGISTER: for r in self.TO_REGISTER:
r.register_servlets(hs, self.mock_resource) r.register_servlets(hs, self.mock_resource)
def make_datastore_mock(self):
return Mock(spec=[
"insert_client_ip",
])

View File

@ -15,6 +15,8 @@
from twisted.internet import defer from twisted.internet import defer
from mock import Mock
from . import V2AlphaRestTestCase from . import V2AlphaRestTestCase
from synapse.rest.client.v2_alpha import filter from synapse.rest.client.v2_alpha import filter
@ -24,6 +26,25 @@ class FilterTestCase(V2AlphaRestTestCase):
USER_ID = "@apple:test" USER_ID = "@apple:test"
TO_REGISTER = [filter] TO_REGISTER = [filter]
def make_datastore_mock(self):
datastore = super(FilterTestCase, self).make_datastore_mock()
self._user_filters = {}
def add_user_filter(user_localpart, definition):
filters = self._user_filters.setdefault(user_localpart, [])
filter_id = len(filters)
filters.append(definition)
return defer.succeed(filter_id)
datastore.add_user_filter = add_user_filter
def get_user_filter(user_localpart, filter_id):
filters = self._user_filters[user_localpart]
return defer.succeed(filters[filter_id])
datastore.get_user_filter = get_user_filter
return datastore
@defer.inlineCallbacks @defer.inlineCallbacks
def test_filter(self): def test_filter(self):
(code, response) = yield self.mock_resource.trigger("POST", (code, response) = yield self.mock_resource.trigger("POST",