Fix tests on postgresql (#3740)

This commit is contained in:
Amber Brown 2018-09-04 02:21:48 +10:00 committed by GitHub
parent 567363e497
commit 77055dba92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 355 additions and 340 deletions

View File

@ -35,10 +35,6 @@ matrix:
- python: 3.6 - python: 3.6
env: TOX_ENV=check-newsfragment env: TOX_ENV=check-newsfragment
allow_failures:
- python: 2.7
env: TOX_ENV=py27-postgres TRIAL_FLAGS="-j 4"
install: install:
- pip install tox - pip install tox

1
changelog.d/3740.misc Normal file
View File

@ -0,0 +1 @@
The test suite now passes on PostgreSQL.

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd # Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,79 +14,79 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
import synapse.api.errors import synapse.api.errors
import synapse.handlers.device import synapse.handlers.device
import synapse.storage import synapse.storage
from tests import unittest, utils from tests import unittest
user1 = "@boris:aaa" user1 = "@boris:aaa"
user2 = "@theresa:bbb" user2 = "@theresa:bbb"
class DeviceTestCase(unittest.TestCase): class DeviceTestCase(unittest.HomeserverTestCase):
def __init__(self, *args, **kwargs): def make_homeserver(self, reactor, clock):
super(DeviceTestCase, self).__init__(*args, **kwargs) hs = self.setup_test_homeserver("server", http_client=None)
self.store = None # type: synapse.storage.DataStore
self.handler = None # type: synapse.handlers.device.DeviceHandler
self.clock = None # type: utils.MockClock
@defer.inlineCallbacks
def setUp(self):
hs = yield utils.setup_test_homeserver(self.addCleanup)
self.handler = hs.get_device_handler() self.handler = hs.get_device_handler()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.clock = hs.get_clock() return hs
def prepare(self, reactor, clock, hs):
# These tests assume that it starts 1000 seconds in.
self.reactor.advance(1000)
@defer.inlineCallbacks
def test_device_is_created_if_doesnt_exist(self): def test_device_is_created_if_doesnt_exist(self):
res = yield self.handler.check_device_registered( res = self.get_success(
user_id="@boris:foo", self.handler.check_device_registered(
device_id="fco", user_id="@boris:foo",
initial_device_display_name="display name", device_id="fco",
initial_device_display_name="display name",
)
) )
self.assertEqual(res, "fco") self.assertEqual(res, "fco")
dev = yield self.handler.store.get_device("@boris:foo", "fco") dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
self.assertEqual(dev["display_name"], "display name") self.assertEqual(dev["display_name"], "display name")
@defer.inlineCallbacks
def test_device_is_preserved_if_exists(self): def test_device_is_preserved_if_exists(self):
res1 = yield self.handler.check_device_registered( res1 = self.get_success(
user_id="@boris:foo", self.handler.check_device_registered(
device_id="fco", user_id="@boris:foo",
initial_device_display_name="display name", device_id="fco",
initial_device_display_name="display name",
)
) )
self.assertEqual(res1, "fco") self.assertEqual(res1, "fco")
res2 = yield self.handler.check_device_registered( res2 = self.get_success(
user_id="@boris:foo", self.handler.check_device_registered(
device_id="fco", user_id="@boris:foo",
initial_device_display_name="new display name", device_id="fco",
initial_device_display_name="new display name",
)
) )
self.assertEqual(res2, "fco") self.assertEqual(res2, "fco")
dev = yield self.handler.store.get_device("@boris:foo", "fco") dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
self.assertEqual(dev["display_name"], "display name") self.assertEqual(dev["display_name"], "display name")
@defer.inlineCallbacks
def test_device_id_is_made_up_if_unspecified(self): def test_device_id_is_made_up_if_unspecified(self):
device_id = yield self.handler.check_device_registered( device_id = self.get_success(
user_id="@theresa:foo", self.handler.check_device_registered(
device_id=None, user_id="@theresa:foo",
initial_device_display_name="display", device_id=None,
initial_device_display_name="display",
)
) )
dev = yield self.handler.store.get_device("@theresa:foo", device_id) dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id))
self.assertEqual(dev["display_name"], "display") self.assertEqual(dev["display_name"], "display")
@defer.inlineCallbacks
def test_get_devices_by_user(self): def test_get_devices_by_user(self):
yield self._record_users() self._record_users()
res = self.get_success(self.handler.get_devices_by_user(user1))
res = yield self.handler.get_devices_by_user(user1)
self.assertEqual(3, len(res)) self.assertEqual(3, len(res))
device_map = {d["device_id"]: d for d in res} device_map = {d["device_id"]: d for d in res}
self.assertDictContainsSubset( self.assertDictContainsSubset(
@ -119,11 +120,10 @@ class DeviceTestCase(unittest.TestCase):
device_map["abc"], device_map["abc"],
) )
@defer.inlineCallbacks
def test_get_device(self): def test_get_device(self):
yield self._record_users() self._record_users()
res = yield self.handler.get_device(user1, "abc") res = self.get_success(self.handler.get_device(user1, "abc"))
self.assertDictContainsSubset( self.assertDictContainsSubset(
{ {
"user_id": user1, "user_id": user1,
@ -135,59 +135,66 @@ class DeviceTestCase(unittest.TestCase):
res, res,
) )
@defer.inlineCallbacks
def test_delete_device(self): def test_delete_device(self):
yield self._record_users() self._record_users()
# delete the device # delete the device
yield self.handler.delete_device(user1, "abc") self.get_success(self.handler.delete_device(user1, "abc"))
# check the device was deleted # check the device was deleted
with self.assertRaises(synapse.api.errors.NotFoundError): res = self.handler.get_device(user1, "abc")
yield self.handler.get_device(user1, "abc") self.pump()
self.assertIsInstance(
self.failureResultOf(res).value, synapse.api.errors.NotFoundError
)
# we'd like to check the access token was invalidated, but that's a # we'd like to check the access token was invalidated, but that's a
# bit of a PITA. # bit of a PITA.
@defer.inlineCallbacks
def test_update_device(self): def test_update_device(self):
yield self._record_users() self._record_users()
update = {"display_name": "new display"} update = {"display_name": "new display"}
yield self.handler.update_device(user1, "abc", update) self.get_success(self.handler.update_device(user1, "abc", update))
res = yield self.handler.get_device(user1, "abc") res = self.get_success(self.handler.get_device(user1, "abc"))
self.assertEqual(res["display_name"], "new display") self.assertEqual(res["display_name"], "new display")
@defer.inlineCallbacks
def test_update_unknown_device(self): def test_update_unknown_device(self):
update = {"display_name": "new_display"} update = {"display_name": "new_display"}
with self.assertRaises(synapse.api.errors.NotFoundError): res = self.handler.update_device("user_id", "unknown_device_id", update)
yield self.handler.update_device("user_id", "unknown_device_id", update) self.pump()
self.assertIsInstance(
self.failureResultOf(res).value, synapse.api.errors.NotFoundError
)
@defer.inlineCallbacks
def _record_users(self): def _record_users(self):
# check this works for both devices which have a recorded client_ip, # check this works for both devices which have a recorded client_ip,
# and those which don't. # and those which don't.
yield self._record_user(user1, "xyz", "display 0") self._record_user(user1, "xyz", "display 0")
yield self._record_user(user1, "fco", "display 1", "token1", "ip1") self._record_user(user1, "fco", "display 1", "token1", "ip1")
yield self._record_user(user1, "abc", "display 2", "token2", "ip2") self._record_user(user1, "abc", "display 2", "token2", "ip2")
yield self._record_user(user1, "abc", "display 2", "token3", "ip3") self._record_user(user1, "abc", "display 2", "token3", "ip3")
yield self._record_user(user2, "def", "dispkay", "token4", "ip4") self._record_user(user2, "def", "dispkay", "token4", "ip4")
self.reactor.advance(10000)
@defer.inlineCallbacks
def _record_user( def _record_user(
self, user_id, device_id, display_name, access_token=None, ip=None self, user_id, device_id, display_name, access_token=None, ip=None
): ):
device_id = yield self.handler.check_device_registered( device_id = self.get_success(
user_id=user_id, self.handler.check_device_registered(
device_id=device_id, user_id=user_id,
initial_device_display_name=display_name, device_id=device_id,
initial_device_display_name=display_name,
)
) )
if ip is not None: if ip is not None:
yield self.store.insert_client_ip( self.get_success(
user_id, access_token, ip, "user_agent", device_id self.store.insert_client_ip(
user_id, access_token, ip, "user_agent", device_id
)
) )
self.clock.advance_time(1000) self.reactor.advance(1000)

View File

@ -1,4 +1,5 @@
# Copyright 2016 OpenMarket Ltd # Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -11,89 +12,91 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import tempfile
from mock import Mock, NonCallableMock from mock import Mock, NonCallableMock
from twisted.internet import defer, reactor import attr
from twisted.internet.defer import Deferred
from synapse.replication.tcp.client import ( from synapse.replication.tcp.client import (
ReplicationClientFactory, ReplicationClientFactory,
ReplicationClientHandler, ReplicationClientHandler,
) )
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
from tests import unittest from tests import unittest
from tests.utils import setup_test_homeserver
class TestReplicationClientHandler(ReplicationClientHandler): class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
"""Overrides on_rdata so that we can wait for it to happen""" def make_homeserver(self, reactor, clock):
def __init__(self, store): hs = self.setup_test_homeserver(
super(TestReplicationClientHandler, self).__init__(store)
self._rdata_awaiters = []
def await_replication(self):
d = Deferred()
self._rdata_awaiters.append(d)
return make_deferred_yieldable(d)
def on_rdata(self, stream_name, token, rows):
awaiters = self._rdata_awaiters
self._rdata_awaiters = []
super(TestReplicationClientHandler, self).on_rdata(stream_name, token, rows)
with PreserveLoggingContext():
for a in awaiters:
a.callback(None)
class BaseSlavedStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hs = yield setup_test_homeserver(
self.addCleanup,
"blue", "blue",
http_client=None,
federation_client=Mock(), federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]), ratelimiter=NonCallableMock(spec_set=["send_message"]),
) )
self.hs.get_ratelimiter().send_message.return_value = (True, 0)
hs.get_ratelimiter().send_message.return_value = (True, 0)
return hs
def prepare(self, reactor, clock, hs):
self.master_store = self.hs.get_datastore() self.master_store = self.hs.get_datastore()
self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs) self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
self.event_id = 0 self.event_id = 0
server_factory = ReplicationStreamProtocolFactory(self.hs) server_factory = ReplicationStreamProtocolFactory(self.hs)
# XXX: mktemp is unsafe and should never be used. but we're just a test.
path = tempfile.mktemp(prefix="base_slaved_store_test_case_socket")
listener = reactor.listenUNIX(path, server_factory)
self.addCleanup(listener.stopListening)
self.streamer = server_factory.streamer self.streamer = server_factory.streamer
self.replication_handler = TestReplicationClientHandler(self.slaved_store) self.replication_handler = ReplicationClientHandler(self.slaved_store)
client_factory = ReplicationClientFactory( client_factory = ReplicationClientFactory(
self.hs, "client_name", self.replication_handler self.hs, "client_name", self.replication_handler
) )
client_connector = reactor.connectUNIX(path, client_factory)
self.addCleanup(client_factory.stopTrying) server = server_factory.buildProtocol(None)
self.addCleanup(client_connector.disconnect) client = client_factory.buildProtocol(None)
@attr.s
class FakeTransport(object):
other = attr.ib()
disconnecting = False
buffer = attr.ib(default=b'')
def registerProducer(self, producer, streaming):
self.producer = producer
def _produce():
self.producer.resumeProducing()
reactor.callLater(0.1, _produce)
reactor.callLater(0.0, _produce)
def write(self, byt):
self.buffer = self.buffer + byt
if getattr(self.other, "transport") is not None:
self.other.dataReceived(self.buffer)
self.buffer = b""
def writeSequence(self, seq):
for x in seq:
self.write(x)
client.makeConnection(FakeTransport(server))
server.makeConnection(FakeTransport(client))
def replicate(self): def replicate(self):
"""Tell the master side of replication that something has happened, and then """Tell the master side of replication that something has happened, and then
wait for the replication to occur. wait for the replication to occur.
""" """
# xxx: should we be more specific in what we wait for?
d = self.replication_handler.await_replication()
self.streamer.on_notifier_poke() self.streamer.on_notifier_poke()
return d self.pump(0.1)
@defer.inlineCallbacks
def check(self, method, args, expected_result=None): def check(self, method, args, expected_result=None):
master_result = yield getattr(self.master_store, method)(*args) master_result = self.get_success(getattr(self.master_store, method)(*args))
slaved_result = yield getattr(self.slaved_store, method)(*args) slaved_result = self.get_success(getattr(self.slaved_store, method)(*args))
if expected_result is not None: if expected_result is not None:
self.assertEqual(master_result, expected_result) self.assertEqual(master_result, expected_result)
self.assertEqual(slaved_result, expected_result) self.assertEqual(slaved_result, expected_result)

