Add type hints to tests/rest. (#12208)

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
This commit is contained in:
Dirk Klimpel 2022-03-11 13:42:22 +01:00 committed by GitHub
parent e10a2fe0c2
commit 32c828d0f7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 129 additions and 85 deletions

View file

@ -16,16 +16,21 @@ import base64
import json
import os
import re
from typing import Any, Dict, Optional, Sequence, Tuple, Type
from urllib.parse import urlencode
from twisted.internet._resolver import HostResolution
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.error import DNSLookupError
from twisted.test.proto_helpers import AccumulatingProtocol
from twisted.internet.interfaces import IAddress, IResolutionReceiver
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor
from synapse.config.oembed import OEmbedEndpointConfig
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.rest.media.v1.preview_url_resource import IMAGE_CACHE_EXPIRY_MS
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from synapse.util.stringutils import parse_and_validate_mxc_uri
from tests import unittest
@ -52,7 +57,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
b"</head></html>"
)
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["url_preview_enabled"] = True
@ -113,22 +118,22 @@ class URLPreviewTests(unittest.HomeserverTestCase):
return hs
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.media_repo = hs.get_media_repository_resource()
self.preview_url = self.media_repo.children[b"preview_url"]
self.lookups = {}
self.lookups: Dict[str, Any] = {}
class Resolver:
def resolveHostName(
_self,
resolutionReceiver,
hostName,
portNumber=0,
addressTypes=None,
transportSemantics="TCP",
):
resolutionReceiver: IResolutionReceiver,
hostName: str,
portNumber: int = 0,
addressTypes: Optional[Sequence[Type[IAddress]]] = None,
transportSemantics: str = "TCP",
) -> IResolutionReceiver:
resolution = HostResolution(hostName)
resolutionReceiver.resolutionBegan(resolution)
@ -140,9 +145,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
resolutionReceiver.resolutionComplete()
return resolutionReceiver
self.reactor.nameResolver = Resolver()
self.reactor.nameResolver = Resolver() # type: ignore[assignment]
def create_test_resource(self):
def create_test_resource(self) -> MediaRepositoryResource:
return self.hs.get_media_repository_resource()
def _assert_small_png(self, json_body: JsonDict) -> None:
@ -153,7 +158,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(json_body["og:image:type"], "image/png")
self.assertEqual(json_body["matrix:image:size"], 67)
def test_cache_returns_correct_type(self):
def test_cache_returns_correct_type(self) -> None:
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
channel = self.make_request(
@ -207,7 +212,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
)
def test_non_ascii_preview_httpequiv(self):
def test_non_ascii_preview_httpequiv(self) -> None:
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = (
@ -243,7 +248,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
def test_video_rejected(self):
def test_video_rejected(self) -> None:
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = b"anything"
@ -279,7 +284,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
def test_audio_rejected(self):
def test_audio_rejected(self) -> None:
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = b"anything"
@ -315,7 +320,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
def test_non_ascii_preview_content_type(self):
def test_non_ascii_preview_content_type(self) -> None:
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = (
@ -350,7 +355,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
def test_overlong_title(self):
def test_overlong_title(self) -> None:
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = (
@ -387,7 +392,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
# We should only see the `og:description` field, as `title` is too long and should be stripped out
self.assertCountEqual(["og:description"], res.keys())
def test_ipaddr(self):
def test_ipaddr(self) -> None:
"""
IP addresses can be previewed directly.
"""
@ -417,7 +422,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
)
def test_blacklisted_ip_specific(self):
def test_blacklisted_ip_specific(self) -> None:
"""
Blacklisted IP addresses, found via DNS, are not spidered.
"""
@ -438,7 +443,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
def test_blacklisted_ip_range(self):
def test_blacklisted_ip_range(self) -> None:
"""
Blacklisted IP ranges, IPs found over DNS, are not spidered.
"""
@ -457,7 +462,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
def test_blacklisted_ip_specific_direct(self):
def test_blacklisted_ip_specific_direct(self) -> None:
"""
Blacklisted IP addresses, accessed directly, are not spidered.
"""
@ -476,7 +481,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 403)
def test_blacklisted_ip_range_direct(self):
def test_blacklisted_ip_range_direct(self) -> None:
"""
Blacklisted IP ranges, accessed directly, are not spidered.
"""
@ -493,7 +498,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
def test_blacklisted_ip_range_whitelisted_ip(self):
def test_blacklisted_ip_range_whitelisted_ip(self) -> None:
"""
Blacklisted but then subsequently whitelisted IP addresses can be
spidered.
@ -526,7 +531,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
)
def test_blacklisted_ip_with_external_ip(self):
def test_blacklisted_ip_with_external_ip(self) -> None:
"""
If a hostname resolves a blacklisted IP, even if there's a
non-blacklisted one, it will be rejected.
@ -549,7 +554,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
def test_blacklisted_ipv6_specific(self):
def test_blacklisted_ipv6_specific(self) -> None:
"""
Blacklisted IP addresses, found via DNS, are not spidered.
"""
@ -572,7 +577,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
def test_blacklisted_ipv6_range(self):
def test_blacklisted_ipv6_range(self) -> None:
"""
Blacklisted IP ranges, IPs found over DNS, are not spidered.
"""
@ -591,7 +596,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
def test_OPTIONS(self):
def test_OPTIONS(self) -> None:
"""
OPTIONS returns the OPTIONS.
"""
@ -601,7 +606,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {})
def test_accept_language_config_option(self):
def test_accept_language_config_option(self) -> None:
"""
Accept-Language header is sent to the remote server
"""
@ -652,7 +657,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
server.data,
)
def test_data_url(self):
def test_data_url(self) -> None:
"""
Requesting to preview a data URL is not supported.
"""
@ -675,7 +680,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 500)
def test_inline_data_url(self):
def test_inline_data_url(self) -> None:
"""
An inline image (as a data URL) should be parsed properly.
"""
@ -712,7 +717,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self._assert_small_png(channel.json_body)
def test_oembed_photo(self):
def test_oembed_photo(self) -> None:
"""Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL."""
self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
@ -771,7 +776,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(body["og:url"], "http://twitter.com/matrixdotorg/status/12345")
self._assert_small_png(body)
def test_oembed_rich(self):
def test_oembed_rich(self) -> None:
"""Test an oEmbed endpoint which returns HTML content via the 'rich' type."""
self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
@ -817,7 +822,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
def test_oembed_format(self):
def test_oembed_format(self) -> None:
"""Test an oEmbed endpoint which requires the format in the URL."""
self.lookups["www.hulu.com"] = [(IPv4Address, "10.1.2.3")]
@ -866,7 +871,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
def test_oembed_autodiscovery(self):
def test_oembed_autodiscovery(self) -> None:
"""
Autodiscovery works by finding the link in the HTML response and then requesting an oEmbed URL.
1. Request a preview of a URL which is not known to the oEmbed code.
@ -962,7 +967,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
self._assert_small_png(body)
def _download_image(self):
def _download_image(self) -> Tuple[str, str]:
"""Downloads an image into the URL cache.
Returns:
A (host, media_id) tuple representing the MXC URI of the image.
@ -995,7 +1000,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertIsNone(_port)
return host, media_id
def test_storage_providers_exclude_files(self):
def test_storage_providers_exclude_files(self) -> None:
"""Test that files are not stored in or fetched from storage providers."""
host, media_id = self._download_image()
@ -1037,7 +1042,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"URL cache file was unexpectedly retrieved from a storage provider",
)
def test_storage_providers_exclude_thumbnails(self):
def test_storage_providers_exclude_thumbnails(self) -> None:
"""Test that thumbnails are not stored in or fetched from storage providers."""
host, media_id = self._download_image()
@ -1090,7 +1095,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"URL cache thumbnail was unexpectedly retrieved from a storage provider",
)
def test_cache_expiry(self):
def test_cache_expiry(self) -> None:
"""Test that URL cache files and thumbnails are cleaned up properly on expiry."""
self.preview_url.clock = MockClock()