Additional type hints for client REST servlets (part 5) (#10736)

Additionally this enforce type hints on all function signatures inside
of the synapse.rest.client package.
This commit is contained in:
Patrick Cloke 2021-09-03 09:22:22 -04:00 committed by GitHub
parent f58d202e3f
commit ecbfa4fe4f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 146 additions and 68 deletions

View file

@ -15,28 +15,37 @@
"""This module contains logic for storing HTTP PUT transactions. This is used
to ensure idempotency when performing PUTs using the REST API."""
import logging
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Tuple
from twisted.python.failure import Failure
from twisted.web.server import Request
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import JsonDict
from synapse.util.async_helpers import ObservableDeferred
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
class HttpTransactionCache:
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = self.hs.get_auth()
self.clock = self.hs.get_clock()
self.transactions = {
# $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
}
# $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
self.transactions: Dict[
str, Tuple[ObservableDeferred[Tuple[int, JsonDict]], int]
] = {}
# Try to clean entries every 30 mins. This means entries will exist
# for at *LEAST* 30 mins, and at *MOST* 60 mins.
self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
def _get_transaction_key(self, request):
def _get_transaction_key(self, request: Request) -> str:
"""A helper function which returns a transaction key that can be used
with TransactionCache for idempotent requests.
@ -45,15 +54,21 @@ class HttpTransactionCache:
path and the access_token for the requesting user.
Args:
request (twisted.web.http.Request): The incoming request. Must
contain an access_token.
request: The incoming request. Must contain an access_token.
Returns:
str: A transaction key
A transaction key
"""
assert request.path is not None
token = self.auth.get_access_token_from_request(request)
return request.path.decode("utf8") + "/" + token
def fetch_or_execute_request(self, request, fn, *args, **kwargs):
def fetch_or_execute_request(
self,
request: Request,
fn: Callable[..., Awaitable[Tuple[int, JsonDict]]],
*args: Any,
**kwargs: Any,
) -> Awaitable[Tuple[int, JsonDict]]:
"""A helper function for fetch_or_execute which extracts
a transaction key from the given request.
@ -64,15 +79,20 @@ class HttpTransactionCache:
self._get_transaction_key(request), fn, *args, **kwargs
)
def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
def fetch_or_execute(
self,
txn_key: str,
fn: Callable[..., Awaitable[Tuple[int, JsonDict]]],
*args: Any,
**kwargs: Any,
) -> Awaitable[Tuple[int, JsonDict]]:
"""Fetches the response for this transaction, or executes the given function
to produce a response for this transaction.
Args:
txn_key (str): A key to ensure idempotency should fetch_or_execute be
called again at a later point in time.
fn (function): A function which returns a tuple of
(response_code, response_dict).
txn_key: A key to ensure idempotency should fetch_or_execute be
called again at a later point in time.
fn: A function which returns a tuple of (response_code, response_dict).
*args: Arguments to pass to fn.
**kwargs: Keyword arguments to pass to fn.
Returns:
@ -90,7 +110,7 @@ class HttpTransactionCache:
# if the request fails with an exception, remove it
# from the transaction map. This is done to ensure that we don't
# cache transient errors like rate-limiting errors, etc.
def remove_from_map(err):
def remove_from_map(err: Failure) -> None:
self.transactions.pop(txn_key, None)
# we deliberately do not propagate the error any further, as we
# expect the observers to have reported it.
@ -99,7 +119,7 @@ class HttpTransactionCache:
return make_deferred_yieldable(observable.observe())
def _cleanup(self):
def _cleanup(self) -> None:
now = self.clock.time_msec()
for key in list(self.transactions):
ts = self.transactions[key][1]