Refactor test_filter to use real DataStore

* add tests for filter api errors
This commit is contained in:
pik 2016-10-18 12:17:38 -05:00
parent d43b63818c
commit e8b1d2a452
3 changed files with 95 additions and 58 deletions

View File

@ -74,6 +74,7 @@ class CreateFilterRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id): def on_POST(self, request, user_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
@ -81,10 +82,9 @@ class CreateFilterRestServlet(RestServlet):
raise AuthError(403, "Cannot create filters for other users") raise AuthError(403, "Cannot create filters for other users")
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only create filters for local users") raise AuthError(403, "Can only create filters for local users")
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
filter_id = yield self.filtering.add_user_filter( filter_id = yield self.filtering.add_user_filter(
user_localpart=target_user.localpart, user_localpart=target_user.localpart,
user_filter=content, user_filter=content,

View File

@ -85,7 +85,6 @@ class LoggingTransaction(object):
sql_logger.debug("[SQL] {%s} %s", self.name, sql) sql_logger.debug("[SQL] {%s} %s", self.name, sql)
sql = self.database_engine.convert_param_style(sql) sql = self.database_engine.convert_param_style(sql)
if args: if args:
try: try:
sql_logger.debug( sql_logger.debug(

View File

@ -15,87 +15,125 @@
from twisted.internet import defer from twisted.internet import defer
from . import V2AlphaRestTestCase from tests import unittest
from synapse.rest.client.v2_alpha import filter from synapse.rest.client.v2_alpha import filter
from synapse.api.errors import StoreError, Codes from synapse.api.errors import Codes
import synapse.types
from synapse.types import UserID
from ....utils import MockHttpResource, setup_test_homeserver
PATH_PREFIX = "/_matrix/client/v2_alpha"
class FilterTestCase(V2AlphaRestTestCase): class FilterTestCase(unittest.TestCase):
USER_ID = "@apple:test" USER_ID = "@apple:test"
EXAMPLE_FILTER = {"type": ["m.*"]}
EXAMPLE_FILTER_JSON = '{"type": ["m.*"]}'
TO_REGISTER = [filter] TO_REGISTER = [filter]
def make_datastore_mock(self): @defer.inlineCallbacks
datastore = super(FilterTestCase, self).make_datastore_mock() def setUp(self):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
self._user_filters = {} self.hs = yield setup_test_homeserver(
http_client=None,
resource_for_client=self.mock_resource,
resource_for_federation=self.mock_resource,
)
def add_user_filter(user_localpart, definition): self.auth = self.hs.get_auth()
filters = self._user_filters.setdefault(user_localpart, [])
filter_id = len(filters)
filters.append(definition)
return defer.succeed(filter_id)
datastore.add_user_filter = add_user_filter
def get_user_filter(user_localpart, filter_id): def get_user_by_access_token(token=None, allow_guest=False):
if user_localpart not in self._user_filters: return {
raise StoreError(404, "No user") "user": UserID.from_string(self.USER_ID),
filters = self._user_filters[user_localpart] "token_id": 1,
if filter_id >= len(filters): "is_guest": False,
raise StoreError(404, "No filter") }
return defer.succeed(filters[filter_id])
datastore.get_user_filter = get_user_filter
return datastore def get_user_by_req(request, allow_guest=False, rights="access"):
return synapse.types.create_requester(
UserID.from_string(self.USER_ID), 1, False, None)
self.auth.get_user_by_access_token = get_user_by_access_token
self.auth.get_user_by_req = get_user_by_req
self.store = self.hs.get_datastore()
self.filtering = self.hs.get_filtering()
for r in self.TO_REGISTER:
r.register_servlets(self.hs, self.mock_resource)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_add_filter(self): def test_add_filter(self):
(code, response) = yield self.mock_resource.trigger( (code, response) = yield self.mock_resource.trigger(
"POST", "/user/%s/filter" % (self.USER_ID), '{"type": ["m.*"]}' "POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON
) )
self.assertEquals(200, code) self.assertEquals(200, code)
self.assertEquals({"filter_id": "0"}, response) self.assertEquals({"filter_id": "0"}, response)
filter = yield self.store.get_user_filter(
self.assertIn("apple", self._user_filters) user_localpart='apple',
self.assertEquals(len(self._user_filters["apple"]), 1) filter_id=0,
self.assertEquals({"type": ["m.*"]}, self._user_filters["apple"][0]) )
self.assertEquals(filter, self.EXAMPLE_FILTER)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_filter(self): def test_add_filter_for_other_user(self):
self._user_filters["apple"] = [ (code, response) = yield self.mock_resource.trigger(
{"type": ["m.*"]} "POST", "/user/%s/filter" % ('@watermelon:test'), self.EXAMPLE_FILTER_JSON
]
(code, response) = yield self.mock_resource.trigger_get(
"/user/%s/filter/0" % (self.USER_ID)
) )
self.assertEquals(200, code) self.assertEquals(403, code)
self.assertEquals({"type": ["m.*"]}, response)
@defer.inlineCallbacks
def test_get_filter_no_id(self):
self._user_filters["apple"] = [
{"type": ["m.*"]}
]
(code, response) = yield self.mock_resource.trigger_get(
"/user/%s/filter/2" % (self.USER_ID)
)
self.assertEquals(400, code)
@defer.inlineCallbacks
def test_get_filter_no_user(self):
(code, response) = yield self.mock_resource.trigger_get(
"/user/%s/filter/0" % (self.USER_ID)
)
self.assertEquals(400, code)
self.assertEquals(response['errcode'], Codes.FORBIDDEN) self.assertEquals(response['errcode'], Codes.FORBIDDEN)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_filter_missing_id(self): def test_add_filter_non_local_user(self):
_is_mine = self.hs.is_mine
self.hs.is_mine = lambda target_user: False
(code, response) = yield self.mock_resource.trigger(
"POST", "/user/%s/filter" % (self.USER_ID), self.EXAMPLE_FILTER_JSON
)
self.hs.is_mine = _is_mine
self.assertEquals(403, code)
self.assertEquals(response['errcode'], Codes.FORBIDDEN)
@defer.inlineCallbacks
def test_get_filter(self):
filter_id = yield self.filtering.add_user_filter(
user_localpart='apple',
user_filter=self.EXAMPLE_FILTER
)
(code, response) = yield self.mock_resource.trigger_get( (code, response) = yield self.mock_resource.trigger_get(
"/user/%s/filter/0" % (self.USER_ID) "/user/%s/filter/%s" % (self.USER_ID, filter_id)
)
self.assertEquals(200, code)
self.assertEquals(self.EXAMPLE_FILTER, response)
@defer.inlineCallbacks
def test_get_filter_non_existant(self):
(code, response) = yield self.mock_resource.trigger_get(
"/user/%s/filter/12382148321" % (self.USER_ID)
) )
self.assertEquals(400, code) self.assertEquals(400, code)
self.assertEquals(response['errcode'], Codes.NOT_FOUND) self.assertEquals(response['errcode'], Codes.NOT_FOUND)
# Currently invalid params do not have an appropriate errcode
# in errors.py
@defer.inlineCallbacks
def test_get_filter_invalid_id(self):
(code, response) = yield self.mock_resource.trigger_get(
"/user/%s/filter/foobar" % (self.USER_ID)
)
self.assertEquals(400, code)
# No ID also returns an invalid_id error
@defer.inlineCallbacks
def test_get_filter_no_id(self):
(code, response) = yield self.mock_resource.trigger_get(
"/user/%s/filter/" % (self.USER_ID)
)
self.assertEquals(400, code)