Add type hints to synapse/tests/rest/admin (#11501)

This commit is contained in:
Dirk Klimpel 2021-12-03 14:57:13 +01:00 committed by GitHub
parent 8cd68b8102
commit e5f426cd54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 257 additions and 228 deletions

1
changelog.d/11501.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to `synapse/tests/rest/admin`.

View File

@ -86,9 +86,6 @@ exclude = (?x)
|tests/push/test_presentable_names.py |tests/push/test_presentable_names.py
|tests/push/test_push_rule_evaluator.py |tests/push/test_push_rule_evaluator.py
|tests/rest/admin/test_admin.py |tests/rest/admin/test_admin.py
|tests/rest/admin/test_device.py
|tests/rest/admin/test_media.py
|tests/rest/admin/test_server_notice.py
|tests/rest/admin/test_user.py |tests/rest/admin/test_user.py
|tests/rest/admin/test_username_available.py |tests/rest/admin/test_username_available.py
|tests/rest/client/test_account.py |tests/rest/client/test_account.py

View File

@ -16,11 +16,14 @@ from typing import Collection
from parameterized import parameterized from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
from synapse.api.errors import Codes from synapse.api.errors import Codes
from synapse.rest.client import login from synapse.rest.client import login
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.background_updates import BackgroundUpdater
from synapse.util import Clock
from tests import unittest from tests import unittest
@ -31,7 +34,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def prepare(self, reactor, clock, hs: HomeServer): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass") self.admin_user_tok = self.login("admin", "pass")
@ -44,7 +47,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
("POST", "/_synapse/admin/v1/background_updates/start_job"), ("POST", "/_synapse/admin/v1/background_updates/start_job"),
] ]
) )
def test_requester_is_no_admin(self, method: str, url: str): def test_requester_is_no_admin(self, method: str, url: str) -> None:
""" """
If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
""" """
@ -62,7 +65,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self): def test_invalid_parameter(self) -> None:
""" """
If parameters are invalid, an error is returned. If parameters are invalid, an error is returned.
""" """
@ -90,7 +93,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
def _register_bg_update(self): def _register_bg_update(self) -> None:
"Adds a bg update but doesn't start it" "Adds a bg update but doesn't start it"
async def _fake_update(progress, batch_size) -> int: async def _fake_update(progress, batch_size) -> int:
@ -112,7 +115,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
) )
) )
def test_status_empty(self): def test_status_empty(self) -> None:
"""Test the status API works.""" """Test the status API works."""
channel = self.make_request( channel = self.make_request(
@ -127,7 +130,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
channel.json_body, {"current_updates": {}, "enabled": True} channel.json_body, {"current_updates": {}, "enabled": True}
) )
def test_status_bg_update(self): def test_status_bg_update(self) -> None:
"""Test the status API works with a background update.""" """Test the status API works with a background update."""
# Create a new background update # Create a new background update
@ -162,7 +165,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
}, },
) )
def test_enabled(self): def test_enabled(self) -> None:
"""Test the enabled API works.""" """Test the enabled API works."""
# Create a new background update # Create a new background update
@ -299,7 +302,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
), ),
] ]
) )
def test_start_backround_job(self, job_name: str, updates: Collection[str]): def test_start_backround_job(self, job_name: str, updates: Collection[str]) -> None:
""" """
Test that background updates add to database and be processed. Test that background updates add to database and be processed.
@ -341,7 +344,7 @@ class BackgroundUpdatesTestCase(unittest.HomeserverTestCase):
) )
) )
def test_start_backround_job_twice(self): def test_start_backround_job_twice(self) -> None:
"""Test that add a background update twice return an error.""" """Test that add a background update twice return an error."""
# add job to database # add job to database

View File

@ -11,15 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import urllib.parse import urllib.parse
from http import HTTPStatus from http import HTTPStatus
from parameterized import parameterized from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
from synapse.api.errors import Codes from synapse.api.errors import Codes
from synapse.rest.client import login from synapse.rest.client import login
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest from tests import unittest
@ -31,7 +34,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_device_handler() self.handler = hs.get_device_handler()
self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user = self.register_user("admin", "pass", admin=True)
@ -48,7 +51,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
) )
@parameterized.expand(["GET", "PUT", "DELETE"]) @parameterized.expand(["GET", "PUT", "DELETE"])
def test_no_auth(self, method: str): def test_no_auth(self, method: str) -> None:
""" """
Try to get a device of an user without authentication. Try to get a device of an user without authentication.
""" """
@ -62,7 +65,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "PUT", "DELETE"]) @parameterized.expand(["GET", "PUT", "DELETE"])
def test_requester_is_no_admin(self, method: str): def test_requester_is_no_admin(self, method: str) -> None:
""" """
If the user is not a server admin, an error is returned. If the user is not a server admin, an error is returned.
""" """
@ -80,7 +83,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@parameterized.expand(["GET", "PUT", "DELETE"]) @parameterized.expand(["GET", "PUT", "DELETE"])
def test_user_does_not_exist(self, method: str): def test_user_does_not_exist(self, method: str) -> None:
""" """
Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
""" """
@ -99,7 +102,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@parameterized.expand(["GET", "PUT", "DELETE"]) @parameterized.expand(["GET", "PUT", "DELETE"])
def test_user_is_not_local(self, method: str): def test_user_is_not_local(self, method: str) -> None:
""" """
Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
""" """
@ -117,7 +120,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"]) self.assertEqual("Can only lookup local users", channel.json_body["error"])
def test_unknown_device(self): def test_unknown_device(self) -> None:
""" """
Tests that a lookup for a device that does not exist returns either HTTPStatus.NOT_FOUND or HTTPStatus.OK. Tests that a lookup for a device that does not exist returns either HTTPStatus.NOT_FOUND or HTTPStatus.OK.
""" """
@ -151,7 +154,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
# Delete unknown device returns status HTTPStatus.OK # Delete unknown device returns status HTTPStatus.OK
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
def test_update_device_too_long_display_name(self): def test_update_device_too_long_display_name(self) -> None:
""" """
Update a device with a display name that is invalid (too long). Update a device with a display name that is invalid (too long).
""" """
@ -189,7 +192,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("new display", channel.json_body["display_name"]) self.assertEqual("new display", channel.json_body["display_name"])
def test_update_no_display_name(self): def test_update_no_display_name(self) -> None:
""" """
Tests that a update for a device without JSON returns a HTTPStatus.OK Tests that a update for a device without JSON returns a HTTPStatus.OK
""" """
@ -219,7 +222,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("new display", channel.json_body["display_name"]) self.assertEqual("new display", channel.json_body["display_name"])
def test_update_display_name(self): def test_update_display_name(self) -> None:
""" """
Tests a normal successful update of display name Tests a normal successful update of display name
""" """
@ -243,7 +246,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual("new displayname", channel.json_body["display_name"]) self.assertEqual("new displayname", channel.json_body["display_name"])
def test_get_device(self): def test_get_device(self) -> None:
""" """
Tests that a normal lookup for a device is successfully Tests that a normal lookup for a device is successfully
""" """
@ -262,7 +265,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.assertIn("last_seen_ip", channel.json_body) self.assertIn("last_seen_ip", channel.json_body)
self.assertIn("last_seen_ts", channel.json_body) self.assertIn("last_seen_ts", channel.json_body)
def test_delete_device(self): def test_delete_device(self) -> None:
""" """
Tests that a remove of a device is successfully Tests that a remove of a device is successfully
""" """
@ -292,7 +295,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass") self.admin_user_tok = self.login("admin", "pass")
@ -302,7 +305,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
self.other_user self.other_user
) )
def test_no_auth(self): def test_no_auth(self) -> None:
""" """
Try to list devices of an user without authentication. Try to list devices of an user without authentication.
""" """
@ -315,7 +318,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self): def test_requester_is_no_admin(self) -> None:
""" """
If the user is not a server admin, an error is returned. If the user is not a server admin, an error is returned.
""" """
@ -334,7 +337,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_does_not_exist(self): def test_user_does_not_exist(self) -> None:
""" """
Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
""" """
@ -348,7 +351,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_user_is_not_local(self): def test_user_is_not_local(self) -> None:
""" """
Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
""" """
@ -363,7 +366,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"]) self.assertEqual("Can only lookup local users", channel.json_body["error"])
def test_user_has_no_devices(self): def test_user_has_no_devices(self) -> None:
""" """
Tests that a normal lookup for devices is successfully Tests that a normal lookup for devices is successfully
if user has no devices if user has no devices
@ -380,7 +383,7 @@ class DevicesRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["devices"])) self.assertEqual(0, len(channel.json_body["devices"]))
def test_get_devices(self): def test_get_devices(self) -> None:
""" """
Tests that a normal lookup for devices is successfully Tests that a normal lookup for devices is successfully
""" """
@ -416,7 +419,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_device_handler() self.handler = hs.get_device_handler()
self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user = self.register_user("admin", "pass", admin=True)
@ -428,7 +431,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
self.other_user self.other_user
) )
def test_no_auth(self): def test_no_auth(self) -> None:
""" """
Try to delete devices of an user without authentication. Try to delete devices of an user without authentication.
""" """
@ -441,7 +444,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self): def test_requester_is_no_admin(self) -> None:
""" """
If the user is not a server admin, an error is returned. If the user is not a server admin, an error is returned.
""" """
@ -460,7 +463,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_does_not_exist(self): def test_user_does_not_exist(self) -> None:
""" """
Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
""" """
@ -474,7 +477,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_user_is_not_local(self): def test_user_is_not_local(self) -> None:
""" """
Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
""" """
@ -489,7 +492,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("Can only lookup local users", channel.json_body["error"]) self.assertEqual("Can only lookup local users", channel.json_body["error"])
def test_unknown_devices(self): def test_unknown_devices(self) -> None:
""" """
Tests that a remove of a device that does not exist returns HTTPStatus.OK. Tests that a remove of a device that does not exist returns HTTPStatus.OK.
""" """
@ -503,7 +506,7 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
# Delete unknown devices returns status HTTPStatus.OK # Delete unknown devices returns status HTTPStatus.OK
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
def test_delete_devices(self): def test_delete_devices(self) -> None:
""" """
Tests that a remove of devices is successfully Tests that a remove of devices is successfully
""" """

View File

@ -11,12 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from http import HTTPStatus from http import HTTPStatus
from typing import List
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
from synapse.api.errors import Codes from synapse.api.errors import Codes
from synapse.rest.client import login, report_event, room from synapse.rest.client import login, report_event, room
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest from tests import unittest
@ -29,7 +34,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
report_event.register_servlets, report_event.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass") self.admin_user_tok = self.login("admin", "pass")
@ -70,7 +75,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.url = "/_synapse/admin/v1/event_reports" self.url = "/_synapse/admin/v1/event_reports"
def test_no_auth(self): def test_no_auth(self) -> None:
""" """
Try to get an event report without authentication. Try to get an event report without authentication.
""" """
@ -83,7 +88,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self): def test_requester_is_no_admin(self) -> None:
""" """
If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
""" """
@ -101,7 +106,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_default_success(self): def test_default_success(self) -> None:
""" """
Testing list of reported events Testing list of reported events
""" """
@ -118,7 +123,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
self._check_fields(channel.json_body["event_reports"]) self._check_fields(channel.json_body["event_reports"])
def test_limit(self): def test_limit(self) -> None:
""" """
Testing list of reported events with limit Testing list of reported events with limit
""" """
@ -135,7 +140,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["next_token"], 5) self.assertEqual(channel.json_body["next_token"], 5)
self._check_fields(channel.json_body["event_reports"]) self._check_fields(channel.json_body["event_reports"])
def test_from(self): def test_from(self) -> None:
""" """
Testing list of reported events with a defined starting point (from) Testing list of reported events with a defined starting point (from)
""" """
@ -152,7 +157,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
self._check_fields(channel.json_body["event_reports"]) self._check_fields(channel.json_body["event_reports"])
def test_limit_and_from(self): def test_limit_and_from(self) -> None:
""" """
Testing list of reported events with a defined starting point and limit Testing list of reported events with a defined starting point and limit
""" """
@ -169,7 +174,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(channel.json_body["event_reports"]), 10) self.assertEqual(len(channel.json_body["event_reports"]), 10)
self._check_fields(channel.json_body["event_reports"]) self._check_fields(channel.json_body["event_reports"])
def test_filter_room(self): def test_filter_room(self) -> None:
""" """
Testing list of reported events with a filter of room Testing list of reported events with a filter of room
""" """
@ -189,7 +194,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
for report in channel.json_body["event_reports"]: for report in channel.json_body["event_reports"]:
self.assertEqual(report["room_id"], self.room_id1) self.assertEqual(report["room_id"], self.room_id1)
def test_filter_user(self): def test_filter_user(self) -> None:
""" """
Testing list of reported events with a filter of user Testing list of reported events with a filter of user
""" """
@ -209,7 +214,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
for report in channel.json_body["event_reports"]: for report in channel.json_body["event_reports"]:
self.assertEqual(report["user_id"], self.other_user) self.assertEqual(report["user_id"], self.other_user)
def test_filter_user_and_room(self): def test_filter_user_and_room(self) -> None:
""" """
Testing list of reported events with a filter of user and room Testing list of reported events with a filter of user and room
""" """
@ -230,7 +235,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.assertEqual(report["user_id"], self.other_user) self.assertEqual(report["user_id"], self.other_user)
self.assertEqual(report["room_id"], self.room_id1) self.assertEqual(report["room_id"], self.room_id1)
def test_valid_search_order(self): def test_valid_search_order(self) -> None:
""" """
Testing search order. Order by timestamps. Testing search order. Order by timestamps.
""" """
@ -271,7 +276,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
) )
report += 1 report += 1
def test_invalid_search_order(self): def test_invalid_search_order(self) -> None:
""" """
Testing that a invalid search order returns a HTTPStatus.BAD_REQUEST Testing that a invalid search order returns a HTTPStatus.BAD_REQUEST
""" """
@ -290,7 +295,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual("Unknown direction: bar", channel.json_body["error"]) self.assertEqual("Unknown direction: bar", channel.json_body["error"])
def test_limit_is_negative(self): def test_limit_is_negative(self) -> None:
""" """
Testing that a negative limit parameter returns a HTTPStatus.BAD_REQUEST Testing that a negative limit parameter returns a HTTPStatus.BAD_REQUEST
""" """
@ -308,7 +313,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_from_is_negative(self): def test_from_is_negative(self) -> None:
""" """
Testing that a negative from parameter returns a HTTPStatus.BAD_REQUEST Testing that a negative from parameter returns a HTTPStatus.BAD_REQUEST
""" """
@ -326,7 +331,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_next_token(self): def test_next_token(self) -> None:
""" """
Testing that `next_token` appears at the right place Testing that `next_token` appears at the right place
""" """
@ -384,7 +389,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(channel.json_body["event_reports"]), 1) self.assertEqual(len(channel.json_body["event_reports"]), 1)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
def _create_event_and_report(self, room_id, user_tok): def _create_event_and_report(self, room_id: str, user_tok: str) -> None:
"""Create and report events""" """Create and report events"""
resp = self.helper.send(room_id, tok=user_tok) resp = self.helper.send(room_id, tok=user_tok)
event_id = resp["event_id"] event_id = resp["event_id"]
@ -397,7 +402,9 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
def _create_event_and_report_without_parameters(self, room_id, user_tok): def _create_event_and_report_without_parameters(
self, room_id: str, user_tok: str
) -> None:
"""Create and report an event, but omit reason and score""" """Create and report an event, but omit reason and score"""
resp = self.helper.send(room_id, tok=user_tok) resp = self.helper.send(room_id, tok=user_tok)
event_id = resp["event_id"] event_id = resp["event_id"]
@ -410,7 +417,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
def _check_fields(self, content): def _check_fields(self, content: List[JsonDict]) -> None:
"""Checks that all attributes are present in an event report""" """Checks that all attributes are present in an event report"""
for c in content: for c in content:
self.assertIn("id", c) self.assertIn("id", c)
@ -433,7 +440,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
report_event.register_servlets, report_event.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass") self.admin_user_tok = self.login("admin", "pass")
@ -453,7 +460,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
# first created event report gets `id`=2 # first created event report gets `id`=2
self.url = "/_synapse/admin/v1/event_reports/2" self.url = "/_synapse/admin/v1/event_reports/2"
def test_no_auth(self): def test_no_auth(self) -> None:
""" """
Try to get event report without authentication. Try to get event report without authentication.
""" """
@ -466,7 +473,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self): def test_requester_is_no_admin(self) -> None:
""" """
If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
""" """
@ -484,7 +491,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_default_success(self): def test_default_success(self) -> None:
""" """
Testing get a reported event Testing get a reported event
""" """
@ -498,7 +505,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self._check_fields(channel.json_body) self._check_fields(channel.json_body)
def test_invalid_report_id(self): def test_invalid_report_id(self) -> None:
""" """
Testing that an invalid `report_id` returns a HTTPStatus.BAD_REQUEST. Testing that an invalid `report_id` returns a HTTPStatus.BAD_REQUEST.
""" """
@ -557,7 +564,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
channel.json_body["error"], channel.json_body["error"],
) )
def test_report_id_not_found(self): def test_report_id_not_found(self) -> None:
""" """
Testing that a not existing `report_id` returns a HTTPStatus.NOT_FOUND. Testing that a not existing `report_id` returns a HTTPStatus.NOT_FOUND.
""" """
@ -576,7 +583,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
self.assertEqual("Event report not found", channel.json_body["error"]) self.assertEqual("Event report not found", channel.json_body["error"])
def _create_event_and_report(self, room_id, user_tok): def _create_event_and_report(self, room_id: str, user_tok: str) -> None:
"""Create and report events""" """Create and report events"""
resp = self.helper.send(room_id, tok=user_tok) resp = self.helper.send(room_id, tok=user_tok)
event_id = resp["event_id"] event_id = resp["event_id"]
@ -589,7 +596,7 @@ class EventReportDetailTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
def _check_fields(self, content): def _check_fields(self, content: JsonDict) -> None:
"""Checks that all attributes are present in a event report""" """Checks that all attributes are present in a event report"""
self.assertIn("id", content) self.assertIn("id", content)
self.assertIn("received_ts", content) self.assertIn("received_ts", content)

