mirror of
				https://git.anonymousland.org/anonymousland/synapse.git
				synced 2025-11-03 22:34:10 -05:00 
			
		
		
		
	Convert some util functions to async (#8035)
This commit is contained in:
		
							parent
							
								
									d4a7829b12
								
							
						
					
					
						commit
						fe6cfc80ec
					
				
					 4 changed files with 40 additions and 62 deletions
				
			
		
							
								
								
									
										1
									
								
								changelog.d/8035.misc
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								changelog.d/8035.misc
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1 @@
 | 
			
		|||
Convert various parts of the codebase to async/await.
 | 
			
		||||
| 
						 | 
				
			
			@ -13,14 +13,11 @@
 | 
			
		|||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import inspect
 | 
			
		||||
import logging
 | 
			
		||||
from functools import wraps
 | 
			
		||||
 | 
			
		||||
from prometheus_client import Counter
 | 
			
		||||
 | 
			
		||||
from twisted.internet import defer
 | 
			
		||||
 | 
			
		||||
from synapse.logging.context import LoggingContext, current_context
 | 
			
		||||
from synapse.metrics import InFlightGauge
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -62,25 +59,31 @@ in_flight = InFlightGauge(
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def measure_func(name=None):
 | 
			
		||||
    """
 | 
			
		||||
    Used to decorate an async function with a `Measure` context manager.
 | 
			
		||||
 | 
			
		||||
    Usage:
 | 
			
		||||
 | 
			
		||||
    @measure_func()
 | 
			
		||||
    async def foo(...):
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    Which is analogous to:
 | 
			
		||||
 | 
			
		||||
    async def foo(...):
 | 
			
		||||
        with Measure(...):
 | 
			
		||||
            ...
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def wrapper(func):
 | 
			
		||||
        block_name = func.__name__ if name is None else name
 | 
			
		||||
 | 
			
		||||
        if inspect.iscoroutinefunction(func):
 | 
			
		||||
 | 
			
		||||
            @wraps(func)
 | 
			
		||||
            async def measured_func(self, *args, **kwargs):
 | 
			
		||||
                with Measure(self.clock, block_name):
 | 
			
		||||
                    r = await func(self, *args, **kwargs)
 | 
			
		||||
                return r
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
 | 
			
		||||
            @wraps(func)
 | 
			
		||||
            @defer.inlineCallbacks
 | 
			
		||||
            def measured_func(self, *args, **kwargs):
 | 
			
		||||
                with Measure(self.clock, block_name):
 | 
			
		||||
                    r = yield func(self, *args, **kwargs)
 | 
			
		||||
                return r
 | 
			
		||||
        @wraps(func)
 | 
			
		||||
        async def measured_func(self, *args, **kwargs):
 | 
			
		||||
            with Measure(self.clock, block_name):
 | 
			
		||||
                r = await func(self, *args, **kwargs)
 | 
			
		||||
            return r
 | 
			
		||||
 | 
			
		||||
        return measured_func
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -15,8 +15,6 @@
 | 
			
		|||
import logging
 | 
			
		||||
import random
 | 
			
		||||
 | 
			
		||||
from twisted.internet import defer
 | 
			
		||||
 | 
			
		||||
import synapse.logging.context
 | 
			
		||||
from synapse.api.errors import CodeMessageException
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -54,8 +52,7 @@ class NotRetryingDestination(Exception):
 | 
			
		|||
        self.destination = destination
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@defer.inlineCallbacks
 | 
			
		||||
def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs):
 | 
			
		||||
async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs):
 | 
			
		||||
    """For a given destination check if we have previously failed to
 | 
			
		||||
    send a request there and are waiting before retrying the destination.
 | 
			
		||||
    If we are not ready to retry the destination, this will raise a
 | 
			
		||||
| 
						 | 
				
			
			@ -73,9 +70,9 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs)
 | 
			
		|||
    Example usage:
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            limiter = yield get_retry_limiter(destination, clock, store)
 | 
			
		||||
            limiter = await get_retry_limiter(destination, clock, store)
 | 
			
		||||
            with limiter:
 | 
			
		||||
                response = yield do_request()
 | 
			
		||||
                response = await do_request()
 | 
			
		||||
        except NotRetryingDestination:
 | 
			
		||||
            # We aren't ready to retry that destination.
 | 
			
		||||
            raise
 | 
			
		||||
| 
						 | 
				
			
			@ -83,7 +80,7 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs)
 | 
			
		|||
    failure_ts = None
 | 
			
		||||
    retry_last_ts, retry_interval = (0, 0)
 | 
			
		||||
 | 
			
		||||
    retry_timings = yield store.get_destination_retry_timings(destination)
 | 
			
		||||
    retry_timings = await store.get_destination_retry_timings(destination)
 | 
			
		||||
 | 
			
		||||
    if retry_timings:
 | 
			
		||||
        failure_ts = retry_timings["failure_ts"]
 | 
			
		||||
| 
						 | 
				
			
			@ -222,10 +219,9 @@ class RetryDestinationLimiter(object):
 | 
			
		|||
            if self.failure_ts is None:
 | 
			
		||||
                self.failure_ts = retry_last_ts
 | 
			
		||||
 | 
			
		||||
        @defer.inlineCallbacks
 | 
			
		||||
        def store_retry_timings():
 | 
			
		||||
        async def store_retry_timings():
 | 
			
		||||
            try:
 | 
			
		||||
                yield self.store.set_destination_retry_timings(
 | 
			
		||||
                await self.store.set_destination_retry_timings(
 | 
			
		||||
                    self.destination,
 | 
			
		||||
                    self.failure_ts,
 | 
			
		||||
                    retry_last_ts,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -26,9 +26,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
 | 
			
		|||
    def test_new_destination(self):
 | 
			
		||||
        """A happy-path case with a new destination and a successful operation"""
 | 
			
		||||
        store = self.hs.get_datastore()
 | 
			
		||||
        d = get_retry_limiter("test_dest", self.clock, store)
 | 
			
		||||
        self.pump()
 | 
			
		||||
        limiter = self.successResultOf(d)
 | 
			
		||||
        limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
 | 
			
		||||
 | 
			
		||||
        # advance the clock a bit before making the request
 | 
			
		||||
        self.pump(1)
 | 
			
		||||
| 
						 | 
				
			
			@ -36,18 +34,14 @@ class RetryLimiterTestCase(HomeserverTestCase):
 | 
			
		|||
        with limiter:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
        d = store.get_destination_retry_timings("test_dest")
 | 
			
		||||
        self.pump()
 | 
			
		||||
        new_timings = self.successResultOf(d)
 | 
			
		||||
        new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
 | 
			
		||||
        self.assertIsNone(new_timings)
 | 
			
		||||
 | 
			
		||||
    def test_limiter(self):
 | 
			
		||||
        """General test case which walks through the process of a failing request"""
 | 
			
		||||
        store = self.hs.get_datastore()
 | 
			
		||||
 | 
			
		||||
        d = get_retry_limiter("test_dest", self.clock, store)
 | 
			
		||||
        self.pump()
 | 
			
		||||
        limiter = self.successResultOf(d)
 | 
			
		||||
        limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
 | 
			
		||||
 | 
			
		||||
        self.pump(1)
 | 
			
		||||
        try:
 | 
			
		||||
| 
						 | 
				
			
			@ -58,29 +52,22 @@ class RetryLimiterTestCase(HomeserverTestCase):
 | 
			
		|||
        except AssertionError:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
        # wait for the update to land
 | 
			
		||||
        self.pump()
 | 
			
		||||
 | 
			
		||||
        d = store.get_destination_retry_timings("test_dest")
 | 
			
		||||
        self.pump()
 | 
			
		||||
        new_timings = self.successResultOf(d)
 | 
			
		||||
        new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
 | 
			
		||||
        self.assertEqual(new_timings["failure_ts"], failure_ts)
 | 
			
		||||
        self.assertEqual(new_timings["retry_last_ts"], failure_ts)
 | 
			
		||||
        self.assertEqual(new_timings["retry_interval"], MIN_RETRY_INTERVAL)
 | 
			
		||||
 | 
			
		||||
        # now if we try again we should get a failure
 | 
			
		||||
        d = get_retry_limiter("test_dest", self.clock, store)
 | 
			
		||||
        self.pump()
 | 
			
		||||
        self.failureResultOf(d, NotRetryingDestination)
 | 
			
		||||
        self.get_failure(
 | 
			
		||||
            get_retry_limiter("test_dest", self.clock, store), NotRetryingDestination
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        #
 | 
			
		||||
        # advance the clock and try again
 | 
			
		||||
        #
 | 
			
		||||
 | 
			
		||||
        self.pump(MIN_RETRY_INTERVAL)
 | 
			
		||||
        d = get_retry_limiter("test_dest", self.clock, store)
 | 
			
		||||
        self.pump()
 | 
			
		||||
        limiter = self.successResultOf(d)
 | 
			
		||||
        limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
 | 
			
		||||
 | 
			
		||||
        self.pump(1)
 | 
			
		||||
        try:
 | 
			
		||||
| 
						 | 
				
			
			@ -91,12 +78,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
 | 
			
		|||
        except AssertionError:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
        # wait for the update to land
 | 
			
		||||
        self.pump()
 | 
			
		||||
 | 
			
		||||
        d = store.get_destination_retry_timings("test_dest")
 | 
			
		||||
        self.pump()
 | 
			
		||||
        new_timings = self.successResultOf(d)
 | 
			
		||||
        new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
 | 
			
		||||
        self.assertEqual(new_timings["failure_ts"], failure_ts)
 | 
			
		||||
        self.assertEqual(new_timings["retry_last_ts"], retry_ts)
 | 
			
		||||
        self.assertGreaterEqual(
 | 
			
		||||
| 
						 | 
				
			
			@ -110,9 +92,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
 | 
			
		|||
        # one more go, with success
 | 
			
		||||
        #
 | 
			
		||||
        self.pump(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0)
 | 
			
		||||
        d = get_retry_limiter("test_dest", self.clock, store)
 | 
			
		||||
        self.pump()
 | 
			
		||||
        limiter = self.successResultOf(d)
 | 
			
		||||
        limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
 | 
			
		||||
 | 
			
		||||
        self.pump(1)
 | 
			
		||||
        with limiter:
 | 
			
		||||
| 
						 | 
				
			
			@ -121,7 +101,5 @@ class RetryLimiterTestCase(HomeserverTestCase):
 | 
			
		|||
        # wait for the update to land
 | 
			
		||||
        self.pump()
 | 
			
		||||
 | 
			
		||||
        d = store.get_destination_retry_timings("test_dest")
 | 
			
		||||
        self.pump()
 | 
			
		||||
        new_timings = self.successResultOf(d)
 | 
			
		||||
        new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
 | 
			
		||||
        self.assertIsNone(new_timings)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue