Misc typing fixes for tests, part 1 of N (#11323)

* Annotate HomeserverTestCase.servlets
* Correct annotation of federation_auth_origin
* Use AnyStr custom_headers instead of a Union

This allows (str, str) and (bytes, bytes).
This disallows (str, bytes) and (bytes, str)

* DomainSpecificString.SIGIL is a ClassVar
This commit is contained in:
David Robertson 2021-11-12 15:50:54 +00:00 committed by GitHub
parent 95547e5300
commit 4c96ce396e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 53 additions and 29 deletions

1
changelog.d/11323.misc Normal file
View File

@ -0,0 +1 @@
Improve type annotations in Synapse's test suite.

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Callable
from synapse.http.server import HttpServer, JsonResource from synapse.http.server import HttpServer, JsonResource
from synapse.rest import admin from synapse.rest import admin
@ -62,6 +62,8 @@ from synapse.rest.client import (
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
RegisterServletsFunc = Callable[["HomeServer", HttpServer], None]
class ClientRestResource(JsonResource): class ClientRestResource(JsonResource):
"""Matrix Client API REST resource. """Matrix Client API REST resource.

View File

@ -19,6 +19,7 @@ from collections import namedtuple
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
ClassVar,
Dict, Dict,
Mapping, Mapping,
MutableMapping, MutableMapping,
@ -219,7 +220,7 @@ class DomainSpecificString(metaclass=abc.ABCMeta):
'domain' : The domain part of the name 'domain' : The domain part of the name
""" """
SIGIL: str = abc.abstractproperty() # type: ignore SIGIL: ClassVar[str] = abc.abstractproperty() # type: ignore
localpart = attr.ib(type=str) localpart = attr.ib(type=str)
domain = attr.ib(type=str) domain = attr.ib(type=str)

View File

@ -12,13 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from twisted.internet.protocol import Protocol from twisted.internet.protocol import Protocol
from twisted.web.resource import Resource from twisted.web.resource import Resource
from synapse.app.generic_worker import GenericWorkerServer from synapse.app.generic_worker import GenericWorkerServer
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest, SynapseSite from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.client import ReplicationDataHandler from synapse.replication.tcp.client import ReplicationDataHandler
@ -220,8 +219,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
unlike `BaseStreamTestCase`. unlike `BaseStreamTestCase`.
""" """
servlets: List[Callable[[HomeServer, JsonResource], None]] = []
def setUp(self): def setUp(self):
super().setUp() super().setUp()

View File

@ -19,7 +19,17 @@ import json
import re import re
import time import time
import urllib.parse import urllib.parse
from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union from typing import (
Any,
AnyStr,
Dict,
Iterable,
Mapping,
MutableMapping,
Optional,
Tuple,
Union,
)
from unittest.mock import patch from unittest.mock import patch
import attr import attr
@ -53,9 +63,7 @@ class RestHelper:
tok: Optional[str] = None, tok: Optional[str] = None,
expect_code: int = 200, expect_code: int = 200,
extra_content: Optional[Dict] = None, extra_content: Optional[Dict] = None,
custom_headers: Optional[ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
) -> str: ) -> str:
""" """
Create a room. Create a room.
@ -227,9 +235,7 @@ class RestHelper:
txn_id=None, txn_id=None,
tok=None, tok=None,
expect_code=200, expect_code=200,
custom_headers: Optional[ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
): ):
if body is None: if body is None:
body = "body_text_here" body = "body_text_here"
@ -418,7 +424,7 @@ class RestHelper:
path, path,
content=image_data, content=image_data,
access_token=tok, access_token=tok,
custom_headers=[(b"Content-Length", str(image_length))], custom_headers=[("Content-Length", str(image_length))],
) )
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % ( assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (

View File

@ -16,7 +16,16 @@ import json
import logging import logging
from collections import deque from collections import deque
from io import SEEK_END, BytesIO from io import SEEK_END, BytesIO
from typing import Callable, Dict, Iterable, MutableMapping, Optional, Tuple, Union from typing import (
AnyStr,
Callable,
Dict,
Iterable,
MutableMapping,
Optional,
Tuple,
Union,
)
import attr import attr
from typing_extensions import Deque from typing_extensions import Deque
@ -222,9 +231,7 @@ def make_request(
federation_auth_origin: Optional[bytes] = None, federation_auth_origin: Optional[bytes] = None,
content_is_form: bool = False, content_is_form: bool = False,
await_result: bool = True, await_result: bool = True,
custom_headers: Optional[ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
client_ip: str = "127.0.0.1", client_ip: str = "127.0.0.1",
) -> FakeChannel: ) -> FakeChannel:
""" """

View File

@ -20,7 +20,20 @@ import inspect
import logging import logging
import secrets import secrets
import time import time
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union from typing import (
Any,
AnyStr,
Callable,
ClassVar,
Dict,
Iterable,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from canonicaljson import json from canonicaljson import json
@ -45,6 +58,7 @@ from synapse.logging.context import (
current_context, current_context,
set_current_context, set_current_context,
) )
from synapse.rest import RegisterServletsFunc
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict, UserID, create_requester from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock from synapse.util import Clock
@ -204,15 +218,15 @@ class HomeserverTestCase(TestCase):
config dict. config dict.
Attributes: Attributes:
servlets (list[function]): List of servlet registration function. servlets: List of servlet registration function.
user_id (str): The user ID to assume if auth is hijacked. user_id (str): The user ID to assume if auth is hijacked.
hijack_auth (bool): Whether to hijack auth to return the user specified hijack_auth (bool): Whether to hijack auth to return the user specified
in user_id. in user_id.
""" """
servlets = []
hijack_auth = True hijack_auth = True
needs_threadpool = False needs_threadpool = False
servlets: ClassVar[List[RegisterServletsFunc]] = []
def __init__(self, methodName, *args, **kwargs): def __init__(self, methodName, *args, **kwargs):
super().__init__(methodName, *args, **kwargs) super().__init__(methodName, *args, **kwargs)
@ -405,12 +419,10 @@ class HomeserverTestCase(TestCase):
access_token: Optional[str] = None, access_token: Optional[str] = None,
request: Type[T] = SynapseRequest, request: Type[T] = SynapseRequest,
shorthand: bool = True, shorthand: bool = True,
federation_auth_origin: str = None, federation_auth_origin: Optional[bytes] = None,
content_is_form: bool = False, content_is_form: bool = False,
await_result: bool = True, await_result: bool = True,
custom_headers: Optional[ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
client_ip: str = "127.0.0.1", client_ip: str = "127.0.0.1",
) -> FakeChannel: ) -> FakeChannel:
""" """
@ -425,7 +437,7 @@ class HomeserverTestCase(TestCase):
a dict. a dict.
shorthand: Whether to try and be helpful and prefix the given URL shorthand: Whether to try and be helpful and prefix the given URL
with the usual REST API path, if it doesn't contain it. with the usual REST API path, if it doesn't contain it.
federation_auth_origin (bytes|None): if set to not-None, we will add a fake federation_auth_origin: if set to not-None, we will add a fake
Authorization header pretenting to be the given server name. Authorization header pretenting to be the given server name.
content_is_form: Whether the content is URL encoded form data. Adds the content_is_form: Whether the content is URL encoded form data. Adds the
'Content-Type': 'application/x-www-form-urlencoded' header. 'Content-Type': 'application/x-www-form-urlencoded' header.
@ -639,9 +651,7 @@ class HomeserverTestCase(TestCase):
username, username,
password, password,
device_id=None, device_id=None,
custom_headers: Optional[ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
): ):
""" """
Log in a user, and get an access token. Requires the Login API be Log in a user, and get an access token. Requires the Login API be