View File

@ -12,16 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
from http import HTTPStatus from http import HTTPStatus
from parameterized import parameterized from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
from synapse.api.errors import Codes from synapse.api.errors import Codes
from synapse.rest.client import login, profile, room from synapse.rest.client import login, profile, room
from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import FakeSite, make_request from tests.server import FakeSite, make_request
@ -39,7 +42,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.media_repo = hs.get_media_repository_resource() self.media_repo = hs.get_media_repository_resource()
self.server_name = hs.hostname self.server_name = hs.hostname
@ -48,7 +51,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
self.filepaths = MediaFilePaths(hs.config.media.media_store_path) self.filepaths = MediaFilePaths(hs.config.media.media_store_path)
def test_no_auth(self): def test_no_auth(self) -> None:
""" """
Try to delete media without authentication. Try to delete media without authentication.
""" """
@ -63,7 +66,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self): def test_requester_is_no_admin(self) -> None:
""" """
If the user is not a server admin, an error is returned. If the user is not a server admin, an error is returned.
""" """
@ -85,7 +88,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_media_does_not_exist(self): def test_media_does_not_exist(self) -> None:
""" """
Tests that a lookup for a media that does not exist returns a HTTPStatus.NOT_FOUND Tests that a lookup for a media that does not exist returns a HTTPStatus.NOT_FOUND
""" """
@ -100,7 +103,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_media_is_not_local(self): def test_media_is_not_local(self) -> None:
""" """
Tests that a lookup for a media that is not a local returns a HTTPStatus.BAD_REQUEST Tests that a lookup for a media that is not a local returns a HTTPStatus.BAD_REQUEST
""" """
@ -115,7 +118,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("Can only delete local media", channel.json_body["error"]) self.assertEqual("Can only delete local media", channel.json_body["error"])
def test_delete_media(self): def test_delete_media(self) -> None:
""" """
Tests that delete a media is successfully Tests that delete a media is successfully
""" """
@ -208,7 +211,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
room.register_servlets, room.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.media_repo = hs.get_media_repository_resource() self.media_repo = hs.get_media_repository_resource()
self.server_name = hs.hostname self.server_name = hs.hostname
@ -221,7 +224,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
# Move clock up to somewhat realistic time # Move clock up to somewhat realistic time
self.reactor.advance(1000000000) self.reactor.advance(1000000000)
def test_no_auth(self): def test_no_auth(self) -> None:
""" """
Try to delete media without authentication. Try to delete media without authentication.
""" """
@ -235,7 +238,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self): def test_requester_is_no_admin(self) -> None:
""" """
If the user is not a server admin, an error is returned. If the user is not a server admin, an error is returned.
""" """
@ -255,7 +258,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_media_is_not_local(self): def test_media_is_not_local(self) -> None:
""" """
Tests that a lookup for media that is not local returns a HTTPStatus.BAD_REQUEST Tests that a lookup for media that is not local returns a HTTPStatus.BAD_REQUEST
""" """
@ -270,7 +273,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual("Can only delete local media", channel.json_body["error"]) self.assertEqual("Can only delete local media", channel.json_body["error"])
def test_missing_parameter(self): def test_missing_parameter(self) -> None:
""" """
If the parameter `before_ts` is missing, an error is returned. If the parameter `before_ts` is missing, an error is returned.
""" """
@ -290,7 +293,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
"Missing integer query parameter 'before_ts'", channel.json_body["error"] "Missing integer query parameter 'before_ts'", channel.json_body["error"]
) )
def test_invalid_parameter(self): def test_invalid_parameter(self) -> None:
""" """
If parameters are invalid, an error is returned. If parameters are invalid, an error is returned.
""" """
@ -363,7 +366,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
channel.json_body["error"], channel.json_body["error"],
) )
def test_delete_media_never_accessed(self): def test_delete_media_never_accessed(self) -> None:
""" """
Tests that media deleted if it is older than `before_ts` and never accessed Tests that media deleted if it is older than `before_ts` and never accessed
`last_access_ts` is `NULL` and `created_ts` < `before_ts` `last_access_ts` is `NULL` and `created_ts` < `before_ts`
@ -394,7 +397,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self._access_media(server_and_media_id, False) self._access_media(server_and_media_id, False)
def test_keep_media_by_date(self): def test_keep_media_by_date(self) -> None:
""" """
Tests that media is not deleted if it is newer than `before_ts` Tests that media is not deleted if it is newer than `before_ts`
""" """
@ -431,7 +434,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self._access_media(server_and_media_id, False) self._access_media(server_and_media_id, False)
def test_keep_media_by_size(self): def test_keep_media_by_size(self) -> None:
""" """
Tests that media is not deleted if its size is smaller than or equal Tests that media is not deleted if its size is smaller than or equal
to `size_gt` to `size_gt`
@ -466,7 +469,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self._access_media(server_and_media_id, False) self._access_media(server_and_media_id, False)
def test_keep_media_by_user_avatar(self): def test_keep_media_by_user_avatar(self) -> None:
""" """
Tests that we do not delete media if is used as a user avatar Tests that we do not delete media if is used as a user avatar
Tests parameter `keep_profiles` Tests parameter `keep_profiles`
@ -510,7 +513,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self._access_media(server_and_media_id, False) self._access_media(server_and_media_id, False)
def test_keep_media_by_room_avatar(self): def test_keep_media_by_room_avatar(self) -> None:
""" """
Tests that we do not delete media if it is used as a room avatar Tests that we do not delete media if it is used as a room avatar
Tests parameter `keep_profiles` Tests parameter `keep_profiles`
@ -555,7 +558,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self._access_media(server_and_media_id, False) self._access_media(server_and_media_id, False)
def _create_media(self): def _create_media(self) -> str:
""" """
Create a media and return media_id and server_and_media_id Create a media and return media_id and server_and_media_id
""" """
@ -577,7 +580,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
return server_and_media_id return server_and_media_id
def _access_media(self, server_and_media_id, expect_success=True): def _access_media(self, server_and_media_id, expect_success=True) -> None:
""" """
Try to access a media and check the result Try to access a media and check the result
""" """
@ -627,7 +630,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
media_repo = hs.get_media_repository_resource() media_repo = hs.get_media_repository_resource()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.server_name = hs.hostname self.server_name = hs.hostname
@ -652,7 +655,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
self.url = "/_synapse/admin/v1/media/%s/%s/%s" self.url = "/_synapse/admin/v1/media/%s/%s/%s"
@parameterized.expand(["quarantine", "unquarantine"]) @parameterized.expand(["quarantine", "unquarantine"])
def test_no_auth(self, action: str): def test_no_auth(self, action: str) -> None:
""" """
Try to protect media without authentication. Try to protect media without authentication.
""" """
@ -671,7 +674,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["quarantine", "unquarantine"]) @parameterized.expand(["quarantine", "unquarantine"])
def test_requester_is_no_admin(self, action: str): def test_requester_is_no_admin(self, action: str) -> None:
""" """
If the user is not a server admin, an error is returned. If the user is not a server admin, an error is returned.
""" """
@ -691,7 +694,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_quarantine_media(self): def test_quarantine_media(self) -> None:
""" """
Tests that quarantining and remove from quarantine a media is successfully Tests that quarantining and remove from quarantine a media is successfully
""" """
@ -725,7 +728,7 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
media_info = self.get_success(self.store.get_local_media(self.media_id)) media_info = self.get_success(self.store.get_local_media(self.media_id))
self.assertFalse(media_info["quarantined_by"]) self.assertFalse(media_info["quarantined_by"])
def test_quarantine_protected_media(self): def test_quarantine_protected_media(self) -> None:
""" """
Tests that quarantining from protected media fails Tests that quarantining from protected media fails
""" """
@ -760,7 +763,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
media_repo = hs.get_media_repository_resource() media_repo = hs.get_media_repository_resource()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -784,7 +787,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
self.url = "/_synapse/admin/v1/media/%s/%s" self.url = "/_synapse/admin/v1/media/%s/%s"
@parameterized.expand(["protect", "unprotect"]) @parameterized.expand(["protect", "unprotect"])
def test_no_auth(self, action: str): def test_no_auth(self, action: str) -> None:
""" """
Try to protect media without authentication. Try to protect media without authentication.
""" """
@ -799,7 +802,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["protect", "unprotect"]) @parameterized.expand(["protect", "unprotect"])
def test_requester_is_no_admin(self, action: str): def test_requester_is_no_admin(self, action: str) -> None:
""" """
If the user is not a server admin, an error is returned. If the user is not a server admin, an error is returned.
""" """
@ -819,7 +822,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_protect_media(self): def test_protect_media(self) -> None:
""" """
Tests that protect and unprotect a media is successfully Tests that protect and unprotect a media is successfully
""" """
@ -864,7 +867,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
room.register_servlets, room.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.media_repo = hs.get_media_repository_resource() self.media_repo = hs.get_media_repository_resource()
self.server_name = hs.hostname self.server_name = hs.hostname
@ -874,7 +877,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
self.filepaths = MediaFilePaths(hs.config.media.media_store_path) self.filepaths = MediaFilePaths(hs.config.media.media_store_path)
self.url = "/_synapse/admin/v1/purge_media_cache" self.url = "/_synapse/admin/v1/purge_media_cache"
def test_no_auth(self): def test_no_auth(self) -> None:
""" """
Try to delete media without authentication. Try to delete media without authentication.
""" """
@ -888,7 +891,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_not_admin(self): def test_requester_is_not_admin(self) -> None:
""" """
If the user is not a server admin, an error is returned. If the user is not a server admin, an error is returned.
""" """
@ -908,7 +911,7 @@ class PurgeMediaCacheTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self): def test_invalid_parameter(self) -> None:
""" """
If parameters are invalid, an error is returned. If parameters are invalid, an error is returned.
""" """

View File

@ -11,14 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import random import random
import string import string
from http import HTTPStatus from http import HTTPStatus
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
from synapse.api.errors import Codes from synapse.api.errors import Codes
from synapse.rest.client import login from synapse.rest.client import login
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest from tests import unittest
@ -29,7 +32,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass") self.admin_user_tok = self.login("admin", "pass")
@ -39,7 +42,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
self.url = "/_synapse/admin/v1/registration_tokens" self.url = "/_synapse/admin/v1/registration_tokens"
def _new_token(self, **kwargs): def _new_token(self, **kwargs) -> str:
"""Helper function to create a token.""" """Helper function to create a token."""
token = kwargs.get( token = kwargs.get(
"token", "token",
@ -61,7 +64,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
# CREATION # CREATION
def test_create_no_auth(self): def test_create_no_auth(self) -> None:
"""Try to create a token without authentication.""" """Try to create a token without authentication."""
channel = self.make_request("POST", self.url + "/new", {}) channel = self.make_request("POST", self.url + "/new", {})
self.assertEqual( self.assertEqual(
@ -71,7 +74,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_create_requester_not_admin(self): def test_create_requester_not_admin(self) -> None:
"""Try to create a token while not an admin.""" """Try to create a token while not an admin."""
channel = self.make_request( channel = self.make_request(
"POST", "POST",
@ -86,7 +89,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_create_using_defaults(self): def test_create_using_defaults(self) -> None:
"""Create a token using all the defaults.""" """Create a token using all the defaults."""
channel = self.make_request( channel = self.make_request(
"POST", "POST",
@ -102,7 +105,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["pending"], 0) self.assertEqual(channel.json_body["pending"], 0)
self.assertEqual(channel.json_body["completed"], 0) self.assertEqual(channel.json_body["completed"], 0)
def test_create_specifying_fields(self): def test_create_specifying_fields(self) -> None:
"""Create a token specifying the value of all fields.""" """Create a token specifying the value of all fields."""
# As many of the allowed characters as possible with length <= 64 # As many of the allowed characters as possible with length <= 64
token = "adefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789._~-" token = "adefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789._~-"
@ -126,7 +129,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["pending"], 0) self.assertEqual(channel.json_body["pending"], 0)
self.assertEqual(channel.json_body["completed"], 0) self.assertEqual(channel.json_body["completed"], 0)
def test_create_with_null_value(self): def test_create_with_null_value(self) -> None:
"""Create a token specifying unlimited uses and no expiry.""" """Create a token specifying unlimited uses and no expiry."""
data = { data = {
"uses_allowed": None, "uses_allowed": None,
@ -147,7 +150,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["pending"], 0) self.assertEqual(channel.json_body["pending"], 0)
self.assertEqual(channel.json_body["completed"], 0) self.assertEqual(channel.json_body["completed"], 0)
def test_create_token_too_long(self): def test_create_token_too_long(self) -> None:
"""Check token longer than 64 chars is invalid.""" """Check token longer than 64 chars is invalid."""
data = {"token": "a" * 65} data = {"token": "a" * 65}
@ -165,7 +168,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_token_invalid_chars(self): def test_create_token_invalid_chars(self) -> None:
"""Check you can't create token with invalid characters.""" """Check you can't create token with invalid characters."""
data = { data = {
"token": "abc/def", "token": "abc/def",
@ -185,7 +188,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_token_already_exists(self): def test_create_token_already_exists(self) -> None:
"""Check you can't create token that already exists.""" """Check you can't create token that already exists."""
data = { data = {
"token": "abcd", "token": "abcd",
@ -208,7 +211,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.BAD_REQUEST, channel2.code, msg=channel2.json_body) self.assertEqual(HTTPStatus.BAD_REQUEST, channel2.code, msg=channel2.json_body)
self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM) self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_unable_to_generate_token(self): def test_create_unable_to_generate_token(self) -> None:
"""Check right error is raised when server can't generate unique token.""" """Check right error is raised when server can't generate unique token."""
# Create all possible single character tokens # Create all possible single character tokens
tokens = [] tokens = []
@ -239,7 +242,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(500, channel.code, msg=channel.json_body) self.assertEqual(500, channel.code, msg=channel.json_body)
def test_create_uses_allowed(self): def test_create_uses_allowed(self) -> None:
"""Check you can only create a token with good values for uses_allowed.""" """Check you can only create a token with good values for uses_allowed."""
# Should work with 0 (token is invalid from the start) # Should work with 0 (token is invalid from the start)
channel = self.make_request( channel = self.make_request(
@ -279,7 +282,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_expiry_time(self): def test_create_expiry_time(self) -> None:
"""Check you can't create a token with an invalid expiry_time.""" """Check you can't create a token with an invalid expiry_time."""
# Should fail with a time in the past # Should fail with a time in the past
channel = self.make_request( channel = self.make_request(
@ -309,7 +312,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_create_length(self): def test_create_length(self) -> None:
"""Check you can only generate a token with a valid length.""" """Check you can only generate a token with a valid length."""
# Should work with 64 # Should work with 64
channel = self.make_request( channel = self.make_request(
@ -379,7 +382,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
# UPDATING # UPDATING
def test_update_no_auth(self): def test_update_no_auth(self) -> None:
"""Try to update a token without authentication.""" """Try to update a token without authentication."""
channel = self.make_request( channel = self.make_request(
"PUT", "PUT",
@ -393,7 +396,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_update_requester_not_admin(self): def test_update_requester_not_admin(self) -> None:
"""Try to update a token while not an admin.""" """Try to update a token while not an admin."""
channel = self.make_request( channel = self.make_request(
"PUT", "PUT",
@ -408,7 +411,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_update_non_existent(self): def test_update_non_existent(self) -> None:
"""Try to update a token that doesn't exist.""" """Try to update a token that doesn't exist."""
channel = self.make_request( channel = self.make_request(
"PUT", "PUT",
@ -424,7 +427,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
def test_update_uses_allowed(self): def test_update_uses_allowed(self) -> None:
"""Test updating just uses_allowed.""" """Test updating just uses_allowed."""
# Create new token using default values # Create new token using default values
token = self._new_token() token = self._new_token()
@ -490,7 +493,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_update_expiry_time(self): def test_update_expiry_time(self) -> None:
"""Test updating just expiry_time.""" """Test updating just expiry_time."""
# Create new token using default values # Create new token using default values
token = self._new_token() token = self._new_token()
@ -547,7 +550,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM) self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
def test_update_both(self): def test_update_both(self) -> None:
"""Test updating both uses_allowed and expiry_time.""" """Test updating both uses_allowed and expiry_time."""
# Create new token using default values # Create new token using default values
token = self._new_token() token = self._new_token()
@ -569,7 +572,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertEqual(channel.json_body["uses_allowed"], 1)
self.assertEqual(channel.json_body["expiry_time"], new_expiry_time) self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
def test_update_invalid_type(self): def test_update_invalid_type(self) -> None:
"""Test using invalid types doesn't work.""" """Test using invalid types doesn't work."""
# Create new token using default values # Create new token using default values
token = self._new_token() token = self._new_token()
@ -595,7 +598,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
# DELETING # DELETING
def test_delete_no_auth(self): def test_delete_no_auth(self) -> None:
"""Try to delete a token without authentication.""" """Try to delete a token without authentication."""
channel = self.make_request( channel = self.make_request(
"DELETE", "DELETE",
@ -609,7 +612,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_delete_requester_not_admin(self): def test_delete_requester_not_admin(self) -> None:
"""Try to delete a token while not an admin.""" """Try to delete a token while not an admin."""
channel = self.make_request( channel = self.make_request(
"DELETE", "DELETE",
@ -624,7 +627,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_delete_non_existent(self): def test_delete_non_existent(self) -> None:
"""Try to delete a token that doesn't exist.""" """Try to delete a token that doesn't exist."""
channel = self.make_request( channel = self.make_request(
"DELETE", "DELETE",
@ -640,7 +643,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
def test_delete(self): def test_delete(self) -> None:
"""Test deleting a token.""" """Test deleting a token."""
# Create new token using default values # Create new token using default values
token = self._new_token() token = self._new_token()
@ -656,7 +659,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
# GETTING ONE # GETTING ONE
def test_get_no_auth(self): def test_get_no_auth(self) -> None:
"""Try to get a token without authentication.""" """Try to get a token without authentication."""
channel = self.make_request( channel = self.make_request(
"GET", "GET",
@ -670,7 +673,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_get_requester_not_admin(self): def test_get_requester_not_admin(self) -> None:
"""Try to get a token while not an admin.""" """Try to get a token while not an admin."""
channel = self.make_request( channel = self.make_request(
"GET", "GET",
@ -685,7 +688,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_get_non_existent(self): def test_get_non_existent(self) -> None:
"""Try to get a token that doesn't exist.""" """Try to get a token that doesn't exist."""
channel = self.make_request( channel = self.make_request(
"GET", "GET",
@ -701,7 +704,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
def test_get(self): def test_get(self) -> None:
"""Test getting a token.""" """Test getting a token."""
# Create new token using default values # Create new token using default values
token = self._new_token() token = self._new_token()
@ -722,7 +725,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
# LISTING # LISTING
def test_list_no_auth(self): def test_list_no_auth(self) -> None:
"""Try to list tokens without authentication.""" """Try to list tokens without authentication."""
channel = self.make_request("GET", self.url, {}) channel = self.make_request("GET", self.url, {})
self.assertEqual( self.assertEqual(
@ -732,7 +735,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_list_requester_not_admin(self): def test_list_requester_not_admin(self) -> None:
"""Try to list tokens while not an admin.""" """Try to list tokens while not an admin."""
channel = self.make_request( channel = self.make_request(
"GET", "GET",
@ -747,7 +750,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_list_all(self): def test_list_all(self) -> None:
"""Test listing all tokens.""" """Test listing all tokens."""
# Create new token using default values # Create new token using default values
token = self._new_token() token = self._new_token()
@ -768,7 +771,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
self.assertEqual(token_info["pending"], 0) self.assertEqual(token_info["pending"], 0)
self.assertEqual(token_info["completed"], 0) self.assertEqual(token_info["completed"], 0)
def test_list_invalid_query_parameter(self): def test_list_invalid_query_parameter(self) -> None:
"""Test with `valid` query parameter not `true` or `false`.""" """Test with `valid` query parameter not `true` or `false`."""
channel = self.make_request( channel = self.make_request(
"GET", "GET",
@ -783,7 +786,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
msg=channel.json_body, msg=channel.json_body,
) )
def _test_list_query_parameter(self, valid: str): def _test_list_query_parameter(self, valid: str) -> None:
"""Helper used to test both valid=true and valid=false.""" """Helper used to test both valid=true and valid=false."""
# Create 2 valid and 2 invalid tokens. # Create 2 valid and 2 invalid tokens.
now = self.hs.get_clock().time_msec() now = self.hs.get_clock().time_msec()
@ -820,10 +823,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
self.assertIn(token_info_1["token"], tokens) self.assertIn(token_info_1["token"], tokens)
self.assertIn(token_info_2["token"], tokens) self.assertIn(token_info_2["token"], tokens)
def test_list_valid(self): def test_list_valid(self) -> None:
"""Test listing just valid tokens.""" """Test listing just valid tokens."""
self._test_list_query_parameter(valid="true") self._test_list_query_parameter(valid="true")
def test_list_invalid(self): def test_list_invalid(self) -> None:
"""Test listing just invalid tokens.""" """Test listing just invalid tokens."""
self._test_list_query_parameter(valid="false") self._test_list_query_parameter(valid="false")

View File

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import urllib.parse import urllib.parse
from http import HTTPStatus from http import HTTPStatus
from typing import List, Optional from typing import List, Optional
@ -19,11 +18,15 @@ from unittest.mock import Mock
from parameterized import parameterized from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import Codes from synapse.api.errors import Codes
from synapse.handlers.pagination import PaginationHandler from synapse.handlers.pagination import PaginationHandler
from synapse.rest.client import directory, events, login, room from synapse.rest.client import directory, events, login, room
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest from tests import unittest
@ -39,7 +42,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
room.register_deprecated_servlets, room.register_deprecated_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
hs.config.consent.user_consent_version = "1" hs.config.consent.user_consent_version = "1"
@ -455,7 +458,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase):
room.register_deprecated_servlets, room.register_deprecated_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
hs.config.consent.user_consent_version = "1" hs.config.consent.user_consent_version = "1"
@ -1062,12 +1065,12 @@ class RoomTestCase(unittest.HomeserverTestCase):
directory.register_servlets, directory.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# Create user # Create user
self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass") self.admin_user_tok = self.login("admin", "pass")
def test_list_rooms(self): def test_list_rooms(self) -> None:
"""Test that we can list rooms""" """Test that we can list rooms"""
# Create 3 test rooms # Create 3 test rooms
total_rooms = 3 total_rooms = 3
@ -1131,7 +1134,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
# We shouldn't receive a next token here as there's no further rooms to show # We shouldn't receive a next token here as there's no further rooms to show
self.assertNotIn("next_batch", channel.json_body) self.assertNotIn("next_batch", channel.json_body)
def test_list_rooms_pagination(self): def test_list_rooms_pagination(self) -> None:
"""Test that we can get a full list of rooms through pagination""" """Test that we can get a full list of rooms through pagination"""
# Create 5 test rooms # Create 5 test rooms
total_rooms = 5 total_rooms = 5
@ -1213,7 +1216,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
def test_correct_room_attributes(self): def test_correct_room_attributes(self) -> None:
"""Test the correct attributes for a room are returned""" """Test the correct attributes for a room are returned"""
# Create a test room # Create a test room
room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
@ -1294,7 +1297,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(test_room_name, r["name"]) self.assertEqual(test_room_name, r["name"])
self.assertEqual(test_alias, r["canonical_alias"]) self.assertEqual(test_alias, r["canonical_alias"])
def test_room_list_sort_order(self): def test_room_list_sort_order(self) -> None:
"""Test room list sort ordering. alphabetical name versus number of members, """Test room list sort ordering. alphabetical name versus number of members,
reversing the order, etc. reversing the order, etc.
""" """
@ -1303,7 +1306,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
order_type: str, order_type: str,
expected_room_list: List[str], expected_room_list: List[str],
reverse: bool = False, reverse: bool = False,
): ) -> None:
"""Request the list of rooms in a certain order. Assert that order is what """Request the list of rooms in a certain order. Assert that order is what
we expect we expect
@ -1432,7 +1435,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
_order_test("state_events", [room_id_3, room_id_2, room_id_1]) _order_test("state_events", [room_id_3, room_id_2, room_id_1])
_order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True) _order_test("state_events", [room_id_1, room_id_2, room_id_3], reverse=True)
def test_search_term(self): def test_search_term(self) -> None:
"""Test that searching for a room works correctly""" """Test that searching for a room works correctly"""
# Create two test rooms # Create two test rooms
room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
@ -1461,7 +1464,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
expected_room_id: Optional[str], expected_room_id: Optional[str],
search_term: str, search_term: str,
expected_http_code: int = HTTPStatus.OK, expected_http_code: int = HTTPStatus.OK,
): ) -> None:
"""Search for a room and check that the returned room's id is a match """Search for a room and check that the returned room's id is a match
Args: Args:
@ -1535,7 +1538,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
# Test search local part of alias # Test search local part of alias
_search_test(room_id_1, "alias1") _search_test(room_id_1, "alias1")
def test_search_term_non_ascii(self): def test_search_term_non_ascii(self) -> None:
"""Test that searching for a room with non-ASCII characters works correctly""" """Test that searching for a room with non-ASCII characters works correctly"""
# Create test room # Create test room
@ -1562,7 +1565,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(room_id, channel.json_body.get("rooms")[0].get("room_id")) self.assertEqual(room_id, channel.json_body.get("rooms")[0].get("room_id"))
self.assertEqual("ж", channel.json_body.get("rooms")[0].get("name")) self.assertEqual("ж", channel.json_body.get("rooms")[0].get("name"))
def test_single_room(self): def test_single_room(self) -> None:
"""Test that a single room can be requested correctly""" """Test that a single room can be requested correctly"""
# Create two test rooms # Create two test rooms
room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
@ -1613,7 +1616,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(room_id_1, channel.json_body["room_id"]) self.assertEqual(room_id_1, channel.json_body["room_id"])
def test_single_room_devices(self): def test_single_room_devices(self) -> None:
"""Test that `joined_local_devices` can be requested correctly""" """Test that `joined_local_devices` can be requested correctly"""
room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
@ -1652,7 +1655,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(0, channel.json_body["joined_local_devices"]) self.assertEqual(0, channel.json_body["joined_local_devices"])
def test_room_members(self): def test_room_members(self) -> None:
"""Test that room members can be requested correctly""" """Test that room members can be requested correctly"""
# Create two test rooms # Create two test rooms
room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) room_id_1 = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
@ -1700,7 +1703,7 @@ class RoomTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(channel.json_body["total"], 3) self.assertEqual(channel.json_body["total"], 3)
def test_room_state(self): def test_room_state(self) -> None:
"""Test that room state can be requested correctly""" """Test that room state can be requested correctly"""
# Create two test rooms # Create two test rooms
room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
@ -1717,7 +1720,9 @@ class RoomTestCase(unittest.HomeserverTestCase):
# the create_room already does the right thing, so no need to verify that we got # the create_room already does the right thing, so no need to verify that we got
# the state events it created. # the state events it created.
def _set_canonical_alias(self, room_id: str, test_alias: str, admin_user_tok: str): def _set_canonical_alias(
self, room_id: str, test_alias: str, admin_user_tok: str
) -> None:
# Create a new alias to this room # Create a new alias to this room
url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),) url = "/_matrix/client/r0/directory/room/%s" % (urllib.parse.quote(test_alias),)
channel = self.make_request( channel = self.make_request(
@ -1752,7 +1757,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass") self.admin_user_tok = self.login("admin", "pass")
@ -1767,7 +1772,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
) )
self.url = f"/_synapse/admin/v1/join/{self.public_room_id}" self.url = f"/_synapse/admin/v1/join/{self.public_room_id}"
def test_requester_is_no_admin(self): def test_requester_is_no_admin(self) -> None:
""" """
If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
""" """
@ -1782,7 +1787,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self): def test_invalid_parameter(self) -> None:
""" """
If a parameter is missing, return an error If a parameter is missing, return an error
""" """
@ -1797,7 +1802,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
def test_local_user_does_not_exist(self): def test_local_user_does_not_exist(self) -> None:
""" """
Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND
""" """
@ -1812,7 +1817,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_remote_user(self): def test_remote_user(self) -> None:
""" """
Check that only local user can join rooms. Check that only local user can join rooms.
""" """
@ -1830,7 +1835,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel.json_body["error"], channel.json_body["error"],
) )
def test_room_does_not_exist(self): def test_room_does_not_exist(self) -> None:
""" """
Check that unknown rooms/server return error HTTPStatus.NOT_FOUND. Check that unknown rooms/server return error HTTPStatus.NOT_FOUND.
""" """
@ -1846,7 +1851,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, msg=channel.json_body)
self.assertEqual("No known servers", channel.json_body["error"]) self.assertEqual("No known servers", channel.json_body["error"])
def test_room_is_not_valid(self): def test_room_is_not_valid(self) -> None:
""" """
Check that invalid room names, return an error HTTPStatus.BAD_REQUEST. Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.
""" """
@ -1865,7 +1870,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
channel.json_body["error"], channel.json_body["error"],
) )
def test_join_public_room(self): def test_join_public_room(self) -> None:
""" """
Test joining a local user to a public room with "JoinRules.PUBLIC" Test joining a local user to a public room with "JoinRules.PUBLIC"
""" """
@ -1890,7 +1895,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0]) self.assertEqual(self.public_room_id, channel.json_body["joined_rooms"][0])
def test_join_private_room_if_not_member(self): def test_join_private_room_if_not_member(self) -> None:
""" """
Test joining a local user to a private room with "JoinRules.INVITE" Test joining a local user to a private room with "JoinRules.INVITE"
when server admin is not member of this room. when server admin is not member of this room.
@ -1910,7 +1915,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_join_private_room_if_member(self): def test_join_private_room_if_member(self) -> None:
""" """
Test joining a local user to a private room with "JoinRules.INVITE", Test joining a local user to a private room with "JoinRules.INVITE",
when server admin is member of this room. when server admin is member of this room.
@ -1961,7 +1966,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
def test_join_private_room_if_owner(self): def test_join_private_room_if_owner(self) -> None:
""" """
Test joining a local user to a private room with "JoinRules.INVITE", Test joining a local user to a private room with "JoinRules.INVITE",
when server admin is owner of this room. when server admin is owner of this room.
@ -1991,7 +1996,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEquals(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0]) self.assertEqual(private_room_id, channel.json_body["joined_rooms"][0])
def test_context_as_non_admin(self): def test_context_as_non_admin(self) -> None:
""" """
Test that, without being admin, one cannot use the context admin API Test that, without being admin, one cannot use the context admin API
""" """
@ -2025,7 +2030,7 @@ class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
self.assertEquals(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body) self.assertEquals(HTTPStatus.FORBIDDEN, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_context_as_admin(self): def test_context_as_admin(self) -> None:
""" """
Test that, as admin, we can find the context of an event without having joined the room. Test that, as admin, we can find the context of an event without having joined the room.
""" """
@ -2081,7 +2086,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass") self.admin_user_tok = self.login("admin", "pass")
@ -2098,7 +2103,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
self.public_room_id self.public_room_id
) )
def test_public_room(self): def test_public_room(self) -> None:
"""Test that getting admin in a public room works.""" """Test that getting admin in a public room works."""
room_id = self.helper.create_room_as( room_id = self.helper.create_room_as(
self.creator, tok=self.creator_tok, is_public=True self.creator, tok=self.creator_tok, is_public=True
@ -2123,7 +2128,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
tok=self.admin_user_tok, tok=self.admin_user_tok,
) )
def test_private_room(self): def test_private_room(self) -> None:
"""Test that getting admin in a private room works and we get invited.""" """Test that getting admin in a private room works and we get invited."""
room_id = self.helper.create_room_as( room_id = self.helper.create_room_as(
self.creator, self.creator,
@ -2151,7 +2156,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
tok=self.admin_user_tok, tok=self.admin_user_tok,
) )
def test_other_user(self): def test_other_user(self) -> None:
"""Test that giving admin in a public room works to a non-admin user works.""" """Test that giving admin in a public room works to a non-admin user works."""
room_id = self.helper.create_room_as( room_id = self.helper.create_room_as(
self.creator, tok=self.creator_tok, is_public=True self.creator, tok=self.creator_tok, is_public=True
@ -2176,7 +2181,7 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
tok=self.second_tok, tok=self.second_tok,
) )
def test_not_enough_power(self): def test_not_enough_power(self) -> None:
"""Test that we get a sensible error if there are no local room admins.""" """Test that we get a sensible error if there are no local room admins."""
room_id = self.helper.create_room_as( room_id = self.helper.create_room_as(
self.creator, tok=self.creator_tok, is_public=True self.creator, tok=self.creator_tok, is_public=True
@ -2216,7 +2221,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self._store = hs.get_datastore() self._store = hs.get_datastore()
self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user = self.register_user("admin", "pass", admin=True)
@ -2231,7 +2236,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
self.url = "/_synapse/admin/v1/rooms/%s/block" self.url = "/_synapse/admin/v1/rooms/%s/block"
@parameterized.expand([("PUT",), ("GET",)]) @parameterized.expand([("PUT",), ("GET",)])
def test_requester_is_no_admin(self, method: str): def test_requester_is_no_admin(self, method: str) -> None:
"""If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.""" """If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned."""
channel = self.make_request( channel = self.make_request(
@ -2245,7 +2250,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@parameterized.expand([("PUT",), ("GET",)]) @parameterized.expand([("PUT",), ("GET",)])
def test_room_is_not_valid(self, method: str): def test_room_is_not_valid(self, method: str) -> None:
"""Check that invalid room names, return an error HTTPStatus.BAD_REQUEST.""" """Check that invalid room names, return an error HTTPStatus.BAD_REQUEST."""
channel = self.make_request( channel = self.make_request(
@ -2261,7 +2266,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
channel.json_body["error"], channel.json_body["error"],
) )
def test_block_is_not_valid(self): def test_block_is_not_valid(self) -> None:
"""If parameter `block` is not valid, return an error.""" """If parameter `block` is not valid, return an error."""
# `block` is not valid # `block` is not valid
@ -2296,7 +2301,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"]) self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"])
def test_block_room(self): def test_block_room(self) -> None:
"""Test that block a room is successful.""" """Test that block a room is successful."""
def _request_and_test_block_room(room_id: str) -> None: def _request_and_test_block_room(room_id: str) -> None:
@ -2320,7 +2325,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
# unknown remote room # unknown remote room
_request_and_test_block_room("!unknown:remote") _request_and_test_block_room("!unknown:remote")
def test_block_room_twice(self): def test_block_room_twice(self) -> None:
"""Test that block a room that is already blocked is successful.""" """Test that block a room that is already blocked is successful."""
self._is_blocked(self.room_id, expect=False) self._is_blocked(self.room_id, expect=False)
@ -2335,7 +2340,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
self.assertTrue(channel.json_body["block"]) self.assertTrue(channel.json_body["block"])
self._is_blocked(self.room_id, expect=True) self._is_blocked(self.room_id, expect=True)
def test_unblock_room(self): def test_unblock_room(self) -> None:
"""Test that unblock a room is successful.""" """Test that unblock a room is successful."""
def _request_and_test_unblock_room(room_id: str) -> None: def _request_and_test_unblock_room(room_id: str) -> None:
@ -2360,7 +2365,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
# unknown remote room # unknown remote room
_request_and_test_unblock_room("!unknown:remote") _request_and_test_unblock_room("!unknown:remote")
def test_unblock_room_twice(self): def test_unblock_room_twice(self) -> None:
"""Test that unblock a room that is not blocked is successful.""" """Test that unblock a room that is not blocked is successful."""
self._block_room(self.room_id) self._block_room(self.room_id)
@ -2375,7 +2380,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
self.assertFalse(channel.json_body["block"]) self.assertFalse(channel.json_body["block"])
self._is_blocked(self.room_id, expect=False) self._is_blocked(self.room_id, expect=False)
def test_get_blocked_room(self): def test_get_blocked_room(self) -> None:
"""Test get status of a blocked room""" """Test get status of a blocked room"""
def _request_blocked_room(room_id: str) -> None: def _request_blocked_room(room_id: str) -> None:
@ -2399,7 +2404,7 @@ class BlockRoomTestCase(unittest.HomeserverTestCase):
# unknown remote room # unknown remote room
_request_blocked_room("!unknown:remote") _request_blocked_room("!unknown:remote")
def test_get_unblocked_room(self): def test_get_unblocked_room(self) -> None:
"""Test get status of a unblocked room""" """Test get status of a unblocked room"""
def _request_unblocked_room(room_id: str) -> None: def _request_unblocked_room(room_id: str) -> None:

View File

@ -11,15 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from http import HTTPStatus from http import HTTPStatus
from typing import List from typing import List
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
from synapse.api.errors import Codes from synapse.api.errors import Codes
from synapse.rest.client import login, room, sync from synapse.rest.client import login, room, sync
from synapse.server import HomeServer
from synapse.storage.roommember import RoomsForUser from synapse.storage.roommember import RoomsForUser
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.unittest import override_config from tests.unittest import override_config
@ -34,7 +37,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
sync.register_servlets, sync.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.room_shutdown_handler = hs.get_room_shutdown_handler() self.room_shutdown_handler = hs.get_room_shutdown_handler()
self.pagination_handler = hs.get_pagination_handler() self.pagination_handler = hs.get_pagination_handler()
@ -49,7 +52,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
self.url = "/_synapse/admin/v1/send_server_notice" self.url = "/_synapse/admin/v1/send_server_notice"
def test_no_auth(self): def test_no_auth(self) -> None:
"""Try to send a server notice without authentication.""" """Try to send a server notice without authentication."""
channel = self.make_request("POST", self.url) channel = self.make_request("POST", self.url)
@ -60,7 +63,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self): def test_requester_is_no_admin(self) -> None:
"""If the user is not a server admin, an error is returned.""" """If the user is not a server admin, an error is returned."""
channel = self.make_request( channel = self.make_request(
"POST", "POST",
@ -76,7 +79,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@override_config({"server_notices": {"system_mxid_localpart": "notices"}}) @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
def test_user_does_not_exist(self): def test_user_does_not_exist(self) -> None:
"""Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND""" """Tests that a lookup for a user that does not exist returns a HTTPStatus.NOT_FOUND"""
channel = self.make_request( channel = self.make_request(
"POST", "POST",
@ -89,7 +92,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
@override_config({"server_notices": {"system_mxid_localpart": "notices"}}) @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
def test_user_is_not_local(self): def test_user_is_not_local(self) -> None:
""" """
Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST Tests that a lookup for a user that is not a local returns a HTTPStatus.BAD_REQUEST
""" """
@ -109,7 +112,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
) )
@override_config({"server_notices": {"system_mxid_localpart": "notices"}}) @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
def test_invalid_parameter(self): def test_invalid_parameter(self) -> None:
"""If parameters are invalid, an error is returned.""" """If parameters are invalid, an error is returned."""
# no content, no user # no content, no user
@ -157,7 +160,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
self.assertEqual("'msgtype' not in content", channel.json_body["error"]) self.assertEqual("'msgtype' not in content", channel.json_body["error"])
def test_server_notice_disabled(self): def test_server_notice_disabled(self) -> None:
"""Tests that server returns error if server notice is disabled""" """Tests that server returns error if server notice is disabled"""
channel = self.make_request( channel = self.make_request(
"POST", "POST",
@ -176,7 +179,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
) )
@override_config({"server_notices": {"system_mxid_localpart": "notices"}}) @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
def test_send_server_notice(self): def test_send_server_notice(self) -> None:
""" """
Tests that sending two server notices is successfully, Tests that sending two server notices is successfully,
the server uses the same room and do not send messages twice. the server uses the same room and do not send messages twice.
@ -240,7 +243,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
self.assertEqual(messages[1]["sender"], "@notices:test") self.assertEqual(messages[1]["sender"], "@notices:test")
@override_config({"server_notices": {"system_mxid_localpart": "notices"}}) @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
def test_send_server_notice_leave_room(self): def test_send_server_notice_leave_room(self) -> None:
""" """
Tests that sending a server notices is successfully. Tests that sending a server notices is successfully.
The user leaves the room and the second message appears The user leaves the room and the second message appears
@ -324,7 +327,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
self.assertNotEqual(first_room_id, second_room_id) self.assertNotEqual(first_room_id, second_room_id)
@override_config({"server_notices": {"system_mxid_localpart": "notices"}}) @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
def test_send_server_notice_delete_room(self): def test_send_server_notice_delete_room(self) -> None:
""" """
Tests that the user get server notice in a new room Tests that the user get server notice in a new room
after the first server notice room was deleted. after the first server notice room was deleted.
@ -414,7 +417,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
def _check_invite_and_join_status( def _check_invite_and_join_status(
self, user_id: str, expected_invites: int, expected_memberships: int self, user_id: str, expected_invites: int, expected_memberships: int
) -> RoomsForUser: ) -> List[RoomsForUser]:
"""Check invite and room membership status of a user. """Check invite and room membership status of a user.
Args Args

View File

@ -12,13 +12,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from http import HTTPStatus from http import HTTPStatus
from typing import Any, Dict, List, Optional from typing import List, Optional
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
from synapse.api.errors import Codes from synapse.api.errors import Codes
from synapse.rest.client import login from synapse.rest.client import login
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import SMALL_PNG from tests.test_utils import SMALL_PNG
@ -30,7 +34,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.media_repo = hs.get_media_repository_resource() self.media_repo = hs.get_media_repository_resource()
self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user = self.register_user("admin", "pass", admin=True)
@ -41,7 +45,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url = "/_synapse/admin/v1/statistics/users/media" self.url = "/_synapse/admin/v1/statistics/users/media"
def test_no_auth(self): def test_no_auth(self) -> None:
""" """
Try to list users without authentication. Try to list users without authentication.
""" """
@ -54,7 +58,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self): def test_requester_is_no_admin(self) -> None:
""" """
If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned. If the user is not a server admin, an error HTTPStatus.FORBIDDEN is returned.
""" """
@ -72,7 +76,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_invalid_parameter(self): def test_invalid_parameter(self) -> None:
""" """
If parameters are invalid, an error is returned. If parameters are invalid, an error is returned.
""" """
@ -188,7 +192,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
) )
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
def test_limit(self): def test_limit(self) -> None:
""" """
Testing list of media with limit Testing list of media with limit
""" """
@ -206,7 +210,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["next_token"], 5) self.assertEqual(channel.json_body["next_token"], 5)
self._check_fields(channel.json_body["users"]) self._check_fields(channel.json_body["users"])
def test_from(self): def test_from(self) -> None:
""" """
Testing list of media with a defined starting point (from) Testing list of media with a defined starting point (from)
""" """
@ -224,7 +228,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
self._check_fields(channel.json_body["users"]) self._check_fields(channel.json_body["users"])
def test_limit_and_from(self): def test_limit_and_from(self) -> None:
""" """
Testing list of media with a defined starting point and limit Testing list of media with a defined starting point and limit
""" """
@ -242,7 +246,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(channel.json_body["users"]), 10) self.assertEqual(len(channel.json_body["users"]), 10)
self._check_fields(channel.json_body["users"]) self._check_fields(channel.json_body["users"])
def test_next_token(self): def test_next_token(self) -> None:
""" """
Testing that `next_token` appears at the right place Testing that `next_token` appears at the right place
""" """
@ -302,7 +306,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(channel.json_body["users"]), 1) self.assertEqual(len(channel.json_body["users"]), 1)
self.assertNotIn("next_token", channel.json_body) self.assertNotIn("next_token", channel.json_body)
def test_no_media(self): def test_no_media(self) -> None:
""" """
Tests that a normal lookup for statistics is successfully Tests that a normal lookup for statistics is successfully
if users have no media created if users have no media created
@ -318,7 +322,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.assertEqual(0, channel.json_body["total"]) self.assertEqual(0, channel.json_body["total"])
self.assertEqual(0, len(channel.json_body["users"])) self.assertEqual(0, len(channel.json_body["users"]))
def test_order_by(self): def test_order_by(self) -> None:
""" """
Testing order list with parameter `order_by` Testing order list with parameter `order_by`
""" """
@ -396,7 +400,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
"b", "b",
) )
def test_from_until_ts(self): def test_from_until_ts(self) -> None:
""" """
Testing filter by time with parameters `from_ts` and `until_ts` Testing filter by time with parameters `from_ts` and `until_ts`
""" """
@ -448,7 +452,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["users"][0]["media_count"], 6) self.assertEqual(channel.json_body["users"][0]["media_count"], 6)
def test_search_term(self): def test_search_term(self) -> None:
self._create_users_with_media(20, 1) self._create_users_with_media(20, 1)
# check without filter get all users # check without filter get all users
@ -488,7 +492,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.json_body)
self.assertEqual(channel.json_body["total"], 0) self.assertEqual(channel.json_body["total"], 0)
def _create_users_with_media(self, number_users: int, media_per_user: int): def _create_users_with_media(self, number_users: int, media_per_user: int) -> None:
""" """
Create a number of users with a number of media Create a number of users with a number of media
Args: Args:
@ -500,7 +504,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
user_tok = self.login("foo_user_%s" % i, "pass") user_tok = self.login("foo_user_%s" % i, "pass")
self._create_media(user_tok, media_per_user) self._create_media(user_tok, media_per_user)
def _create_media(self, user_token: str, number_media: int): def _create_media(self, user_token: str, number_media: int) -> None:
""" """
Create a number of media for a specific user Create a number of media for a specific user
Args: Args:
@ -514,7 +518,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
upload_resource, SMALL_PNG, tok=user_token, expect_code=HTTPStatus.OK upload_resource, SMALL_PNG, tok=user_token, expect_code=HTTPStatus.OK
) )
def _check_fields(self, content: List[Dict[str, Any]]): def _check_fields(self, content: List[JsonDict]) -> None:
"""Checks that all attributes are present in content """Checks that all attributes are present in content
Args: Args:
content: List that is checked for content content: List that is checked for content
@ -527,7 +531,7 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
def _order_test( def _order_test(
self, order_type: str, expected_user_list: List[str], dir: Optional[str] = None self, order_type: str, expected_user_list: List[str], dir: Optional[str] = None
): ) -> None:
"""Request the list of users in a certain order. Assert that order is what """Request the list of users in a certain order. Assert that order is what
we expect we expect
Args: Args:

View File

@ -127,14 +127,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin") want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin")
want_mac = want_mac.hexdigest() want_mac_str = want_mac.hexdigest()
body = { body = {
"nonce": nonce, "nonce": nonce,
"username": "bob", "username": "bob",
"password": "abc123", "password": "abc123",
"admin": True, "admin": True,
"mac": want_mac, "mac": want_mac_str,
} }
channel = self.make_request("POST", self.url, body) channel = self.make_request("POST", self.url, body)
@ -153,7 +153,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
want_mac.update( want_mac.update(
nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support" nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support"
) )
want_mac = want_mac.hexdigest() want_mac_str = want_mac.hexdigest()
body = { body = {
"nonce": nonce, "nonce": nonce,
@ -161,7 +161,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"password": "abc123", "password": "abc123",
"admin": True, "admin": True,
"user_type": UserTypes.SUPPORT, "user_type": UserTypes.SUPPORT,
"mac": want_mac, "mac": want_mac_str,
} }
channel = self.make_request("POST", self.url, body) channel = self.make_request("POST", self.url, body)
@ -177,14 +177,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
want_mac.update(nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin") want_mac.update(nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin")
want_mac = want_mac.hexdigest() want_mac_str = want_mac.hexdigest()
body = { body = {
"nonce": nonce, "nonce": nonce,
"username": "bob", "username": "bob",
"password": "abc123", "password": "abc123",
"admin": True, "admin": True,
"mac": want_mac, "mac": want_mac_str,
} }
channel = self.make_request("POST", self.url, body) channel = self.make_request("POST", self.url, body)
@ -308,13 +308,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
want_mac.update(nonce.encode("ascii") + b"\x00bob1\x00abc123\x00notadmin") want_mac.update(nonce.encode("ascii") + b"\x00bob1\x00abc123\x00notadmin")
want_mac = want_mac.hexdigest() want_mac_str = want_mac.hexdigest()
body = { body = {
"nonce": nonce, "nonce": nonce,
"username": "bob1", "username": "bob1",
"password": "abc123", "password": "abc123",
"mac": want_mac, "mac": want_mac_str,
} }
channel = self.make_request("POST", self.url, body) channel = self.make_request("POST", self.url, body)
@ -332,14 +332,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
want_mac.update(nonce.encode("ascii") + b"\x00bob2\x00abc123\x00notadmin") want_mac.update(nonce.encode("ascii") + b"\x00bob2\x00abc123\x00notadmin")
want_mac = want_mac.hexdigest() want_mac_str = want_mac.hexdigest()
body = { body = {
"nonce": nonce, "nonce": nonce,
"username": "bob2", "username": "bob2",
"displayname": None, "displayname": None,
"password": "abc123", "password": "abc123",
"mac": want_mac, "mac": want_mac_str,
} }
channel = self.make_request("POST", self.url, body) channel = self.make_request("POST", self.url, body)
@ -356,14 +356,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
want_mac.update(nonce.encode("ascii") + b"\x00bob3\x00abc123\x00notadmin") want_mac.update(nonce.encode("ascii") + b"\x00bob3\x00abc123\x00notadmin")
want_mac = want_mac.hexdigest() want_mac_str = want_mac.hexdigest()
body = { body = {
"nonce": nonce, "nonce": nonce,
"username": "bob3", "username": "bob3",
"displayname": "", "displayname": "",
"password": "abc123", "password": "abc123",
"mac": want_mac, "mac": want_mac_str,
} }
channel = self.make_request("POST", self.url, body) channel = self.make_request("POST", self.url, body)
@ -379,14 +379,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
want_mac.update(nonce.encode("ascii") + b"\x00bob4\x00abc123\x00notadmin") want_mac.update(nonce.encode("ascii") + b"\x00bob4\x00abc123\x00notadmin")
want_mac = want_mac.hexdigest() want_mac_str = want_mac.hexdigest()
body = { body = {
"nonce": nonce, "nonce": nonce,
"username": "bob4", "username": "bob4",
"displayname": "Bob's Name", "displayname": "Bob's Name",
"password": "abc123", "password": "abc123",
"mac": want_mac, "mac": want_mac_str,
} }
channel = self.make_request("POST", self.url, body) channel = self.make_request("POST", self.url, body)
@ -426,7 +426,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
want_mac.update( want_mac.update(
nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support" nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin\x00support"
) )
want_mac = want_mac.hexdigest() want_mac_str = want_mac.hexdigest()
body = { body = {
"nonce": nonce, "nonce": nonce,
@ -434,7 +434,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
"password": "abc123", "password": "abc123",
"admin": True, "admin": True,
"user_type": UserTypes.SUPPORT, "user_type": UserTypes.SUPPORT,
"mac": want_mac, "mac": want_mac_str,
} }
channel = self.make_request("POST", self.url, body) channel = self.make_request("POST", self.url, body)
@ -870,7 +870,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertEqual(expected_user_list, returned_order) self.assertEqual(expected_user_list, returned_order)
self._check_fields(channel.json_body["users"]) self._check_fields(channel.json_body["users"])
def _check_fields(self, content: JsonDict): def _check_fields(self, content: List[JsonDict]):
"""Checks that the expected user attributes are present in content """Checks that the expected user attributes are present in content
Args: Args:
content: List that is checked for content content: List that is checked for content
@ -3235,7 +3235,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
return media_id return media_id
def _check_fields(self, content: JsonDict): def _check_fields(self, content: List[JsonDict]):
"""Checks that the expected user attributes are present in content """Checks that the expected user attributes are present in content
Args: Args:
content: List that is checked for content content: List that is checked for content