View File

@ -12,9 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from ._base import BaseSlavedStoreTestCase from ._base import BaseSlavedStoreTestCase
@ -27,16 +24,19 @@ class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase):
STORE_TYPE = SlavedAccountDataStore STORE_TYPE = SlavedAccountDataStore
@defer.inlineCallbacks
def test_user_account_data(self): def test_user_account_data(self):
yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1}) self.get_success(
yield self.replicate() self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1})
yield self.check( )
self.replicate()
self.check(
"get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 1} "get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 1}
) )
yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2}) self.get_success(
yield self.replicate() self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2})
yield self.check( )
self.replicate()
self.check(
"get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 2} "get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 2}
) )

View File

@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.events import FrozenEvent, _EventInternalMetadata from synapse.events import FrozenEvent, _EventInternalMetadata
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.events import SlavedEventStore
@ -55,70 +53,66 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
def tearDown(self): def tearDown(self):
[unpatch() for unpatch in self.unpatches] [unpatch() for unpatch in self.unpatches]
@defer.inlineCallbacks
def test_get_latest_event_ids_in_room(self): def test_get_latest_event_ids_in_room(self):
create = yield self.persist(type="m.room.create", key="", creator=USER_ID) create = self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.replicate() self.replicate()
yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id]) self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
join = yield self.persist( join = self.persist(
type="m.room.member", type="m.room.member",
key=USER_ID, key=USER_ID,
membership="join", membership="join",
prev_events=[(create.event_id, {})], prev_events=[(create.event_id, {})],
) )
yield self.replicate() self.replicate()
yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id]) self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
@defer.inlineCallbacks
def test_redactions(self): def test_redactions(self):
yield self.persist(type="m.room.create", key="", creator=USER_ID) self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.persist(type="m.room.member", key=USER_ID, membership="join") self.persist(type="m.room.member", key=USER_ID, membership="join")
msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello") msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
yield self.replicate() self.replicate()
yield self.check("get_event", [msg.event_id], msg) self.check("get_event", [msg.event_id], msg)
redaction = yield self.persist(type="m.room.redaction", redacts=msg.event_id) redaction = self.persist(type="m.room.redaction", redacts=msg.event_id)
yield self.replicate() self.replicate()
msg_dict = msg.get_dict() msg_dict = msg.get_dict()
msg_dict["content"] = {} msg_dict["content"] = {}
msg_dict["unsigned"]["redacted_by"] = redaction.event_id msg_dict["unsigned"]["redacted_by"] = redaction.event_id
msg_dict["unsigned"]["redacted_because"] = redaction msg_dict["unsigned"]["redacted_because"] = redaction
redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict()) redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
yield self.check("get_event", [msg.event_id], redacted) self.check("get_event", [msg.event_id], redacted)
@defer.inlineCallbacks
def test_backfilled_redactions(self): def test_backfilled_redactions(self):
yield self.persist(type="m.room.create", key="", creator=USER_ID) self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.persist(type="m.room.member", key=USER_ID, membership="join") self.persist(type="m.room.member", key=USER_ID, membership="join")
msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello") msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
yield self.replicate() self.replicate()
yield self.check("get_event", [msg.event_id], msg) self.check("get_event", [msg.event_id], msg)
redaction = yield self.persist( redaction = self.persist(
type="m.room.redaction", redacts=msg.event_id, backfill=True type="m.room.redaction", redacts=msg.event_id, backfill=True
) )
yield self.replicate() self.replicate()
msg_dict = msg.get_dict() msg_dict = msg.get_dict()
msg_dict["content"] = {} msg_dict["content"] = {}
msg_dict["unsigned"]["redacted_by"] = redaction.event_id msg_dict["unsigned"]["redacted_by"] = redaction.event_id
msg_dict["unsigned"]["redacted_because"] = redaction msg_dict["unsigned"]["redacted_because"] = redaction
redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict()) redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
yield self.check("get_event", [msg.event_id], redacted) self.check("get_event", [msg.event_id], redacted)
@defer.inlineCallbacks
def test_invites(self): def test_invites(self):
yield self.persist(type="m.room.create", key="", creator=USER_ID) self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.check("get_invited_rooms_for_user", [USER_ID_2], []) self.check("get_invited_rooms_for_user", [USER_ID_2], [])
event = yield self.persist( event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
type="m.room.member", key=USER_ID_2, membership="invite"
) self.replicate()
yield self.replicate()
yield self.check( self.check(
"get_invited_rooms_for_user", "get_invited_rooms_for_user",
[USER_ID_2], [USER_ID_2],
[ [
@ -132,37 +126,34 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
], ],
) )
@defer.inlineCallbacks
def test_push_actions_for_user(self): def test_push_actions_for_user(self):
yield self.persist(type="m.room.create", key="", creator=USER_ID) self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.persist(type="m.room.join", key=USER_ID, membership="join") self.persist(type="m.room.join", key=USER_ID, membership="join")
yield self.persist( self.persist(
type="m.room.join", sender=USER_ID, key=USER_ID_2, membership="join" type="m.room.join", sender=USER_ID, key=USER_ID_2, membership="join"
) )
event1 = yield self.persist( event1 = self.persist(type="m.room.message", msgtype="m.text", body="hello")
type="m.room.message", msgtype="m.text", body="hello" self.replicate()
) self.check(
yield self.replicate()
yield self.check(
"get_unread_event_push_actions_by_room_for_user", "get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id], [ROOM_ID, USER_ID_2, event1.event_id],
{"highlight_count": 0, "notify_count": 0}, {"highlight_count": 0, "notify_count": 0},
) )
yield self.persist( self.persist(
type="m.room.message", type="m.room.message",
msgtype="m.text", msgtype="m.text",
body="world", body="world",
push_actions=[(USER_ID_2, ["notify"])], push_actions=[(USER_ID_2, ["notify"])],
) )
yield self.replicate() self.replicate()
yield self.check( self.check(
"get_unread_event_push_actions_by_room_for_user", "get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id], [ROOM_ID, USER_ID_2, event1.event_id],
{"highlight_count": 0, "notify_count": 1}, {"highlight_count": 0, "notify_count": 1},
) )
yield self.persist( self.persist(
type="m.room.message", type="m.room.message",
msgtype="m.text", msgtype="m.text",
body="world", body="world",
@ -170,8 +161,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
(USER_ID_2, ["notify", {"set_tweak": "highlight", "value": True}]) (USER_ID_2, ["notify", {"set_tweak": "highlight", "value": True}])
], ],
) )
yield self.replicate() self.replicate()
yield self.check( self.check(
"get_unread_event_push_actions_by_room_for_user", "get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id], [ROOM_ID, USER_ID_2, event1.event_id],
{"highlight_count": 1, "notify_count": 2}, {"highlight_count": 1, "notify_count": 2},
@ -179,7 +170,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
event_id = 0 event_id = 0
@defer.inlineCallbacks
def persist( def persist(
self, self,
sender=USER_ID, sender=USER_ID,
@ -206,8 +196,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
depth = self.event_id depth = self.event_id
if not prev_events: if not prev_events:
latest_event_ids = yield self.master_store.get_latest_event_ids_in_room( latest_event_ids = self.get_success(
room_id self.master_store.get_latest_event_ids_in_room(room_id)
) )
prev_events = [(ev_id, {}) for ev_id in latest_event_ids] prev_events = [(ev_id, {}) for ev_id in latest_event_ids]
@ -240,19 +230,23 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
) )
else: else:
state_handler = self.hs.get_state_handler() state_handler = self.hs.get_state_handler()
context = yield state_handler.compute_event_context(event) context = self.get_success(state_handler.compute_event_context(event))
yield self.master_store.add_push_actions_to_staging( self.master_store.add_push_actions_to_staging(
event.event_id, {user_id: actions for user_id, actions in push_actions} event.event_id, {user_id: actions for user_id, actions in push_actions}
) )
ordering = None ordering = None
if backfill: if backfill:
yield self.master_store.persist_events([(event, context)], backfilled=True) self.get_success(
self.master_store.persist_events([(event, context)], backfilled=True)
)
else: else:
ordering, _ = yield self.master_store.persist_event(event, context) ordering, _ = self.get_success(
self.master_store.persist_event(event, context)
)
if ordering: if ordering:
event.internal_metadata.stream_ordering = ordering event.internal_metadata.stream_ordering = ordering
defer.returnValue(event) return event

View File

@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from ._base import BaseSlavedStoreTestCase from ._base import BaseSlavedStoreTestCase
@ -27,13 +25,10 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
STORE_TYPE = SlavedReceiptsStore STORE_TYPE = SlavedReceiptsStore
@defer.inlineCallbacks
def test_receipt(self): def test_receipt(self):
yield self.check("get_receipts_for_user", [USER_ID, "m.read"], {}) self.check("get_receipts_for_user", [USER_ID, "m.read"], {})
yield self.master_store.insert_receipt( self.get_success(
ROOM_ID, "m.read", USER_ID, [EVENT_ID], {} self.master_store.insert_receipt(ROOM_ID, "m.read", USER_ID, [EVENT_ID], {})
)
yield self.replicate()
yield self.check(
"get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID}
) )
self.replicate()
self.check("get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID})

View File

@ -232,6 +232,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
clock.threadpool = ThreadPool() clock.threadpool = ThreadPool()
pool.threadpool = ThreadPool() pool.threadpool = ThreadPool()
pool.running = True
return d return d

View File

@ -37,18 +37,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
self.as_yaml_files = [] self.as_yaml_files = []
config = Mock(
app_service_config_files=self.as_yaml_files,
event_cache_size=1,
password_providers=[],
)
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
self.addCleanup, self.addCleanup, federation_sender=Mock(), federation_client=Mock()
config=config,
federation_sender=Mock(),
federation_client=Mock(),
) )
hs.config.app_service_config_files = self.as_yaml_files
hs.config.event_cache_size = 1
hs.config.password_providers = []
self.as_token = "token1" self.as_token = "token1"
self.as_url = "some_url" self.as_url = "some_url"
self.as_id = "as1" self.as_id = "as1"
@ -58,7 +54,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob") self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob") self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
# must be done after inserts # must be done after inserts
self.store = ApplicationServiceStore(None, hs) self.store = ApplicationServiceStore(hs.get_db_conn(), hs)
def tearDown(self): def tearDown(self):
# TODO: suboptimal that we need to create files for tests! # TODO: suboptimal that we need to create files for tests!
@ -105,18 +101,16 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.as_yaml_files = [] self.as_yaml_files = []
config = Mock(
app_service_config_files=self.as_yaml_files,
event_cache_size=1,
password_providers=[],
)
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
self.addCleanup, self.addCleanup, federation_sender=Mock(), federation_client=Mock()
config=config,
federation_sender=Mock(),
federation_client=Mock(),
) )
hs.config.app_service_config_files = self.as_yaml_files
hs.config.event_cache_size = 1
hs.config.password_providers = []
self.db_pool = hs.get_db_pool() self.db_pool = hs.get_db_pool()
self.engine = hs.database_engine
self.as_list = [ self.as_list = [
{"token": "token1", "url": "https://matrix-as.org", "id": "id_1"}, {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"},
@ -129,7 +123,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
self.as_yaml_files = [] self.as_yaml_files = []
self.store = TestTransactionStore(None, hs) self.store = TestTransactionStore(hs.get_db_conn(), hs)
def _add_service(self, url, as_token, id): def _add_service(self, url, as_token, id):
as_yaml = dict( as_yaml = dict(
@ -146,29 +140,35 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
self.as_yaml_files.append(as_token) self.as_yaml_files.append(as_token)
def _set_state(self, id, state, txn=None): def _set_state(self, id, state, txn=None):
return self.db_pool.runQuery( return self.db_pool.runOperation(
"INSERT INTO application_services_state(as_id, state, last_txn) " self.engine.convert_param_style(
"VALUES(?,?,?)", "INSERT INTO application_services_state(as_id, state, last_txn) "
"VALUES(?,?,?)"
),
(id, state, txn), (id, state, txn),
) )
def _insert_txn(self, as_id, txn_id, events): def _insert_txn(self, as_id, txn_id, events):
return self.db_pool.runQuery( return self.db_pool.runOperation(
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) " self.engine.convert_param_style(
"VALUES(?,?,?)", "INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
"VALUES(?,?,?)"
),
(as_id, txn_id, json.dumps([e.event_id for e in events])), (as_id, txn_id, json.dumps([e.event_id for e in events])),
) )
def _set_last_txn(self, as_id, txn_id): def _set_last_txn(self, as_id, txn_id):
return self.db_pool.runQuery( return self.db_pool.runOperation(
"INSERT INTO application_services_state(as_id, last_txn, state) " self.engine.convert_param_style(
"VALUES(?,?,?)", "INSERT INTO application_services_state(as_id, last_txn, state) "
"VALUES(?,?,?)"
),
(as_id, txn_id, ApplicationServiceState.UP), (as_id, txn_id, ApplicationServiceState.UP),
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_appservice_state_none(self): def test_get_appservice_state_none(self):
service = Mock(id=999) service = Mock(id="999")
state = yield self.store.get_appservice_state(service) state = yield self.store.get_appservice_state(service)
self.assertEquals(None, state) self.assertEquals(None, state)
@ -200,7 +200,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
service = Mock(id=self.as_list[1]["id"]) service = Mock(id=self.as_list[1]["id"])
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
rows = yield self.db_pool.runQuery( rows = yield self.db_pool.runQuery(
"SELECT as_id FROM application_services_state WHERE state=?", self.engine.convert_param_style(
"SELECT as_id FROM application_services_state WHERE state=?"
),
(ApplicationServiceState.DOWN,), (ApplicationServiceState.DOWN,),
) )
self.assertEquals(service.id, rows[0][0]) self.assertEquals(service.id, rows[0][0])
@ -212,7 +214,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
yield self.store.set_appservice_state(service, ApplicationServiceState.UP) yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
rows = yield self.db_pool.runQuery( rows = yield self.db_pool.runQuery(
"SELECT as_id FROM application_services_state WHERE state=?", self.engine.convert_param_style(
"SELECT as_id FROM application_services_state WHERE state=?"
),
(ApplicationServiceState.UP,), (ApplicationServiceState.UP,),
) )
self.assertEquals(service.id, rows[0][0]) self.assertEquals(service.id, rows[0][0])
@ -279,14 +283,19 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self.store.complete_appservice_txn(txn_id=txn_id, service=service) yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
res = yield self.db_pool.runQuery( res = yield self.db_pool.runQuery(
"SELECT last_txn FROM application_services_state WHERE as_id=?", self.engine.convert_param_style(
"SELECT last_txn FROM application_services_state WHERE as_id=?"
),
(service.id,), (service.id,),
) )
self.assertEquals(1, len(res)) self.assertEquals(1, len(res))
self.assertEquals(txn_id, res[0][0]) self.assertEquals(txn_id, res[0][0])
res = yield self.db_pool.runQuery( res = yield self.db_pool.runQuery(
"SELECT * FROM application_services_txns WHERE txn_id=?", (txn_id,) self.engine.convert_param_style(
"SELECT * FROM application_services_txns WHERE txn_id=?"
),
(txn_id,),
) )
self.assertEquals(0, len(res)) self.assertEquals(0, len(res))
@ -300,7 +309,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self.store.complete_appservice_txn(txn_id=txn_id, service=service) yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
res = yield self.db_pool.runQuery( res = yield self.db_pool.runQuery(
"SELECT last_txn, state FROM application_services_state WHERE " "as_id=?", self.engine.convert_param_style(
"SELECT last_txn, state FROM application_services_state WHERE as_id=?"
),
(service.id,), (service.id,),
) )
self.assertEquals(1, len(res)) self.assertEquals(1, len(res))
@ -308,7 +319,10 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
self.assertEquals(ApplicationServiceState.UP, res[0][1]) self.assertEquals(ApplicationServiceState.UP, res[0][1])
res = yield self.db_pool.runQuery( res = yield self.db_pool.runQuery(
"SELECT * FROM application_services_txns WHERE txn_id=?", (txn_id,) self.engine.convert_param_style(
"SELECT * FROM application_services_txns WHERE txn_id=?"
),
(txn_id,),
) )
self.assertEquals(0, len(res)) self.assertEquals(0, len(res))
@ -394,37 +408,31 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f1 = self._write_config(suffix="1") f1 = self._write_config(suffix="1")
f2 = self._write_config(suffix="2") f2 = self._write_config(suffix="2")
config = Mock(
app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[]
)
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
self.addCleanup, self.addCleanup, federation_sender=Mock(), federation_client=Mock()
config=config,
datastore=Mock(),
federation_sender=Mock(),
federation_client=Mock(),
) )
ApplicationServiceStore(None, hs) hs.config.app_service_config_files = [f1, f2]
hs.config.event_cache_size = 1
hs.config.password_providers = []
ApplicationServiceStore(hs.get_db_conn(), hs)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_duplicate_ids(self): def test_duplicate_ids(self):
f1 = self._write_config(id="id", suffix="1") f1 = self._write_config(id="id", suffix="1")
f2 = self._write_config(id="id", suffix="2") f2 = self._write_config(id="id", suffix="2")
config = Mock(
app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[]
)
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
self.addCleanup, self.addCleanup, federation_sender=Mock(), federation_client=Mock()
config=config,
datastore=Mock(),
federation_sender=Mock(),
federation_client=Mock(),
) )
hs.config.app_service_config_files = [f1, f2]
hs.config.event_cache_size = 1
hs.config.password_providers = []
with self.assertRaises(ConfigError) as cm: with self.assertRaises(ConfigError) as cm:
ApplicationServiceStore(None, hs) ApplicationServiceStore(hs.get_db_conn(), hs)
e = cm.exception e = cm.exception
self.assertIn(f1, str(e)) self.assertIn(f1, str(e))
@ -436,19 +444,16 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f1 = self._write_config(as_token="as_token", suffix="1") f1 = self._write_config(as_token="as_token", suffix="1")
f2 = self._write_config(as_token="as_token", suffix="2") f2 = self._write_config(as_token="as_token", suffix="2")
config = Mock(
app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[]
)
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
self.addCleanup, self.addCleanup, federation_sender=Mock(), federation_client=Mock()
config=config,
datastore=Mock(),
federation_sender=Mock(),
federation_client=Mock(),
) )
hs.config.app_service_config_files = [f1, f2]
hs.config.event_cache_size = 1
hs.config.password_providers = []
with self.assertRaises(ConfigError) as cm: with self.assertRaises(ConfigError) as cm:
ApplicationServiceStore(None, hs) ApplicationServiceStore(hs.get_db_conn(), hs)
e = cm.exception e = cm.exception
self.assertIn(f1, str(e)) self.assertIn(f1, str(e))

View File

@ -16,7 +16,6 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.storage.directory import DirectoryStore
from synapse.types import RoomAlias, RoomID from synapse.types import RoomAlias, RoomID
from tests import unittest from tests import unittest
@ -28,7 +27,7 @@ class DirectoryStoreTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup) hs = yield setup_test_homeserver(self.addCleanup)
self.store = DirectoryStore(None, hs) self.store = hs.get_datastore()
self.room = RoomID.from_string("!abcde:test") self.room = RoomID.from_string("!abcde:test")
self.alias = RoomAlias.from_string("#my-room:test") self.alias = RoomAlias.from_string("#my-room:test")

View File

@ -37,10 +37,10 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
( (
"INSERT INTO events (" "INSERT INTO events ("
" room_id, event_id, type, depth, topological_ordering," " room_id, event_id, type, depth, topological_ordering,"
" content, processed, outlier) " " content, processed, outlier, stream_ordering) "
"VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?)" "VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?, ?)"
), ),
(room_id, event_id, i, i, True, False), (room_id, event_id, i, i, True, False, i),
) )
txn.execute( txn.execute(

View File

@ -13,25 +13,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer from tests.unittest import HomeserverTestCase
import tests.unittest
import tests.utils
from tests.utils import setup_test_homeserver
FORTY_DAYS = 40 * 24 * 60 * 60 FORTY_DAYS = 40 * 24 * 60 * 60
class MonthlyActiveUsersTestCase(tests.unittest.TestCase): class MonthlyActiveUsersTestCase(HomeserverTestCase):
def __init__(self, *args, **kwargs): def make_homeserver(self, reactor, clock):
super(MonthlyActiveUsersTestCase, self).__init__(*args, **kwargs)
@defer.inlineCallbacks hs = self.setup_test_homeserver()
def setUp(self): self.store = hs.get_datastore()
self.hs = yield setup_test_homeserver(self.addCleanup)
self.store = self.hs.get_datastore() # Advance the clock a bit
reactor.advance(FORTY_DAYS)
return hs
@defer.inlineCallbacks
def test_initialise_reserved_users(self): def test_initialise_reserved_users(self):
self.hs.config.max_mau_value = 5 self.hs.config.max_mau_value = 5
user1 = "@user1:server" user1 = "@user1:server"
@ -44,88 +41,101 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase):
] ]
user_num = len(threepids) user_num = len(threepids)
yield self.store.register(user_id=user1, token="123", password_hash=None) self.store.register(user_id=user1, token="123", password_hash=None)
self.store.register(user_id=user2, token="456", password_hash=None)
yield self.store.register(user_id=user2, token="456", password_hash=None) self.pump()
now = int(self.hs.get_clock().time_msec()) now = int(self.hs.get_clock().time_msec())
yield self.store.user_add_threepid(user1, "email", user1_email, now, now) self.store.user_add_threepid(user1, "email", user1_email, now, now)
yield self.store.user_add_threepid(user2, "email", user2_email, now, now) self.store.user_add_threepid(user2, "email", user2_email, now, now)
yield self.store.initialise_reserved_users(threepids) self.store.initialise_reserved_users(threepids)
self.pump()
active_count = yield self.store.get_monthly_active_count() active_count = self.store.get_monthly_active_count()
# Test total counts # Test total counts
self.assertEquals(active_count, user_num) self.assertEquals(self.get_success(active_count), user_num)
# Test user is marked as active # Test user is marked as active
timestamp = self.store.user_last_seen_monthly_active(user1)
timestamp = yield self.store.user_last_seen_monthly_active(user1) self.assertTrue(self.get_success(timestamp))
self.assertTrue(timestamp) timestamp = self.store.user_last_seen_monthly_active(user2)
timestamp = yield self.store.user_last_seen_monthly_active(user2) self.assertTrue(self.get_success(timestamp))
self.assertTrue(timestamp)
# Test that users are never removed from the db. # Test that users are never removed from the db.
self.hs.config.max_mau_value = 0 self.hs.config.max_mau_value = 0
self.hs.get_clock().advance_time(FORTY_DAYS) self.reactor.advance(FORTY_DAYS)
yield self.store.reap_monthly_active_users() self.store.reap_monthly_active_users()
self.pump()
active_count = yield self.store.get_monthly_active_count() active_count = self.store.get_monthly_active_count()
self.assertEquals(active_count, user_num) self.assertEquals(self.get_success(active_count), user_num)
# Test that regalar users are removed from the db # Test that regalar users are removed from the db
ru_count = 2 ru_count = 2
yield self.store.upsert_monthly_active_user("@ru1:server") self.store.upsert_monthly_active_user("@ru1:server")
yield self.store.upsert_monthly_active_user("@ru2:server") self.store.upsert_monthly_active_user("@ru2:server")
active_count = yield self.store.get_monthly_active_count() self.pump()
self.assertEqual(active_count, user_num + ru_count) active_count = self.store.get_monthly_active_count()
self.assertEqual(self.get_success(active_count), user_num + ru_count)
self.hs.config.max_mau_value = user_num self.hs.config.max_mau_value = user_num
yield self.store.reap_monthly_active_users() self.store.reap_monthly_active_users()
self.pump()
active_count = yield self.store.get_monthly_active_count() active_count = self.store.get_monthly_active_count()
self.assertEquals(active_count, user_num) self.assertEquals(self.get_success(active_count), user_num)
@defer.inlineCallbacks
def test_can_insert_and_count_mau(self): def test_can_insert_and_count_mau(self):
count = yield self.store.get_monthly_active_count() count = self.store.get_monthly_active_count()
self.assertEqual(0, count) self.assertEqual(0, self.get_success(count))
yield self.store.upsert_monthly_active_user("@user:server") self.store.upsert_monthly_active_user("@user:server")
count = yield self.store.get_monthly_active_count() self.pump()
self.assertEqual(1, count) count = self.store.get_monthly_active_count()
self.assertEqual(1, self.get_success(count))
@defer.inlineCallbacks
def test_user_last_seen_monthly_active(self): def test_user_last_seen_monthly_active(self):
user_id1 = "@user1:server" user_id1 = "@user1:server"
user_id2 = "@user2:server" user_id2 = "@user2:server"
user_id3 = "@user3:server" user_id3 = "@user3:server"
result = yield self.store.user_last_seen_monthly_active(user_id1) result = self.store.user_last_seen_monthly_active(user_id1)
self.assertFalse(result == 0) self.assertFalse(self.get_success(result) == 0)
yield self.store.upsert_monthly_active_user(user_id1)
yield self.store.upsert_monthly_active_user(user_id2) self.store.upsert_monthly_active_user(user_id1)
result = yield self.store.user_last_seen_monthly_active(user_id1) self.store.upsert_monthly_active_user(user_id2)
self.assertTrue(result > 0) self.pump()
result = yield self.store.user_last_seen_monthly_active(user_id3)
self.assertFalse(result == 0) result = self.store.user_last_seen_monthly_active(user_id1)
self.assertGreater(self.get_success(result), 0)
result = self.store.user_last_seen_monthly_active(user_id3)
self.assertNotEqual(self.get_success(result), 0)
@defer.inlineCallbacks
def test_reap_monthly_active_users(self): def test_reap_monthly_active_users(self):
self.hs.config.max_mau_value = 5 self.hs.config.max_mau_value = 5
initial_users = 10 initial_users = 10
for i in range(initial_users): for i in range(initial_users):
yield self.store.upsert_monthly_active_user("@user%d:server" % i) self.store.upsert_monthly_active_user("@user%d:server" % i)
count = yield self.store.get_monthly_active_count() self.pump()
self.assertTrue(count, initial_users)
yield self.store.reap_monthly_active_users()
count = yield self.store.get_monthly_active_count()
self.assertEquals(count, initial_users - self.hs.config.max_mau_value)
self.hs.get_clock().advance_time(FORTY_DAYS) count = self.store.get_monthly_active_count()
yield self.store.reap_monthly_active_users() self.assertTrue(self.get_success(count), initial_users)
count = yield self.store.get_monthly_active_count()
self.assertEquals(count, 0) self.store.reap_monthly_active_users()
self.pump()
count = self.store.get_monthly_active_count()
self.assertEquals(
self.get_success(count), initial_users - self.hs.config.max_mau_value
)
self.reactor.advance(FORTY_DAYS)
self.store.reap_monthly_active_users()
self.pump()
count = self.store.get_monthly_active_count()
self.assertEquals(self.get_success(count), 0)

