Convert some util functions to async (#8035)

This commit is contained in:
Patrick Cloke 2020-08-06 08:39:35 -04:00 committed by GitHub
parent d4a7829b12
commit fe6cfc80ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 40 additions and 62 deletions

1
changelog.d/8035.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@ -13,14 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
import logging import logging
from functools import wraps from functools import wraps
from prometheus_client import Counter from prometheus_client import Counter
from twisted.internet import defer
from synapse.logging.context import LoggingContext, current_context from synapse.logging.context import LoggingContext, current_context
from synapse.metrics import InFlightGauge from synapse.metrics import InFlightGauge
@ -62,26 +59,32 @@ in_flight = InFlightGauge(
def measure_func(name=None): def measure_func(name=None):
"""
Used to decorate an async function with a `Measure` context manager.
Usage:
@measure_func()
async def foo(...):
...
Which is analogous to:
async def foo(...):
with Measure(...):
...
"""
def wrapper(func): def wrapper(func):
block_name = func.__name__ if name is None else name block_name = func.__name__ if name is None else name
if inspect.iscoroutinefunction(func):
@wraps(func) @wraps(func)
async def measured_func(self, *args, **kwargs): async def measured_func(self, *args, **kwargs):
with Measure(self.clock, block_name): with Measure(self.clock, block_name):
r = await func(self, *args, **kwargs) r = await func(self, *args, **kwargs)
return r return r
else:
@wraps(func)
@defer.inlineCallbacks
def measured_func(self, *args, **kwargs):
with Measure(self.clock, block_name):
r = yield func(self, *args, **kwargs)
return r
return measured_func return measured_func
return wrapper return wrapper

View File

@ -15,8 +15,6 @@
import logging import logging
import random import random
from twisted.internet import defer
import synapse.logging.context import synapse.logging.context
from synapse.api.errors import CodeMessageException from synapse.api.errors import CodeMessageException
@ -54,8 +52,7 @@ class NotRetryingDestination(Exception):
self.destination = destination self.destination = destination
@defer.inlineCallbacks async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs):
def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs):
"""For a given destination check if we have previously failed to """For a given destination check if we have previously failed to
send a request there and are waiting before retrying the destination. send a request there and are waiting before retrying the destination.
If we are not ready to retry the destination, this will raise a If we are not ready to retry the destination, this will raise a
@ -73,9 +70,9 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs)
Example usage: Example usage:
try: try:
limiter = yield get_retry_limiter(destination, clock, store) limiter = await get_retry_limiter(destination, clock, store)
with limiter: with limiter:
response = yield do_request() response = await do_request()
except NotRetryingDestination: except NotRetryingDestination:
# We aren't ready to retry that destination. # We aren't ready to retry that destination.
raise raise
@ -83,7 +80,7 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs)
failure_ts = None failure_ts = None
retry_last_ts, retry_interval = (0, 0) retry_last_ts, retry_interval = (0, 0)
retry_timings = yield store.get_destination_retry_timings(destination) retry_timings = await store.get_destination_retry_timings(destination)
if retry_timings: if retry_timings:
failure_ts = retry_timings["failure_ts"] failure_ts = retry_timings["failure_ts"]
@ -222,10 +219,9 @@ class RetryDestinationLimiter(object):
if self.failure_ts is None: if self.failure_ts is None:
self.failure_ts = retry_last_ts self.failure_ts = retry_last_ts
@defer.inlineCallbacks async def store_retry_timings():
def store_retry_timings():
try: try:
yield self.store.set_destination_retry_timings( await self.store.set_destination_retry_timings(
self.destination, self.destination,
self.failure_ts, self.failure_ts,
retry_last_ts, retry_last_ts,

View File

@ -26,9 +26,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
def test_new_destination(self): def test_new_destination(self):
"""A happy-path case with a new destination and a successful operation""" """A happy-path case with a new destination and a successful operation"""
store = self.hs.get_datastore() store = self.hs.get_datastore()
d = get_retry_limiter("test_dest", self.clock, store) limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
self.pump()
limiter = self.successResultOf(d)
# advance the clock a bit before making the request # advance the clock a bit before making the request
self.pump(1) self.pump(1)
@ -36,18 +34,14 @@ class RetryLimiterTestCase(HomeserverTestCase):
with limiter: with limiter:
pass pass
d = store.get_destination_retry_timings("test_dest") new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.pump()
new_timings = self.successResultOf(d)
self.assertIsNone(new_timings) self.assertIsNone(new_timings)
def test_limiter(self): def test_limiter(self):
"""General test case which walks through the process of a failing request""" """General test case which walks through the process of a failing request"""
store = self.hs.get_datastore() store = self.hs.get_datastore()
d = get_retry_limiter("test_dest", self.clock, store) limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
self.pump()
limiter = self.successResultOf(d)
self.pump(1) self.pump(1)
try: try:
@ -58,29 +52,22 @@ class RetryLimiterTestCase(HomeserverTestCase):
except AssertionError: except AssertionError:
pass pass
# wait for the update to land new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.pump()
d = store.get_destination_retry_timings("test_dest")
self.pump()
new_timings = self.successResultOf(d)
self.assertEqual(new_timings["failure_ts"], failure_ts) self.assertEqual(new_timings["failure_ts"], failure_ts)
self.assertEqual(new_timings["retry_last_ts"], failure_ts) self.assertEqual(new_timings["retry_last_ts"], failure_ts)
self.assertEqual(new_timings["retry_interval"], MIN_RETRY_INTERVAL) self.assertEqual(new_timings["retry_interval"], MIN_RETRY_INTERVAL)
# now if we try again we should get a failure # now if we try again we should get a failure
d = get_retry_limiter("test_dest", self.clock, store) self.get_failure(
self.pump() get_retry_limiter("test_dest", self.clock, store), NotRetryingDestination
self.failureResultOf(d, NotRetryingDestination) )
# #
# advance the clock and try again # advance the clock and try again
# #
self.pump(MIN_RETRY_INTERVAL) self.pump(MIN_RETRY_INTERVAL)
d = get_retry_limiter("test_dest", self.clock, store) limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
self.pump()
limiter = self.successResultOf(d)
self.pump(1) self.pump(1)
try: try:
@ -91,12 +78,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
except AssertionError: except AssertionError:
pass pass
# wait for the update to land new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.pump()
d = store.get_destination_retry_timings("test_dest")
self.pump()
new_timings = self.successResultOf(d)
self.assertEqual(new_timings["failure_ts"], failure_ts) self.assertEqual(new_timings["failure_ts"], failure_ts)
self.assertEqual(new_timings["retry_last_ts"], retry_ts) self.assertEqual(new_timings["retry_last_ts"], retry_ts)
self.assertGreaterEqual( self.assertGreaterEqual(
@ -110,9 +92,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
# one more go, with success # one more go, with success
# #
self.pump(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0) self.pump(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0)
d = get_retry_limiter("test_dest", self.clock, store) limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
self.pump()
limiter = self.successResultOf(d)
self.pump(1) self.pump(1)
with limiter: with limiter:
@ -121,7 +101,5 @@ class RetryLimiterTestCase(HomeserverTestCase):
# wait for the update to land # wait for the update to land
self.pump() self.pump()
d = store.get_destination_retry_timings("test_dest") new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.pump()
new_timings = self.successResultOf(d)
self.assertIsNone(new_timings) self.assertIsNone(new_timings)