Use inline type hints in various other places (in synapse/) (#10380)

This commit is contained in:
Jonathan de Jong 2021-07-15 12:02:43 +02:00 committed by GitHub
parent c7603af1d0
commit bf72d10dbf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
79 changed files with 329 additions and 336 deletions

1
changelog.d/10380.misc Normal file
View File

@ -0,0 +1 @@
Convert internal type variable syntax to reflect wider ecosystem use.

View File

@ -63,9 +63,9 @@ class Auth:
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.token_cache = LruCache( self.token_cache: LruCache[str, Tuple[str, bool]] = LruCache(
10000, "token_cache" 10000, "token_cache"
) # type: LruCache[str, Tuple[str, bool]] )
self._auth_blocking = AuthBlocking(self.hs) self._auth_blocking = AuthBlocking(self.hs)

View File

@ -118,7 +118,7 @@ class RedirectException(CodeMessageException):
super().__init__(code=http_code, msg=msg) super().__init__(code=http_code, msg=msg)
self.location = location self.location = location
self.cookies = [] # type: List[bytes] self.cookies: List[bytes] = []
class SynapseError(CodeMessageException): class SynapseError(CodeMessageException):
@ -160,7 +160,7 @@ class ProxiedRequestError(SynapseError):
): ):
super().__init__(code, msg, errcode) super().__init__(code, msg, errcode)
if additional_fields is None: if additional_fields is None:
self._additional_fields = {} # type: Dict self._additional_fields: Dict = {}
else: else:
self._additional_fields = dict(additional_fields) self._additional_fields = dict(additional_fields)

View File

@ -289,7 +289,7 @@ class Filter:
room_id = None room_id = None
ev_type = "m.presence" ev_type = "m.presence"
contains_url = False contains_url = False
labels = [] # type: List[str] labels: List[str] = []
else: else:
sender = event.get("sender", None) sender = event.get("sender", None)
if not sender: if not sender:

View File

@ -46,9 +46,7 @@ class Ratelimiter:
# * How many times an action has occurred since a point in time # * How many times an action has occurred since a point in time
# * The point in time # * The point in time
# * The rate_hz of this particular entry. This can vary per request # * The rate_hz of this particular entry. This can vary per request
self.actions = ( self.actions: OrderedDict[Hashable, Tuple[float, int, float]] = OrderedDict()
OrderedDict()
) # type: OrderedDict[Hashable, Tuple[float, int, float]]
async def can_do_action( async def can_do_action(
self, self,

View File

@ -195,7 +195,7 @@ class RoomVersions:
) )
KNOWN_ROOM_VERSIONS = { KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = {
v.identifier: v v.identifier: v
for v in ( for v in (
RoomVersions.V1, RoomVersions.V1,
@ -209,4 +209,4 @@ KNOWN_ROOM_VERSIONS = {
RoomVersions.V7, RoomVersions.V7,
) )
# Note that we do not include MSC2043 here unless it is enabled in the config. # Note that we do not include MSC2043 here unless it is enabled in the config.
} # type: Dict[str, RoomVersion] }

View File

@ -270,7 +270,7 @@ class GenericWorkerServer(HomeServer):
site_tag = port site_tag = port
# We always include a health resource. # We always include a health resource.
resources = {"/health": HealthResource()} # type: Dict[str, IResource] resources: Dict[str, IResource] = {"/health": HealthResource()}
for res in listener_config.http_options.resources: for res in listener_config.http_options.resources:
for name in res.names: for name in res.names:

View File

@ -88,9 +88,9 @@ class ApplicationServiceApi(SimpleHttpClient):
super().__init__(hs) super().__init__(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.protocol_meta_cache = ResponseCache( self.protocol_meta_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS
) # type: ResponseCache[Tuple[str, str]] )
async def query_user(self, service, user_id): async def query_user(self, service, user_id):
if service.url is None: if service.url is None:

View File

@ -57,8 +57,8 @@ def load_appservices(hostname, config_files):
return [] return []
# Dicts of value -> filename # Dicts of value -> filename
seen_as_tokens = {} # type: Dict[str, str] seen_as_tokens: Dict[str, str] = {}
seen_ids = {} # type: Dict[str, str] seen_ids: Dict[str, str] = {}
appservices = [] appservices = []

View File

@ -25,7 +25,7 @@ from ._base import Config, ConfigError
_CACHE_PREFIX = "SYNAPSE_CACHE_FACTOR" _CACHE_PREFIX = "SYNAPSE_CACHE_FACTOR"
# Map from canonicalised cache name to cache. # Map from canonicalised cache name to cache.
_CACHES = {} # type: Dict[str, Callable[[float], None]] _CACHES: Dict[str, Callable[[float], None]] = {}
# a lock on the contents of _CACHES # a lock on the contents of _CACHES
_CACHES_LOCK = threading.Lock() _CACHES_LOCK = threading.Lock()
@ -157,7 +157,7 @@ class CacheConfig(Config):
self.event_cache_size = self.parse_size( self.event_cache_size = self.parse_size(
config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE) config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE)
) )
self.cache_factors = {} # type: Dict[str, float] self.cache_factors: Dict[str, float] = {}
cache_config = config.get("caches") or {} cache_config = config.get("caches") or {}
self.global_factor = cache_config.get( self.global_factor = cache_config.get(

View File

@ -134,9 +134,9 @@ class EmailConfig(Config):
# trusted_third_party_id_servers does not contain a scheme whereas # trusted_third_party_id_servers does not contain a scheme whereas
# account_threepid_delegate_email is expected to. Presume https # account_threepid_delegate_email is expected to. Presume https
self.account_threepid_delegate_email = ( self.account_threepid_delegate_email: Optional[str] = (
"https://" + first_trusted_identity_server "https://" + first_trusted_identity_server
) # type: Optional[str] )
self.using_identity_server_from_trusted_list = True self.using_identity_server_from_trusted_list = True
else: else:
raise ConfigError( raise ConfigError(

View File

@ -25,10 +25,10 @@ class ExperimentalConfig(Config):
experimental = config.get("experimental_features") or {} experimental = config.get("experimental_features") or {}
# MSC2858 (multiple SSO identity providers) # MSC2858 (multiple SSO identity providers)
self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool self.msc2858_enabled: bool = experimental.get("msc2858_enabled", False)
# MSC3026 (busy presence state) # MSC3026 (busy presence state)
self.msc3026_enabled = experimental.get("msc3026_enabled", False) # type: bool self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False)
# MSC2716 (backfill existing history) # MSC2716 (backfill existing history)
self.msc2716_enabled = experimental.get("msc2716_enabled", False) # type: bool self.msc2716_enabled: bool = experimental.get("msc2716_enabled", False)

View File

@ -22,7 +22,7 @@ class FederationConfig(Config):
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
# FIXME: federation_domain_whitelist needs sytests # FIXME: federation_domain_whitelist needs sytests
self.federation_domain_whitelist = None # type: Optional[dict] self.federation_domain_whitelist: Optional[dict] = None
federation_domain_whitelist = config.get("federation_domain_whitelist", None) federation_domain_whitelist = config.get("federation_domain_whitelist", None)
if federation_domain_whitelist is not None: if federation_domain_whitelist is not None:

View File

@ -460,7 +460,7 @@ def _parse_oidc_config_dict(
) from e ) from e
client_secret_jwt_key_config = oidc_config.get("client_secret_jwt_key") client_secret_jwt_key_config = oidc_config.get("client_secret_jwt_key")
client_secret_jwt_key = None # type: Optional[OidcProviderClientSecretJwtKey] client_secret_jwt_key: Optional[OidcProviderClientSecretJwtKey] = None
if client_secret_jwt_key_config is not None: if client_secret_jwt_key_config is not None:
keyfile = client_secret_jwt_key_config.get("key_file") keyfile = client_secret_jwt_key_config.get("key_file")
if keyfile: if keyfile:

View File

@ -25,7 +25,7 @@ class PasswordAuthProviderConfig(Config):
section = "authproviders" section = "authproviders"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
self.password_providers = [] # type: List[Any] self.password_providers: List[Any] = []
providers = [] providers = []
# We want to be backwards compatible with the old `ldap_config` # We want to be backwards compatible with the old `ldap_config`

View File

@ -62,7 +62,7 @@ def parse_thumbnail_requirements(thumbnail_sizes):
Dictionary mapping from media type string to list of Dictionary mapping from media type string to list of
ThumbnailRequirement tuples. ThumbnailRequirement tuples.
""" """
requirements = {} # type: Dict[str, List] requirements: Dict[str, List] = {}
for size in thumbnail_sizes: for size in thumbnail_sizes:
width = size["width"] width = size["width"]
height = size["height"] height = size["height"]
@ -141,7 +141,7 @@ class ContentRepositoryConfig(Config):
# #
# We don't create the storage providers here as not all workers need # We don't create the storage providers here as not all workers need
# them to be started. # them to be started.
self.media_storage_providers = [] # type: List[tuple] self.media_storage_providers: List[tuple] = []
for i, provider_config in enumerate(storage_providers): for i, provider_config in enumerate(storage_providers):
# We special case the module "file_system" so as not to need to # We special case the module "file_system" so as not to need to

View File

@ -505,7 +505,7 @@ class ServerConfig(Config):
" greater than 'allowed_lifetime_max'" " greater than 'allowed_lifetime_max'"
) )
self.retention_purge_jobs = [] # type: List[Dict[str, Optional[int]]] self.retention_purge_jobs: List[Dict[str, Optional[int]]] = []
for purge_job_config in retention_config.get("purge_jobs", []): for purge_job_config in retention_config.get("purge_jobs", []):
interval_config = purge_job_config.get("interval") interval_config = purge_job_config.get("interval")
@ -688,23 +688,21 @@ class ServerConfig(Config):
# not included in the sample configuration file on purpose as it's a temporary # not included in the sample configuration file on purpose as it's a temporary
# hack, so that some users can trial the new defaults without impacting every # hack, so that some users can trial the new defaults without impacting every
# user on the homeserver. # user on the homeserver.
users_new_default_push_rules = ( users_new_default_push_rules: list = (
config.get("users_new_default_push_rules") or [] config.get("users_new_default_push_rules") or []
) # type: list )
if not isinstance(users_new_default_push_rules, list): if not isinstance(users_new_default_push_rules, list):
raise ConfigError("'users_new_default_push_rules' must be a list") raise ConfigError("'users_new_default_push_rules' must be a list")
# Turn the list into a set to improve lookup speed. # Turn the list into a set to improve lookup speed.
self.users_new_default_push_rules = set( self.users_new_default_push_rules: set = set(users_new_default_push_rules)
users_new_default_push_rules
) # type: set
# Whitelist of domain names that given next_link parameters must have # Whitelist of domain names that given next_link parameters must have
next_link_domain_whitelist = config.get( next_link_domain_whitelist: Optional[List[str]] = config.get(
"next_link_domain_whitelist" "next_link_domain_whitelist"
) # type: Optional[List[str]] )
self.next_link_domain_whitelist = None # type: Optional[Set[str]] self.next_link_domain_whitelist: Optional[Set[str]] = None
if next_link_domain_whitelist is not None: if next_link_domain_whitelist is not None:
if not isinstance(next_link_domain_whitelist, list): if not isinstance(next_link_domain_whitelist, list):
raise ConfigError("'next_link_domain_whitelist' must be a list") raise ConfigError("'next_link_domain_whitelist' must be a list")

View File

@ -34,7 +34,7 @@ class SpamCheckerConfig(Config):
section = "spamchecker" section = "spamchecker"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
self.spam_checkers = [] # type: List[Tuple[Any, Dict]] self.spam_checkers: List[Tuple[Any, Dict]] = []
spam_checkers = config.get("spam_checker") or [] spam_checkers = config.get("spam_checker") or []
if isinstance(spam_checkers, dict): if isinstance(spam_checkers, dict):

View File

@ -39,7 +39,7 @@ class SSOConfig(Config):
section = "sso" section = "sso"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
sso_config = config.get("sso") or {} # type: Dict[str, Any] sso_config: Dict[str, Any] = config.get("sso") or {}
# The sso-specific template_dir # The sso-specific template_dir
self.sso_template_dir = sso_config.get("template_dir") self.sso_template_dir = sso_config.get("template_dir")

View File

@ -80,7 +80,7 @@ class TlsConfig(Config):
fed_whitelist_entries = [] fed_whitelist_entries = []
# Support globs (*) in whitelist values # Support globs (*) in whitelist values
self.federation_certificate_verification_whitelist = [] # type: List[Pattern] self.federation_certificate_verification_whitelist: List[Pattern] = []
for entry in fed_whitelist_entries: for entry in fed_whitelist_entries:
try: try:
entry_regex = glob_to_regex(entry.encode("ascii").decode("ascii")) entry_regex = glob_to_regex(entry.encode("ascii").decode("ascii"))
@ -132,8 +132,8 @@ class TlsConfig(Config):
"use_insecure_ssl_client_just_for_testing_do_not_use" "use_insecure_ssl_client_just_for_testing_do_not_use"
) )
self.tls_certificate = None # type: Optional[crypto.X509] self.tls_certificate: Optional[crypto.X509] = None
self.tls_private_key = None # type: Optional[crypto.PKey] self.tls_private_key: Optional[crypto.PKey] = None
def is_disk_cert_valid(self, allow_self_signed=True): def is_disk_cert_valid(self, allow_self_signed=True):
""" """

View File

@ -170,11 +170,13 @@ class Keyring:
) )
self._key_fetchers = key_fetchers self._key_fetchers = key_fetchers
self._server_queue = BatchingQueue( self._server_queue: BatchingQueue[
_FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]]
] = BatchingQueue(
"keyring_server", "keyring_server",
clock=hs.get_clock(), clock=hs.get_clock(),
process_batch_callback=self._inner_fetch_key_requests, process_batch_callback=self._inner_fetch_key_requests,
) # type: BatchingQueue[_FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]]] )
async def verify_json_for_server( async def verify_json_for_server(
self, self,
@ -330,7 +332,7 @@ class Keyring:
# First we need to deduplicate requests for the same key. We do this by # First we need to deduplicate requests for the same key. We do this by
# taking the *maximum* requested `minimum_valid_until_ts` for each pair # taking the *maximum* requested `minimum_valid_until_ts` for each pair
# of server name/key ID. # of server name/key ID.
server_to_key_to_ts = {} # type: Dict[str, Dict[str, int]] server_to_key_to_ts: Dict[str, Dict[str, int]] = {}
for request in requests: for request in requests:
by_server = server_to_key_to_ts.setdefault(request.server_name, {}) by_server = server_to_key_to_ts.setdefault(request.server_name, {})
for key_id in request.key_ids: for key_id in request.key_ids:
@ -355,7 +357,7 @@ class Keyring:
# We now convert the returned list of results into a map from server # We now convert the returned list of results into a map from server
# name to key ID to FetchKeyResult, to return. # name to key ID to FetchKeyResult, to return.
to_return = {} # type: Dict[str, Dict[str, FetchKeyResult]] to_return: Dict[str, Dict[str, FetchKeyResult]] = {}
for (request, results) in zip(deduped_requests, results_per_request): for (request, results) in zip(deduped_requests, results_per_request):
to_return_by_server = to_return.setdefault(request.server_name, {}) to_return_by_server = to_return.setdefault(request.server_name, {})
for key_id, key_result in results.items(): for key_id, key_result in results.items():
@ -455,7 +457,7 @@ class StoreKeyFetcher(KeyFetcher):
) )
res = await self.store.get_server_verify_keys(key_ids_to_fetch) res = await self.store.get_server_verify_keys(key_ids_to_fetch)
keys = {} # type: Dict[str, Dict[str, FetchKeyResult]] keys: Dict[str, Dict[str, FetchKeyResult]] = {}
for (server_name, key_id), key in res.items(): for (server_name, key_id), key in res.items():
keys.setdefault(server_name, {})[key_id] = key keys.setdefault(server_name, {})[key_id] = key
return keys return keys
@ -603,7 +605,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
) )
union_of_keys = {} # type: Dict[str, Dict[str, FetchKeyResult]] union_of_keys: Dict[str, Dict[str, FetchKeyResult]] = {}
for result in results: for result in results:
for server_name, keys in result.items(): for server_name, keys in result.items():
union_of_keys.setdefault(server_name, {}).update(keys) union_of_keys.setdefault(server_name, {}).update(keys)
@ -656,8 +658,8 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
except HttpResponseException as e: except HttpResponseException as e:
raise KeyLookupError("Remote server returned an error: %s" % (e,)) raise KeyLookupError("Remote server returned an error: %s" % (e,))
keys = {} # type: Dict[str, Dict[str, FetchKeyResult]] keys: Dict[str, Dict[str, FetchKeyResult]] = {}
added_keys = [] # type: List[Tuple[str, str, FetchKeyResult]] added_keys: List[Tuple[str, str, FetchKeyResult]] = []
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
@ -805,7 +807,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
Raises: Raises:
KeyLookupError if there was a problem making the lookup KeyLookupError if there was a problem making the lookup
""" """
keys = {} # type: Dict[str, FetchKeyResult] keys: Dict[str, FetchKeyResult] = {}
for requested_key_id in key_ids: for requested_key_id in key_ids:
# we may have found this key as a side-effect of asking for another. # we may have found this key as a side-effect of asking for another.

View File

@ -531,7 +531,7 @@ def _check_power_levels(
user_level = get_user_power_level(event.user_id, auth_events) user_level = get_user_power_level(event.user_id, auth_events)
# Check other levels: # Check other levels:
levels_to_check = [ levels_to_check: List[Tuple[str, Optional[str]]] = [
("users_default", None), ("users_default", None),
("events_default", None), ("events_default", None),
("state_default", None), ("state_default", None),
@ -539,7 +539,7 @@ def _check_power_levels(
("redact", None), ("redact", None),
("kick", None), ("kick", None),
("invite", None), ("invite", None),
] # type: List[Tuple[str, Optional[str]]] ]
old_list = current_state.content.get("users", {}) old_list = current_state.content.get("users", {})
for user in set(list(old_list) + list(user_list)): for user in set(list(old_list) + list(user_list)):
@ -569,12 +569,12 @@ def _check_power_levels(
new_loc = new_loc.get(dir, {}) new_loc = new_loc.get(dir, {})
if level_to_check in old_loc: if level_to_check in old_loc:
old_level = int(old_loc[level_to_check]) # type: Optional[int] old_level: Optional[int] = int(old_loc[level_to_check])
else: else:
old_level = None old_level = None
if level_to_check in new_loc: if level_to_check in new_loc:
new_level = int(new_loc[level_to_check]) # type: Optional[int] new_level: Optional[int] = int(new_loc[level_to_check])
else: else:
new_level = None new_level = None

View File

@ -105,28 +105,28 @@ class _EventInternalMetadata:
self._dict = dict(internal_metadata_dict) self._dict = dict(internal_metadata_dict)
# the stream ordering of this event. None, until it has been persisted. # the stream ordering of this event. None, until it has been persisted.
self.stream_ordering = None # type: Optional[int] self.stream_ordering: Optional[int] = None
# whether this event is an outlier (ie, whether we have the state at that point # whether this event is an outlier (ie, whether we have the state at that point
# in the DAG) # in the DAG)
self.outlier = False self.outlier = False
out_of_band_membership = DictProperty("out_of_band_membership") # type: bool out_of_band_membership: bool = DictProperty("out_of_band_membership")
send_on_behalf_of = DictProperty("send_on_behalf_of") # type: str send_on_behalf_of: str = DictProperty("send_on_behalf_of")
recheck_redaction = DictProperty("recheck_redaction") # type: bool recheck_redaction: bool = DictProperty("recheck_redaction")
soft_failed = DictProperty("soft_failed") # type: bool soft_failed: bool = DictProperty("soft_failed")
proactively_send = DictProperty("proactively_send") # type: bool proactively_send: bool = DictProperty("proactively_send")
redacted = DictProperty("redacted") # type: bool redacted: bool = DictProperty("redacted")
txn_id = DictProperty("txn_id") # type: str txn_id: str = DictProperty("txn_id")
token_id = DictProperty("token_id") # type: int token_id: int = DictProperty("token_id")
historical = DictProperty("historical") # type: bool historical: bool = DictProperty("historical")
# XXX: These are set by StreamWorkerStore._set_before_and_after. # XXX: These are set by StreamWorkerStore._set_before_and_after.
# I'm pretty sure that these are never persisted to the database, so shouldn't # I'm pretty sure that these are never persisted to the database, so shouldn't
# be here # be here
before = DictProperty("before") # type: RoomStreamToken before: RoomStreamToken = DictProperty("before")
after = DictProperty("after") # type: RoomStreamToken after: RoomStreamToken = DictProperty("after")
order = DictProperty("order") # type: Tuple[int, int] order: Tuple[int, int] = DictProperty("order")
def get_dict(self) -> JsonDict: def get_dict(self) -> JsonDict:
return dict(self._dict) return dict(self._dict)

View File

@ -132,12 +132,12 @@ class EventBuilder:
format_version = self.room_version.event_format format_version = self.room_version.event_format
if format_version == EventFormatVersions.V1: if format_version == EventFormatVersions.V1:
# The types of auth/prev events changes between event versions. # The types of auth/prev events changes between event versions.
auth_events = await self._store.add_event_hashes( auth_events: Union[
auth_event_ids List[str], List[Tuple[str, Dict[str, str]]]
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] ] = await self._store.add_event_hashes(auth_event_ids)
prev_events = await self._store.add_event_hashes( prev_events: Union[
prev_event_ids List[str], List[Tuple[str, Dict[str, str]]]
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] ] = await self._store.add_event_hashes(prev_event_ids)
else: else:
auth_events = auth_event_ids auth_events = auth_event_ids
prev_events = prev_event_ids prev_events = prev_event_ids
@ -156,7 +156,7 @@ class EventBuilder:
# the db) # the db)
depth = min(depth, MAX_DEPTH) depth = min(depth, MAX_DEPTH)
event_dict = { event_dict: Dict[str, Any] = {
"auth_events": auth_events, "auth_events": auth_events,
"prev_events": prev_events, "prev_events": prev_events,
"type": self.type, "type": self.type,
@ -166,7 +166,7 @@ class EventBuilder:
"unsigned": self.unsigned, "unsigned": self.unsigned,
"depth": depth, "depth": depth,
"prev_state": [], "prev_state": [],
} # type: Dict[str, Any] }
if self.is_state(): if self.is_state():
event_dict["state_key"] = self._state_key event_dict["state_key"] = self._state_key

View File

@ -76,7 +76,7 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
"""Wrapper that loads spam checkers configured using the old configuration, and """Wrapper that loads spam checkers configured using the old configuration, and
registers the spam checker hooks they implement. registers the spam checker hooks they implement.
""" """
spam_checkers = [] # type: List[Any] spam_checkers: List[Any] = []
api = hs.get_module_api() api = hs.get_module_api()
for module, config in hs.config.spam_checkers: for module, config in hs.config.spam_checkers:
# Older spam checkers don't accept the `api` argument, so we # Older spam checkers don't accept the `api` argument, so we
@ -239,7 +239,7 @@ class SpamChecker:
will be used as the error message returned to the user. will be used as the error message returned to the user.
""" """
for callback in self._check_event_for_spam_callbacks: for callback in self._check_event_for_spam_callbacks:
res = await callback(event) # type: Union[bool, str] res: Union[bool, str] = await callback(event)
if res: if res:
return res return res

View File

@ -86,7 +86,7 @@ class FederationClient(FederationBase):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.pdu_destination_tried = {} # type: Dict[str, Dict[str, int]] self.pdu_destination_tried: Dict[str, Dict[str, int]] = {}
self._clock.looping_call(self._clear_tried_cache, 60 * 1000) self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client() self.transport_layer = hs.get_federation_transport_client()
@ -94,13 +94,13 @@ class FederationClient(FederationBase):
self.hostname = hs.hostname self.hostname = hs.hostname
self.signing_key = hs.signing_key self.signing_key = hs.signing_key
self._get_pdu_cache = ExpiringCache( self._get_pdu_cache: ExpiringCache[str, EventBase] = ExpiringCache(
cache_name="get_pdu_cache", cache_name="get_pdu_cache",
clock=self._clock, clock=self._clock,
max_len=1000, max_len=1000,
expiry_ms=120 * 1000, expiry_ms=120 * 1000,
reset_expiry_on_get=False, reset_expiry_on_get=False,
) # type: ExpiringCache[str, EventBase] )
def _clear_tried_cache(self): def _clear_tried_cache(self):
"""Clear pdu_destination_tried cache""" """Clear pdu_destination_tried cache"""
@ -293,10 +293,10 @@ class FederationClient(FederationBase):
transaction_data, transaction_data,
) )
pdu_list = [ pdu_list: List[EventBase] = [
event_from_pdu_json(p, room_version, outlier=outlier) event_from_pdu_json(p, room_version, outlier=outlier)
for p in transaction_data["pdus"] for p in transaction_data["pdus"]
] # type: List[EventBase] ]
if pdu_list and pdu_list[0]: if pdu_list and pdu_list[0]:
pdu = pdu_list[0] pdu = pdu_list[0]

View File

@ -122,12 +122,12 @@ class FederationServer(FederationBase):
# origins that we are currently processing a transaction from. # origins that we are currently processing a transaction from.
# a dict from origin to txn id. # a dict from origin to txn id.
self._active_transactions = {} # type: Dict[str, str] self._active_transactions: Dict[str, str] = {}
# We cache results for transaction with the same ID # We cache results for transaction with the same ID
self._transaction_resp_cache = ResponseCache( self._transaction_resp_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
hs.get_clock(), "fed_txn_handler", timeout_ms=30000 hs.get_clock(), "fed_txn_handler", timeout_ms=30000
) # type: ResponseCache[Tuple[str, str]] )
self.transaction_actions = TransactionActions(self.store) self.transaction_actions = TransactionActions(self.store)
@ -135,12 +135,12 @@ class FederationServer(FederationBase):
# We cache responses to state queries, as they take a while and often # We cache responses to state queries, as they take a while and often
# come in waves. # come in waves.
self._state_resp_cache = ResponseCache( self._state_resp_cache: ResponseCache[
hs.get_clock(), "state_resp", timeout_ms=30000 Tuple[str, Optional[str]]
) # type: ResponseCache[Tuple[str, Optional[str]]] ] = ResponseCache(hs.get_clock(), "state_resp", timeout_ms=30000)
self._state_ids_resp_cache = ResponseCache( self._state_ids_resp_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
hs.get_clock(), "state_ids_resp", timeout_ms=30000 hs.get_clock(), "state_ids_resp", timeout_ms=30000
) # type: ResponseCache[Tuple[str, str]] )
self._federation_metrics_domains = ( self._federation_metrics_domains = (
hs.config.federation.federation_metrics_domains hs.config.federation.federation_metrics_domains
@ -337,7 +337,7 @@ class FederationServer(FederationBase):
origin_host, _ = parse_server_name(origin) origin_host, _ = parse_server_name(origin)
pdus_by_room = {} # type: Dict[str, List[EventBase]] pdus_by_room: Dict[str, List[EventBase]] = {}
newest_pdu_ts = 0 newest_pdu_ts = 0
@ -516,9 +516,9 @@ class FederationServer(FederationBase):
self, room_id: str, event_id: Optional[str] self, room_id: str, event_id: Optional[str]
) -> Dict[str, list]: ) -> Dict[str, list]:
if event_id: if event_id:
pdus = await self.handler.get_state_for_pdu( pdus: Iterable[EventBase] = await self.handler.get_state_for_pdu(
room_id, event_id room_id, event_id
) # type: Iterable[EventBase] )
else: else:
pdus = (await self.state.get_current_state(room_id)).values() pdus = (await self.state.get_current_state(room_id)).values()
@ -791,7 +791,7 @@ class FederationServer(FederationBase):
log_kv({"message": "Claiming one time keys.", "user, device pairs": query}) log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
results = await self.store.claim_e2e_one_time_keys(query) results = await self.store.claim_e2e_one_time_keys(query)
json_result = {} # type: Dict[str, Dict[str, dict]] json_result: Dict[str, Dict[str, dict]] = {}
for user_id, device_keys in results.items(): for user_id, device_keys in results.items():
for device_id, keys in device_keys.items(): for device_id, keys in device_keys.items():
for key_id, json_str in keys.items(): for key_id, json_str in keys.items():
@ -1119,17 +1119,13 @@ class FederationHandlerRegistry:
self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs) self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs)
self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs) self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs)
self.edu_handlers = ( self.edu_handlers: Dict[str, Callable[[str, dict], Awaitable[None]]] = {}
{} self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
) # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
self.query_handlers = (
{}
) # type: Dict[str, Callable[[dict], Awaitable[JsonDict]]]
# Map from type to instance names that we should route EDU handling to. # Map from type to instance names that we should route EDU handling to.
# We randomly choose one instance from the list to route to for each new # We randomly choose one instance from the list to route to for each new
# EDU received. # EDU received.
self._edu_type_to_instance = {} # type: Dict[str, List[str]] self._edu_type_to_instance: Dict[str, List[str]] = {}
def register_edu_handler( def register_edu_handler(
self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]] self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]

View File

@ -71,34 +71,32 @@ class FederationRemoteSendQueue(AbstractFederationSender):
# We may have multiple federation sender instances, so we need to track # We may have multiple federation sender instances, so we need to track
# their positions separately. # their positions separately.
self._sender_instances = hs.config.worker.federation_shard_config.instances self._sender_instances = hs.config.worker.federation_shard_config.instances
self._sender_positions = {} # type: Dict[str, int] self._sender_positions: Dict[str, int] = {}
# Pending presence map user_id -> UserPresenceState # Pending presence map user_id -> UserPresenceState
self.presence_map = {} # type: Dict[str, UserPresenceState] self.presence_map: Dict[str, UserPresenceState] = {}
# Stores the destinations we need to explicitly send presence to about a # Stores the destinations we need to explicitly send presence to about a
# given user. # given user.
# Stream position -> (user_id, destinations) # Stream position -> (user_id, destinations)
self.presence_destinations = ( self.presence_destinations: SortedDict[
SortedDict() int, Tuple[str, Iterable[str]]
) # type: SortedDict[int, Tuple[str, Iterable[str]]] ] = SortedDict()
# (destination, key) -> EDU # (destination, key) -> EDU
self.keyed_edu = {} # type: Dict[Tuple[str, tuple], Edu] self.keyed_edu: Dict[Tuple[str, tuple], Edu] = {}
# stream position -> (destination, key) # stream position -> (destination, key)
self.keyed_edu_changed = ( self.keyed_edu_changed: SortedDict[int, Tuple[str, tuple]] = SortedDict()
SortedDict()
) # type: SortedDict[int, Tuple[str, tuple]]
self.edus = SortedDict() # type: SortedDict[int, Edu] self.edus: SortedDict[int, Edu] = SortedDict()
# stream ID for the next entry into keyed_edu_changed/edus. # stream ID for the next entry into keyed_edu_changed/edus.
self.pos = 1 self.pos = 1
# map from stream ID to the time that stream entry was generated, so that we # map from stream ID to the time that stream entry was generated, so that we
# can clear out entries after a while # can clear out entries after a while
self.pos_time = SortedDict() # type: SortedDict[int, int] self.pos_time: SortedDict[int, int] = SortedDict()
# EVERYTHING IS SAD. In particular, python only makes new scopes when # EVERYTHING IS SAD. In particular, python only makes new scopes when
# we make a new function, so we need to make a new function so the inner # we make a new function, so we need to make a new function so the inner
@ -291,7 +289,7 @@ class FederationRemoteSendQueue(AbstractFederationSender):
# list of tuple(int, BaseFederationRow), where the first is the position # list of tuple(int, BaseFederationRow), where the first is the position
# of the federation stream. # of the federation stream.
rows = [] # type: List[Tuple[int, BaseFederationRow]] rows: List[Tuple[int, BaseFederationRow]] = []
# Fetch presence to send to destinations # Fetch presence to send to destinations
i = self.presence_destinations.bisect_right(from_token) i = self.presence_destinations.bisect_right(from_token)
@ -445,11 +443,11 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu
buff.edus.setdefault(self.edu.destination, []).append(self.edu) buff.edus.setdefault(self.edu.destination, []).append(self.edu)
_rowtypes = ( _rowtypes: Tuple[Type[BaseFederationRow], ...] = (
PresenceDestinationsRow, PresenceDestinationsRow,
KeyedEduRow, KeyedEduRow,
EduRow, EduRow,
) # type: Tuple[Type[BaseFederationRow], ...] )
TypeToRow = {Row.TypeId: Row for Row in _rowtypes} TypeToRow = {Row.TypeId: Row for Row in _rowtypes}

View File

@ -148,14 +148,14 @@ class FederationSender(AbstractFederationSender):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self._presence_router = None # type: Optional[PresenceRouter] self._presence_router: Optional["PresenceRouter"] = None
self._transaction_manager = TransactionManager(hs) self._transaction_manager = TransactionManager(hs)
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
self._federation_shard_config = hs.config.worker.federation_shard_config self._federation_shard_config = hs.config.worker.federation_shard_config
# map from destination to PerDestinationQueue # map from destination to PerDestinationQueue
self._per_destination_queues = {} # type: Dict[str, PerDestinationQueue] self._per_destination_queues: Dict[str, PerDestinationQueue] = {}
LaterGauge( LaterGauge(
"synapse_federation_transaction_queue_pending_destinations", "synapse_federation_transaction_queue_pending_destinations",
@ -192,9 +192,7 @@ class FederationSender(AbstractFederationSender):
# awaiting a call to flush_read_receipts_for_room. The presence of an entry # awaiting a call to flush_read_receipts_for_room. The presence of an entry
# here for a given room means that we are rate-limiting RR flushes to that room, # here for a given room means that we are rate-limiting RR flushes to that room,
# and that there is a pending call to _flush_rrs_for_room in the system. # and that there is a pending call to _flush_rrs_for_room in the system.
self._queues_awaiting_rr_flush_by_room = ( self._queues_awaiting_rr_flush_by_room: Dict[str, Set[PerDestinationQueue]] = {}
{}
) # type: Dict[str, Set[PerDestinationQueue]]
self._rr_txn_interval_per_room_ms = ( self._rr_txn_interval_per_room_ms = (
1000.0 / hs.config.federation_rr_transactions_per_room_per_second 1000.0 / hs.config.federation_rr_transactions_per_room_per_second
@ -265,7 +263,7 @@ class FederationSender(AbstractFederationSender):
if not event.internal_metadata.should_proactively_send(): if not event.internal_metadata.should_proactively_send():
return return
destinations = None # type: Optional[Set[str]] destinations: Optional[Set[str]] = None
if not event.prev_event_ids(): if not event.prev_event_ids():
# If there are no prev event IDs then the state is empty # If there are no prev event IDs then the state is empty
# and so no remote servers in the room # and so no remote servers in the room
@ -331,7 +329,7 @@ class FederationSender(AbstractFederationSender):
for event in events: for event in events:
await handle_event(event) await handle_event(event)
events_by_room = {} # type: Dict[str, List[EventBase]] events_by_room: Dict[str, List[EventBase]] = {}
for event in events: for event in events:
events_by_room.setdefault(event.room_id, []).append(event) events_by_room.setdefault(event.room_id, []).append(event)
@ -628,7 +626,7 @@ class FederationSender(AbstractFederationSender):
In order to reduce load spikes, adds a delay between each destination. In order to reduce load spikes, adds a delay between each destination.
""" """
last_processed = None # type: Optional[str] last_processed: Optional[str] = None
while True: while True:
destinations_to_wake = ( destinations_to_wake = (

View File

@ -105,34 +105,34 @@ class PerDestinationQueue:
# catch-up at startup. # catch-up at startup.
# New events will only be sent once this is finished, at which point # New events will only be sent once this is finished, at which point
# _catching_up is flipped to False. # _catching_up is flipped to False.
self._catching_up = True # type: bool self._catching_up: bool = True
# The stream_ordering of the most recent PDU that was discarded due to # The stream_ordering of the most recent PDU that was discarded due to
# being in catch-up mode. # being in catch-up mode.
self._catchup_last_skipped = 0 # type: int self._catchup_last_skipped: int = 0
# Cache of the last successfully-transmitted stream ordering for this # Cache of the last successfully-transmitted stream ordering for this
# destination (we are the only updater so this is safe) # destination (we are the only updater so this is safe)
self._last_successful_stream_ordering = None # type: Optional[int] self._last_successful_stream_ordering: Optional[int] = None
# a queue of pending PDUs # a queue of pending PDUs
self._pending_pdus = [] # type: List[EventBase] self._pending_pdus: List[EventBase] = []
# XXX this is never actually used: see # XXX this is never actually used: see
# https://github.com/matrix-org/synapse/issues/7549 # https://github.com/matrix-org/synapse/issues/7549
self._pending_edus = [] # type: List[Edu] self._pending_edus: List[Edu] = []
# Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered # Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered
# based on their key (e.g. typing events by room_id) # based on their key (e.g. typing events by room_id)
# Map of (edu_type, key) -> Edu # Map of (edu_type, key) -> Edu
self._pending_edus_keyed = {} # type: Dict[Tuple[str, Hashable], Edu] self._pending_edus_keyed: Dict[Tuple[str, Hashable], Edu] = {}
# Map of user_id -> UserPresenceState of pending presence to be sent to this # Map of user_id -> UserPresenceState of pending presence to be sent to this
# destination # destination
self._pending_presence = {} # type: Dict[str, UserPresenceState] self._pending_presence: Dict[str, UserPresenceState] = {}
# room_id -> receipt_type -> user_id -> receipt_dict # room_id -> receipt_type -> user_id -> receipt_dict
self._pending_rrs = {} # type: Dict[str, Dict[str, Dict[str, dict]]] self._pending_rrs: Dict[str, Dict[str, Dict[str, dict]]] = {}
self._rrs_pending_flush = False self._rrs_pending_flush = False
# stream_id of last successfully sent to-device message. # stream_id of last successfully sent to-device message.
@ -243,7 +243,7 @@ class PerDestinationQueue:
) )
async def _transaction_transmission_loop(self) -> None: async def _transaction_transmission_loop(self) -> None:
pending_pdus = [] # type: List[EventBase] pending_pdus: List[EventBase] = []
try: try:
self.transmission_loop_running = True self.transmission_loop_running = True

View File

@ -395,9 +395,9 @@ class TransportLayerClient:
# this uses MSC2197 (Search Filtering over Federation) # this uses MSC2197 (Search Filtering over Federation)
path = _create_v1_path("/publicRooms") path = _create_v1_path("/publicRooms")
data = { data: Dict[str, Any] = {
"include_all_networks": "true" if include_all_networks else "false" "include_all_networks": "true" if include_all_networks else "false"
} # type: Dict[str, Any] }
if third_party_instance_id: if third_party_instance_id:
data["third_party_instance_id"] = third_party_instance_id data["third_party_instance_id"] = third_party_instance_id
if limit: if limit:
@ -423,9 +423,9 @@ class TransportLayerClient:
else: else:
path = _create_v1_path("/publicRooms") path = _create_v1_path("/publicRooms")
args = { args: Dict[str, Any] = {
"include_all_networks": "true" if include_all_networks else "false" "include_all_networks": "true" if include_all_networks else "false"
} # type: Dict[str, Any] }
if third_party_instance_id: if third_party_instance_id:
args["third_party_instance_id"] = (third_party_instance_id,) args["third_party_instance_id"] = (third_party_instance_id,)
if limit: if limit:

View File

@ -1013,7 +1013,7 @@ class PublicRoomList(BaseFederationServlet):
if not self.allow_access: if not self.allow_access:
raise FederationDeniedError(origin) raise FederationDeniedError(origin)
limit = int(content.get("limit", 100)) # type: Optional[int] limit: Optional[int] = int(content.get("limit", 100))
since_token = content.get("since", None) since_token = content.get("since", None)
search_filter = content.get("filter", None) search_filter = content.get("filter", None)
@ -1991,7 +1991,7 @@ class RoomComplexityServlet(BaseFederationServlet):
return 200, complexity return 200, complexity
FEDERATION_SERVLET_CLASSES = ( FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
FederationSendServlet, FederationSendServlet,
FederationEventServlet, FederationEventServlet,
FederationStateV1Servlet, FederationStateV1Servlet,
@ -2019,15 +2019,13 @@ FEDERATION_SERVLET_CLASSES = (
FederationSpaceSummaryServlet, FederationSpaceSummaryServlet,
FederationV1SendKnockServlet, FederationV1SendKnockServlet,
FederationMakeKnockServlet, FederationMakeKnockServlet,
) # type: Tuple[Type[BaseFederationServlet], ...] )
OPENID_SERVLET_CLASSES = ( OPENID_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (OpenIdUserInfo,)
OpenIdUserInfo,
) # type: Tuple[Type[BaseFederationServlet], ...]
ROOM_LIST_CLASSES = (PublicRoomList,) # type: Tuple[Type[PublicRoomList], ...] ROOM_LIST_CLASSES: Tuple[Type[PublicRoomList], ...] = (PublicRoomList,)
GROUP_SERVER_SERVLET_CLASSES = ( GROUP_SERVER_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
FederationGroupsProfileServlet, FederationGroupsProfileServlet,
FederationGroupsSummaryServlet, FederationGroupsSummaryServlet,
FederationGroupsRoomsServlet, FederationGroupsRoomsServlet,
@ -2046,19 +2044,19 @@ GROUP_SERVER_SERVLET_CLASSES = (
FederationGroupsAddRoomsServlet, FederationGroupsAddRoomsServlet,
FederationGroupsAddRoomsConfigServlet, FederationGroupsAddRoomsConfigServlet,
FederationGroupsSettingJoinPolicyServlet, FederationGroupsSettingJoinPolicyServlet,
) # type: Tuple[Type[BaseFederationServlet], ...] )
GROUP_LOCAL_SERVLET_CLASSES = ( GROUP_LOCAL_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
FederationGroupsLocalInviteServlet, FederationGroupsLocalInviteServlet,
FederationGroupsRemoveLocalUserServlet, FederationGroupsRemoveLocalUserServlet,
FederationGroupsBulkPublicisedServlet, FederationGroupsBulkPublicisedServlet,
) # type: Tuple[Type[BaseFederationServlet], ...] )
GROUP_ATTESTATION_SERVLET_CLASSES = ( GROUP_ATTESTATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
FederationGroupsRenewAttestaionServlet, FederationGroupsRenewAttestaionServlet,
) # type: Tuple[Type[BaseFederationServlet], ...] )
DEFAULT_SERVLET_GROUPS = ( DEFAULT_SERVLET_GROUPS = (

View File

@ -707,9 +707,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
See accept_invite, join_group. See accept_invite, join_group.
""" """
if not self.hs.is_mine_id(user_id): if not self.hs.is_mine_id(user_id):
local_attestation = self.attestations.create_attestation( local_attestation: Optional[
group_id, user_id JsonDict
) # type: Optional[JsonDict] ] = self.attestations.create_attestation(group_id, user_id)
remote_attestation = content["attestation"] remote_attestation = content["attestation"]
@ -868,9 +868,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
remote_attestation, user_id=requester_user_id, group_id=group_id remote_attestation, user_id=requester_user_id, group_id=group_id
) )
local_attestation = self.attestations.create_attestation( local_attestation: Optional[
group_id, requester_user_id JsonDict
) # type: Optional[JsonDict] ] = self.attestations.create_attestation(group_id, requester_user_id)
else: else:
local_attestation = None local_attestation = None
remote_attestation = None remote_attestation = None

View File

@ -69,7 +69,7 @@ def _get_requested_host(request: IRequest) -> bytes:
return hostname return hostname
# no Host header, use the address/port that the request arrived on # no Host header, use the address/port that the request arrived on
host = request.getHost() # type: Union[address.IPv4Address, address.IPv6Address] host: Union[address.IPv4Address, address.IPv6Address] = request.getHost()
hostname = host.host.encode("ascii") hostname = host.host.encode("ascii")

View File

@ -160,7 +160,7 @@ class _IPBlacklistingResolver:
def resolveHostName( def resolveHostName(
self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0 self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
) -> IResolutionReceiver: ) -> IResolutionReceiver:
addresses = [] # type: List[IAddress] addresses: List[IAddress] = []
def _callback() -> None: def _callback() -> None:
has_bad_ip = False has_bad_ip = False
@ -333,9 +333,9 @@ class SimpleHttpClient:
if self._ip_blacklist: if self._ip_blacklist:
# If we have an IP blacklist, we need to use a DNS resolver which # If we have an IP blacklist, we need to use a DNS resolver which
# filters out blacklisted IP addresses, to prevent DNS rebinding. # filters out blacklisted IP addresses, to prevent DNS rebinding.
self.reactor = BlacklistingReactorWrapper( self.reactor: ISynapseReactor = BlacklistingReactorWrapper(
hs.get_reactor(), self._ip_whitelist, self._ip_blacklist hs.get_reactor(), self._ip_whitelist, self._ip_blacklist
) # type: ISynapseReactor )
else: else:
self.reactor = hs.get_reactor() self.reactor = hs.get_reactor()
@ -349,14 +349,14 @@ class SimpleHttpClient:
pool.maxPersistentPerHost = max((100 * hs.config.caches.global_factor, 5)) pool.maxPersistentPerHost = max((100 * hs.config.caches.global_factor, 5))
pool.cachedConnectionTimeout = 2 * 60 pool.cachedConnectionTimeout = 2 * 60
self.agent = ProxyAgent( self.agent: IAgent = ProxyAgent(
self.reactor, self.reactor,
hs.get_reactor(), hs.get_reactor(),
connectTimeout=15, connectTimeout=15,
contextFactory=self.hs.get_http_client_context_factory(), contextFactory=self.hs.get_http_client_context_factory(),
pool=pool, pool=pool,
use_proxy=use_proxy, use_proxy=use_proxy,
) # type: IAgent )
if self._ip_blacklist: if self._ip_blacklist:
# If we have an IP blacklist, we then install the blacklisting Agent # If we have an IP blacklist, we then install the blacklisting Agent
@ -411,7 +411,7 @@ class SimpleHttpClient:
cooperator=self._cooperator, cooperator=self._cooperator,
) )
request_deferred = treq.request( request_deferred: defer.Deferred = treq.request(
method, method,
uri, uri,
agent=self.agent, agent=self.agent,
@ -421,7 +421,7 @@ class SimpleHttpClient:
# response bodies. # response bodies.
unbuffered=True, unbuffered=True,
**self._extra_treq_args, **self._extra_treq_args,
) # type: defer.Deferred )
# we use our own timeout mechanism rather than treq's as a workaround # we use our own timeout mechanism rather than treq's as a workaround
# for https://twistedmatrix.com/trac/ticket/9534. # for https://twistedmatrix.com/trac/ticket/9534.
@ -772,7 +772,7 @@ class BodyExceededMaxSize(Exception):
class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol): class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which immediately errors upon receiving data.""" """A protocol which immediately errors upon receiving data."""
transport = None # type: Optional[ITCPTransport] transport: Optional[ITCPTransport] = None
def __init__(self, deferred: defer.Deferred): def __init__(self, deferred: defer.Deferred):
self.deferred = deferred self.deferred = deferred
@ -798,7 +798,7 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size.""" """A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
transport = None # type: Optional[ITCPTransport] transport: Optional[ITCPTransport] = None
def __init__( def __init__(
self, stream: ByteWriteable, deferred: defer.Deferred, max_size: Optional[int] self, stream: ByteWriteable, deferred: defer.Deferred, max_size: Optional[int]

View File

@ -106,7 +106,7 @@ class ByteParser(ByteWriteable, Generic[T], abc.ABC):
the parsed data. the parsed data.
""" """
CONTENT_TYPE = abc.abstractproperty() # type: str # type: ignore CONTENT_TYPE: str = abc.abstractproperty() # type: ignore
"""The expected content type of the response, e.g. `application/json`. If """The expected content type of the response, e.g. `application/json`. If
the content type doesn't match we fail the request. the content type doesn't match we fail the request.
""" """
@ -327,11 +327,11 @@ class MatrixFederationHttpClient:
# We need to use a DNS resolver which filters out blacklisted IP # We need to use a DNS resolver which filters out blacklisted IP
# addresses, to prevent DNS rebinding. # addresses, to prevent DNS rebinding.
self.reactor = BlacklistingReactorWrapper( self.reactor: ISynapseReactor = BlacklistingReactorWrapper(
hs.get_reactor(), hs.get_reactor(),
hs.config.federation_ip_range_whitelist, hs.config.federation_ip_range_whitelist,
hs.config.federation_ip_range_blacklist, hs.config.federation_ip_range_blacklist,
) # type: ISynapseReactor )
user_agent = hs.version_string user_agent = hs.version_string
if hs.config.user_agent_suffix: if hs.config.user_agent_suffix:
@ -504,7 +504,7 @@ class MatrixFederationHttpClient:
) )
# Inject the span into the headers # Inject the span into the headers
headers_dict = {} # type: Dict[bytes, List[bytes]] headers_dict: Dict[bytes, List[bytes]] = {}
opentracing.inject_header_dict(headers_dict, request.destination) opentracing.inject_header_dict(headers_dict, request.destination)
headers_dict[b"User-Agent"] = [self.version_string_bytes] headers_dict[b"User-Agent"] = [self.version_string_bytes]
@ -533,9 +533,9 @@ class MatrixFederationHttpClient:
destination_bytes, method_bytes, url_to_sign_bytes, json destination_bytes, method_bytes, url_to_sign_bytes, json
) )
data = encode_canonical_json(json) data = encode_canonical_json(json)
producer = QuieterFileBodyProducer( producer: Optional[IBodyProducer] = QuieterFileBodyProducer(
BytesIO(data), cooperator=self._cooperator BytesIO(data), cooperator=self._cooperator
) # type: Optional[IBodyProducer] )
else: else:
producer = None producer = None
auth_headers = self.build_auth_headers( auth_headers = self.build_auth_headers(

View File

@ -81,7 +81,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
if f.check(SynapseError): if f.check(SynapseError):
# mypy doesn't understand that f.check asserts the type. # mypy doesn't understand that f.check asserts the type.
exc = f.value # type: SynapseError # type: ignore exc: SynapseError = f.value # type: ignore
error_code = exc.code error_code = exc.code
error_dict = exc.error_dict() error_dict = exc.error_dict()
@ -132,7 +132,7 @@ def return_html_error(
""" """
if f.check(CodeMessageException): if f.check(CodeMessageException):
# mypy doesn't understand that f.check asserts the type. # mypy doesn't understand that f.check asserts the type.
cme = f.value # type: CodeMessageException # type: ignore cme: CodeMessageException = f.value # type: ignore
code = cme.code code = cme.code
msg = cme.msg msg = cme.msg
@ -404,7 +404,7 @@ class JsonResource(DirectServeJsonResource):
key word arguments to pass to the callback key word arguments to pass to the callback
""" """
# At this point the path must be bytes. # At this point the path must be bytes.
request_path_bytes = request.path # type: bytes # type: ignore request_path_bytes: bytes = request.path # type: ignore
request_path = request_path_bytes.decode("ascii") request_path = request_path_bytes.decode("ascii")
# Treat HEAD requests as GET requests. # Treat HEAD requests as GET requests.
request_method = request.method request_method = request.method
@ -557,7 +557,7 @@ class _ByteProducer:
request: Request, request: Request,
iterator: Iterator[bytes], iterator: Iterator[bytes],
): ):
self._request = request # type: Optional[Request] self._request: Optional[Request] = request
self._iterator = iterator self._iterator = iterator
self._paused = False self._paused = False

View File

@ -205,7 +205,7 @@ def parse_string(
parameter is present, must be one of a list of allowed values and parameter is present, must be one of a list of allowed values and
is not one of those allowed values. is not one of those allowed values.
""" """
args = request.args # type: Dict[bytes, List[bytes]] # type: ignore args: Dict[bytes, List[bytes]] = request.args # type: ignore
return parse_string_from_args( return parse_string_from_args(
args, args,
name, name,

View File

@ -64,16 +64,16 @@ class SynapseRequest(Request):
def __init__(self, channel, *args, max_request_body_size=1024, **kw): def __init__(self, channel, *args, max_request_body_size=1024, **kw):
Request.__init__(self, channel, *args, **kw) Request.__init__(self, channel, *args, **kw)
self._max_request_body_size = max_request_body_size self._max_request_body_size = max_request_body_size
self.site = channel.site # type: SynapseSite self.site: SynapseSite = channel.site
self._channel = channel # this is used by the tests self._channel = channel # this is used by the tests
self.start_time = 0.0 self.start_time = 0.0
# The requester, if authenticated. For federation requests this is the # The requester, if authenticated. For federation requests this is the
# server name, for client requests this is the Requester object. # server name, for client requests this is the Requester object.
self._requester = None # type: Optional[Union[Requester, str]] self._requester: Optional[Union[Requester, str]] = None
# we can't yet create the logcontext, as we don't know the method. # we can't yet create the logcontext, as we don't know the method.
self.logcontext = None # type: Optional[LoggingContext] self.logcontext: Optional[LoggingContext] = None
global _next_request_seq global _next_request_seq
self.request_seq = _next_request_seq self.request_seq = _next_request_seq
@ -152,7 +152,7 @@ class SynapseRequest(Request):
Returns: Returns:
The redacted URI as a string. The redacted URI as a string.
""" """
uri = self.uri # type: Union[bytes, str] uri: Union[bytes, str] = self.uri
if isinstance(uri, bytes): if isinstance(uri, bytes):
uri = uri.decode("ascii", errors="replace") uri = uri.decode("ascii", errors="replace")
return redact_uri(uri) return redact_uri(uri)
@ -167,7 +167,7 @@ class SynapseRequest(Request):
Returns: Returns:
The request method as a string. The request method as a string.
""" """
method = self.method # type: Union[bytes, str] method: Union[bytes, str] = self.method
if isinstance(method, bytes): if isinstance(method, bytes):
return self.method.decode("ascii") return self.method.decode("ascii")
return method return method
@ -434,8 +434,8 @@ class XForwardedForRequest(SynapseRequest):
""" """
# the client IP and ssl flag, as extracted from the headers. # the client IP and ssl flag, as extracted from the headers.
_forwarded_for = None # type: Optional[_XForwardedForAddress] _forwarded_for: "Optional[_XForwardedForAddress]" = None
_forwarded_https = False # type: bool _forwarded_https: bool = False
def requestReceived(self, command, path, version): def requestReceived(self, command, path, version):
# this method is called by the Channel once the full request has been # this method is called by the Channel once the full request has been

View File

@ -110,9 +110,9 @@ class RemoteHandler(logging.Handler):
self.port = port self.port = port
self.maximum_buffer = maximum_buffer self.maximum_buffer = maximum_buffer
self._buffer = deque() # type: Deque[logging.LogRecord] self._buffer: Deque[logging.LogRecord] = deque()
self._connection_waiter = None # type: Optional[Deferred] self._connection_waiter: Optional[Deferred] = None
self._producer = None # type: Optional[LogProducer] self._producer: Optional[LogProducer] = None
# Connect without DNS lookups if it's a direct IP. # Connect without DNS lookups if it's a direct IP.
if _reactor is None: if _reactor is None:
@ -123,9 +123,9 @@ class RemoteHandler(logging.Handler):
try: try:
ip = ip_address(self.host) ip = ip_address(self.host)
if isinstance(ip, IPv4Address): if isinstance(ip, IPv4Address):
endpoint = TCP4ClientEndpoint( endpoint: IStreamClientEndpoint = TCP4ClientEndpoint(
_reactor, self.host, self.port _reactor, self.host, self.port
) # type: IStreamClientEndpoint )
elif isinstance(ip, IPv6Address): elif isinstance(ip, IPv6Address):
endpoint = TCP6ClientEndpoint(_reactor, self.host, self.port) endpoint = TCP6ClientEndpoint(_reactor, self.host, self.port)
else: else:
@ -165,7 +165,7 @@ class RemoteHandler(logging.Handler):
def writer(result: Protocol) -> None: def writer(result: Protocol) -> None:
# Force recognising transport as a Connection and not the more # Force recognising transport as a Connection and not the more
# generic ITransport. # generic ITransport.
transport = result.transport # type: Connection # type: ignore transport: Connection = result.transport # type: ignore
# We have a connection. If we already have a producer, and its # We have a connection. If we already have a producer, and its
# transport is the same, just trigger a resumeProducing. # transport is the same, just trigger a resumeProducing.
@ -188,7 +188,7 @@ class RemoteHandler(logging.Handler):
self._producer.resumeProducing() self._producer.resumeProducing()
self._connection_waiter = None self._connection_waiter = None
deferred = self._service.whenConnected(failAfterFailures=1) # type: Deferred deferred: Deferred = self._service.whenConnected(failAfterFailures=1)
deferred.addCallbacks(writer, fail) deferred.addCallbacks(writer, fail)
self._connection_waiter = deferred self._connection_waiter = deferred

View File

@ -63,7 +63,7 @@ def parse_drain_configs(
DrainType.CONSOLE_JSON, DrainType.CONSOLE_JSON,
DrainType.FILE_JSON, DrainType.FILE_JSON,
): ):
formatter = "json" # type: Optional[str] formatter: Optional[str] = "json"
elif logging_type in ( elif logging_type in (
DrainType.CONSOLE_JSON_TERSE, DrainType.CONSOLE_JSON_TERSE,
DrainType.NETWORK_JSON_TERSE, DrainType.NETWORK_JSON_TERSE,

View File

@ -113,13 +113,13 @@ class ContextResourceUsage:
self.reset() self.reset()
else: else:
# FIXME: mypy can't infer the types set via reset() above, so specify explicitly for now # FIXME: mypy can't infer the types set via reset() above, so specify explicitly for now
self.ru_utime = copy_from.ru_utime # type: float self.ru_utime: float = copy_from.ru_utime
self.ru_stime = copy_from.ru_stime # type: float self.ru_stime: float = copy_from.ru_stime
self.db_txn_count = copy_from.db_txn_count # type: int self.db_txn_count: int = copy_from.db_txn_count
self.db_txn_duration_sec = copy_from.db_txn_duration_sec # type: float self.db_txn_duration_sec: float = copy_from.db_txn_duration_sec
self.db_sched_duration_sec = copy_from.db_sched_duration_sec # type: float self.db_sched_duration_sec: float = copy_from.db_sched_duration_sec
self.evt_db_fetch_count = copy_from.evt_db_fetch_count # type: int self.evt_db_fetch_count: int = copy_from.evt_db_fetch_count
def copy(self) -> "ContextResourceUsage": def copy(self) -> "ContextResourceUsage":
return ContextResourceUsage(copy_from=self) return ContextResourceUsage(copy_from=self)
@ -289,12 +289,12 @@ class LoggingContext:
# The thread resource usage when the logcontext became active. None # The thread resource usage when the logcontext became active. None
# if the context is not currently active. # if the context is not currently active.
self.usage_start = None # type: Optional[resource._RUsage] self.usage_start: Optional[resource._RUsage] = None
self.main_thread = get_thread_id() self.main_thread = get_thread_id()
self.request = None self.request = None
self.tag = "" self.tag = ""
self.scope = None # type: Optional[_LogContextScope] self.scope: Optional["_LogContextScope"] = None
# keep track of whether we have hit the __exit__ block for this context # keep track of whether we have hit the __exit__ block for this context
# (suggesting that the the thing that created the context thinks it should # (suggesting that the the thing that created the context thinks it should

View File

@ -251,7 +251,7 @@ try:
except Exception: except Exception:
logger.exception("Failed to report span") logger.exception("Failed to report span")
RustReporter = _WrappedRustReporter # type: Optional[Type[_WrappedRustReporter]] RustReporter: Optional[Type[_WrappedRustReporter]] = _WrappedRustReporter
except ImportError: except ImportError:
RustReporter = None RustReporter = None
@ -286,7 +286,7 @@ class SynapseBaggage:
# Block everything by default # Block everything by default
# A regex which matches the server_names to expose traces for. # A regex which matches the server_names to expose traces for.
# None means 'block everything'. # None means 'block everything'.
_homeserver_whitelist = None # type: Optional[Pattern[str]] _homeserver_whitelist: Optional[Pattern[str]] = None
# Util methods # Util methods
@ -662,7 +662,7 @@ def inject_header_dict(
span = opentracing.tracer.active_span span = opentracing.tracer.active_span
carrier = {} # type: Dict[str, str] carrier: Dict[str, str] = {}
opentracing.tracer.inject(span.context, opentracing.Format.HTTP_HEADERS, carrier) opentracing.tracer.inject(span.context, opentracing.Format.HTTP_HEADERS, carrier)
for key, value in carrier.items(): for key, value in carrier.items():
@ -704,7 +704,7 @@ def get_active_span_text_map(destination=None):
if destination and not whitelisted_homeserver(destination): if destination and not whitelisted_homeserver(destination):
return {} return {}
carrier = {} # type: Dict[str, str] carrier: Dict[str, str] = {}
opentracing.tracer.inject( opentracing.tracer.inject(
opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier
) )
@ -718,7 +718,7 @@ def active_span_context_as_string():
Returns: Returns:
The active span context encoded as a string. The active span context encoded as a string.
""" """
carrier = {} # type: Dict[str, str] carrier: Dict[str, str] = {}
if opentracing: if opentracing:
opentracing.tracer.inject( opentracing.tracer.inject(
opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier opentracing.tracer.active_span.context, opentracing.Format.TEXT_MAP, carrier

View File

@ -46,7 +46,7 @@ logger = logging.getLogger(__name__)
METRICS_PREFIX = "/_synapse/metrics" METRICS_PREFIX = "/_synapse/metrics"
running_on_pypy = platform.python_implementation() == "PyPy" running_on_pypy = platform.python_implementation() == "PyPy"
all_gauges = {} # type: Dict[str, Union[LaterGauge, InFlightGauge]] all_gauges: "Dict[str, Union[LaterGauge, InFlightGauge]]" = {}
HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat") HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
@ -130,7 +130,7 @@ class InFlightGauge:
) )
# Counts number of in flight blocks for a given set of label values # Counts number of in flight blocks for a given set of label values
self._registrations = {} # type: Dict self._registrations: Dict = {}
# Protects access to _registrations # Protects access to _registrations
self._lock = threading.Lock() self._lock = threading.Lock()
@ -248,7 +248,7 @@ class GaugeBucketCollector:
# We initially set this to None. We won't report metrics until # We initially set this to None. We won't report metrics until
# this has been initialised after a successful data update # this has been initialised after a successful data update
self._metric = None # type: Optional[GaugeHistogramMetricFamily] self._metric: Optional[GaugeHistogramMetricFamily] = None
registry.register(self) registry.register(self)

View File

@ -125,7 +125,7 @@ def generate_latest(registry, emit_help=False):
) )
output.append("# TYPE {0} {1}\n".format(mname, mtype)) output.append("# TYPE {0} {1}\n".format(mname, mtype))
om_samples = {} # type: Dict[str, List[str]] om_samples: Dict[str, List[str]] = {}
for s in metric.samples: for s in metric.samples:
for suffix in ["_created", "_gsum", "_gcount"]: for suffix in ["_created", "_gsum", "_gcount"]:
if s.name == metric.name + suffix: if s.name == metric.name + suffix:

View File

@ -93,7 +93,7 @@ _background_process_db_sched_duration = Counter(
# map from description to a counter, so that we can name our logcontexts # map from description to a counter, so that we can name our logcontexts
# incrementally. (It actually duplicates _background_process_start_count, but # incrementally. (It actually duplicates _background_process_start_count, but
# it's much simpler to do so than to try to combine them.) # it's much simpler to do so than to try to combine them.)
_background_process_counts = {} # type: Dict[str, int] _background_process_counts: Dict[str, int] = {}
# Set of all running background processes that became active active since the # Set of all running background processes that became active active since the
# last time metrics were scraped (i.e. background processes that performed some # last time metrics were scraped (i.e. background processes that performed some
@ -103,7 +103,7 @@ _background_process_counts = {} # type: Dict[str, int]
# background processes stacking up behind a lock or linearizer, where we then # background processes stacking up behind a lock or linearizer, where we then
# only need to iterate over and update metrics for the process that have # only need to iterate over and update metrics for the process that have
# actually been active and can ignore the idle ones. # actually been active and can ignore the idle ones.
_background_processes_active_since_last_scrape = set() # type: Set[_BackgroundProcess] _background_processes_active_since_last_scrape: "Set[_BackgroundProcess]" = set()
# A lock that covers the above set and dict # A lock that covers the above set and dict
_bg_metrics_lock = threading.Lock() _bg_metrics_lock = threading.Lock()

View File

@ -54,7 +54,7 @@ class ModuleApi:
self._state = hs.get_state_handler() self._state = hs.get_state_handler()
# We expose these as properties below in order to attach a helpful docstring. # We expose these as properties below in order to attach a helpful docstring.
self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient self._http_client: SimpleHttpClient = hs.get_simple_http_client()
self._public_room_list_manager = PublicRoomListManager(hs) self._public_room_list_manager = PublicRoomListManager(hs)
self._spam_checker = hs.get_spam_checker() self._spam_checker = hs.get_spam_checker()

View File

@ -203,21 +203,21 @@ class Notifier:
UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000 UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000
def __init__(self, hs: "synapse.server.HomeServer"): def __init__(self, hs: "synapse.server.HomeServer"):
self.user_to_user_stream = {} # type: Dict[str, _NotifierUserStream] self.user_to_user_stream: Dict[str, _NotifierUserStream] = {}
self.room_to_user_streams = {} # type: Dict[str, Set[_NotifierUserStream]] self.room_to_user_streams: Dict[str, Set[_NotifierUserStream]] = {}
self.hs = hs self.hs = hs
self.storage = hs.get_storage() self.storage = hs.get_storage()
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.pending_new_room_events = [] # type: List[_PendingRoomEventEntry] self.pending_new_room_events: List[_PendingRoomEventEntry] = []
# Called when there are new things to stream over replication # Called when there are new things to stream over replication
self.replication_callbacks = [] # type: List[Callable[[], None]] self.replication_callbacks: List[Callable[[], None]] = []
# Called when remote servers have come back online after having been # Called when remote servers have come back online after having been
# down. # down.
self.remote_server_up_callbacks = [] # type: List[Callable[[str], None]] self.remote_server_up_callbacks: List[Callable[[str], None]] = []
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.appservice_handler = hs.get_application_service_handler() self.appservice_handler = hs.get_application_service_handler()
@ -237,7 +237,7 @@ class Notifier:
# when rendering the metrics page, which is likely once per minute at # when rendering the metrics page, which is likely once per minute at
# most when scraping it. # most when scraping it.
def count_listeners(): def count_listeners():
all_user_streams = set() # type: Set[_NotifierUserStream] all_user_streams: Set[_NotifierUserStream] = set()
for streams in list(self.room_to_user_streams.values()): for streams in list(self.room_to_user_streams.values()):
all_user_streams |= streams all_user_streams |= streams
@ -329,8 +329,8 @@ class Notifier:
pending = self.pending_new_room_events pending = self.pending_new_room_events
self.pending_new_room_events = [] self.pending_new_room_events = []
users = set() # type: Set[UserID] users: Set[UserID] = set()
rooms = set() # type: Set[str] rooms: Set[str] = set()
for entry in pending: for entry in pending:
if entry.event_pos.persisted_after(max_room_stream_token): if entry.event_pos.persisted_after(max_room_stream_token):
@ -580,7 +580,7 @@ class Notifier:
if after_token == before_token: if after_token == before_token:
return EventStreamResult([], (from_token, from_token)) return EventStreamResult([], (from_token, from_token))
events = [] # type: List[EventBase] events: List[EventBase] = []
end_token = from_token end_token = from_token
for name, source in self.event_sources.sources.items(): for name, source in self.event_sources.sources.items():

View File

@ -194,7 +194,7 @@ class BulkPushRuleEvaluator:
count_as_unread = _should_count_as_unread(event, context) count_as_unread = _should_count_as_unread(event, context)
rules_by_user = await self._get_rules_for_event(event, context) rules_by_user = await self._get_rules_for_event(event, context)
actions_by_user = {} # type: Dict[str, List[Union[dict, str]]] actions_by_user: Dict[str, List[Union[dict, str]]] = {}
room_members = await self.store.get_joined_users_from_context(event, context) room_members = await self.store.get_joined_users_from_context(event, context)
@ -207,7 +207,7 @@ class BulkPushRuleEvaluator:
event, len(room_members), sender_power_level, power_levels event, len(room_members), sender_power_level, power_levels
) )
condition_cache = {} # type: Dict[str, bool] condition_cache: Dict[str, bool] = {}
# If the event is not a state event check if any users ignore the sender. # If the event is not a state event check if any users ignore the sender.
if not event.is_state(): if not event.is_state():

View File

@ -26,10 +26,10 @@ def format_push_rules_for_user(user: UserID, ruleslist) -> Dict[str, Dict[str, l
# We're going to be mutating this a lot, so do a deep copy # We're going to be mutating this a lot, so do a deep copy
ruleslist = copy.deepcopy(ruleslist) ruleslist = copy.deepcopy(ruleslist)
rules = { rules: Dict[str, Dict[str, List[Dict[str, Any]]]] = {
"global": {}, "global": {},
"device": {}, "device": {},
} # type: Dict[str, Dict[str, List[Dict[str, Any]]]] }
rules["global"] = _add_empty_priority_class_arrays(rules["global"]) rules["global"] = _add_empty_priority_class_arrays(rules["global"])

View File

@ -66,8 +66,8 @@ class EmailPusher(Pusher):
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.email = pusher_config.pushkey self.email = pusher_config.pushkey
self.timed_call = None # type: Optional[IDelayedCall] self.timed_call: Optional[IDelayedCall] = None
self.throttle_params = {} # type: Dict[str, ThrottleParams] self.throttle_params: Dict[str, ThrottleParams] = {}
self._inited = False self._inited = False
self._is_processing = False self._is_processing = False
@ -168,7 +168,7 @@ class EmailPusher(Pusher):
) )
) )
soonest_due_at = None # type: Optional[int] soonest_due_at: Optional[int] = None
if not unprocessed: if not unprocessed:
await self.save_last_stream_ordering_and_success(self.max_stream_ordering) await self.save_last_stream_ordering_and_success(self.max_stream_ordering)

View File

@ -71,7 +71,7 @@ class HttpPusher(Pusher):
self.data = pusher_config.data self.data = pusher_config.data
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.failing_since = pusher_config.failing_since self.failing_since = pusher_config.failing_since
self.timed_call = None # type: Optional[IDelayedCall] self.timed_call: Optional[IDelayedCall] = None
self._is_processing = False self._is_processing = False
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
self._pusherpool = hs.get_pusherpool() self._pusherpool = hs.get_pusherpool()

View File

@ -110,7 +110,7 @@ class Mailer:
self.state_handler = self.hs.get_state_handler() self.state_handler = self.hs.get_state_handler()
self.storage = hs.get_storage() self.storage = hs.get_storage()
self.app_name = app_name self.app_name = app_name
self.email_subjects = hs.config.email_subjects # type: EmailSubjectConfig self.email_subjects: EmailSubjectConfig = hs.config.email_subjects
logger.info("Created Mailer for app_name %s" % app_name) logger.info("Created Mailer for app_name %s" % app_name)
@ -230,7 +230,7 @@ class Mailer:
[pa["event_id"] for pa in push_actions] [pa["event_id"] for pa in push_actions]
) )
notifs_by_room = {} # type: Dict[str, List[Dict[str, Any]]] notifs_by_room: Dict[str, List[Dict[str, Any]]] = {}
for pa in push_actions: for pa in push_actions:
notifs_by_room.setdefault(pa["room_id"], []).append(pa) notifs_by_room.setdefault(pa["room_id"], []).append(pa)
@ -356,13 +356,13 @@ class Mailer:
room_name = await calculate_room_name(self.store, room_state_ids, user_id) room_name = await calculate_room_name(self.store, room_state_ids, user_id)
room_vars = { room_vars: Dict[str, Any] = {
"title": room_name, "title": room_name,
"hash": string_ordinal_total(room_id), # See sender avatar hash "hash": string_ordinal_total(room_id), # See sender avatar hash
"notifs": [], "notifs": [],
"invite": is_invite, "invite": is_invite,
"link": self._make_room_link(room_id), "link": self._make_room_link(room_id),
} # type: Dict[str, Any] }
if not is_invite: if not is_invite:
for n in notifs: for n in notifs:
@ -460,9 +460,9 @@ class Mailer:
type_state_key = ("m.room.member", event.sender) type_state_key = ("m.room.member", event.sender)
sender_state_event_id = room_state_ids.get(type_state_key) sender_state_event_id = room_state_ids.get(type_state_key)
if sender_state_event_id: if sender_state_event_id:
sender_state_event = await self.store.get_event( sender_state_event: Optional[EventBase] = await self.store.get_event(
sender_state_event_id sender_state_event_id
) # type: Optional[EventBase] )
else: else:
# Attempt to check the historical state for the room. # Attempt to check the historical state for the room.
historical_state = await self.state_store.get_state_for_event( historical_state = await self.state_store.get_state_for_event(

View File

@ -199,7 +199,7 @@ def name_from_member_event(member_event: EventBase) -> str:
def _state_as_two_level_dict(state: StateMap[str]) -> Dict[str, Dict[str, str]]: def _state_as_two_level_dict(state: StateMap[str]) -> Dict[str, Dict[str, str]]:
ret = {} # type: Dict[str, Dict[str, str]] ret: Dict[str, Dict[str, str]] = {}
for k, v in state.items(): for k, v in state.items():
ret.setdefault(k[0], {})[k[1]] = v ret.setdefault(k[0], {})[k[1]] = v
return ret return ret

View File

@ -195,9 +195,9 @@ class PushRuleEvaluatorForEvent:
# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
regex_cache = LruCache( regex_cache: LruCache[Tuple[str, bool, bool], Pattern] = LruCache(
50000, "regex_push_cache" 50000, "regex_push_cache"
) # type: LruCache[Tuple[str, bool, bool], Pattern] )
def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:

View File

@ -31,13 +31,13 @@ class PusherFactory:
self.hs = hs self.hs = hs
self.config = hs.config self.config = hs.config
self.pusher_types = { self.pusher_types: Dict[str, Callable[[HomeServer, PusherConfig], Pusher]] = {
"http": HttpPusher "http": HttpPusher
} # type: Dict[str, Callable[[HomeServer, PusherConfig], Pusher]] }
logger.info("email enable notifs: %r", hs.config.email_enable_notifs) logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
if hs.config.email_enable_notifs: if hs.config.email_enable_notifs:
self.mailers = {} # type: Dict[str, Mailer] self.mailers: Dict[str, Mailer] = {}
self._notif_template_html = hs.config.email_notif_template_html self._notif_template_html = hs.config.email_notif_template_html
self._notif_template_text = hs.config.email_notif_template_text self._notif_template_text = hs.config.email_notif_template_text

View File

@ -87,7 +87,7 @@ class PusherPool:
self._last_room_stream_id_seen = self.store.get_room_max_stream_ordering() self._last_room_stream_id_seen = self.store.get_room_max_stream_ordering()
# map from user id to app_id:pushkey to pusher # map from user id to app_id:pushkey to pusher
self.pushers = {} # type: Dict[str, Dict[str, Pusher]] self.pushers: Dict[str, Dict[str, Pusher]] = {}
def start(self) -> None: def start(self) -> None:
"""Starts the pushers off in a background process.""" """Starts the pushers off in a background process."""

View File

@ -115,7 +115,7 @@ CONDITIONAL_REQUIREMENTS = {
"cache_memory": ["pympler"], "cache_memory": ["pympler"],
} }
ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str] ALL_OPTIONAL_REQUIREMENTS: Set[str] = set()
for name, optional_deps in CONDITIONAL_REQUIREMENTS.items(): for name, optional_deps in CONDITIONAL_REQUIREMENTS.items():
# Exclude systemd as it's a system-based requirement. # Exclude systemd as it's a system-based requirement.
@ -193,7 +193,7 @@ def check_requirements(for_feature=None):
if not for_feature: if not for_feature:
# Check the optional dependencies are up to date. We allow them to not be # Check the optional dependencies are up to date. We allow them to not be
# installed. # installed.
OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), []) # type: List[str] OPTS: List[str] = sum(CONDITIONAL_REQUIREMENTS.values(), [])
for dependency in OPTS: for dependency in OPTS:
try: try:

View File

@ -85,17 +85,17 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
is received. is received.
""" """
NAME = abc.abstractproperty() # type: str # type: ignore NAME: str = abc.abstractproperty() # type: ignore
PATH_ARGS = abc.abstractproperty() # type: Tuple[str, ...] # type: ignore PATH_ARGS: Tuple[str, ...] = abc.abstractproperty() # type: ignore
METHOD = "POST" METHOD = "POST"
CACHE = True CACHE = True
RETRY_ON_TIMEOUT = True RETRY_ON_TIMEOUT = True
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
if self.CACHE: if self.CACHE:
self.response_cache = ResponseCache( self.response_cache: ResponseCache[str] = ResponseCache(
hs.get_clock(), "repl." + self.NAME, timeout_ms=30 * 60 * 1000 hs.get_clock(), "repl." + self.NAME, timeout_ms=30 * 60 * 1000
) # type: ResponseCache[str] )
# We reserve `instance_name` as a parameter to sending requests, so we # We reserve `instance_name` as a parameter to sending requests, so we
# assert here that sub classes don't try and use the name. # assert here that sub classes don't try and use the name.
@ -232,7 +232,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
# have a good idea that the request has either succeeded or failed on # have a good idea that the request has either succeeded or failed on
# the master, and so whether we should clean up or not. # the master, and so whether we should clean up or not.
while True: while True:
headers = {} # type: Dict[bytes, List[bytes]] headers: Dict[bytes, List[bytes]] = {}
# Add an authorization header, if configured. # Add an authorization header, if configured.
if replication_secret: if replication_secret:
headers[b"Authorization"] = [b"Bearer " + replication_secret] headers[b"Authorization"] = [b"Bearer " + replication_secret]

View File

@ -27,7 +27,9 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = MultiWriterIdGenerator( self._cache_id_gen: Optional[
MultiWriterIdGenerator
] = MultiWriterIdGenerator(
db_conn, db_conn,
database, database,
stream_name="caches", stream_name="caches",
@ -41,7 +43,7 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
], ],
sequence_name="cache_invalidation_stream_seq", sequence_name="cache_invalidation_stream_seq",
writers=[], writers=[],
) # type: Optional[MultiWriterIdGenerator] )
else: else:
self._cache_id_gen = None self._cache_id_gen = None

View File

@ -23,9 +23,9 @@ class SlavedClientIpStore(BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.client_ip_last_seen = LruCache( self.client_ip_last_seen: LruCache[tuple, int] = LruCache(
cache_name="client_ip_last_seen", max_size=50000 cache_name="client_ip_last_seen", max_size=50000
) # type: LruCache[tuple, int] )
async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec()) now = int(self._clock.time_msec())

View File

@ -121,13 +121,13 @@ class ReplicationDataHandler:
self._pusher_pool = hs.get_pusherpool() self._pusher_pool = hs.get_pusherpool()
self._presence_handler = hs.get_presence_handler() self._presence_handler = hs.get_presence_handler()
self.send_handler = None # type: Optional[FederationSenderHandler] self.send_handler: Optional[FederationSenderHandler] = None
if hs.should_send_federation(): if hs.should_send_federation():
self.send_handler = FederationSenderHandler(hs) self.send_handler = FederationSenderHandler(hs)
# Map from stream to list of deferreds waiting for the stream to # Map from stream to list of deferreds waiting for the stream to
# arrive at a particular position. The lists are sorted by stream position. # arrive at a particular position. The lists are sorted by stream position.
self._streams_to_waiters = {} # type: Dict[str, List[Tuple[int, Deferred]]] self._streams_to_waiters: Dict[str, List[Tuple[int, Deferred]]] = {}
async def on_rdata( async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list self, stream_name: str, instance_name: str, token: int, rows: list
@ -173,7 +173,7 @@ class ReplicationDataHandler:
if entities: if entities:
self.notifier.on_new_event("to_device_key", token, users=entities) self.notifier.on_new_event("to_device_key", token, users=entities)
elif stream_name == DeviceListsStream.NAME: elif stream_name == DeviceListsStream.NAME:
all_room_ids = set() # type: Set[str] all_room_ids: Set[str] = set()
for row in rows: for row in rows:
if row.entity.startswith("@"): if row.entity.startswith("@"):
room_ids = await self.store.get_rooms_for_user(row.entity) room_ids = await self.store.get_rooms_for_user(row.entity)
@ -201,7 +201,7 @@ class ReplicationDataHandler:
if row.data.rejected: if row.data.rejected:
continue continue
extra_users = () # type: Tuple[UserID, ...] extra_users: Tuple[UserID, ...] = ()
if row.data.type == EventTypes.Member and row.data.state_key: if row.data.type == EventTypes.Member and row.data.state_key:
extra_users = (UserID.from_string(row.data.state_key),) extra_users = (UserID.from_string(row.data.state_key),)
@ -348,7 +348,7 @@ class FederationSenderHandler:
# Stores the latest position in the federation stream we've gotten up # Stores the latest position in the federation stream we've gotten up
# to. This is always set before we use it. # to. This is always set before we use it.
self.federation_position = None # type: Optional[int] self.federation_position: Optional[int] = None
self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer") self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")

View File

@ -34,7 +34,7 @@ class Command(metaclass=abc.ABCMeta):
A full command line on the wire is constructed from `NAME + " " + to_line()` A full command line on the wire is constructed from `NAME + " " + to_line()`
""" """
NAME = None # type: str NAME: str
@classmethod @classmethod
@abc.abstractmethod @abc.abstractmethod
@ -380,7 +380,7 @@ class RemoteServerUpCommand(_SimpleCommand):
NAME = "REMOTE_SERVER_UP" NAME = "REMOTE_SERVER_UP"
_COMMANDS = ( _COMMANDS: Tuple[Type[Command], ...] = (
ServerCommand, ServerCommand,
RdataCommand, RdataCommand,
PositionCommand, PositionCommand,
@ -393,7 +393,7 @@ _COMMANDS = (
UserIpCommand, UserIpCommand,
RemoteServerUpCommand, RemoteServerUpCommand,
ClearUserSyncsCommand, ClearUserSyncsCommand,
) # type: Tuple[Type[Command], ...] )
# Map of command name to command type. # Map of command name to command type.
COMMAND_MAP = {cmd.NAME: cmd for cmd in _COMMANDS} COMMAND_MAP = {cmd.NAME: cmd for cmd in _COMMANDS}

View File

@ -105,12 +105,12 @@ class ReplicationCommandHandler:
hs.get_instance_name() in hs.config.worker.writers.presence hs.get_instance_name() in hs.config.worker.writers.presence
) )
self._streams = { self._streams: Dict[str, Stream] = {
stream.NAME: stream(hs) for stream in STREAMS_MAP.values() stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
} # type: Dict[str, Stream] }
# List of streams that this instance is the source of # List of streams that this instance is the source of
self._streams_to_replicate = [] # type: List[Stream] self._streams_to_replicate: List[Stream] = []
for stream in self._streams.values(): for stream in self._streams.values():
if hs.config.redis.redis_enabled and stream.NAME == CachesStream.NAME: if hs.config.redis.redis_enabled and stream.NAME == CachesStream.NAME:
@ -180,14 +180,14 @@ class ReplicationCommandHandler:
# Map of stream name to batched updates. See RdataCommand for info on # Map of stream name to batched updates. See RdataCommand for info on
# how batching works. # how batching works.
self._pending_batches = {} # type: Dict[str, List[Any]] self._pending_batches: Dict[str, List[Any]] = {}
# The factory used to create connections. # The factory used to create connections.
self._factory = None # type: Optional[ReconnectingClientFactory] self._factory: Optional[ReconnectingClientFactory] = None
# The currently connected connections. (The list of places we need to send # The currently connected connections. (The list of places we need to send
# outgoing replication commands to.) # outgoing replication commands to.)
self._connections = [] # type: List[IReplicationConnection] self._connections: List[IReplicationConnection] = []
LaterGauge( LaterGauge(
"synapse_replication_tcp_resource_total_connections", "synapse_replication_tcp_resource_total_connections",
@ -200,7 +200,7 @@ class ReplicationCommandHandler:
# them in order in a separate background process. # them in order in a separate background process.
# the streams which are currently being processed by _unsafe_process_queue # the streams which are currently being processed by _unsafe_process_queue
self._processing_streams = set() # type: Set[str] self._processing_streams: Set[str] = set()
# for each stream, a queue of commands that are awaiting processing, and the # for each stream, a queue of commands that are awaiting processing, and the
# connection that they arrived on. # connection that they arrived on.
@ -210,7 +210,7 @@ class ReplicationCommandHandler:
# For each connection, the incoming stream names that have received a POSITION # For each connection, the incoming stream names that have received a POSITION
# from that connection. # from that connection.
self._streams_by_connection = {} # type: Dict[IReplicationConnection, Set[str]] self._streams_by_connection: Dict[IReplicationConnection, Set[str]] = {}
LaterGauge( LaterGauge(
"synapse_replication_tcp_command_queue", "synapse_replication_tcp_command_queue",

View File

@ -102,7 +102,7 @@ tcp_outbound_commands_counter = Counter(
# A list of all connected protocols. This allows us to send metrics about the # A list of all connected protocols. This allows us to send metrics about the
# connections. # connections.
connected_connections = [] # type: List[BaseReplicationStreamProtocol] connected_connections: "List[BaseReplicationStreamProtocol]" = []
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -146,15 +146,15 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# The transport is going to be an ITCPTransport, but that doesn't have the # The transport is going to be an ITCPTransport, but that doesn't have the
# (un)registerProducer methods, those are only on the implementation. # (un)registerProducer methods, those are only on the implementation.
transport = None # type: Connection transport: Connection
delimiter = b"\n" delimiter = b"\n"
# Valid commands we expect to receive # Valid commands we expect to receive
VALID_INBOUND_COMMANDS = [] # type: Collection[str] VALID_INBOUND_COMMANDS: Collection[str] = []
# Valid commands we can send # Valid commands we can send
VALID_OUTBOUND_COMMANDS = [] # type: Collection[str] VALID_OUTBOUND_COMMANDS: Collection[str] = []
max_line_buffer = 10000 max_line_buffer = 10000
@ -165,7 +165,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.last_received_command = self.clock.time_msec() self.last_received_command = self.clock.time_msec()
self.last_sent_command = 0 self.last_sent_command = 0
# When we requested the connection be closed # When we requested the connection be closed
self.time_we_closed = None # type: Optional[int] self.time_we_closed: Optional[int] = None
self.received_ping = False # Have we received a ping from the other side self.received_ping = False # Have we received a ping from the other side
@ -175,10 +175,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.conn_id = random_string(5) # To dedupe in case of name clashes. self.conn_id = random_string(5) # To dedupe in case of name clashes.
# List of pending commands to send once we've established the connection # List of pending commands to send once we've established the connection
self.pending_commands = [] # type: List[Command] self.pending_commands: List[Command] = []
# The LoopingCall for sending pings. # The LoopingCall for sending pings.
self._send_ping_loop = None # type: Optional[task.LoopingCall] self._send_ping_loop: Optional[task.LoopingCall] = None
# a logcontext which we use for processing incoming commands. We declare it as a # a logcontext which we use for processing incoming commands. We declare it as a
# background process so that the CPU stats get reported to prometheus. # background process so that the CPU stats get reported to prometheus.

View File

@ -57,7 +57,7 @@ class ConstantProperty(Generic[T, V]):
it. it.
""" """
constant = attr.ib() # type: V constant: V = attr.ib()
def __get__(self, obj: Optional[T], objtype: Optional[Type[T]] = None) -> V: def __get__(self, obj: Optional[T], objtype: Optional[Type[T]] = None) -> V:
return self.constant return self.constant
@ -91,9 +91,9 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
commands. commands.
""" """
synapse_handler = None # type: ReplicationCommandHandler synapse_handler: "ReplicationCommandHandler"
synapse_stream_name = None # type: str synapse_stream_name: str
synapse_outbound_redis_connection = None # type: txredisapi.RedisProtocol synapse_outbound_redis_connection: txredisapi.RedisProtocol
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)

View File

@ -85,9 +85,9 @@ class Stream:
time it was called. time it was called.
""" """
NAME = None # type: str # The name of the stream NAME: str # The name of the stream
# The type of the row. Used by the default impl of parse_row. # The type of the row. Used by the default impl of parse_row.
ROW_TYPE = None # type: Any ROW_TYPE: Any = None
@classmethod @classmethod
def parse_row(cls, row: StreamRow): def parse_row(cls, row: StreamRow):
@ -283,9 +283,7 @@ class PresenceStream(Stream):
assert isinstance(presence_handler, PresenceHandler) assert isinstance(presence_handler, PresenceHandler)
update_function = ( update_function: UpdateFunction = presence_handler.get_all_presence_updates
presence_handler.get_all_presence_updates
) # type: UpdateFunction
else: else:
# Query presence writer process # Query presence writer process
update_function = make_http_update_function(hs, self.NAME) update_function = make_http_update_function(hs, self.NAME)
@ -334,9 +332,9 @@ class TypingStream(Stream):
if writer_instance == hs.get_instance_name(): if writer_instance == hs.get_instance_name():
# On the writer, query the typing handler # On the writer, query the typing handler
typing_writer_handler = hs.get_typing_writer_handler() typing_writer_handler = hs.get_typing_writer_handler()
update_function = ( update_function: Callable[
typing_writer_handler.get_all_typing_updates [str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]
) # type: Callable[[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]] ] = typing_writer_handler.get_all_typing_updates
current_token_function = typing_writer_handler.get_current_token current_token_function = typing_writer_handler.get_current_token
else: else:
# Query the typing writer process # Query the typing writer process

View File

@ -65,7 +65,7 @@ class BaseEventsStreamRow:
""" """
# Unique string that ids the type. Must be overridden in sub classes. # Unique string that ids the type. Must be overridden in sub classes.
TypeId = None # type: str TypeId: str
@classmethod @classmethod
def from_data(cls, data): def from_data(cls, data):
@ -103,10 +103,10 @@ class EventsStreamCurrentStateRow(BaseEventsStreamRow):
event_id = attr.ib() # str, optional event_id = attr.ib() # str, optional
_EventRows = ( _EventRows: Tuple[Type[BaseEventsStreamRow], ...] = (
EventsStreamEventRow, EventsStreamEventRow,
EventsStreamCurrentStateRow, EventsStreamCurrentStateRow,
) # type: Tuple[Type[BaseEventsStreamRow], ...] )
TypeToRow = {Row.TypeId: Row for Row in _EventRows} TypeToRow = {Row.TypeId: Row for Row in _EventRows}
@ -157,9 +157,9 @@ class EventsStream(Stream):
# now we fetch up to that many rows from the events table # now we fetch up to that many rows from the events table
event_rows = await self._store.get_all_new_forward_event_rows( event_rows: List[Tuple] = await self._store.get_all_new_forward_event_rows(
instance_name, from_token, current_token, target_row_count instance_name, from_token, current_token, target_row_count
) # type: List[Tuple] )
# we rely on get_all_new_forward_event_rows strictly honouring the limit, so # we rely on get_all_new_forward_event_rows strictly honouring the limit, so
# that we know it is safe to just take upper_limit = event_rows[-1][0]. # that we know it is safe to just take upper_limit = event_rows[-1][0].
@ -172,7 +172,7 @@ class EventsStream(Stream):
if len(event_rows) == target_row_count: if len(event_rows) == target_row_count:
limited = True limited = True
upper_limit = event_rows[-1][0] # type: int upper_limit: int = event_rows[-1][0]
else: else:
limited = False limited = False
upper_limit = current_token upper_limit = current_token
@ -191,30 +191,30 @@ class EventsStream(Stream):
# finally, fetch the ex-outliers rows. We assume there are few enough of these # finally, fetch the ex-outliers rows. We assume there are few enough of these
# not to bother with the limit. # not to bother with the limit.
ex_outliers_rows = await self._store.get_ex_outlier_stream_rows( ex_outliers_rows: List[Tuple] = await self._store.get_ex_outlier_stream_rows(
instance_name, from_token, upper_limit instance_name, from_token, upper_limit
) # type: List[Tuple] )
# we now need to turn the raw database rows returned into tuples suitable # we now need to turn the raw database rows returned into tuples suitable
# for the replication protocol (basically, we add an identifier to # for the replication protocol (basically, we add an identifier to
# distinguish the row type). At the same time, we can limit the event_rows # distinguish the row type). At the same time, we can limit the event_rows
# to the max stream_id from state_rows. # to the max stream_id from state_rows.
event_updates = ( event_updates: Iterable[Tuple[int, Tuple]] = (
(stream_id, (EventsStreamEventRow.TypeId, rest)) (stream_id, (EventsStreamEventRow.TypeId, rest))
for (stream_id, *rest) in event_rows for (stream_id, *rest) in event_rows
if stream_id <= upper_limit if stream_id <= upper_limit
) # type: Iterable[Tuple[int, Tuple]] )
state_updates = ( state_updates: Iterable[Tuple[int, Tuple]] = (
(stream_id, (EventsStreamCurrentStateRow.TypeId, rest)) (stream_id, (EventsStreamCurrentStateRow.TypeId, rest))
for (stream_id, *rest) in state_rows for (stream_id, *rest) in state_rows
) # type: Iterable[Tuple[int, Tuple]] )
ex_outliers_updates = ( ex_outliers_updates: Iterable[Tuple[int, Tuple]] = (
(stream_id, (EventsStreamEventRow.TypeId, rest)) (stream_id, (EventsStreamEventRow.TypeId, rest))
for (stream_id, *rest) in ex_outliers_rows for (stream_id, *rest) in ex_outliers_rows
) # type: Iterable[Tuple[int, Tuple]] )
# we need to return a sorted list, so merge them together. # we need to return a sorted list, so merge them together.
updates = list(heapq.merge(event_updates, state_updates, ex_outliers_updates)) updates = list(heapq.merge(event_updates, state_updates, ex_outliers_updates))

View File

@ -51,9 +51,9 @@ class FederationStream(Stream):
current_token = current_token_without_instance( current_token = current_token_without_instance(
federation_sender.get_current_token federation_sender.get_current_token
) )
update_function = ( update_function: Callable[
federation_sender.get_replication_rows [str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]
) # type: Callable[[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]] ] = federation_sender.get_replication_rows
elif hs.should_send_federation(): elif hs.should_send_federation():
# federation sender: Query master process # federation sender: Query master process

View File

@ -247,15 +247,15 @@ class HomeServer(metaclass=abc.ABCMeta):
# the key we use to sign events and requests # the key we use to sign events and requests
self.signing_key = config.key.signing_key[0] self.signing_key = config.key.signing_key[0]
self.config = config self.config = config
self._listening_services = [] # type: List[twisted.internet.tcp.Port] self._listening_services: List[twisted.internet.tcp.Port] = []
self.start_time = None # type: Optional[int] self.start_time: Optional[int] = None
self._instance_id = random_string(5) self._instance_id = random_string(5)
self._instance_name = config.worker.instance_name self._instance_name = config.worker.instance_name
self.version_string = version_string self.version_string = version_string
self.datastores = None # type: Optional[Databases] self.datastores: Optional[Databases] = None
self._module_web_resources: Dict[str, IResource] = {} self._module_web_resources: Dict[str, IResource] = {}
self._module_web_resources_consumed = False self._module_web_resources_consumed = False

View File

@ -34,7 +34,7 @@ class ConsentServerNotices:
self._server_notices_manager = hs.get_server_notices_manager() self._server_notices_manager = hs.get_server_notices_manager()
self._store = hs.get_datastore() self._store = hs.get_datastore()
self._users_in_progress = set() # type: Set[str] self._users_in_progress: Set[str] = set()
self._current_consent_version = hs.config.user_consent_version self._current_consent_version = hs.config.user_consent_version
self._server_notice_content = hs.config.user_consent_server_notice_content self._server_notice_content = hs.config.user_consent_server_notice_content

View File

@ -205,7 +205,7 @@ class ResourceLimitsServerNotices:
# The user has yet to join the server notices room # The user has yet to join the server notices room
pass pass
referenced_events = [] # type: List[str] referenced_events: List[str] = []
if pinned_state_event is not None: if pinned_state_event is not None:
referenced_events = list(pinned_state_event.content.get("pinned", [])) referenced_events = list(pinned_state_event.content.get("pinned", []))

View File

@ -32,10 +32,12 @@ class ServerNoticesSender(WorkerServerNoticesSender):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self._server_notices = ( self._server_notices: Iterable[
Union[ConsentServerNotices, ResourceLimitsServerNotices]
] = (
ConsentServerNotices(hs), ConsentServerNotices(hs),
ResourceLimitsServerNotices(hs), ResourceLimitsServerNotices(hs),
) # type: Iterable[Union[ConsentServerNotices, ResourceLimitsServerNotices]] )
async def on_user_syncing(self, user_id: str) -> None: async def on_user_syncing(self, user_id: str) -> None:
"""Called when the user performs a sync operation. """Called when the user performs a sync operation.

View File

@ -309,9 +309,9 @@ class StateHandler:
if old_state: if old_state:
# if we're given the state before the event, then we use that # if we're given the state before the event, then we use that
state_ids_before_event = { state_ids_before_event: StateMap[str] = {
(s.type, s.state_key): s.event_id for s in old_state (s.type, s.state_key): s.event_id for s in old_state
} # type: StateMap[str] }
state_group_before_event = None state_group_before_event = None
state_group_before_event_prev_group = None state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None deltas_to_state_group_before_event = None
@ -513,23 +513,25 @@ class StateResolutionHandler:
self.resolve_linearizer = Linearizer(name="state_resolve_lock") self.resolve_linearizer = Linearizer(name="state_resolve_lock")
# dict of set of event_ids -> _StateCacheEntry. # dict of set of event_ids -> _StateCacheEntry.
self._state_cache = ExpiringCache( self._state_cache: ExpiringCache[
FrozenSet[int], _StateCacheEntry
] = ExpiringCache(
cache_name="state_cache", cache_name="state_cache",
clock=self.clock, clock=self.clock,
max_len=100000, max_len=100000,
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
iterable=True, iterable=True,
reset_expiry_on_get=True, reset_expiry_on_get=True,
) # type: ExpiringCache[FrozenSet[int], _StateCacheEntry] )
# #
# stuff for tracking time spent on state-res by room # stuff for tracking time spent on state-res by room
# #
# tracks the amount of work done on state res per room # tracks the amount of work done on state res per room
self._state_res_metrics = defaultdict( self._state_res_metrics: DefaultDict[str, _StateResMetrics] = defaultdict(
_StateResMetrics _StateResMetrics
) # type: DefaultDict[str, _StateResMetrics] )
self.clock.looping_call(self._report_metrics, 120 * 1000) self.clock.looping_call(self._report_metrics, 120 * 1000)
@ -700,9 +702,9 @@ class StateResolutionHandler:
items = self._state_res_metrics.items() items = self._state_res_metrics.items()
# log the N biggest rooms # log the N biggest rooms
biggest = heapq.nlargest( biggest: List[Tuple[str, _StateResMetrics]] = heapq.nlargest(
n_to_log, items, key=lambda i: extract_key(i[1]) n_to_log, items, key=lambda i: extract_key(i[1])
) # type: List[Tuple[str, _StateResMetrics]] )
metrics_logger.debug( metrics_logger.debug(
"%i biggest rooms for state-res by %s: %s", "%i biggest rooms for state-res by %s: %s",
len(biggest), len(biggest),
@ -754,7 +756,7 @@ def _make_state_cache_entry(
# failing that, look for the closest match. # failing that, look for the closest match.
prev_group = None prev_group = None
delta_ids = None # type: Optional[StateMap[str]] delta_ids: Optional[StateMap[str]] = None
for old_group, old_state in state_groups_ids.items(): for old_group, old_state in state_groups_ids.items():
n_delta_ids = {k: v for k, v in new_state.items() if old_state.get(k) != v} n_delta_ids = {k: v for k, v in new_state.items() if old_state.get(k) != v}

View File

@ -159,7 +159,7 @@ def _seperate(
""" """
state_set_iterator = iter(state_sets) state_set_iterator = iter(state_sets)
unconflicted_state = dict(next(state_set_iterator)) unconflicted_state = dict(next(state_set_iterator))
conflicted_state = {} # type: MutableStateMap[Set[str]] conflicted_state: MutableStateMap[Set[str]] = {}
for state_set in state_set_iterator: for state_set in state_set_iterator:
for key, value in state_set.items(): for key, value in state_set.items():

View File

@ -276,7 +276,7 @@ async def _get_auth_chain_difference(
# event IDs if they appear in the `event_map`. This is the intersection of # event IDs if they appear in the `event_map`. This is the intersection of
# the event's auth chain with the events in the `event_map` *plus* their # the event's auth chain with the events in the `event_map` *plus* their
# auth event IDs. # auth event IDs.
events_to_auth_chain = {} # type: Dict[str, Set[str]] events_to_auth_chain: Dict[str, Set[str]] = {}
for event in event_map.values(): for event in event_map.values():
chain = {event.event_id} chain = {event.event_id}
events_to_auth_chain[event.event_id] = chain events_to_auth_chain[event.event_id] = chain
@ -301,17 +301,17 @@ async def _get_auth_chain_difference(
# ((type, state_key)->event_id) mappings; and (b) we have stripped out # ((type, state_key)->event_id) mappings; and (b) we have stripped out
# unpersisted events and replaced them with the persisted events in # unpersisted events and replaced them with the persisted events in
# their auth chain. # their auth chain.
state_sets_ids = [] # type: List[Set[str]] state_sets_ids: List[Set[str]] = []
# For each state set, the unpersisted event IDs reachable (by their auth # For each state set, the unpersisted event IDs reachable (by their auth
# chain) from the events in that set. # chain) from the events in that set.
unpersisted_set_ids = [] # type: List[Set[str]] unpersisted_set_ids: List[Set[str]] = []
for state_set in state_sets: for state_set in state_sets:
set_ids = set() # type: Set[str] set_ids: Set[str] = set()
state_sets_ids.append(set_ids) state_sets_ids.append(set_ids)
unpersisted_ids = set() # type: Set[str] unpersisted_ids: Set[str] = set()
unpersisted_set_ids.append(unpersisted_ids) unpersisted_set_ids.append(unpersisted_ids)
for event_id in state_set.values(): for event_id in state_set.values():
@ -334,7 +334,7 @@ async def _get_auth_chain_difference(
union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:]) union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:]) intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
difference_from_event_map = union - intersection # type: Collection[str] difference_from_event_map: Collection[str] = union - intersection
else: else:
difference_from_event_map = () difference_from_event_map = ()
state_sets_ids = [set(state_set.values()) for state_set in state_sets] state_sets_ids = [set(state_set.values()) for state_set in state_sets]
@ -458,7 +458,7 @@ async def _reverse_topological_power_sort(
The sorted list The sorted list
""" """
graph = {} # type: Dict[str, Set[str]] graph: Dict[str, Set[str]] = {}
for idx, event_id in enumerate(event_ids, start=1): for idx, event_id in enumerate(event_ids, start=1):
await _add_event_and_auth_chain_to_graph( await _add_event_and_auth_chain_to_graph(
graph, room_id, event_id, event_map, state_res_store, auth_diff graph, room_id, event_id, event_map, state_res_store, auth_diff
@ -657,7 +657,7 @@ async def _get_mainline_depth_for_event(
""" """
room_id = event.room_id room_id = event.room_id
tmp_event = event # type: Optional[EventBase] tmp_event: Optional[EventBase] = event
# We do an iterative search, replacing `event with the power level in its # We do an iterative search, replacing `event with the power level in its
# auth events (if any) # auth events (if any)
@ -767,7 +767,7 @@ def lexicographical_topological_sort(
# outgoing edges, c.f. # outgoing edges, c.f.
# https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm # https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
outdegree_map = graph outdegree_map = graph
reverse_graph = {} # type: Dict[str, Set[str]] reverse_graph: Dict[str, Set[str]] = {}
# Lists of nodes with zero out degree. Is actually a tuple of # Lists of nodes with zero out degree. Is actually a tuple of
# `(key(node), node)` so that sorting does the right thing # `(key(node), node)` so that sorting does the right thing

View File

@ -32,9 +32,9 @@ class EventSources:
} }
def __init__(self, hs): def __init__(self, hs):
self.sources = { self.sources: Dict[str, Any] = {
name: cls(hs) for name, cls in EventSources.SOURCE_TYPES.items() name: cls(hs) for name, cls in EventSources.SOURCE_TYPES.items()
} # type: Dict[str, Any] }
self.store = hs.get_datastore() self.store = hs.get_datastore()
def get_current_token(self) -> StreamToken: def get_current_token(self) -> StreamToken:

View File

@ -210,7 +210,7 @@ class DomainSpecificString(metaclass=abc.ABCMeta):
'domain' : The domain part of the name 'domain' : The domain part of the name
""" """
SIGIL = abc.abstractproperty() # type: str # type: ignore SIGIL: str = abc.abstractproperty() # type: ignore
localpart = attr.ib(type=str) localpart = attr.ib(type=str)
domain = attr.ib(type=str) domain = attr.ib(type=str)
@ -304,7 +304,7 @@ class GroupID(DomainSpecificString):
@classmethod @classmethod
def from_string(cls: Type[DS], s: str) -> DS: def from_string(cls: Type[DS], s: str) -> DS:
group_id = super().from_string(s) # type: DS # type: ignore group_id: DS = super().from_string(s) # type: ignore
if not group_id.localpart: if not group_id.localpart:
raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM) raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM)
@ -600,7 +600,7 @@ class StreamToken:
groups_key = attr.ib(type=int) groups_key = attr.ib(type=int)
_SEPARATOR = "_" _SEPARATOR = "_"
START = None # type: StreamToken START: "StreamToken"
@classmethod @classmethod
async def from_string(cls, store: "DataStore", string: str) -> "StreamToken": async def from_string(cls, store: "DataStore", string: str) -> "StreamToken":

View File

@ -90,7 +90,7 @@ async def filter_events_for_client(
AccountDataTypes.IGNORED_USER_LIST, user_id AccountDataTypes.IGNORED_USER_LIST, user_id
) )
ignore_list = frozenset() # type: FrozenSet[str] ignore_list: FrozenSet[str] = frozenset()
if ignore_dict_content: if ignore_dict_content:
ignored_users_dict = ignore_dict_content.get("ignored_users", {}) ignored_users_dict = ignore_dict_content.get("ignored_users", {})
if isinstance(ignored_users_dict, dict): if isinstance(ignored_users_dict, dict):