Add type hints to groups code. (#9393)

This commit is contained in:
Patrick Cloke 2021-02-17 08:41:47 -05:00 committed by GitHub
parent e1071fd625
commit d2f0ec12d5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 341 additions and 124 deletions

View file

@ -37,13 +37,16 @@ An attestation is a signed blob of json that looks like:
import logging
import random
from typing import Tuple
from typing import TYPE_CHECKING, Optional, Tuple
from signedjson.sign import sign_json
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import get_domain_from_id
from synapse.types import JsonDict, get_domain_from_id
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@ -63,15 +66,19 @@ UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000
class GroupAttestationSigning:
"""Creates and verifies group attestations."""
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.keyring = hs.get_keyring()
self.clock = hs.get_clock()
self.server_name = hs.hostname
self.signing_key = hs.signing_key
async def verify_attestation(
self, attestation, group_id, user_id, server_name=None
):
self,
attestation: JsonDict,
group_id: str,
user_id: str,
server_name: Optional[str] = None,
) -> None:
"""Verifies that the given attestation matches the given parameters.
An optional server_name can be supplied to explicitly set which server's
@ -100,16 +107,18 @@ class GroupAttestationSigning:
if valid_until_ms < now:
raise SynapseError(400, "Attestation expired")
assert server_name is not None
await self.keyring.verify_json_for_server(
server_name, attestation, now, "Group attestation"
)
def create_attestation(self, group_id, user_id):
def create_attestation(self, group_id: str, user_id: str) -> JsonDict:
"""Create an attestation for the group_id and user_id with default
validity length.
"""
validity_period = DEFAULT_ATTESTATION_LENGTH_MS
validity_period *= random.uniform(*DEFAULT_ATTESTATION_JITTER)
validity_period = DEFAULT_ATTESTATION_LENGTH_MS * random.uniform(
*DEFAULT_ATTESTATION_JITTER
)
valid_until_ms = int(self.clock.time_msec() + validity_period)
return sign_json(
@ -126,7 +135,7 @@ class GroupAttestationSigning:
class GroupAttestionRenewer:
"""Responsible for sending and receiving attestation updates."""
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.assestations = hs.get_groups_attestation_signing()
@ -139,7 +148,9 @@ class GroupAttestionRenewer:
self._start_renew_attestations, 30 * 60 * 1000
)
async def on_renew_attestation(self, group_id, user_id, content):
async def on_renew_attestation(
self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
"""When a remote updates an attestation"""
attestation = content["attestation"]
@ -154,10 +165,10 @@ class GroupAttestionRenewer:
return {}
def _start_renew_attestations(self):
def _start_renew_attestations(self) -> None:
return run_as_background_process("renew_attestations", self._renew_attestations)
async def _renew_attestations(self):
async def _renew_attestations(self) -> None:
"""Called periodically to check if we need to update any of our attestations"""
now = self.clock.time_msec()
@ -166,7 +177,7 @@ class GroupAttestionRenewer:
now + UPDATE_ATTESTATION_TIME_MS
)
async def _renew_attestation(group_user: Tuple[str, str]):
async def _renew_attestation(group_user: Tuple[str, str]) -> None:
group_id, user_id = group_user
try:
if not self.is_mine_id(group_id):