View File

@ -16,19 +16,18 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.storage.presence import PresenceStore
from synapse.types import UserID from synapse.types import UserID
from tests import unittest from tests import unittest
from tests.utils import MockClock, setup_test_homeserver from tests.utils import setup_test_homeserver
class PresenceStoreTestCase(unittest.TestCase): class PresenceStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup, clock=MockClock()) hs = yield setup_test_homeserver(self.addCleanup)
self.store = PresenceStore(None, hs) self.store = hs.get_datastore()
self.u_apple = UserID.from_string("@apple:test") self.u_apple = UserID.from_string("@apple:test")
self.u_banana = UserID.from_string("@banana:test") self.u_banana = UserID.from_string("@banana:test")

View File

@ -28,7 +28,7 @@ class ProfileStoreTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup) hs = yield setup_test_homeserver(self.addCleanup)
self.store = ProfileStore(None, hs) self.store = ProfileStore(hs.get_db_conn(), hs)
self.u_frank = UserID.from_string("@frank:test") self.u_frank = UserID.from_string("@frank:test")

View File

@ -30,7 +30,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
self.hs = yield setup_test_homeserver(self.addCleanup) self.hs = yield setup_test_homeserver(self.addCleanup)
self.store = UserDirectoryStore(None, self.hs) self.store = UserDirectoryStore(self.hs.get_db_conn(), self.hs)
# alice and bob are both in !room_id. bobby is not but shares # alice and bob are both in !room_id. bobby is not but shares
# a homeserver with alice. # a homeserver with alice.

View File

@ -96,7 +96,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
events_to_filter.append(evt) events_to_filter.append(evt)
# the erasey user gets erased # the erasey user gets erased
self.hs.get_datastore().mark_user_erased("@erased:local_hs") yield self.hs.get_datastore().mark_user_erased("@erased:local_hs")
# ... and the filtering happens. # ... and the filtering happens.
filtered = yield filter_events_for_server( filtered = yield filter_events_for_server(

View File

@ -22,6 +22,7 @@ from canonicaljson import json
import twisted import twisted
import twisted.logger import twisted.logger
from twisted.internet.defer import Deferred
from twisted.trial import unittest from twisted.trial import unittest
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
@ -281,12 +282,14 @@ class HomeserverTestCase(TestCase):
kwargs.update(self._hs_args) kwargs.update(self._hs_args)
return setup_test_homeserver(self.addCleanup, *args, **kwargs) return setup_test_homeserver(self.addCleanup, *args, **kwargs)
def pump(self): def pump(self, by=0.0):
""" """
Pump the reactor enough that Deferreds will fire. Pump the reactor enough that Deferreds will fire.
""" """
self.reactor.pump([0.0] * 100) self.reactor.pump([by] * 100)
def get_success(self, d): def get_success(self, d):
if not isinstance(d, Deferred):
return d
self.pump() self.pump()
return self.successResultOf(d) return self.successResultOf(d)

View File

@ -30,8 +30,8 @@ from synapse.config.server import ServerConfig
from synapse.federation.transport import server from synapse.federation.transport import server
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage import DataStore, PostgresEngine from synapse.storage import DataStore
from synapse.storage.engines import create_engine from synapse.storage.engines import PostgresEngine, create_engine
from synapse.storage.prepare_database import ( from synapse.storage.prepare_database import (
_get_or_create_schema_state, _get_or_create_schema_state,
_setup_new_database, _setup_new_database,
@ -42,6 +42,7 @@ from synapse.util.ratelimitutils import FederationRateLimiter
# set this to True to run the tests against postgres instead of sqlite. # set this to True to run the tests against postgres instead of sqlite.
USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False) USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False)
LEAVE_DB = os.environ.get("SYNAPSE_LEAVE_DB", False)
POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", "postgres") POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", "postgres")
POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),) POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),)
@ -244,8 +245,9 @@ def setup_test_homeserver(
cur.close() cur.close()
db_conn.close() db_conn.close()
# Register the cleanup hook if not LEAVE_DB:
cleanup_func(cleanup) # Register the cleanup hook
cleanup_func(cleanup)
hs.setup() hs.setup()
else: else: