Pass around the reactor explicitly (#3385)

This commit is contained in:
Amber Brown 2018-06-22 09:37:10 +01:00 committed by GitHub
parent c2eff937ac
commit 77ac14b960
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 141 additions and 93 deletions

View File

@ -33,6 +33,7 @@ import logging
import bcrypt import bcrypt
import pymacaroons import pymacaroons
import simplejson import simplejson
import attr
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
@ -854,7 +855,11 @@ class AuthHandler(BaseHandler):
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper, return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
bcrypt.gensalt(self.bcrypt_rounds)) bcrypt.gensalt(self.bcrypt_rounds))
return make_deferred_yieldable(threads.deferToThread(_do_hash)) return make_deferred_yieldable(
threads.deferToThreadPool(
self.hs.get_reactor(), self.hs.get_reactor().getThreadPool(), _do_hash
),
)
def validate_hash(self, password, stored_hash): def validate_hash(self, password, stored_hash):
"""Validates that self.hash(password) == stored_hash. """Validates that self.hash(password) == stored_hash.
@ -874,16 +879,21 @@ class AuthHandler(BaseHandler):
) )
if stored_hash: if stored_hash:
return make_deferred_yieldable(threads.deferToThread(_do_validate_hash)) return make_deferred_yieldable(
threads.deferToThreadPool(
self.hs.get_reactor(),
self.hs.get_reactor().getThreadPool(),
_do_validate_hash,
),
)
else: else:
return defer.succeed(False) return defer.succeed(False)
class MacaroonGeneartor(object): @attr.s
def __init__(self, hs): class MacaroonGenerator(object):
self.clock = hs.get_clock()
self.server_name = hs.config.server_name hs = attr.ib()
self.macaroon_secret_key = hs.config.macaroon_secret_key
def generate_access_token(self, user_id, extra_caveats=None): def generate_access_token(self, user_id, extra_caveats=None):
extra_caveats = extra_caveats or [] extra_caveats = extra_caveats or []
@ -901,7 +911,7 @@ class MacaroonGeneartor(object):
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)): def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
macaroon = self._generate_base_macaroon(user_id) macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login") macaroon.add_first_party_caveat("type = login")
now = self.clock.time_msec() now = self.hs.get_clock().time_msec()
expiry = now + duration_in_ms expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,)) macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize() return macaroon.serialize()
@ -913,9 +923,9 @@ class MacaroonGeneartor(object):
def _generate_base_macaroon(self, user_id): def _generate_base_macaroon(self, user_id):
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
location=self.server_name, location=self.hs.config.server_name,
identifier="key", identifier="key",
key=self.macaroon_secret_key) key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon return macaroon

View File

@ -806,6 +806,7 @@ class EventCreationHandler(object):
# If we're a worker we need to hit out to the master. # If we're a worker we need to hit out to the master.
if self.config.worker_app: if self.config.worker_app:
yield send_event_to_master( yield send_event_to_master(
self.hs.get_clock(),
self.http_client, self.http_client,
host=self.config.worker_replication_host, host=self.config.worker_replication_host,
port=self.config.worker_replication_http_port, port=self.config.worker_replication_http_port,

View File

@ -19,7 +19,6 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.storage.roommember import ProfileInfo from synapse.storage.roommember import ProfileInfo
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.async import sleep
from synapse.types import get_localpart_from_id from synapse.types import get_localpart_from_id
from six import iteritems from six import iteritems
@ -174,7 +173,7 @@ class UserDirectoryHandler(object):
logger.info("Handling room %d/%d", num_processed_rooms + 1, len(room_ids)) logger.info("Handling room %d/%d", num_processed_rooms + 1, len(room_ids))
yield self._handle_initial_room(room_id) yield self._handle_initial_room(room_id)
num_processed_rooms += 1 num_processed_rooms += 1
yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.) yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
logger.info("Processed all rooms.") logger.info("Processed all rooms.")
@ -188,7 +187,7 @@ class UserDirectoryHandler(object):
logger.info("Handling user %d/%d", num_processed_users + 1, len(user_ids)) logger.info("Handling user %d/%d", num_processed_users + 1, len(user_ids))
yield self._handle_local_user(user_id) yield self._handle_local_user(user_id)
num_processed_users += 1 num_processed_users += 1
yield sleep(self.INITIAL_USER_SLEEP_MS / 1000.) yield self.clock.sleep(self.INITIAL_USER_SLEEP_MS / 1000.)
logger.info("Processed all users") logger.info("Processed all users")
@ -236,7 +235,7 @@ class UserDirectoryHandler(object):
count = 0 count = 0
for user_id in user_ids: for user_id in user_ids:
if count % self.INITIAL_ROOM_SLEEP_COUNT == 0: if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.) yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
if not self.is_mine_id(user_id): if not self.is_mine_id(user_id):
count += 1 count += 1
@ -251,7 +250,7 @@ class UserDirectoryHandler(object):
continue continue
if count % self.INITIAL_ROOM_SLEEP_COUNT == 0: if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.) yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
count += 1 count += 1
user_set = (user_id, other_user_id) user_set = (user_id, other_user_id)

View File

@ -98,8 +98,8 @@ class SimpleHttpClient(object):
method, uri, *args, **kwargs method, uri, *args, **kwargs
) )
add_timeout_to_deferred( add_timeout_to_deferred(
request_deferred, request_deferred, 60, self.hs.get_reactor(),
60, cancelled_to_request_timed_out_error, cancelled_to_request_timed_out_error,
) )
response = yield make_deferred_yieldable(request_deferred) response = yield make_deferred_yieldable(request_deferred)
@ -115,7 +115,7 @@ class SimpleHttpClient(object):
"Error sending request to %s %s: %s %s", "Error sending request to %s %s: %s %s",
method, redact_uri(uri), type(e).__name__, e.message method, redact_uri(uri), type(e).__name__, e.message
) )
raise e raise
@defer.inlineCallbacks @defer.inlineCallbacks
def post_urlencoded_get_json(self, uri, args={}, headers=None): def post_urlencoded_get_json(self, uri, args={}, headers=None):

View File

@ -22,7 +22,7 @@ from twisted.web._newclient import ResponseDone
from synapse.http import cancelled_to_request_timed_out_error from synapse.http import cancelled_to_request_timed_out_error
from synapse.http.endpoint import matrix_federation_endpoint from synapse.http.endpoint import matrix_federation_endpoint
import synapse.metrics import synapse.metrics
from synapse.util.async import sleep, add_timeout_to_deferred from synapse.util.async import add_timeout_to_deferred
from synapse.util import logcontext from synapse.util import logcontext
from synapse.util.logcontext import make_deferred_yieldable from synapse.util.logcontext import make_deferred_yieldable
import synapse.util.retryutils import synapse.util.retryutils
@ -193,6 +193,7 @@ class MatrixFederationHttpClient(object):
add_timeout_to_deferred( add_timeout_to_deferred(
request_deferred, request_deferred,
timeout / 1000. if timeout else 60, timeout / 1000. if timeout else 60,
self.hs.get_reactor(),
cancelled_to_request_timed_out_error, cancelled_to_request_timed_out_error,
) )
response = yield make_deferred_yieldable( response = yield make_deferred_yieldable(
@ -234,7 +235,7 @@ class MatrixFederationHttpClient(object):
delay = min(delay, 2) delay = min(delay, 2)
delay *= random.uniform(0.8, 1.4) delay *= random.uniform(0.8, 1.4)
yield sleep(delay) yield self.clock.sleep(delay)
retries_left -= 1 retries_left -= 1
else: else:
raise raise

View File

@ -161,6 +161,7 @@ class Notifier(object):
self.user_to_user_stream = {} self.user_to_user_stream = {}
self.room_to_user_streams = {} self.room_to_user_streams = {}
self.hs = hs
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.pending_new_room_events = [] self.pending_new_room_events = []
@ -340,6 +341,7 @@ class Notifier(object):
add_timeout_to_deferred( add_timeout_to_deferred(
listener.deferred, listener.deferred,
(end_time - now) / 1000., (end_time - now) / 1000.,
self.hs.get_reactor(),
) )
with PreserveLoggingContext(): with PreserveLoggingContext():
yield listener.deferred yield listener.deferred
@ -561,6 +563,7 @@ class Notifier(object):
add_timeout_to_deferred( add_timeout_to_deferred(
listener.deferred.addTimeout, listener.deferred.addTimeout,
(end_time - now) / 1000., (end_time - now) / 1000.,
self.hs.get_reactor(),
) )
try: try:
with PreserveLoggingContext(): with PreserveLoggingContext():

View File

@ -21,7 +21,6 @@ from synapse.api.errors import (
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.util.async import sleep
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.types import Requester, UserID from synapse.types import Requester, UserID
@ -33,11 +32,12 @@ logger = logging.getLogger(__name__)
@defer.inlineCallbacks @defer.inlineCallbacks
def send_event_to_master(client, host, port, requester, event, context, def send_event_to_master(clock, client, host, port, requester, event, context,
ratelimit, extra_users): ratelimit, extra_users):
"""Send event to be handled on the master """Send event to be handled on the master
Args: Args:
clock (synapse.util.Clock)
client (SimpleHttpClient) client (SimpleHttpClient)
host (str): host of master host (str): host of master
port (int): port on master listening for HTTP replication port (int): port on master listening for HTTP replication
@ -77,7 +77,7 @@ def send_event_to_master(client, host, port, requester, event, context,
# If we timed out we probably don't need to worry about backing # If we timed out we probably don't need to worry about backing
# off too much, but lets just wait a little anyway. # off too much, but lets just wait a little anyway.
yield sleep(1) yield clock.sleep(1)
except MatrixCodeMessageException as e: except MatrixCodeMessageException as e:
# We convert to SynapseError as we know that it was a SynapseError # We convert to SynapseError as we know that it was a SynapseError
# on the master process that we should send to the client. (And # on the master process that we should send to the client. (And

View File

@ -58,6 +58,7 @@ UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
class MediaRepository(object): class MediaRepository(object):
def __init__(self, hs): def __init__(self, hs):
self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.client = MatrixFederationHttpClient(hs) self.client = MatrixFederationHttpClient(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -94,7 +95,7 @@ class MediaRepository(object):
storage_providers.append(provider) storage_providers.append(provider)
self.media_storage = MediaStorage( self.media_storage = MediaStorage(
self.primary_base_path, self.filepaths, storage_providers, self.hs, self.primary_base_path, self.filepaths, storage_providers,
) )
self.clock.looping_call( self.clock.looping_call(

View File

@ -37,13 +37,15 @@ class MediaStorage(object):
"""Responsible for storing/fetching files from local sources. """Responsible for storing/fetching files from local sources.
Args: Args:
hs (synapse.server.Homeserver)
local_media_directory (str): Base path where we store media on disk local_media_directory (str): Base path where we store media on disk
filepaths (MediaFilePaths) filepaths (MediaFilePaths)
storage_providers ([StorageProvider]): List of StorageProvider that are storage_providers ([StorageProvider]): List of StorageProvider that are
used to fetch and store files. used to fetch and store files.
""" """
def __init__(self, local_media_directory, filepaths, storage_providers): def __init__(self, hs, local_media_directory, filepaths, storage_providers):
self.hs = hs
self.local_media_directory = local_media_directory self.local_media_directory = local_media_directory
self.filepaths = filepaths self.filepaths = filepaths
self.storage_providers = storage_providers self.storage_providers = storage_providers
@ -175,7 +177,8 @@ class MediaStorage(object):
res = yield provider.fetch(path, file_info) res = yield provider.fetch(path, file_info)
if res: if res:
with res: with res:
consumer = BackgroundFileConsumer(open(local_path, "w")) consumer = BackgroundFileConsumer(
open(local_path, "w"), self.hs.get_reactor())
yield res.write_to_consumer(consumer) yield res.write_to_consumer(consumer)
yield consumer.wait() yield consumer.wait()
defer.returnValue(local_path) defer.returnValue(local_path)

View File

@ -40,7 +40,7 @@ from synapse.federation.transport.client import TransportLayerClient
from synapse.federation.transaction_queue import TransactionQueue from synapse.federation.transaction_queue import TransactionQueue
from synapse.handlers import Handlers from synapse.handlers import Handlers
from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler, MacaroonGeneartor from synapse.handlers.auth import AuthHandler, MacaroonGenerator
from synapse.handlers.deactivate_account import DeactivateAccountHandler from synapse.handlers.deactivate_account import DeactivateAccountHandler
from synapse.handlers.devicemessage import DeviceMessageHandler from synapse.handlers.devicemessage import DeviceMessageHandler
from synapse.handlers.device import DeviceHandler from synapse.handlers.device import DeviceHandler
@ -165,15 +165,19 @@ class HomeServer(object):
'server_notices_sender', 'server_notices_sender',
] ]
def __init__(self, hostname, **kwargs): def __init__(self, hostname, reactor=None, **kwargs):
""" """
Args: Args:
hostname : The hostname for the server. hostname : The hostname for the server.
""" """
if not reactor:
from twisted.internet import reactor
self._reactor = reactor
self.hostname = hostname self.hostname = hostname
self._building = {} self._building = {}
self.clock = Clock() self.clock = Clock(reactor)
self.distributor = Distributor() self.distributor = Distributor()
self.ratelimiter = Ratelimiter() self.ratelimiter = Ratelimiter()
@ -186,6 +190,12 @@ class HomeServer(object):
self.datastore = DataStore(self.get_db_conn(), self) self.datastore = DataStore(self.get_db_conn(), self)
logger.info("Finished setting up.") logger.info("Finished setting up.")
def get_reactor(self):
"""
Fetch the Twisted reactor in use by this HomeServer.
"""
return self._reactor
def get_ip_from_request(self, request): def get_ip_from_request(self, request):
# X-Forwarded-For is handled by our custom request type. # X-Forwarded-For is handled by our custom request type.
return request.getClientIP() return request.getClientIP()
@ -261,7 +271,7 @@ class HomeServer(object):
return AuthHandler(self) return AuthHandler(self)
def build_macaroon_generator(self): def build_macaroon_generator(self):
return MacaroonGeneartor(self) return MacaroonGenerator(self)
def build_device_handler(self): def build_device_handler(self):
return DeviceHandler(self) return DeviceHandler(self)
@ -328,6 +338,7 @@ class HomeServer(object):
return adbapi.ConnectionPool( return adbapi.ConnectionPool(
name, name,
cp_reactor=self.get_reactor(),
**self.db_config.get("args", {}) **self.db_config.get("args", {})
) )

View File

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import synapse.util.async
from ._base import SQLBaseStore from ._base import SQLBaseStore
from . import engines from . import engines
@ -92,7 +91,7 @@ class BackgroundUpdateStore(SQLBaseStore):
logger.info("Starting background schema updates") logger.info("Starting background schema updates")
while True: while True:
yield synapse.util.async.sleep( yield self.hs.get_clock().sleep(
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.) self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.)
try: try:

View File

@ -15,7 +15,7 @@
import logging import logging
from twisted.internet import defer, reactor from twisted.internet import defer
from ._base import Cache from ._base import Cache
from . import background_updates from . import background_updates
@ -70,7 +70,9 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
self._client_ip_looper = self._clock.looping_call( self._client_ip_looper = self._clock.looping_call(
self._update_client_ips_batch, 5 * 1000 self._update_client_ips_batch, 5 * 1000
) )
reactor.addSystemEventTrigger("before", "shutdown", self._update_client_ips_batch) self.hs.get_reactor().addSystemEventTrigger(
"before", "shutdown", self._update_client_ips_batch
)
def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id, def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id,
now=None): now=None):

View File

@ -16,7 +16,6 @@
from synapse.storage._base import SQLBaseStore, LoggingTransaction from synapse.storage._base import SQLBaseStore, LoggingTransaction
from twisted.internet import defer from twisted.internet import defer
from synapse.util.async import sleep
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks
import logging import logging
@ -800,7 +799,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
) )
if caught_up: if caught_up:
break break
yield sleep(5) yield self.hs.get_clock().sleep(5)
finally: finally:
self._doing_notif_rotation = False self._doing_notif_rotation = False

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore from ._base import SQLBaseStore
from twisted.internet import defer, reactor from twisted.internet import defer
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
@ -265,7 +265,7 @@ class EventsWorkerStore(SQLBaseStore):
except Exception: except Exception:
logger.exception("Failed to callback") logger.exception("Failed to callback")
with PreserveLoggingContext(): with PreserveLoggingContext():
reactor.callFromThread(fire, event_list, row_dict) self.hs.get_reactor().callFromThread(fire, event_list, row_dict)
except Exception as e: except Exception as e:
logger.exception("do_fetch") logger.exception("do_fetch")
@ -278,7 +278,7 @@ class EventsWorkerStore(SQLBaseStore):
if event_list: if event_list:
with PreserveLoggingContext(): with PreserveLoggingContext():
reactor.callFromThread(fire, event_list) self.hs.get_reactor().callFromThread(fire, event_list)
@defer.inlineCallbacks @defer.inlineCallbacks
def _enqueue_events(self, events, check_redacted=True, allow_rejected=False): def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):

View File

@ -13,15 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.util.logcontext import PreserveLoggingContext
from twisted.internet import defer, reactor, task
import time
import logging import logging
from itertools import islice from itertools import islice
import attr
from twisted.internet import defer, task
from synapse.util.logcontext import PreserveLoggingContext
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -31,16 +30,24 @@ def unwrapFirstError(failure):
return failure.value.subFailure return failure.value.subFailure
@attr.s
class Clock(object): class Clock(object):
"""A small utility that obtains current time-of-day so that time may be
mocked during unit-tests.
TODO(paul): Also move the sleep() functionality into it
""" """
A Clock wraps a Twisted reactor and provides utilities on top of it.
"""
_reactor = attr.ib()
@defer.inlineCallbacks
def sleep(self, seconds):
d = defer.Deferred()
with PreserveLoggingContext():
self._reactor.callLater(seconds, d.callback, seconds)
res = yield d
defer.returnValue(res)
def time(self): def time(self):
"""Returns the current system time in seconds since epoch.""" """Returns the current system time in seconds since epoch."""
return time.time() return self._reactor.seconds()
def time_msec(self): def time_msec(self):
"""Returns the current system time in miliseconds since epoch.""" """Returns the current system time in miliseconds since epoch."""
@ -56,6 +63,7 @@ class Clock(object):
msec(float): How long to wait between calls in milliseconds. msec(float): How long to wait between calls in milliseconds.
""" """
call = task.LoopingCall(f) call = task.LoopingCall(f)
call.clock = self._reactor
call.start(msec / 1000.0, now=False) call.start(msec / 1000.0, now=False)
return call return call
@ -73,7 +81,7 @@ class Clock(object):
callback(*args, **kwargs) callback(*args, **kwargs)
with PreserveLoggingContext(): with PreserveLoggingContext():
return reactor.callLater(delay, wrapped_callback, *args, **kwargs) return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs)
def cancel_call_later(self, timer, ignore_errs=False): def cancel_call_later(self, timer, ignore_errs=False):
try: try:

View File

@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer, reactor from twisted.internet import defer
from twisted.internet.defer import CancelledError from twisted.internet.defer import CancelledError
from twisted.python import failure from twisted.python import failure
from .logcontext import ( from .logcontext import (
PreserveLoggingContext, make_deferred_yieldable, run_in_background PreserveLoggingContext, make_deferred_yieldable, run_in_background
) )
from synapse.util import logcontext, unwrapFirstError from synapse.util import logcontext, unwrapFirstError, Clock
from contextlib import contextmanager from contextlib import contextmanager
@ -31,15 +31,6 @@ from six.moves import range
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@defer.inlineCallbacks
def sleep(seconds):
d = defer.Deferred()
with PreserveLoggingContext():
reactor.callLater(seconds, d.callback, seconds)
res = yield d
defer.returnValue(res)
class ObservableDeferred(object): class ObservableDeferred(object):
"""Wraps a deferred object so that we can add observer deferreds. These """Wraps a deferred object so that we can add observer deferreds. These
observer deferreds do not affect the callback chain of the original observer deferreds do not affect the callback chain of the original
@ -172,13 +163,18 @@ class Linearizer(object):
# do some work. # do some work.
""" """
def __init__(self, name=None): def __init__(self, name=None, clock=None):
if name is None: if name is None:
self.name = id(self) self.name = id(self)
else: else:
self.name = name self.name = name
self.key_to_defer = {} self.key_to_defer = {}
if not clock:
from twisted.internet import reactor
clock = Clock(reactor)
self._clock = clock
@defer.inlineCallbacks @defer.inlineCallbacks
def queue(self, key): def queue(self, key):
# If there is already a deferred in the queue, we pull it out so that # If there is already a deferred in the queue, we pull it out so that
@ -219,7 +215,7 @@ class Linearizer(object):
# the context manager, but it needs to happen while we hold the # the context manager, but it needs to happen while we hold the
# lock, and the context manager's exit code must be synchronous, # lock, and the context manager's exit code must be synchronous,
# so actually this is the only sensible place. # so actually this is the only sensible place.
yield sleep(0) yield self._clock.sleep(0)
else: else:
logger.info("Acquired uncontended linearizer lock %r for key %r", logger.info("Acquired uncontended linearizer lock %r for key %r",
@ -396,7 +392,7 @@ class DeferredTimeoutError(Exception):
""" """
def add_timeout_to_deferred(deferred, timeout, on_timeout_cancel=None): def add_timeout_to_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
""" """
Add a timeout to a deferred by scheduling it to be cancelled after Add a timeout to a deferred by scheduling it to be cancelled after
timeout seconds. timeout seconds.
@ -411,6 +407,7 @@ def add_timeout_to_deferred(deferred, timeout, on_timeout_cancel=None):
Args: Args:
deferred (defer.Deferred): deferred to be timed out deferred (defer.Deferred): deferred to be timed out
timeout (Number): seconds to time out after timeout (Number): seconds to time out after
reactor (twisted.internet.reactor): the Twisted reactor to use
on_timeout_cancel (callable): A callable which is called immediately on_timeout_cancel (callable): A callable which is called immediately
after the deferred times out, and not if this deferred is after the deferred times out, and not if this deferred is

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import threads, reactor from twisted.internet import threads
from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.util.logcontext import make_deferred_yieldable, run_in_background
@ -27,6 +27,7 @@ class BackgroundFileConsumer(object):
Args: Args:
file_obj (file): The file like object to write to. Closed when file_obj (file): The file like object to write to. Closed when
finished. finished.
reactor (twisted.internet.reactor): the Twisted reactor to use
""" """
# For PushProducers pause if we have this many unwritten slices # For PushProducers pause if we have this many unwritten slices
@ -34,9 +35,11 @@ class BackgroundFileConsumer(object):
# And resume once the size of the queue is less than this # And resume once the size of the queue is less than this
_RESUME_ON_QUEUE_SIZE = 2 _RESUME_ON_QUEUE_SIZE = 2
def __init__(self, file_obj): def __init__(self, file_obj, reactor):
self._file_obj = file_obj self._file_obj = file_obj
self._reactor = reactor
# Producer we're registered with # Producer we're registered with
self._producer = None self._producer = None
@ -71,7 +74,10 @@ class BackgroundFileConsumer(object):
self._producer = producer self._producer = producer
self.streaming = streaming self.streaming = streaming
self._finished_deferred = run_in_background( self._finished_deferred = run_in_background(
threads.deferToThread, self._writer threads.deferToThreadPool,
self._reactor,
self._reactor.getThreadPool(),
self._writer,
) )
if not streaming: if not streaming:
self._producer.resumeProducing() self._producer.resumeProducing()
@ -109,7 +115,7 @@ class BackgroundFileConsumer(object):
# producer. # producer.
if self._producer and self._paused_producer: if self._producer and self._paused_producer:
if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE: if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE:
reactor.callFromThread(self._resume_paused_producer) self._reactor.callFromThread(self._resume_paused_producer)
bytes = self._bytes_queue.get() bytes = self._bytes_queue.get()
@ -121,7 +127,7 @@ class BackgroundFileConsumer(object):
# If its a pull producer then we need to explicitly ask for # If its a pull producer then we need to explicitly ask for
# more stuff. # more stuff.
if not self.streaming and self._producer: if not self.streaming and self._producer:
reactor.callFromThread(self._producer.resumeProducing) self._reactor.callFromThread(self._producer.resumeProducing)
except Exception as e: except Exception as e:
self._write_exception = e self._write_exception = e
raise raise

View File

@ -17,7 +17,6 @@ from twisted.internet import defer
from synapse.api.errors import LimitExceededError from synapse.api.errors import LimitExceededError
from synapse.util.async import sleep
from synapse.util.logcontext import ( from synapse.util.logcontext import (
run_in_background, make_deferred_yieldable, run_in_background, make_deferred_yieldable,
PreserveLoggingContext, PreserveLoggingContext,
@ -153,7 +152,7 @@ class _PerHostRatelimiter(object):
"Ratelimit [%s]: sleeping req", "Ratelimit [%s]: sleeping req",
id(request_id), id(request_id),
) )
ret_defer = run_in_background(sleep, self.sleep_msec / 1000.0) ret_defer = run_in_background(self.clock.sleep, self.sleep_msec / 1000.0)
self.sleeping_requests.add(request_id) self.sleeping_requests.add(request_id)

View File

@ -19,10 +19,10 @@ import signedjson.sign
from mock import Mock from mock import Mock
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.crypto import keyring from synapse.crypto import keyring
from synapse.util import async, logcontext from synapse.util import logcontext, Clock
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from tests import unittest, utils from tests import unittest, utils
from twisted.internet import defer from twisted.internet import defer, reactor
class MockPerspectiveServer(object): class MockPerspectiveServer(object):
@ -118,6 +118,7 @@ class KeyringTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_verify_json_objects_for_server_awaits_previous_requests(self): def test_verify_json_objects_for_server_awaits_previous_requests(self):
clock = Clock(reactor)
key1 = signedjson.key.generate_signing_key(1) key1 = signedjson.key.generate_signing_key(1)
kr = keyring.Keyring(self.hs) kr = keyring.Keyring(self.hs)
@ -167,7 +168,7 @@ class KeyringTestCase(unittest.TestCase):
# wait a tick for it to send the request to the perspectives server # wait a tick for it to send the request to the perspectives server
# (it first tries the datastore) # (it first tries the datastore)
yield async.sleep(1) # XXX find out why this takes so long! yield clock.sleep(1) # XXX find out why this takes so long!
self.http_client.post_json.assert_called_once() self.http_client.post_json.assert_called_once()
self.assertIs(LoggingContext.current_context(), context_11) self.assertIs(LoggingContext.current_context(), context_11)
@ -183,7 +184,7 @@ class KeyringTestCase(unittest.TestCase):
res_deferreds_2 = kr.verify_json_objects_for_server( res_deferreds_2 = kr.verify_json_objects_for_server(
[("server10", json1)], [("server10", json1)],
) )
yield async.sleep(1) yield clock.sleep(1)
self.http_client.post_json.assert_not_called() self.http_client.post_json.assert_not_called()
res_deferreds_2[0].addBoth(self.check_context, None) res_deferreds_2[0].addBoth(self.check_context, None)

View File

@ -1,9 +1,9 @@
from synapse.rest.client.transactions import HttpTransactionCache from synapse.rest.client.transactions import HttpTransactionCache
from synapse.rest.client.transactions import CLEANUP_PERIOD_MS from synapse.rest.client.transactions import CLEANUP_PERIOD_MS
from twisted.internet import defer from twisted.internet import defer, reactor
from mock import Mock, call from mock import Mock, call
from synapse.util import async from synapse.util import Clock
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from tests import unittest from tests import unittest
from tests.utils import MockClock from tests.utils import MockClock
@ -46,7 +46,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
def test_logcontexts_with_async_result(self): def test_logcontexts_with_async_result(self):
@defer.inlineCallbacks @defer.inlineCallbacks
def cb(): def cb():
yield async.sleep(0) yield Clock(reactor).sleep(0)
defer.returnValue("yay") defer.returnValue("yay")
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer, reactor
from synapse.rest.media.v1._base import FileInfo from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.media_storage import MediaStorage from synapse.rest.media.v1.media_storage import MediaStorage
@ -38,6 +38,7 @@ class MediaStorageTests(unittest.TestCase):
self.secondary_base_path = os.path.join(self.test_dir, "secondary") self.secondary_base_path = os.path.join(self.test_dir, "secondary")
hs = Mock() hs = Mock()
hs.get_reactor = Mock(return_value=reactor)
hs.config.media_store_path = self.primary_base_path hs.config.media_store_path = self.primary_base_path
storage_providers = [FileStorageProviderBackend( storage_providers = [FileStorageProviderBackend(
@ -46,7 +47,7 @@ class MediaStorageTests(unittest.TestCase):
self.filepaths = MediaFilePaths(self.primary_base_path) self.filepaths = MediaFilePaths(self.primary_base_path)
self.media_storage = MediaStorage( self.media_storage = MediaStorage(
self.primary_base_path, self.filepaths, storage_providers, hs, self.primary_base_path, self.filepaths, storage_providers,
) )
def tearDown(self): def tearDown(self):

View File

@ -30,7 +30,7 @@ class FileConsumerTests(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_pull_consumer(self): def test_pull_consumer(self):
string_file = StringIO() string_file = StringIO()
consumer = BackgroundFileConsumer(string_file) consumer = BackgroundFileConsumer(string_file, reactor=reactor)
try: try:
producer = DummyPullProducer() producer = DummyPullProducer()
@ -54,7 +54,7 @@ class FileConsumerTests(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_push_consumer(self): def test_push_consumer(self):
string_file = BlockingStringWrite() string_file = BlockingStringWrite()
consumer = BackgroundFileConsumer(string_file) consumer = BackgroundFileConsumer(string_file, reactor=reactor)
try: try:
producer = NonCallableMock(spec_set=[]) producer = NonCallableMock(spec_set=[])
@ -80,7 +80,7 @@ class FileConsumerTests(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_push_producer_feedback(self): def test_push_producer_feedback(self):
string_file = BlockingStringWrite() string_file = BlockingStringWrite()
consumer = BackgroundFileConsumer(string_file) consumer = BackgroundFileConsumer(string_file, reactor=reactor)
try: try:
producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"]) producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])

View File

@ -12,10 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.util import async, logcontext
from synapse.util import logcontext, Clock
from tests import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer, reactor
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from six.moves import range from six.moves import range
@ -53,7 +54,7 @@ class LinearizerTestCase(unittest.TestCase):
self.assertEqual( self.assertEqual(
logcontext.LoggingContext.current_context(), lc) logcontext.LoggingContext.current_context(), lc)
if sleep: if sleep:
yield async.sleep(0) yield Clock(reactor).sleep(0)
self.assertEqual( self.assertEqual(
logcontext.LoggingContext.current_context(), lc) logcontext.LoggingContext.current_context(), lc)

View File

@ -3,8 +3,7 @@ from twisted.internet import defer
from twisted.internet import reactor from twisted.internet import reactor
from .. import unittest from .. import unittest
from synapse.util.async import sleep from synapse.util import logcontext, Clock
from synapse.util import logcontext
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
@ -22,18 +21,20 @@ class LoggingContextTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_sleep(self): def test_sleep(self):
clock = Clock(reactor)
@defer.inlineCallbacks @defer.inlineCallbacks
def competing_callback(): def competing_callback():
with LoggingContext() as competing_context: with LoggingContext() as competing_context:
competing_context.request = "competing" competing_context.request = "competing"
yield sleep(0) yield clock.sleep(0)
self._check_test_key("competing") self._check_test_key("competing")
reactor.callLater(0, competing_callback) reactor.callLater(0, competing_callback)
with LoggingContext() as context_one: with LoggingContext() as context_one:
context_one.request = "one" context_one.request = "one"
yield sleep(0) yield clock.sleep(0)
self._check_test_key("one") self._check_test_key("one")
def _test_run_in_background(self, function): def _test_run_in_background(self, function):
@ -87,7 +88,7 @@ class LoggingContextTestCase(unittest.TestCase):
def test_run_in_background_with_blocking_fn(self): def test_run_in_background_with_blocking_fn(self):
@defer.inlineCallbacks @defer.inlineCallbacks
def blocking_function(): def blocking_function():
yield sleep(0) yield Clock(reactor).sleep(0)
return self._test_run_in_background(blocking_function) return self._test_run_in_background(blocking_function)

View File

@ -37,11 +37,15 @@ USE_POSTGRES_FOR_TESTS = False
@defer.inlineCallbacks @defer.inlineCallbacks
def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None,
**kargs):
"""Setup a homeserver suitable for running tests against. Keyword arguments """Setup a homeserver suitable for running tests against. Keyword arguments
are passed to the Homeserver constructor. If no datastore is supplied a are passed to the Homeserver constructor. If no datastore is supplied a
datastore backed by an in-memory sqlite db will be given to the HS. datastore backed by an in-memory sqlite db will be given to the HS.
""" """
if reactor is None:
from twisted.internet import reactor
if config is None: if config is None:
config = Mock() config = Mock()
config.signing_key = [MockKey()] config.signing_key = [MockKey()]
@ -110,6 +114,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
database_engine=db_engine, database_engine=db_engine,
room_list_handler=object(), room_list_handler=object(),
tls_server_context_factory=Mock(), tls_server_context_factory=Mock(),
reactor=reactor,
**kargs **kargs
) )
db_conn = hs.get_db_conn() db_conn = hs.get_db_conn()