Fix tests on postgresql (#3740)

This commit is contained in:
Amber Brown 2018-09-04 02:21:48 +10:00 committed by GitHub
parent 567363e497
commit 77055dba92
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 355 additions and 340 deletions

View file

@ -1,4 +1,5 @@
# Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -11,89 +12,91 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
from mock import Mock, NonCallableMock
from twisted.internet import defer, reactor
from twisted.internet.defer import Deferred
import attr
from synapse.replication.tcp.client import (
ReplicationClientFactory,
ReplicationClientHandler,
)
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
from tests import unittest
from tests.utils import setup_test_homeserver
class TestReplicationClientHandler(ReplicationClientHandler):
"""Overrides on_rdata so that we can wait for it to happen"""
class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
def __init__(self, store):
super(TestReplicationClientHandler, self).__init__(store)
self._rdata_awaiters = []
def await_replication(self):
d = Deferred()
self._rdata_awaiters.append(d)
return make_deferred_yieldable(d)
def on_rdata(self, stream_name, token, rows):
awaiters = self._rdata_awaiters
self._rdata_awaiters = []
super(TestReplicationClientHandler, self).on_rdata(stream_name, token, rows)
with PreserveLoggingContext():
for a in awaiters:
a.callback(None)
class BaseSlavedStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hs = yield setup_test_homeserver(
self.addCleanup,
hs = self.setup_test_homeserver(
"blue",
http_client=None,
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]),
)
self.hs.get_ratelimiter().send_message.return_value = (True, 0)
hs.get_ratelimiter().send_message.return_value = (True, 0)
return hs
def prepare(self, reactor, clock, hs):
self.master_store = self.hs.get_datastore()
self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
self.event_id = 0
server_factory = ReplicationStreamProtocolFactory(self.hs)
# XXX: mktemp is unsafe and should never be used. but we're just a test.
path = tempfile.mktemp(prefix="base_slaved_store_test_case_socket")
listener = reactor.listenUNIX(path, server_factory)
self.addCleanup(listener.stopListening)
self.streamer = server_factory.streamer
self.replication_handler = TestReplicationClientHandler(self.slaved_store)
self.replication_handler = ReplicationClientHandler(self.slaved_store)
client_factory = ReplicationClientFactory(
self.hs, "client_name", self.replication_handler
)
client_connector = reactor.connectUNIX(path, client_factory)
self.addCleanup(client_factory.stopTrying)
self.addCleanup(client_connector.disconnect)
server = server_factory.buildProtocol(None)
client = client_factory.buildProtocol(None)
@attr.s
class FakeTransport(object):
other = attr.ib()
disconnecting = False
buffer = attr.ib(default=b'')
def registerProducer(self, producer, streaming):
self.producer = producer
def _produce():
self.producer.resumeProducing()
reactor.callLater(0.1, _produce)
reactor.callLater(0.0, _produce)
def write(self, byt):
self.buffer = self.buffer + byt
if getattr(self.other, "transport") is not None:
self.other.dataReceived(self.buffer)
self.buffer = b""
def writeSequence(self, seq):
for x in seq:
self.write(x)
client.makeConnection(FakeTransport(server))
server.makeConnection(FakeTransport(client))
def replicate(self):
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
"""
# xxx: should we be more specific in what we wait for?
d = self.replication_handler.await_replication()
self.streamer.on_notifier_poke()
return d
self.pump(0.1)
@defer.inlineCallbacks
def check(self, method, args, expected_result=None):
master_result = yield getattr(self.master_store, method)(*args)
slaved_result = yield getattr(self.slaved_store, method)(*args)
master_result = self.get_success(getattr(self.master_store, method)(*args))
slaved_result = self.get_success(getattr(self.slaved_store, method)(*args))
if expected_result is not None:
self.assertEqual(master_result, expected_result)
self.assertEqual(slaved_result, expected_result)