Fixup synapse.rest to pass mypy (#6732)

This commit is contained in:
Erik Johnston 2020-01-20 17:38:21 +00:00 committed by GitHub
parent 74b74462f1
commit b0a66ab83c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 56 additions and 35 deletions

1
changelog.d/6732.misc Normal file
View File

@ -0,0 +1 @@
Fixup `synapse.rest` to pass mypy.

View File

@ -66,3 +66,12 @@ ignore_missing_imports = True
[mypy-sentry_sdk] [mypy-sentry_sdk]
ignore_missing_imports = True ignore_missing_imports = True
[mypy-PIL.*]
ignore_missing_imports = True
[mypy-lxml]
ignore_missing_imports = True
[mypy-jwt.*]
ignore_missing_imports = True

View File

@ -338,21 +338,22 @@ class UserRegisterServlet(RestServlet):
got_mac = body["mac"] got_mac = body["mac"]
want_mac = hmac.new( want_mac_builder = hmac.new(
key=self.hs.config.registration_shared_secret.encode(), key=self.hs.config.registration_shared_secret.encode(),
digestmod=hashlib.sha1, digestmod=hashlib.sha1,
) )
want_mac.update(nonce.encode("utf8")) want_mac_builder.update(nonce.encode("utf8"))
want_mac.update(b"\x00") want_mac_builder.update(b"\x00")
want_mac.update(username) want_mac_builder.update(username)
want_mac.update(b"\x00") want_mac_builder.update(b"\x00")
want_mac.update(password) want_mac_builder.update(password)
want_mac.update(b"\x00") want_mac_builder.update(b"\x00")
want_mac.update(b"admin" if admin else b"notadmin") want_mac_builder.update(b"admin" if admin else b"notadmin")
if user_type: if user_type:
want_mac.update(b"\x00") want_mac_builder.update(b"\x00")
want_mac.update(user_type.encode("utf8")) want_mac_builder.update(user_type.encode("utf8"))
want_mac = want_mac.hexdigest()
want_mac = want_mac_builder.hexdigest()
if not hmac.compare_digest(want_mac.encode("ascii"), got_mac.encode("ascii")): if not hmac.compare_digest(want_mac.encode("ascii"), got_mac.encode("ascii")):
raise SynapseError(403, "HMAC incorrect") raise SynapseError(403, "HMAC incorrect")

View File

@ -514,7 +514,7 @@ class CasTicketServlet(RestServlet):
if user is None: if user is None:
raise Exception("CAS response does not contain user") raise Exception("CAS response does not contain user")
except Exception: except Exception:
logger.error("Error parsing CAS response", exc_info=1) logger.exception("Error parsing CAS response")
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not success: if not success:
raise LoginError( raise LoginError(

View File

@ -16,6 +16,7 @@
""" This module contains REST servlets to do with rooms: /rooms/<paths> """ """ This module contains REST servlets to do with rooms: /rooms/<paths> """
import logging import logging
from typing import List, Optional
from six.moves.urllib import parse as urlparse from six.moves.urllib import parse as urlparse
@ -207,7 +208,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
requester, event_dict, txn_id=txn_id requester, event_dict, txn_id=txn_id
) )
ret = {} ret = {} # type: dict
if event: if event:
set_tag("event_id", event.event_id) set_tag("event_id", event.event_id)
ret = {"event_id": event.event_id} ret = {"event_id": event.event_id}
@ -285,7 +286,7 @@ class JoinRoomAliasServlet(TransactionRestServlet):
try: try:
remote_room_hosts = [ remote_room_hosts = [
x.decode("ascii") for x in request.args[b"server_name"] x.decode("ascii") for x in request.args[b"server_name"]
] ] # type: Optional[List[str]]
except Exception: except Exception:
remote_room_hosts = None remote_room_hosts = None
elif RoomAlias.is_valid(room_identifier): elif RoomAlias.is_valid(room_identifier):
@ -375,7 +376,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
server = parse_string(request, "server", default=None) server = parse_string(request, "server", default=None)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
limit = int(content.get("limit", 100)) limit = int(content.get("limit", 100)) # type: Optional[int]
since_token = content.get("since", None) since_token = content.get("since", None)
search_filter = content.get("filter", None) search_filter = content.get("filter", None)
@ -504,11 +505,16 @@ class RoomMessageListRestServlet(RestServlet):
filter_bytes = parse_string(request, b"filter", encoding=None) filter_bytes = parse_string(request, b"filter", encoding=None)
if filter_bytes: if filter_bytes:
filter_json = urlparse.unquote(filter_bytes.decode("UTF-8")) filter_json = urlparse.unquote(filter_bytes.decode("UTF-8"))
event_filter = Filter(json.loads(filter_json)) event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
if event_filter.filter_json.get("event_format", "client") == "federation": if (
event_filter
and event_filter.filter_json.get("event_format", "client")
== "federation"
):
as_client_event = False as_client_event = False
else: else:
event_filter = None event_filter = None
msgs = await self.pagination_handler.get_messages( msgs = await self.pagination_handler.get_messages(
room_id=room_id, room_id=room_id,
requester=requester, requester=requester,
@ -611,7 +617,7 @@ class RoomEventContextServlet(RestServlet):
filter_bytes = parse_string(request, "filter") filter_bytes = parse_string(request, "filter")
if filter_bytes: if filter_bytes:
filter_json = urlparse.unquote(filter_bytes) filter_json = urlparse.unquote(filter_bytes)
event_filter = Filter(json.loads(filter_json)) event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
else: else:
event_filter = None event_filter = None

View File

@ -21,6 +21,7 @@ from typing import List, Union
from six import string_types from six import string_types
import synapse import synapse
import synapse.api.auth
import synapse.types import synapse.types
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import ( from synapse.api.errors import (
@ -405,7 +406,7 @@ class RegisterRestServlet(RestServlet):
return ret return ret
elif kind != b"user": elif kind != b"user":
raise UnrecognizedRequestError( raise UnrecognizedRequestError(
"Do not understand membership kind: %s" % (kind,) "Do not understand membership kind: %s" % (kind.decode("utf8"),)
) )
# we do basic sanity checks here because the auth layer will store these # we do basic sanity checks here because the auth layer will store these

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Tuple
from synapse.http import servlet from synapse.http import servlet
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
@ -60,7 +61,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
sender_user_id, message_type, content["messages"] sender_user_id, message_type, content["messages"]
) )
response = (200, {}) response = (200, {}) # type: Tuple[int, dict]
return response return response

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict, Set
from canonicaljson import encode_canonical_json, json from canonicaljson import encode_canonical_json, json
from signedjson.sign import sign_json from signedjson.sign import sign_json
@ -103,7 +104,7 @@ class RemoteKey(DirectServeResource):
async def _async_render_GET(self, request): async def _async_render_GET(self, request):
if len(request.postpath) == 1: if len(request.postpath) == 1:
(server,) = request.postpath (server,) = request.postpath
query = {server.decode("ascii"): {}} query = {server.decode("ascii"): {}} # type: dict
elif len(request.postpath) == 2: elif len(request.postpath) == 2:
server, key_id = request.postpath server, key_id = request.postpath
minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts") minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts")
@ -148,7 +149,7 @@ class RemoteKey(DirectServeResource):
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
cache_misses = dict() cache_misses = dict() # type: Dict[str, Set[str]]
for (server_name, key_id, from_server), results in cached.items(): for (server_name, key_id, from_server), results in cached.items():
results = [(result["ts_added_ms"], result) for result in results] results = [(result["ts_added_ms"], result) for result in results]

View File

@ -18,6 +18,7 @@ import errno
import logging import logging
import os import os
import shutil import shutil
from typing import Dict, Tuple
from six import iteritems from six import iteritems
@ -605,7 +606,7 @@ class MediaRepository(object):
# We deduplicate the thumbnail sizes by ignoring the cropped versions if # We deduplicate the thumbnail sizes by ignoring the cropped versions if
# they have the same dimensions of a scaled one. # they have the same dimensions of a scaled one.
thumbnails = {} thumbnails = {} # type: Dict[Tuple[int, int, str], str]
for r_width, r_height, r_method, r_type in requirements: for r_width, r_height, r_method, r_type in requirements:
if r_method == "crop": if r_method == "crop":
thumbnails.setdefault((r_width, r_height, r_type), r_method) thumbnails.setdefault((r_width, r_height, r_type), r_method)

View File

@ -23,6 +23,7 @@ import re
import shutil import shutil
import sys import sys
import traceback import traceback
from typing import Dict, Optional
import six import six
from six import string_types from six import string_types
@ -237,8 +238,8 @@ class PreviewUrlResource(DirectServeResource):
# If we don't find a match, we'll look at the HTTP Content-Type, and # If we don't find a match, we'll look at the HTTP Content-Type, and
# if that doesn't exist, we'll fall back to UTF-8. # if that doesn't exist, we'll fall back to UTF-8.
if not encoding: if not encoding:
match = _content_type_match.match(media_info["media_type"]) content_match = _content_type_match.match(media_info["media_type"])
encoding = match.group(1) if match else "utf-8" encoding = content_match.group(1) if content_match else "utf-8"
og = decode_and_calc_og(body, media_info["uri"], encoding) og = decode_and_calc_og(body, media_info["uri"], encoding)
@ -518,7 +519,7 @@ def _calc_og(tree, media_uri):
# "og:video:height" : "720", # "og:video:height" : "720",
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3", # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
og = {} og = {} # type: Dict[str, Optional[str]]
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
if "content" in tag.attrib: if "content" in tag.attrib:
# if we've got more than 50 tags, someone is taking the piss # if we've got more than 50 tags, someone is taking the piss

View File

@ -296,8 +296,8 @@ class ThumbnailResource(DirectServeResource):
d_h = desired_height d_h = desired_height
if desired_method.lower() == "crop": if desired_method.lower() == "crop":
info_list = [] crop_info_list = []
info_list2 = [] crop_info_list2 = []
for info in thumbnail_infos: for info in thumbnail_infos:
t_w = info["thumbnail_width"] t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"] t_h = info["thumbnail_height"]
@ -309,7 +309,7 @@ class ThumbnailResource(DirectServeResource):
type_quality = desired_type != info["thumbnail_type"] type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"] length_quality = info["thumbnail_length"]
if t_w >= d_w or t_h >= d_h: if t_w >= d_w or t_h >= d_h:
info_list.append( crop_info_list.append(
( (
aspect_quality, aspect_quality,
min_quality, min_quality,
@ -320,7 +320,7 @@ class ThumbnailResource(DirectServeResource):
) )
) )
else: else:
info_list2.append( crop_info_list2.append(
( (
aspect_quality, aspect_quality,
min_quality, min_quality,
@ -330,10 +330,10 @@ class ThumbnailResource(DirectServeResource):
info, info,
) )
) )
if info_list: if crop_info_list:
return min(info_list)[-1] return min(crop_info_list2)[-1]
else: else:
return min(info_list2)[-1] return min(crop_info_list2)[-1]
else: else:
info_list = [] info_list = []
info_list2 = [] info_list2 = []

View File

@ -183,8 +183,7 @@ commands = mypy \
synapse/logging/ \ synapse/logging/ \
synapse/module_api \ synapse/module_api \
synapse/replication \ synapse/replication \
synapse/rest/consent \ synapse/rest \
synapse/rest/saml2 \
synapse/spam_checker_api \ synapse/spam_checker_api \
synapse/storage/engines \ synapse/storage/engines \
synapse/streams synapse/streams