Add some type hints to the tests.handlers module. (#12207)

This commit is contained in:
Patrick Cloke 2022-03-11 07:07:15 -05:00 committed by GitHub
parent bc9dff1d95
commit e10a2fe0c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 74 additions and 61 deletions

View File

@ -1 +1 @@
Add type hints to `tests/rest/client`. Add type hints to tests files.

View File

@ -1 +1 @@
Add type hints to `tests/rest`. Add type hints to tests files.

1
changelog.d/12207.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to tests files.

View File

@ -15,11 +15,15 @@
from collections import Counter from collections import Counter
from unittest.mock import Mock from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
import synapse.storage import synapse.storage
from synapse.api.constants import EventTypes, JoinRules from synapse.api.constants import EventTypes, JoinRules
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.rest.client import knock, login, room from synapse.rest.client import knock, login, room
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest from tests import unittest
@ -32,7 +36,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
knock.register_servlets, knock.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.admin_handler = hs.get_admin_handler() self.admin_handler = hs.get_admin_handler()
self.user1 = self.register_user("user1", "password") self.user1 = self.register_user("user1", "password")
@ -41,7 +45,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.user2 = self.register_user("user2", "password") self.user2 = self.register_user("user2", "password")
self.token2 = self.login("user2", "password") self.token2 = self.login("user2", "password")
def test_single_public_joined_room(self): def test_single_public_joined_room(self) -> None:
"""Test that we write *all* events for a public room""" """Test that we write *all* events for a public room"""
room_id = self.helper.create_room_as( room_id = self.helper.create_room_as(
self.user1, tok=self.token1, is_public=True self.user1, tok=self.token1, is_public=True
@ -74,7 +78,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1) self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
self.assertEqual(counter[(EventTypes.Member, self.user2)], 1) self.assertEqual(counter[(EventTypes.Member, self.user2)], 1)
def test_single_private_joined_room(self): def test_single_private_joined_room(self) -> None:
"""Tests that we correctly write state when we can't see all events in """Tests that we correctly write state when we can't see all events in
a room. a room.
""" """
@ -112,7 +116,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1) self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
self.assertEqual(counter[(EventTypes.Member, self.user2)], 1) self.assertEqual(counter[(EventTypes.Member, self.user2)], 1)
def test_single_left_room(self): def test_single_left_room(self) -> None:
"""Tests that we don't see events in the room after we leave.""" """Tests that we don't see events in the room after we leave."""
room_id = self.helper.create_room_as(self.user1, tok=self.token1) room_id = self.helper.create_room_as(self.user1, tok=self.token1)
self.helper.send(room_id, body="Hello!", tok=self.token1) self.helper.send(room_id, body="Hello!", tok=self.token1)
@ -144,7 +148,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1) self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
self.assertEqual(counter[(EventTypes.Member, self.user2)], 2) self.assertEqual(counter[(EventTypes.Member, self.user2)], 2)
def test_single_left_rejoined_private_room(self): def test_single_left_rejoined_private_room(self) -> None:
"""Tests that see the correct events in private rooms when we """Tests that see the correct events in private rooms when we
repeatedly join and leave. repeatedly join and leave.
""" """
@ -185,7 +189,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1) self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
self.assertEqual(counter[(EventTypes.Member, self.user2)], 3) self.assertEqual(counter[(EventTypes.Member, self.user2)], 3)
def test_invite(self): def test_invite(self) -> None:
"""Tests that pending invites get handled correctly.""" """Tests that pending invites get handled correctly."""
room_id = self.helper.create_room_as(self.user1, tok=self.token1) room_id = self.helper.create_room_as(self.user1, tok=self.token1)
self.helper.send(room_id, body="Hello!", tok=self.token1) self.helper.send(room_id, body="Hello!", tok=self.token1)
@ -204,7 +208,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(args[1].content["membership"], "invite") self.assertEqual(args[1].content["membership"], "invite")
self.assertTrue(args[2]) # Assert there is at least one bit of state self.assertTrue(args[2]) # Assert there is at least one bit of state
def test_knock(self): def test_knock(self) -> None:
"""Tests that knock get handled correctly.""" """Tests that knock get handled correctly."""
# create a knockable v7 room # create a knockable v7 room
room_id = self.helper.create_room_as( room_id = self.helper.create_room_as(

View File

@ -15,8 +15,12 @@ from unittest.mock import Mock
import pymacaroons import pymacaroons
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import AuthError, ResourceLimitError from synapse.api.errors import AuthError, ResourceLimitError
from synapse.rest import admin from synapse.rest import admin
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
@ -27,7 +31,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
admin.register_servlets, admin.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.macaroon_generator = hs.get_macaroon_generator() self.macaroon_generator = hs.get_macaroon_generator()
@ -42,23 +46,23 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.user1 = self.register_user("a_user", "pass") self.user1 = self.register_user("a_user", "pass")
def test_macaroon_caveats(self): def test_macaroon_caveats(self) -> None:
token = self.macaroon_generator.generate_guest_access_token("a_user") token = self.macaroon_generator.generate_guest_access_token("a_user")
macaroon = pymacaroons.Macaroon.deserialize(token) macaroon = pymacaroons.Macaroon.deserialize(token)
def verify_gen(caveat): def verify_gen(caveat: str) -> bool:
return caveat == "gen = 1" return caveat == "gen = 1"
def verify_user(caveat): def verify_user(caveat: str) -> bool:
return caveat == "user_id = a_user" return caveat == "user_id = a_user"
def verify_type(caveat): def verify_type(caveat: str) -> bool:
return caveat == "type = access" return caveat == "type = access"
def verify_nonce(caveat): def verify_nonce(caveat: str) -> bool:
return caveat.startswith("nonce =") return caveat.startswith("nonce =")
def verify_guest(caveat): def verify_guest(caveat: str) -> bool:
return caveat == "guest = true" return caveat == "guest = true"
v = pymacaroons.Verifier() v = pymacaroons.Verifier()
@ -69,7 +73,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
v.satisfy_general(verify_guest) v.satisfy_general(verify_guest)
v.verify(macaroon, self.hs.config.key.macaroon_secret_key) v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
def test_short_term_login_token_gives_user_id(self): def test_short_term_login_token_gives_user_id(self) -> None:
token = self.macaroon_generator.generate_short_term_login_token( token = self.macaroon_generator.generate_short_term_login_token(
self.user1, "", duration_in_ms=5000 self.user1, "", duration_in_ms=5000
) )
@ -84,7 +88,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
AuthError, AuthError,
) )
def test_short_term_login_token_gives_auth_provider(self): def test_short_term_login_token_gives_auth_provider(self) -> None:
token = self.macaroon_generator.generate_short_term_login_token( token = self.macaroon_generator.generate_short_term_login_token(
self.user1, auth_provider_id="my_idp" self.user1, auth_provider_id="my_idp"
) )
@ -92,7 +96,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(self.user1, res.user_id) self.assertEqual(self.user1, res.user_id)
self.assertEqual("my_idp", res.auth_provider_id) self.assertEqual("my_idp", res.auth_provider_id)
def test_short_term_login_token_cannot_replace_user_id(self): def test_short_term_login_token_cannot_replace_user_id(self) -> None:
token = self.macaroon_generator.generate_short_term_login_token( token = self.macaroon_generator.generate_short_term_login_token(
self.user1, "", duration_in_ms=5000 self.user1, "", duration_in_ms=5000
) )
@ -112,7 +116,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
AuthError, AuthError,
) )
def test_mau_limits_disabled(self): def test_mau_limits_disabled(self) -> None:
self.auth_blocking._limit_usage_by_mau = False self.auth_blocking._limit_usage_by_mau = False
# Ensure does not throw exception # Ensure does not throw exception
self.get_success( self.get_success(
@ -127,7 +131,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
) )
) )
def test_mau_limits_exceeded_large(self): def test_mau_limits_exceeded_large(self) -> None:
self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastores().main.get_monthly_active_count = Mock( self.hs.get_datastores().main.get_monthly_active_count = Mock(
return_value=make_awaitable(self.large_number_of_users) return_value=make_awaitable(self.large_number_of_users)
@ -150,7 +154,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
ResourceLimitError, ResourceLimitError,
) )
def test_mau_limits_parity(self): def test_mau_limits_parity(self) -> None:
# Ensure we're not at the unix epoch. # Ensure we're not at the unix epoch.
self.reactor.advance(1) self.reactor.advance(1)
self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
@ -189,7 +193,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
) )
) )
def test_mau_limits_not_exceeded(self): def test_mau_limits_not_exceeded(self) -> None:
self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastores().main.get_monthly_active_count = Mock( self.hs.get_datastores().main.get_monthly_active_count = Mock(
@ -211,7 +215,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
) )
) )
def _get_macaroon(self): def _get_macaroon(self) -> pymacaroons.Macaroon:
token = self.macaroon_generator.generate_short_term_login_token( token = self.macaroon_generator.generate_short_term_login_token(
self.user1, "", duration_in_ms=5000 self.user1, "", duration_in_ms=5000
) )

View File

@ -39,7 +39,7 @@ class DeactivateAccountTestCase(HomeserverTestCase):
self.user = self.register_user("user", "pass") self.user = self.register_user("user", "pass")
self.token = self.login("user", "pass") self.token = self.login("user", "pass")
def _deactivate_my_account(self): def _deactivate_my_account(self) -> None:
""" """
Deactivates the account `self.user` using `self.token` and asserts Deactivates the account `self.user` using `self.token` and asserts
that it returns a 200 success code. that it returns a 200 success code.

View File

@ -14,9 +14,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import synapse.api.errors from typing import Optional
import synapse.handlers.device
import synapse.storage from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import NotFoundError, SynapseError
from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest from tests import unittest
@ -25,28 +30,27 @@ user2 = "@theresa:bbb"
class DeviceTestCase(unittest.HomeserverTestCase): class DeviceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None) hs = self.setup_test_homeserver("server", federation_http_client=None)
self.handler = hs.get_device_handler() self.handler = hs.get_device_handler()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
return hs return hs
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# These tests assume that it starts 1000 seconds in. # These tests assume that it starts 1000 seconds in.
self.reactor.advance(1000) self.reactor.advance(1000)
def test_device_is_created_with_invalid_name(self): def test_device_is_created_with_invalid_name(self) -> None:
self.get_failure( self.get_failure(
self.handler.check_device_registered( self.handler.check_device_registered(
user_id="@boris:foo", user_id="@boris:foo",
device_id="foo", device_id="foo",
initial_device_display_name="a" initial_device_display_name="a" * (MAX_DEVICE_DISPLAY_NAME_LEN + 1),
* (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1),
), ),
synapse.api.errors.SynapseError, SynapseError,
) )
def test_device_is_created_if_doesnt_exist(self): def test_device_is_created_if_doesnt_exist(self) -> None:
res = self.get_success( res = self.get_success(
self.handler.check_device_registered( self.handler.check_device_registered(
user_id="@boris:foo", user_id="@boris:foo",
@ -59,7 +63,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
self.assertEqual(dev["display_name"], "display name") self.assertEqual(dev["display_name"], "display name")
def test_device_is_preserved_if_exists(self): def test_device_is_preserved_if_exists(self) -> None:
res1 = self.get_success( res1 = self.get_success(
self.handler.check_device_registered( self.handler.check_device_registered(
user_id="@boris:foo", user_id="@boris:foo",
@ -81,7 +85,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco")) dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
self.assertEqual(dev["display_name"], "display name") self.assertEqual(dev["display_name"], "display name")
def test_device_id_is_made_up_if_unspecified(self): def test_device_id_is_made_up_if_unspecified(self) -> None:
device_id = self.get_success( device_id = self.get_success(
self.handler.check_device_registered( self.handler.check_device_registered(
user_id="@theresa:foo", user_id="@theresa:foo",
@ -93,7 +97,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id)) dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id))
self.assertEqual(dev["display_name"], "display") self.assertEqual(dev["display_name"], "display")
def test_get_devices_by_user(self): def test_get_devices_by_user(self) -> None:
self._record_users() self._record_users()
res = self.get_success(self.handler.get_devices_by_user(user1)) res = self.get_success(self.handler.get_devices_by_user(user1))
@ -131,7 +135,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
device_map["abc"], device_map["abc"],
) )
def test_get_device(self): def test_get_device(self) -> None:
self._record_users() self._record_users()
res = self.get_success(self.handler.get_device(user1, "abc")) res = self.get_success(self.handler.get_device(user1, "abc"))
@ -146,21 +150,19 @@ class DeviceTestCase(unittest.HomeserverTestCase):
res, res,
) )
def test_delete_device(self): def test_delete_device(self) -> None:
self._record_users() self._record_users()
# delete the device # delete the device
self.get_success(self.handler.delete_device(user1, "abc")) self.get_success(self.handler.delete_device(user1, "abc"))
# check the device was deleted # check the device was deleted
self.get_failure( self.get_failure(self.handler.get_device(user1, "abc"), NotFoundError)
self.handler.get_device(user1, "abc"), synapse.api.errors.NotFoundError
)
# we'd like to check the access token was invalidated, but that's a # we'd like to check the access token was invalidated, but that's a
# bit of a PITA. # bit of a PITA.
def test_delete_device_and_device_inbox(self): def test_delete_device_and_device_inbox(self) -> None:
self._record_users() self._record_users()
# add an device_inbox # add an device_inbox
@ -191,7 +193,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
) )
self.assertIsNone(res) self.assertIsNone(res)
def test_update_device(self): def test_update_device(self) -> None:
self._record_users() self._record_users()
update = {"display_name": "new display"} update = {"display_name": "new display"}
@ -200,32 +202,29 @@ class DeviceTestCase(unittest.HomeserverTestCase):
res = self.get_success(self.handler.get_device(user1, "abc")) res = self.get_success(self.handler.get_device(user1, "abc"))
self.assertEqual(res["display_name"], "new display") self.assertEqual(res["display_name"], "new display")
def test_update_device_too_long_display_name(self): def test_update_device_too_long_display_name(self) -> None:
"""Update a device with a display name that is invalid (too long).""" """Update a device with a display name that is invalid (too long)."""
self._record_users() self._record_users()
# Request to update a device display name with a new value that is longer than allowed. # Request to update a device display name with a new value that is longer than allowed.
update = { update = {"display_name": "a" * (MAX_DEVICE_DISPLAY_NAME_LEN + 1)}
"display_name": "a"
* (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1)
}
self.get_failure( self.get_failure(
self.handler.update_device(user1, "abc", update), self.handler.update_device(user1, "abc", update),
synapse.api.errors.SynapseError, SynapseError,
) )
# Ensure the display name was not updated. # Ensure the display name was not updated.
res = self.get_success(self.handler.get_device(user1, "abc")) res = self.get_success(self.handler.get_device(user1, "abc"))
self.assertEqual(res["display_name"], "display 2") self.assertEqual(res["display_name"], "display 2")
def test_update_unknown_device(self): def test_update_unknown_device(self) -> None:
update = {"display_name": "new_display"} update = {"display_name": "new_display"}
self.get_failure( self.get_failure(
self.handler.update_device("user_id", "unknown_device_id", update), self.handler.update_device("user_id", "unknown_device_id", update),
synapse.api.errors.NotFoundError, NotFoundError,
) )
def _record_users(self): def _record_users(self) -> None:
# check this works for both devices which have a recorded client_ip, # check this works for both devices which have a recorded client_ip,
# and those which don't. # and those which don't.
self._record_user(user1, "xyz", "display 0") self._record_user(user1, "xyz", "display 0")
@ -238,8 +237,13 @@ class DeviceTestCase(unittest.HomeserverTestCase):
self.reactor.advance(10000) self.reactor.advance(10000)
def _record_user( def _record_user(
self, user_id, device_id, display_name, access_token=None, ip=None self,
): user_id: str,
device_id: str,
display_name: str,
access_token: Optional[str] = None,
ip: Optional[str] = None,
) -> None:
device_id = self.get_success( device_id = self.get_success(
self.handler.check_device_registered( self.handler.check_device_registered(
user_id=user_id, user_id=user_id,
@ -248,7 +252,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
) )
) )
if ip is not None: if access_token is not None and ip is not None:
self.get_success( self.get_success(
self.store.insert_client_ip( self.store.insert_client_ip(
user_id, access_token, ip, "user_agent", device_id user_id, access_token, ip, "user_agent", device_id
@ -258,7 +262,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
class DehydrationTestCase(unittest.HomeserverTestCase): class DehydrationTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None) hs = self.setup_test_homeserver("server", federation_http_client=None)
self.handler = hs.get_device_handler() self.handler = hs.get_device_handler()
self.registration = hs.get_registration_handler() self.registration = hs.get_registration_handler()
@ -266,7 +270,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
return hs return hs
def test_dehydrate_and_rehydrate_device(self): def test_dehydrate_and_rehydrate_device(self) -> None:
user_id = "@boris:dehydration" user_id = "@boris:dehydration"
self.get_success(self.store.register_user(user_id, "foobar")) self.get_success(self.store.register_user(user_id, "foobar"))
@ -303,7 +307,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
access_token=access_token, access_token=access_token,
device_id="not the right device ID", device_id="not the right device ID",
), ),
synapse.api.errors.NotFoundError, NotFoundError,
) )
# dehydrating the right devices should succeed and change our device ID # dehydrating the right devices should succeed and change our device ID
@ -331,7 +335,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
# make sure that the device ID that we were initially assigned no longer exists # make sure that the device ID that we were initially assigned no longer exists
self.get_failure( self.get_failure(
self.handler.get_device(user_id, device_id), self.handler.get_device(user_id, device_id),
synapse.api.errors.NotFoundError, NotFoundError,
) )
# make sure that there's no device available for dehydrating now # make sure that there's no device available for dehydrating now