Require types in tests.storage. (#14646)

Adds missing type hints to `tests.storage` package
and does not allow untyped definitions.
This commit is contained in:
Patrick Cloke 2022-12-09 12:36:32 -05:00 committed by GitHub
parent 94bc21e69f
commit 3ac412b4e2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
36 changed files with 489 additions and 341 deletions

View file

@ -13,15 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict
from unittest.mock import Mock
from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.http.site import XForwardedForRequest
from synapse.rest.client import login
from synapse.server import HomeServer
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.types import UserID
from synapse.util import Clock
from tests import unittest
from tests.server import make_request
@ -30,14 +35,10 @@ from tests.unittest import override_config
class ClientIpStoreTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver()
return hs
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
def prepare(self, hs, reactor, clock):
self.store = self.hs.get_datastores().main
def test_insert_new_client_ip(self):
def test_insert_new_client_ip(self) -> None:
self.reactor.advance(12345678)
user_id = "@user:id"
@ -76,7 +77,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
r,
)
def test_insert_new_client_ip_none_device_id(self):
def test_insert_new_client_ip_none_device_id(self) -> None:
"""
An insert with a device ID of NULL will not create a new entry, but
update an existing entry in the user_ips table.
@ -148,7 +149,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
@parameterized.expand([(False,), (True,)])
def test_get_last_client_ip_by_device(self, after_persisting: bool):
def test_get_last_client_ip_by_device(self, after_persisting: bool) -> None:
"""Test `get_last_client_ip_by_device` for persisted and unpersisted data"""
self.reactor.advance(12345678)
@ -213,7 +214,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
},
)
def test_get_last_client_ip_by_device_combined_data(self):
def test_get_last_client_ip_by_device_combined_data(self) -> None:
"""Test that `get_last_client_ip_by_device` combines persisted and unpersisted
data together correctly
"""
@ -312,7 +313,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
@parameterized.expand([(False,), (True,)])
def test_get_user_ip_and_agents(self, after_persisting: bool):
def test_get_user_ip_and_agents(self, after_persisting: bool) -> None:
"""Test `get_user_ip_and_agents` for persisted and unpersisted data"""
self.reactor.advance(12345678)
@ -352,7 +353,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
],
)
def test_get_user_ip_and_agents_combined_data(self):
def test_get_user_ip_and_agents_combined_data(self) -> None:
"""Test that `get_user_ip_and_agents` combines persisted and unpersisted data
together correctly
"""
@ -429,7 +430,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
)
@override_config({"limit_usage_by_mau": False, "max_mau_value": 50})
def test_disabled_monthly_active_user(self):
def test_disabled_monthly_active_user(self) -> None:
user_id = "@user:server"
self.get_success(
self.store.insert_client_ip(
@ -440,7 +441,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.assertFalse(active)
@override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
def test_adding_monthly_active_user_when_full(self):
def test_adding_monthly_active_user_when_full(self) -> None:
lots_of_users = 100
user_id = "@user:server"
@ -456,7 +457,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.assertFalse(active)
@override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
def test_adding_monthly_active_user_when_space(self):
def test_adding_monthly_active_user_when_space(self) -> None:
user_id = "@user:server"
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertFalse(active)
@ -473,7 +474,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.assertTrue(active)
@override_config({"limit_usage_by_mau": True, "max_mau_value": 50})
def test_updating_monthly_active_user_when_space(self):
def test_updating_monthly_active_user_when_space(self) -> None:
user_id = "@user:server"
self.get_success(self.store.register_user(user_id=user_id, password_hash=None))
@ -491,7 +492,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
active = self.get_success(self.store.user_last_seen_monthly_active(user_id))
self.assertTrue(active)
def test_devices_last_seen_bg_update(self):
def test_devices_last_seen_bg_update(self) -> None:
# First make sure we have completed all updates.
self.wait_for_background_updates()
@ -576,7 +577,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
r,
)
def test_old_user_ips_pruned(self):
def test_old_user_ips_pruned(self) -> None:
# First make sure we have completed all updates.
self.wait_for_background_updates()
@ -639,11 +640,11 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.assertEqual(result, [])
# But we should still get the correct values for the device
result = self.get_success(
result2 = self.get_success(
self.store.get_last_client_ip_by_device(user_id, device_id)
)
r = result[(user_id, device_id)]
r = result2[(user_id, device_id)]
self.assertDictContainsSubset(
{
"user_id": user_id,
@ -663,15 +664,11 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver()
return hs
def prepare(self, hs, reactor, clock):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = self.hs.get_datastores().main
self.user_id = self.register_user("bob", "abc123", True)
def test_request_with_xforwarded(self):
def test_request_with_xforwarded(self) -> None:
"""
The IP in X-Forwarded-For is entered into the client IPs table.
"""
@ -681,14 +678,19 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
{"request": XForwardedForRequest},
)
def test_request_from_getPeer(self):
def test_request_from_getPeer(self) -> None:
"""
The IP returned by getPeer is entered into the client IPs table, if
there's no X-Forwarded-For header.
"""
self._runtest({}, "127.0.0.1", {})
def _runtest(self, headers, expected_ip, make_request_args):
def _runtest(
self,
headers: Dict[bytes, bytes],
expected_ip: str,
make_request_args: Dict[str, Any],
) -> None:
device_id = "bleb"
access_token = self.login("bob", "abc123", device_id=device_id)