Use inline type hints in handlers/ and rest/. (#10382)

This commit is contained in:
Jonathan de Jong 2021-07-16 19:22:36 +02:00 committed by GitHub
parent 36dc15412d
commit 98aec1cc9d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
43 changed files with 212 additions and 215 deletions

View file

@ -402,9 +402,9 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
# Get the room ID from the identifier.
try:
remote_room_hosts = [
remote_room_hosts: Optional[List[str]] = [
x.decode("ascii") for x in request.args[b"server_name"]
] # type: Optional[List[str]]
]
except Exception:
remote_room_hosts = None
room_id, remote_room_hosts = await self.resolve_room_id(
@ -659,9 +659,7 @@ class RoomEventContextServlet(RestServlet):
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter = Filter(
json_decoder.decode(filter_json)
) # type: Optional[Filter]
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
else:
event_filter = None

View file

@ -357,7 +357,7 @@ class UserRegisterServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
self.auth_handler = hs.get_auth_handler()
self.reactor = hs.get_reactor()
self.nonces = {} # type: Dict[str, int]
self.nonces: Dict[str, int] = {}
self.hs = hs
def _clear_old_nonces(self):

View file

@ -121,7 +121,7 @@ class LoginRestServlet(RestServlet):
flows.append({"type": LoginRestServlet.CAS_TYPE})
if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
sso_flow = {
sso_flow: JsonDict = {
"type": LoginRestServlet.SSO_TYPE,
"identity_providers": [
_get_auth_flow_dict_for_idp(
@ -129,7 +129,7 @@ class LoginRestServlet(RestServlet):
)
for idp in self._sso_handler.get_identity_providers().values()
],
} # type: JsonDict
}
if self._msc2858_enabled:
# backwards-compatibility support for clients which don't
@ -447,7 +447,7 @@ def _get_auth_flow_dict_for_idp(
use_unstable_brands: whether we should use brand identifiers suitable
for the unstable API
"""
e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict
e: JsonDict = {"id": idp.idp_id, "name": idp.idp_name}
if idp.idp_icon:
e["icon"] = idp.idp_icon
if idp.idp_brand:
@ -561,7 +561,7 @@ class SsoRedirectServlet(RestServlet):
finish_request(request)
return
args = request.args # type: Dict[bytes, List[bytes]] # type: ignore
args: Dict[bytes, List[bytes]] = request.args # type: ignore
client_redirect_url = parse_bytes_from_args(args, "redirectUrl", required=True)
sso_url = await self._sso_handler.handle_redirect_request(
request,

View file

@ -783,7 +783,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
server = parse_string(request, "server", default=None)
content = parse_json_object_from_request(request)
limit = int(content.get("limit", 100)) # type: Optional[int]
limit: Optional[int] = int(content.get("limit", 100))
since_token = content.get("since", None)
search_filter = content.get("filter", None)
@ -929,9 +929,7 @@ class RoomMessageListRestServlet(RestServlet):
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter = Filter(
json_decoder.decode(filter_json)
) # type: Optional[Filter]
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
if (
event_filter
and event_filter.filter_json.get("event_format", "client")
@ -1044,9 +1042,7 @@ class RoomEventContextServlet(RestServlet):
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter = Filter(
json_decoder.decode(filter_json)
) # type: Optional[Filter]
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
else:
event_filter = None

View file

@ -59,7 +59,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
requester, message_type, content["messages"]
)
response = (200, {}) # type: Tuple[int, dict]
response: Tuple[int, dict] = (200, {})
return response

View file

@ -117,7 +117,7 @@ class ConsentResource(DirectServeHtmlResource):
has_consented = False
public_version = username == ""
if not public_version:
args = request.args # type: Dict[bytes, List[bytes]]
args: Dict[bytes, List[bytes]] = request.args
userhmac_bytes = parse_bytes_from_args(args, "h", required=True)
self._check_hash(username, userhmac_bytes)
@ -154,7 +154,7 @@ class ConsentResource(DirectServeHtmlResource):
"""
version = parse_string(request, "v", required=True)
username = parse_string(request, "u", required=True)
args = request.args # type: Dict[bytes, List[bytes]]
args: Dict[bytes, List[bytes]] = request.args
userhmac = parse_bytes_from_args(args, "h", required=True)
self._check_hash(username, userhmac)

View file

@ -97,7 +97,7 @@ class RemoteKey(DirectServeJsonResource):
async def _async_render_GET(self, request):
if len(request.postpath) == 1:
(server,) = request.postpath
query = {server.decode("ascii"): {}} # type: dict
query: dict = {server.decode("ascii"): {}}
elif len(request.postpath) == 2:
server, key_id = request.postpath
minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts")
@ -141,7 +141,7 @@ class RemoteKey(DirectServeJsonResource):
time_now_ms = self.clock.time_msec()
# Note that the value is unused.
cache_misses = {} # type: Dict[str, Dict[str, int]]
cache_misses: Dict[str, Dict[str, int]] = {}
for (server_name, key_id, _), results in cached.items():
results = [(result["ts_added_ms"], result) for result in results]

View file

@ -49,7 +49,7 @@ TEXT_CONTENT_TYPES = [
def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
try:
# The type on postpath seems incorrect in Twisted 21.2.0.
postpath = request.postpath # type: List[bytes] # type: ignore
postpath: List[bytes] = request.postpath # type: ignore
assert postpath
# This allows users to append e.g. /test.png to the URL. Useful for

View file

@ -78,16 +78,16 @@ class MediaRepository:
Thumbnailer.set_limits(self.max_image_pixels)
self.primary_base_path = hs.config.media_store_path # type: str
self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths
self.primary_base_path: str = hs.config.media_store_path
self.filepaths: MediaFilePaths = MediaFilePaths(self.primary_base_path)
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements
self.remote_media_linearizer = Linearizer(name="media_remote")
self.recently_accessed_remotes = set() # type: Set[Tuple[str, str]]
self.recently_accessed_locals = set() # type: Set[str]
self.recently_accessed_remotes: Set[Tuple[str, str]] = set()
self.recently_accessed_locals: Set[str] = set()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
@ -711,7 +711,7 @@ class MediaRepository:
# We deduplicate the thumbnail sizes by ignoring the cropped versions if
# they have the same dimensions of a scaled one.
thumbnails = {} # type: Dict[Tuple[int, int, str], str]
thumbnails: Dict[Tuple[int, int, str], str] = {}
for r_width, r_height, r_method, r_type in requirements:
if r_method == "crop":
thumbnails.setdefault((r_width, r_height, r_type), r_method)

View file

@ -191,7 +191,7 @@ class MediaStorage:
for provider in self.storage_providers:
for path in paths:
res = await provider.fetch(path, file_info) # type: Any
res: Any = await provider.fetch(path, file_info)
if res:
logger.debug("Streaming %s from %s", path, provider)
return res
@ -233,7 +233,7 @@ class MediaStorage:
os.makedirs(dirname)
for provider in self.storage_providers:
res = await provider.fetch(path, file_info) # type: Any
res: Any = await provider.fetch(path, file_info)
if res:
with res:
consumer = BackgroundFileConsumer(

View file

@ -169,12 +169,12 @@ class PreviewUrlResource(DirectServeJsonResource):
# memory cache mapping urls to an ObservableDeferred returning
# JSON-encoded OG metadata
self._cache = ExpiringCache(
self._cache: ExpiringCache[str, ObservableDeferred] = ExpiringCache(
cache_name="url_previews",
clock=self.clock,
# don't spider URLs more often than once an hour
expiry_ms=ONE_HOUR,
) # type: ExpiringCache[str, ObservableDeferred]
)
if self._worker_run_media_background_jobs:
self._cleaner_loop = self.clock.looping_call(
@ -460,7 +460,7 @@ class PreviewUrlResource(DirectServeJsonResource):
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
# If this URL can be accessed via oEmbed, use that instead.
url_to_download = url # type: Optional[str]
url_to_download: Optional[str] = url
oembed_url = self._get_oembed_url(url)
if oembed_url:
# The result might be a new URL to download, or it might be HTML content.
@ -788,7 +788,7 @@ def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
# "og:video:height" : "720",
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
og = {} # type: Dict[str, Optional[str]]
og: Dict[str, Optional[str]] = {}
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
if "content" in tag.attrib:
# if we've got more than 50 tags, someone is taking the piss

View file

@ -61,11 +61,11 @@ class UploadResource(DirectServeJsonResource):
errcode=Codes.TOO_LARGE,
)
args = request.args # type: Dict[bytes, List[bytes]] # type: ignore
args: Dict[bytes, List[bytes]] = request.args # type: ignore
upload_name_bytes = parse_bytes_from_args(args, "filename")
if upload_name_bytes:
try:
upload_name = upload_name_bytes.decode("utf8") # type: Optional[str]
upload_name: Optional[str] = upload_name_bytes.decode("utf8")
except UnicodeDecodeError:
raise SynapseError(
msg="Invalid UTF-8 filename parameter: %r" % (upload_name), code=400
@ -89,7 +89,7 @@ class UploadResource(DirectServeJsonResource):
# TODO(markjh): parse content-dispostion
try:
content = request.content # type: IO # type: ignore
content: IO = request.content # type: ignore
content_uri = await self.media_repo.create_content(
media_type, upload_name, content, content_length, requester.user
)

View file

@ -118,9 +118,9 @@ class AccountDetailsResource(DirectServeHtmlResource):
use_display_name = parse_boolean(request, "use_display_name", default=False)
try:
emails_to_use = [
emails_to_use: List[str] = [
val.decode("utf-8") for val in request.args.get(b"use_email", [])
] # type: List[str]
]
except ValueError:
raise SynapseError(400, "Query parameter use_email must be utf-8")
except SynapseError as e: