Replace make_awaitable with AsyncMock (#16179)

Python 3.8 provides a native AsyncMock, we can replace the
homegrown version we have.
This commit is contained in:
Patrick Cloke 2023-08-24 19:38:46 -04:00 committed by GitHub
parent 5856a8ba42
commit daf11e26ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 508 additions and 604 deletions

1
changelog.d/16179.misc Normal file
View File

@ -0,0 +1 @@
Use `AsyncMock` instead of custom code.

View File

@ -13,7 +13,7 @@
# limitations under the License.
import time
from typing import Any, Dict, List, Optional, cast
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock
import attr
import canonicaljson
@ -45,7 +45,6 @@ from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
from tests.unittest import logcontext_clean, override_config
@ -291,7 +290,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
with a null `ts_valid_until_ms`
"""
mock_fetcher = Mock()
mock_fetcher.get_keys = Mock(return_value=make_awaitable({}))
mock_fetcher.get_keys = AsyncMock(return_value={})
key1 = signedjson.key.generate_signing_key("1")
r = self.hs.get_datastores().main.store_server_signature_keys(

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock
from unittest.mock import AsyncMock
from synapse.api.errors import Codes, SynapseError
from synapse.rest import admin
@ -20,7 +20,6 @@ from synapse.rest.client import login, room
from synapse.types import JsonDict, UserID, create_requester
from tests import unittest
from tests.test_utils import make_awaitable
class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
@ -75,9 +74,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[assignment]
handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[assignment]
return_value=("", 1)
)
d = handler._remote_join(
@ -106,9 +105,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[assignment]
handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[assignment]
return_value=("", 1)
)
d = handler._remote_join(
@ -143,9 +142,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
fed_transport.client.get_json = AsyncMock(return_value=None) # type: ignore[assignment]
handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[assignment]
return_value=("", 1)
)
# Artificially raise the complexity
@ -200,9 +199,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[assignment]
handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[assignment]
return_value=("", 1)
)
d = handler._remote_join(
@ -230,9 +229,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[assignment]
handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[assignment]
return_value=("", 1)
)
d = handler._remote_join(

View File

@ -1,6 +1,6 @@
from typing import Callable, Collection, List, Optional, Tuple
from unittest import mock
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock
from twisted.test.proto_helpers import MemoryReactor
@ -19,7 +19,7 @@ from synapse.types import JsonDict
from synapse.util import Clock
from synapse.util.retryutils import NotRetryingDestination
from tests.test_utils import event_injection, make_awaitable
from tests.test_utils import event_injection
from tests.unittest import FederatingHomeserverTestCase
@ -50,8 +50,8 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
# This mock is crucial for destination_rooms to be populated.
# TODO: this seems to no longer be the case---tests pass with this mock
# commented out.
state_storage_controller.get_current_hosts_in_room = Mock( # type: ignore[assignment]
return_value=make_awaitable({"test", "host2"})
state_storage_controller.get_current_hosts_in_room = AsyncMock( # type: ignore[assignment]
return_value={"test", "host2"}
)
# whenever send_transaction is called, record the pdu data

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, FrozenSet, List, Optional, Set
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock
from signedjson import key, sign
from signedjson.types import BaseKey, SigningKey
@ -29,7 +29,6 @@ from synapse.server import HomeServer
from synapse.types import JsonDict, ReadReceipt
from synapse.util import Clock
from tests.test_utils import make_awaitable
from tests.unittest import HomeserverTestCase
@ -43,12 +42,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.federation_transport_client = Mock(spec=["send_transaction"])
self.federation_transport_client.send_transaction = AsyncMock()
hs = self.setup_test_homeserver(
federation_transport_client=self.federation_transport_client,
)
hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment]
return_value=make_awaitable({"test", "host2"})
hs.get_storage_controllers().state.get_current_hosts_in_room = AsyncMock( # type: ignore[assignment]
return_value={"test", "host2"}
)
hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = ( # type: ignore[assignment]
@ -64,7 +64,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
def test_send_receipts(self) -> None:
mock_send_transaction = self.federation_transport_client.send_transaction
mock_send_transaction.return_value = make_awaitable({})
mock_send_transaction.return_value = {}
sender = self.hs.get_federation_sender()
receipt = ReadReceipt(
@ -104,7 +104,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
def test_send_receipts_thread(self) -> None:
mock_send_transaction = self.federation_transport_client.send_transaction
mock_send_transaction.return_value = make_awaitable({})
mock_send_transaction.return_value = {}
# Create receipts for:
#
@ -180,7 +180,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
"""Send two receipts in quick succession; the second should be flushed, but
only after 20ms"""
mock_send_transaction = self.federation_transport_client.send_transaction
mock_send_transaction.return_value = make_awaitable({})
mock_send_transaction.return_value = {}
sender = self.hs.get_federation_sender()
receipt = ReadReceipt(
@ -276,6 +276,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.federation_transport_client = Mock(
spec=["send_transaction", "query_user_devices"]
)
self.federation_transport_client.send_transaction = AsyncMock()
self.federation_transport_client.query_user_devices = AsyncMock()
return self.setup_test_homeserver(
federation_transport_client=self.federation_transport_client,
)
@ -317,13 +319,13 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.record_transaction
)
def record_transaction(
async def record_transaction(
self, txn: Transaction, json_cb: Optional[Callable[[], JsonDict]] = None
) -> "defer.Deferred[JsonDict]":
) -> JsonDict:
assert json_cb is not None
data = json_cb()
self.edus.extend(data["edus"])
return defer.succeed({})
return {}
def test_send_device_updates(self) -> None:
"""Basic case: each device update should result in an EDU"""
@ -354,15 +356,11 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# Send the server a device list EDU for the other user, this will cause
# it to try and resync the device lists.
self.federation_transport_client.query_user_devices.return_value = (
make_awaitable(
{
self.federation_transport_client.query_user_devices.return_value = {
"stream_id": "1",
"user_id": "@user2:host2",
"devices": [{"device_id": "D1"}],
}
)
)
self.get_success(
self.device_handler.device_list_updater.incoming_device_list_update(
@ -533,7 +531,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
recovery
"""
mock_send_txn = self.federation_transport_client.send_transaction
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
mock_send_txn.side_effect = AssertionError("fail")
# create devices
u1 = self.register_user("user", "pass")
@ -578,7 +576,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
This case tests the behaviour when the server has never been reachable.
"""
mock_send_txn = self.federation_transport_client.send_transaction
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
mock_send_txn.side_effect = AssertionError("fail")
# create devices
u1 = self.register_user("user", "pass")
@ -636,7 +634,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# now the server goes offline
mock_send_txn = self.federation_transport_client.send_transaction
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
mock_send_txn.side_effect = AssertionError("fail")
self.login("user", "pass", device_id="D2")
self.login("user", "pass", device_id="D3")

View File

@ -13,7 +13,7 @@
# limitations under the License.
from typing import Dict, Iterable, List, Optional
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock
from parameterized import parameterized
@ -36,7 +36,7 @@ from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
from tests.test_utils import event_injection, make_awaitable, simple_async_mock
from tests.test_utils import event_injection, simple_async_mock
from tests.unittest import override_config
from tests.utils import MockClock
@ -46,15 +46,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def setUp(self) -> None:
self.mock_store = Mock()
self.mock_as_api = Mock()
self.mock_as_api = AsyncMock()
self.mock_scheduler = Mock()
hs = Mock()
hs.get_datastores.return_value = Mock(main=self.mock_store)
self.mock_store.get_appservice_last_pos.return_value = make_awaitable(None)
self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None)
self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable(
None
)
self.mock_store.get_appservice_last_pos = AsyncMock(return_value=None)
self.mock_store.set_appservice_last_pos = AsyncMock(return_value=None)
self.mock_store.set_appservice_stream_type_pos = AsyncMock(return_value=None)
hs.get_application_service_api.return_value = self.mock_as_api
hs.get_application_service_scheduler.return_value = self.mock_scheduler
hs.get_clock.return_value = MockClock()
@ -69,21 +67,25 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self._mkservice(is_interested_in_event=False),
]
self.mock_as_api.query_user.return_value = make_awaitable(True)
self.mock_as_api.query_user.return_value = True
self.mock_store.get_app_services.return_value = services
self.mock_store.get_user_by_id.return_value = make_awaitable([])
self.mock_store.get_user_by_id = AsyncMock(return_value=[])
event = Mock(
sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
)
self.mock_store.get_all_new_event_ids_stream.side_effect = [
make_awaitable((0, {})),
make_awaitable((1, {event.event_id: 0})),
self.mock_store.get_all_new_event_ids_stream = AsyncMock(
side_effect=[
(0, {}),
(1, {event.event_id: 0}),
]
self.mock_store.get_events_as_list.side_effect = [
make_awaitable([]),
make_awaitable([event]),
)
self.mock_store.get_events_as_list = AsyncMock(
side_effect=[
[],
[event],
]
)
self.handler.notify_interested_services(RoomStreamToken(None, 1))
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
@ -95,14 +97,16 @@ class AppServiceHandlerTestCase(unittest.TestCase):
services = [self._mkservice(is_interested_in_event=True)]
services[0].is_interested_in_user.return_value = True
self.mock_store.get_app_services.return_value = services
self.mock_store.get_user_by_id.return_value = make_awaitable(None)
self.mock_store.get_user_by_id = AsyncMock(return_value=None)
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
self.mock_as_api.query_user.return_value = make_awaitable(True)
self.mock_store.get_all_new_event_ids_stream.side_effect = [
make_awaitable((0, {event.event_id: 0})),
self.mock_as_api.query_user.return_value = True
self.mock_store.get_all_new_event_ids_stream = AsyncMock(
side_effect=[
(0, {event.event_id: 0}),
]
self.mock_store.get_events_as_list.side_effect = [make_awaitable([event])]
)
self.mock_store.get_events_as_list = AsyncMock(side_effect=[[event]])
self.handler.notify_interested_services(RoomStreamToken(None, 0))
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
@ -112,13 +116,15 @@ class AppServiceHandlerTestCase(unittest.TestCase):
services = [self._mkservice(is_interested_in_event=True)]
services[0].is_interested_in_user.return_value = True
self.mock_store.get_app_services.return_value = services
self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id})
self.mock_store.get_user_by_id = AsyncMock(return_value={"name": user_id})
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
self.mock_as_api.query_user.return_value = make_awaitable(True)
self.mock_store.get_all_new_event_ids_stream.side_effect = [
make_awaitable((0, [event], {event.event_id: 0})),
self.mock_as_api.query_user.return_value = True
self.mock_store.get_all_new_event_ids_stream = AsyncMock(
side_effect=[
(0, [event], {event.event_id: 0}),
]
)
self.handler.notify_interested_services(RoomStreamToken(None, 0))
@ -141,10 +147,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self._mkservice_alias(is_room_alias_in_namespace=False),
]
self.mock_as_api.query_alias.return_value = make_awaitable(True)
self.mock_as_api.query_alias = AsyncMock(return_value=True)
self.mock_store.get_app_services.return_value = services
self.mock_store.get_association_from_room_alias.return_value = make_awaitable(
Mock(room_id=room_id, servers=servers)
self.mock_store.get_association_from_room_alias = AsyncMock(
return_value=Mock(room_id=room_id, servers=servers)
)
result = self.successResultOf(
@ -177,7 +183,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def test_get_3pe_protocols_protocol_no_response(self) -> None:
service = self._mkservice(False, ["my-protocol"])
self.mock_store.get_app_services.return_value = [service]
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None)
self.mock_as_api.get_3pe_protocol.return_value = None
response = self.successResultOf(
defer.ensureDeferred(self.handler.get_3pe_protocols())
)
@ -189,9 +195,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def test_get_3pe_protocols_select_one_protocol(self) -> None:
service = self._mkservice(False, ["my-protocol"])
self.mock_store.get_app_services.return_value = [service]
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
{"x-protocol-data": 42, "instances": []}
)
self.mock_as_api.get_3pe_protocol.return_value = {
"x-protocol-data": 42,
"instances": [],
}
response = self.successResultOf(
defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol"))
)
@ -205,9 +212,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def test_get_3pe_protocols_one_protocol(self) -> None:
service = self._mkservice(False, ["my-protocol"])
self.mock_store.get_app_services.return_value = [service]
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
{"x-protocol-data": 42, "instances": []}
)
self.mock_as_api.get_3pe_protocol.return_value = {
"x-protocol-data": 42,
"instances": [],
}
response = self.successResultOf(
defer.ensureDeferred(self.handler.get_3pe_protocols())
)
@ -222,9 +230,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
service_one = self._mkservice(False, ["my-protocol"])
service_two = self._mkservice(False, ["other-protocol"])
self.mock_store.get_app_services.return_value = [service_one, service_two]
self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
{"x-protocol-data": 42, "instances": []}
)
self.mock_as_api.get_3pe_protocol.return_value = {
"x-protocol-data": 42,
"instances": [],
}
response = self.successResultOf(
defer.ensureDeferred(self.handler.get_3pe_protocols())
)
@ -287,13 +296,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
interested_service = self._mkservice(is_interested_in_event=True)
services = [interested_service]
self.mock_store.get_app_services.return_value = services
self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
579
)
self.mock_store.get_type_stream_id_for_appservice = AsyncMock(return_value=579)
event = Mock(event_id="event_1")
self.event_source.sources.receipt.get_new_events_as.return_value = (
make_awaitable(([event], None))
self.event_source.sources.receipt.get_new_events_as = AsyncMock(
return_value=([event], None)
)
self.handler.notify_interested_services_ephemeral(
@ -317,13 +324,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
services = [interested_service]
self.mock_store.get_app_services.return_value = services
self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
580
)
self.mock_store.get_type_stream_id_for_appservice = AsyncMock(return_value=580)
event = Mock(event_id="event_1")
self.event_source.sources.receipt.get_new_events_as.return_value = (
make_awaitable(([event], None))
self.event_source.sources.receipt.get_new_events_as = AsyncMock(
return_value=([event], None)
)
self.handler.notify_interested_services_ephemeral(
@ -350,9 +355,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
A mock representing the ApplicationService.
"""
service = Mock()
service.is_interested_in_event.return_value = make_awaitable(
is_interested_in_event
)
service.is_interested_in_event = AsyncMock(return_value=is_interested_in_event)
service.token = "mock_service_token"
service.url = "mock_service_url"
service.protocols = protocols

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from unittest.mock import Mock
from unittest.mock import AsyncMock
import pymacaroons
@ -25,7 +25,6 @@ from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
class AuthTestCase(unittest.HomeserverTestCase):
@ -166,8 +165,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
def test_mau_limits_exceeded_large(self) -> None:
self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastores().main.get_monthly_active_count = Mock(
return_value=make_awaitable(self.large_number_of_users)
self.hs.get_datastores().main.get_monthly_active_count = AsyncMock(
return_value=self.large_number_of_users
)
self.get_failure(
@ -177,8 +176,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
ResourceLimitError,
)
self.hs.get_datastores().main.get_monthly_active_count = Mock(
return_value=make_awaitable(self.large_number_of_users)
self.hs.get_datastores().main.get_monthly_active_count = AsyncMock(
return_value=self.large_number_of_users
)
token = self.get_success(
self.auth_handler.create_login_token_for_user_id(self.user1)
@ -191,8 +190,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.auth_blocking._limit_usage_by_mau = True
# Set the server to be at the edge of too many users.
self.hs.get_datastores().main.get_monthly_active_count = Mock(
return_value=make_awaitable(self.auth_blocking._max_mau_value)
self.hs.get_datastores().main.get_monthly_active_count = AsyncMock(
return_value=self.auth_blocking._max_mau_value
)
# If not in monthly active cohort
@ -208,8 +207,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertIsNone(self.token_login(token))
# If in monthly active cohort
self.hs.get_datastores().main.user_last_seen_monthly_active = Mock(
return_value=make_awaitable(self.clock.time_msec())
self.hs.get_datastores().main.user_last_seen_monthly_active = AsyncMock(
return_value=self.clock.time_msec()
)
self.get_success(
self.auth_handler.create_access_token_for_user_id(
@ -224,8 +223,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
def test_mau_limits_not_exceeded(self) -> None:
self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastores().main.get_monthly_active_count = Mock(
return_value=make_awaitable(self.small_number_of_users)
self.hs.get_datastores().main.get_monthly_active_count = AsyncMock(
return_value=self.small_number_of_users
)
# Ensure does not raise exception
self.get_success(
@ -234,8 +233,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
)
self.hs.get_datastores().main.get_monthly_active_count = Mock(
return_value=make_awaitable(self.small_number_of_users)
self.hs.get_datastores().main.get_monthly_active_count = AsyncMock(
return_value=self.small_number_of_users
)
token = self.get_success(
self.auth_handler.create_login_token_for_user_id(self.user1)

View File

@ -32,7 +32,6 @@ from synapse.types import JsonDict, create_requester
from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
from tests.unittest import override_config
user1 = "@boris:aaa"
@ -41,7 +40,7 @@ user2 = "@theresa:bbb"
class DeviceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.appservice_api = mock.Mock()
self.appservice_api = mock.AsyncMock()
hs = self.setup_test_homeserver(
"server",
application_service_api=self.appservice_api,
@ -375,13 +374,11 @@ class DeviceTestCase(unittest.HomeserverTestCase):
)
# Setup a response.
self.appservice_api.query_keys.return_value = make_awaitable(
{
self.appservice_api.query_keys.return_value = {
"device_keys": {
local_user: {device_2: device_key_2b, device_3: device_key_3}
}
}
)
# Request all devices.
res = self.get_success(

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Awaitable, Callable, Dict
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock
from twisted.test.proto_helpers import MemoryReactor
@ -27,14 +27,13 @@ from synapse.types import JsonDict, RoomAlias, create_requester
from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
class DirectoryTestCase(unittest.HomeserverTestCase):
"""Tests the directory service."""
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.mock_federation = Mock()
self.mock_federation = AsyncMock()
self.mock_registry = Mock()
self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
@ -73,9 +72,10 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
def test_get_remote_association(self) -> None:
self.mock_federation.make_query.return_value = make_awaitable(
{"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
)
self.mock_federation.make_query.return_value = {
"room_id": "!8765qwer:test",
"servers": ["test", "remote"],
}
result = self.get_success(self.handler.get_association(self.remote_room))

View File

@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable
from typing import Dict, Iterable
from unittest import mock
from parameterized import parameterized
@ -31,13 +31,12 @@ from synapse.types import JsonDict, UserID
from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
from tests.unittest import override_config
class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.appservice_api = mock.Mock()
self.appservice_api = mock.AsyncMock()
return self.setup_test_homeserver(
federation_client=mock.Mock(), application_service_api=self.appservice_api
)
@ -801,9 +800,8 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
self.hs.get_federation_client().query_client_keys = mock.Mock( # type: ignore[assignment]
return_value=make_awaitable(
{
self.hs.get_federation_client().query_client_keys = mock.AsyncMock( # type: ignore[assignment]
return_value={
"device_keys": {remote_user_id: {}},
"master_keys": {
remote_user_id: {
@ -824,7 +822,6 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
},
}
)
)
e2e_handler = self.hs.get_e2e_keys_handler()
@ -874,16 +871,13 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# Pretend we're sharing a room with the user we're querying. If not,
# `_query_devices_for_destination` will return early.
self.store.get_rooms_for_user = mock.Mock(
return_value=make_awaitable({"some_room_id"})
)
self.store.get_rooms_for_user = mock.AsyncMock(return_value={"some_room_id"})
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
self.hs.get_federation_client().query_user_devices = mock.Mock( # type: ignore[assignment]
return_value=make_awaitable(
{
self.hs.get_federation_client().query_user_devices = mock.AsyncMock( # type: ignore[assignment]
return_value={
"user_id": remote_user_id,
"stream_id": 1,
"devices": [],
@ -896,13 +890,11 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
"user_id": remote_user_id,
"usage": ["self_signing"],
"keys": {
"ed25519:"
+ remote_self_signing_key: remote_self_signing_key
"ed25519:" + remote_self_signing_key: remote_self_signing_key
},
},
}
)
)
e2e_handler = self.hs.get_e2e_keys_handler()
@ -987,20 +979,20 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
mock_get_rooms = mock.patch.object(
self.store,
"get_rooms_for_user",
new_callable=mock.MagicMock,
return_value=make_awaitable(["some_room_id"]),
new_callable=mock.AsyncMock,
return_value=["some_room_id"],
)
mock_get_users = mock.patch.object(
self.store,
"get_users_server_still_shares_room_with",
new_callable=mock.MagicMock,
return_value=make_awaitable({remote_user_id}),
new_callable=mock.AsyncMock,
return_value={remote_user_id},
)
mock_request = mock.patch.object(
self.hs.get_federation_client(),
"query_user_devices",
new_callable=mock.MagicMock,
return_value=make_awaitable(response_body),
new_callable=mock.AsyncMock,
return_value=response_body,
)
with mock_get_rooms, mock_get_users, mock_request as mocked_federation_request:
@ -1060,8 +1052,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
# Setup a response, but only for device 2.
self.appservice_api.claim_client_keys.return_value = make_awaitable(
({local_user: {device_id_2: otk}}, [(local_user, device_id_1, "alg1", 1)])
self.appservice_api.claim_client_keys.return_value = (
{local_user: {device_id_2: otk}},
[(local_user, device_id_1, "alg1", 1)],
)
# we shouldn't have any unused fallback keys yet
@ -1127,9 +1120,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
# Setup a response.
self.appservice_api.claim_client_keys.return_value = make_awaitable(
({local_user: {device_id_1: {**as_otk, **as_fallback_key}}}, [])
)
response: Dict[str, Dict[str, Dict[str, JsonDict]]] = {
local_user: {device_id_1: {**as_otk, **as_fallback_key}}
}
self.appservice_api.claim_client_keys.return_value = (response, [])
# Claim OTKs, which will ask the appservice and do nothing else.
claim_res = self.get_success(
@ -1171,8 +1165,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
self.assertEqual(fallback_res, ["alg1"])
# The appservice will return only the OTK.
self.appservice_api.claim_client_keys.return_value = make_awaitable(
({local_user: {device_id_1: as_otk}}, [])
self.appservice_api.claim_client_keys.return_value = (
{local_user: {device_id_1: as_otk}},
[],
)
# Claim OTKs, which should return the OTK from the appservice and the
@ -1234,8 +1229,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
self.assertEqual(fallback_res, ["alg1"])
# Finally, return only the fallback key from the appservice.
self.appservice_api.claim_client_keys.return_value = make_awaitable(
({local_user: {device_id_1: as_fallback_key}}, [])
self.appservice_api.claim_client_keys.return_value = (
{local_user: {device_id_1: as_fallback_key}},
[],
)
# Claim OTKs, which will return only the fallback key from the database.
@ -1350,13 +1346,11 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
# Setup a response.
self.appservice_api.query_keys.return_value = make_awaitable(
{
self.appservice_api.query_keys.return_value = {
"device_keys": {
local_user: {device_2: device_key_2b, device_3: device_key_3}
}
}
)
# Request all devices.
res = self.get_success(self.handler.query_local_devices({local_user: None}))

View File

@ -14,7 +14,7 @@
import logging
from typing import Collection, Optional, cast
from unittest import TestCase
from unittest.mock import Mock, patch
from unittest.mock import AsyncMock, Mock, patch
from twisted.internet.defer import Deferred
from twisted.test.proto_helpers import MemoryReactor
@ -40,7 +40,7 @@ from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
from tests.test_utils import event_injection, make_awaitable
from tests.test_utils import event_injection
logger = logging.getLogger(__name__)
@ -370,7 +370,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
# We mock out the FederationClient.backfill method, to pretend that a remote
# server has returned our fake event.
federation_client_backfill_mock = Mock(return_value=make_awaitable([event]))
federation_client_backfill_mock = AsyncMock(return_value=[event])
self.hs.get_federation_client().backfill = federation_client_backfill_mock # type: ignore[assignment]
# We also mock the persist method with a side effect of itself. This allows us
@ -631,18 +631,15 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
},
RoomVersions.V10,
)
mock_make_membership_event = Mock(
return_value=make_awaitable(
(
mock_make_membership_event = AsyncMock(
return_value=(
"example.com",
membership_event,
RoomVersions.V10,
)
)
)
mock_send_join = Mock(
return_value=make_awaitable(
SendJoinResult(
mock_send_join = AsyncMock(
return_value=SendJoinResult(
membership_event,
"example.com",
state=[
@ -659,7 +656,6 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
servers_in_room={"example.com"},
)
)
)
with patch.object(
fed_client, "make_membership_event", mock_make_membership_event

View File

@ -35,7 +35,7 @@ from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
from tests.test_utils import event_injection, make_awaitable
from tests.test_utils import event_injection
class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
@ -50,6 +50,10 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
self.mock_federation_transport_client = mock.Mock(
spec=["get_room_state_ids", "get_room_state", "get_event", "backfill"]
)
self.mock_federation_transport_client.get_room_state_ids = mock.AsyncMock()
self.mock_federation_transport_client.get_room_state = mock.AsyncMock()
self.mock_federation_transport_client.get_event = mock.AsyncMock()
self.mock_federation_transport_client.backfill = mock.AsyncMock()
return super().setup_test_homeserver(
federation_transport_client=self.mock_federation_transport_client
)
@ -198,21 +202,15 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
)
# we expect an outbound request to /state_ids, so stub that out
self.mock_federation_transport_client.get_room_state_ids.return_value = (
make_awaitable(
{
self.mock_federation_transport_client.get_room_state_ids.return_value = {
"pdu_ids": [e.event_id for e in state_at_prev_event],
"auth_chain_ids": [],
}
)
)
# we also expect an outbound request to /state
self.mock_federation_transport_client.get_room_state.return_value = (
make_awaitable(
StateRequestResponse(auth_events=[], state=state_at_prev_event)
)
)
# we have to bump the clock a bit, to keep the retry logic in
# FederationClient.get_pdu happy
@ -273,8 +271,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
room_version = self.get_success(main_store.get_room_version(room_id))
# We expect an outbound request to /state_ids, so stub that out
self.mock_federation_transport_client.get_room_state_ids.return_value = make_awaitable(
{
self.mock_federation_transport_client.get_room_state_ids.return_value = {
# Mimic the other server not knowing about the state at all.
# We want to cause Synapse to throw an error (`Unable to get
# missing prev_event $fake_prev_event`) and fail to backfill
@ -282,10 +279,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
"pdu_ids": [],
"auth_chain_ids": [],
}
)
# We also expect an outbound request to /state
self.mock_federation_transport_client.get_room_state.return_value = make_awaitable(
StateRequestResponse(
self.mock_federation_transport_client.get_room_state.return_value = StateRequestResponse(
# Mimic the other server not knowing about the state at all.
# We want to cause Synapse to throw an error (`Unable to get
# missing prev_event $fake_prev_event`) and fail to backfill
@ -293,7 +289,6 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
auth_events=[],
state=[],
)
)
pulled_event = make_event_from_dict(
self.add_hashes_and_signatures_from_other_server(
@ -545,8 +540,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
)
# We expect an outbound request to /backfill, so stub that out
self.mock_federation_transport_client.backfill.return_value = make_awaitable(
{
self.mock_federation_transport_client.backfill.return_value = {
"origin": self.OTHER_SERVER_NAME,
"origin_server_ts": 123,
"pdus": [
@ -563,7 +557,6 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
pulled_event.get_pdu_json(),
],
}
)
# Keep track of the count and make sure we don't make any of these requests
event_endpoint_requested_count = 0
@ -731,15 +724,13 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
)
# We expect an outbound request to /backfill, so stub that out
self.mock_federation_transport_client.backfill.return_value = make_awaitable(
{
self.mock_federation_transport_client.backfill.return_value = {
"origin": self.OTHER_SERVER_NAME,
"origin_server_ts": 123,
"pdus": [
pulled_event.get_pdu_json(),
],
}
)
# The function under test: try to backfill and process the pulled event
with LoggingContext("test"):

View File

@ -16,7 +16,7 @@
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Type, Union
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock
from twisted.test.proto_helpers import MemoryReactor
@ -32,7 +32,6 @@ from synapse.util import Clock
from tests import unittest
from tests.server import FakeChannel
from tests.test_utils import make_awaitable
from tests.unittest import override_config
# Login flows we expect to appear in the list after the normal ones.
@ -187,7 +186,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
# check_password must return an awaitable
mock_password_provider.check_password.return_value = make_awaitable(True)
mock_password_provider.check_password = AsyncMock(return_value=True)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@u:test", channel.json_body["user_id"])
@ -209,13 +208,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"""UI Auth should delegate correctly to the password provider"""
# log in twice, to get two devices
mock_password_provider.check_password.return_value = make_awaitable(True)
mock_password_provider.check_password = AsyncMock(return_value=True)
tok1 = self.login("u", "p")
self.login("u", "p", device_id="dev2")
mock_password_provider.reset_mock()
# have the auth provider deny the request to start with
mock_password_provider.check_password.return_value = make_awaitable(False)
mock_password_provider.check_password = AsyncMock(return_value=False)
# make the initial request which returns a 401
session = self._start_delete_device_session(tok1, "dev2")
@ -229,7 +228,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock()
# Finally, check the request goes through when we allow it
mock_password_provider.check_password.return_value = make_awaitable(True)
mock_password_provider.check_password = AsyncMock(return_value=True)
channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
self.assertEqual(channel.code, 200)
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
@ -243,7 +242,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.register_user("localuser", "localpass")
# check_password must return an awaitable
mock_password_provider.check_password.return_value = make_awaitable(False)
mock_password_provider.check_password = AsyncMock(return_value=False)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
@ -260,7 +259,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.register_user("localuser", "localpass")
# have the auth provider deny the request
mock_password_provider.check_password.return_value = make_awaitable(False)
mock_password_provider.check_password = AsyncMock(return_value=False)
# log in twice, to get two devices
tok1 = self.login("localuser", "localpass")
@ -303,7 +302,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.register_user("localuser", "localpass")
# check_password must return an awaitable
mock_password_provider.check_password.return_value = make_awaitable(False)
mock_password_provider.check_password = AsyncMock(return_value=False)
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@ -325,7 +324,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.register_user("localuser", "localpass")
# allow login via the auth provider
mock_password_provider.check_password.return_value = make_awaitable(True)
mock_password_provider.check_password = AsyncMock(return_value=True)
# log in twice, to get two devices
tok1 = self.login("localuser", "p")
@ -342,7 +341,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.check_password.assert_not_called()
# now try deleting with the local password
mock_password_provider.check_password.return_value = make_awaitable(False)
mock_password_provider.check_password = AsyncMock(return_value=False)
channel = self._authed_delete_device(
tok1, "dev2", session, "localuser", "localpass"
)
@ -396,9 +395,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_auth.assert_not_called()
mock_password_provider.check_auth.return_value = make_awaitable(
("@user:test", None)
)
mock_password_provider.check_auth = AsyncMock(return_value=("@user:test", None))
channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@user:test", channel.json_body["user_id"])
@ -447,9 +444,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock()
# right params, but authing as the wrong user
mock_password_provider.check_auth.return_value = make_awaitable(
("@user:test", None)
)
mock_password_provider.check_auth = AsyncMock(return_value=("@user:test", None))
body["auth"]["test_field"] = "foo"
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 403)
@ -460,8 +455,8 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock()
# and finally, succeed
mock_password_provider.check_auth.return_value = make_awaitable(
("@localuser:test", None)
mock_password_provider.check_auth = AsyncMock(
return_value=("@localuser:test", None)
)
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 200)
@ -478,10 +473,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.custom_auth_provider_callback_test_body()
def custom_auth_provider_callback_test_body(self) -> None:
callback = Mock(return_value=make_awaitable(None))
callback = AsyncMock(return_value=None)
mock_password_provider.check_auth.return_value = make_awaitable(
("@user:test", callback)
mock_password_provider.check_auth = AsyncMock(
return_value=("@user:test", callback)
)
channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
@ -616,8 +611,8 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
login is disabled"""
# register the user and log in twice via the test login type to get two devices,
self.register_user("localuser", "localpass")
mock_password_provider.check_auth.return_value = make_awaitable(
("@localuser:test", None)
mock_password_provider.check_auth = AsyncMock(
return_value=("@localuser:test", None)
)
channel = self._send_login("test.login_type", "localuser", test_field="")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
@ -835,11 +830,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
username: The username to use for the test.
registration: Whether to test with registration URLs.
"""
self.hs.get_identity_handler().send_threepid_validation = Mock( # type: ignore[assignment]
return_value=make_awaitable(0),
self.hs.get_identity_handler().send_threepid_validation = AsyncMock( # type: ignore[assignment]
return_value=0
)
m = Mock(return_value=make_awaitable(False))
m = AsyncMock(return_value=False)
self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
self.register_user(username, "password")
@ -869,7 +864,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
m.assert_called_once_with("email", "foo@test.com", registration)
m = Mock(return_value=make_awaitable(True))
m = AsyncMock(return_value=True)
self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
channel = self.make_request(

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Awaitable, Callable, Dict
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock
from parameterized import parameterized
@ -26,7 +26,6 @@ from synapse.types import JsonDict, UserID
from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
class ProfileTestCase(unittest.HomeserverTestCase):
@ -35,7 +34,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
servlets = [admin.register_servlets]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.mock_federation = Mock()
self.mock_federation = AsyncMock()
self.mock_registry = Mock()
self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
@ -135,9 +134,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
def test_get_other_name(self) -> None:
self.mock_federation.make_query.return_value = make_awaitable(
{"displayname": "Alice"}
)
self.mock_federation.make_query.return_value = {"displayname": "Alice"}
displayname = self.get_success(self.handler.get_displayname(self.alice))

View File

@ -13,7 +13,7 @@
# limitations under the License.
from typing import Any, Collection, List, Optional, Tuple
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock
from twisted.test.proto_helpers import MemoryReactor
@ -38,7 +38,6 @@ from synapse.types import (
)
from synapse.util import Clock
from tests.test_utils import make_awaitable
from tests.unittest import override_config
from tests.utils import mock_getRawHeaders
@ -203,24 +202,22 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_not_blocked(self) -> None:
self.store.count_monthly_users = Mock( # type: ignore[assignment]
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
self.store.count_monthly_users = AsyncMock( # type: ignore[assignment]
return_value=self.hs.config.server.max_mau_value - 1
)
# Ensure does not throw exception
self.get_success(self.get_or_create_user(self.requester, "c", "User"))
@override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_blocked(self) -> None:
self.store.get_monthly_active_count = Mock(
return_value=make_awaitable(self.lots_of_users)
)
self.store.get_monthly_active_count = AsyncMock(return_value=self.lots_of_users)
self.get_failure(
self.get_or_create_user(self.requester, "b", "display_name"),
ResourceLimitError,
)
self.store.get_monthly_active_count = Mock(
return_value=make_awaitable(self.hs.config.server.max_mau_value)
self.store.get_monthly_active_count = AsyncMock(
return_value=self.hs.config.server.max_mau_value
)
self.get_failure(
self.get_or_create_user(self.requester, "b", "display_name"),
@ -229,15 +226,13 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": True})
def test_register_mau_blocked(self) -> None:
self.store.get_monthly_active_count = Mock(
return_value=make_awaitable(self.lots_of_users)
)
self.store.get_monthly_active_count = AsyncMock(return_value=self.lots_of_users)
self.get_failure(
self.handler.register_user(localpart="local_part"), ResourceLimitError
)
self.store.get_monthly_active_count = Mock(
return_value=make_awaitable(self.hs.config.server.max_mau_value)
self.store.get_monthly_active_count = AsyncMock(
return_value=self.hs.config.server.max_mau_value
)
self.get_failure(
self.handler.register_user(localpart="local_part"), ResourceLimitError
@ -292,7 +287,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config({"auto_join_rooms": ["#room:test"]})
def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self) -> None:
room_alias_str = "#room:test"
self.store.is_real_user = Mock(return_value=make_awaitable(False))
self.store.is_real_user = AsyncMock(return_value=False)
user_id = self.get_success(self.handler.register_user(localpart="support"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
@ -304,8 +299,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None:
room_alias_str = "#room:test"
self.store.count_real_users = Mock(return_value=make_awaitable(1)) # type: ignore[assignment]
self.store.is_real_user = Mock(return_value=make_awaitable(True))
self.store.count_real_users = AsyncMock(return_value=1) # type: ignore[assignment]
self.store.is_real_user = AsyncMock(return_value=True)
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
directory_handler = self.hs.get_directory_handler()
@ -319,8 +314,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(
self,
) -> None:
self.store.count_real_users = Mock(return_value=make_awaitable(2)) # type: ignore[assignment]
self.store.is_real_user = Mock(return_value=make_awaitable(True))
self.store.count_real_users = AsyncMock(return_value=2) # type: ignore[assignment]
self.store.is_real_user = AsyncMock(return_value=True)
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)

View File

@ -1,4 +1,4 @@
from unittest.mock import Mock, patch
from unittest.mock import AsyncMock, patch
from twisted.test.proto_helpers import MemoryReactor
@ -16,7 +16,6 @@ from synapse.util import Clock
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import make_request
from tests.test_utils import make_awaitable
from tests.unittest import (
FederatingHomeserverTestCase,
HomeserverTestCase,
@ -154,18 +153,15 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
None,
)
mock_make_membership_event = Mock(
return_value=make_awaitable(
(
mock_make_membership_event = AsyncMock(
return_value=(
self.OTHER_SERVER_NAME,
join_event,
self.hs.config.server.default_room_version,
)
)
)
mock_send_join = Mock(
return_value=make_awaitable(
SendJoinResult(
mock_send_join = AsyncMock(
return_value=SendJoinResult(
join_event,
self.OTHER_SERVER_NAME,
state=[create_event],
@ -174,7 +170,6 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
servers_in_room=frozenset(),
)
)
)
with patch.object(
self.handler.federation_handler.federation_client,

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import AsyncMock, Mock, patch
from twisted.test.proto_helpers import MemoryReactor
@ -29,7 +29,6 @@ from synapse.util import Clock
import tests.unittest
import tests.utils
from tests.test_utils import make_awaitable
class SyncTestCase(tests.unittest.HomeserverTestCase):
@ -253,8 +252,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
mocked_get_prev_events = patch.object(
self.hs.get_datastores().main,
"get_prev_events_for_room",
new_callable=MagicMock,
return_value=make_awaitable([last_room_creation_event_id]),
new_callable=AsyncMock,
return_value=[last_room_creation_event_id],
)
with mocked_get_prev_events:
self.helper.join(room_id, eve, tok=eve_token)

View File

@ -15,7 +15,7 @@
import json
from typing import Dict, List, Set
from unittest.mock import ANY, Mock, call
from unittest.mock import ANY, AsyncMock, Mock, call
from netaddr import IPSet
@ -33,7 +33,6 @@ from synapse.util import Clock
from tests import unittest
from tests.server import ThreadedMemoryReactorClock
from tests.test_utils import make_awaitable
from tests.unittest import override_config
# Some local users to test with
@ -74,11 +73,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
# we mock out the keyring so as to skip the authentication check on the
# federation API call.
mock_keyring = Mock(spec=["verify_json_for_server"])
mock_keyring.verify_json_for_server.return_value = make_awaitable(True)
mock_keyring.verify_json_for_server = AsyncMock(return_value=True)
# we mock out the federation client too
self.mock_federation_client = Mock(spec=["put_json"])
self.mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))
self.mock_federation_client = AsyncMock(spec=["put_json"])
self.mock_federation_client.put_json.return_value = (200, "OK")
self.mock_federation_client.agent = MatrixFederationAgent(
reactor,
tls_client_options_factory=None,
@ -121,20 +120,18 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore = hs.get_datastores().main
self.datastore.get_destination_retry_timings = Mock(
return_value=make_awaitable(None)
self.datastore.get_destination_retry_timings = AsyncMock(return_value=None)
self.datastore.get_device_updates_by_remote = AsyncMock( # type: ignore[assignment]
return_value=(0, [])
)
self.datastore.get_device_updates_by_remote = Mock( # type: ignore[assignment]
return_value=make_awaitable((0, []))
self.datastore.get_destination_last_successful_stream_ordering = AsyncMock( # type: ignore[assignment]
return_value=None
)
self.datastore.get_destination_last_successful_stream_ordering = Mock( # type: ignore[assignment]
return_value=make_awaitable(None)
)
self.datastore.get_received_txn_response = Mock( # type: ignore[assignment]
return_value=make_awaitable(None)
self.datastore.get_received_txn_response = AsyncMock( # type: ignore[assignment]
return_value=None
)
self.room_members: List[UserID] = []
@ -173,27 +170,25 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_users_in_room = Mock(side_effect=get_users_in_room)
self.datastore.get_user_directory_stream_pos = Mock( # type: ignore[assignment]
side_effect=(
self.datastore.get_user_directory_stream_pos = AsyncMock( # type: ignore[assignment]
# we deliberately return a non-None stream pos to avoid
# doing an initial_sync
lambda: make_awaitable(1)
)
return_value=1
)
self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None)) # type: ignore[assignment]
self.datastore.get_to_device_stream_token = Mock( # type: ignore[assignment]
side_effect=lambda: 0
return_value=0
)
self.datastore.get_new_device_msgs_for_remote = Mock( # type: ignore[assignment]
side_effect=lambda *args, **kargs: make_awaitable(([], 0))
self.datastore.get_new_device_msgs_for_remote = AsyncMock( # type: ignore[assignment]
return_value=([], 0)
)
self.datastore.delete_device_msgs_for_remote = Mock( # type: ignore[assignment]
side_effect=lambda *args, **kargs: make_awaitable(None)
self.datastore.delete_device_msgs_for_remote = AsyncMock( # type: ignore[assignment]
return_value=None
)
self.datastore.set_received_txn_response = Mock( # type: ignore[assignment]
side_effect=lambda *args, **kwargs: make_awaitable(None)
self.datastore.set_received_txn_response = AsyncMock( # type: ignore[assignment]
return_value=None
)
def test_started_typing_local(self) -> None:

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Tuple
from unittest.mock import Mock, patch
from unittest.mock import AsyncMock, Mock, patch
from urllib.parse import quote
from twisted.test.proto_helpers import MemoryReactor
@ -30,7 +30,7 @@ from synapse.util import Clock
from tests import unittest
from tests.storage.test_user_directory import GetUserDirectoryTables
from tests.test_utils import event_injection, make_awaitable
from tests.test_utils import event_injection
from tests.test_utils.event_injection import inject_member_event
from tests.unittest import override_config
@ -471,7 +471,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.store.register_user(user_id=r_user_id, password_hash=None)
)
mock_remove_from_user_dir = Mock(return_value=make_awaitable(None))
mock_remove_from_user_dir = AsyncMock(return_value=None)
with patch.object(
self.store, "remove_from_user_dir", mock_remove_from_user_dir
):

View File

@ -14,8 +14,8 @@
import base64
import logging
import os
from typing import Any, Awaitable, Callable, Generator, List, Optional, cast
from unittest.mock import Mock, patch
from typing import Generator, List, Optional, cast
from unittest.mock import AsyncMock, patch
import treq
from netaddr import IPSet
@ -41,7 +41,7 @@ from twisted.web.iweb import IPolicyForHTTPS, IResponse
from synapse.config.homeserver import HomeServerConfig
from synapse.crypto.context_factory import FederationPolicyForHTTPS
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.http.federation.srv_resolver import Server
from synapse.http.federation.srv_resolver import Server, SrvResolver
from synapse.http.federation.well_known_resolver import (
WELL_KNOWN_MAX_SIZE,
WellKnownResolver,
@ -68,21 +68,11 @@ from tests.utils import checked_cast, default_config
logger = logging.getLogger(__name__)
# Once Async Mocks or lambdas are supported this can go away.
def generate_resolve_service(
result: List[Server],
) -> Callable[[Any], Awaitable[List[Server]]]:
async def resolve_service(_: Any) -> List[Server]:
return result
return resolve_service
class MatrixFederationAgentTests(unittest.TestCase):
def setUp(self) -> None:
self.reactor = ThreadedMemoryReactorClock()
self.mock_resolver = Mock()
self.mock_resolver = AsyncMock(spec=SrvResolver)
config_dict = default_config("test", parse=False)
config_dict["federation_custom_ca_list"] = [get_test_ca_cert_file()]
@ -636,7 +626,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""
self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.mock_resolver.resolve_service.return_value = []
self.reactor.lookups["testserv1"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix-federation://testserv1/foo/bar")
@ -722,7 +712,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""
self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.mock_resolver.resolve_service.return_value = []
self.reactor.lookups["testserv"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
@ -776,7 +766,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""Test the behaviour when the .well-known delegates elsewhere"""
self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.mock_resolver.resolve_service.return_value = []
self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["target-server"] = "1::f"
@ -840,7 +830,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""
self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.mock_resolver.resolve_service.return_value = []
self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["target-server"] = "1::f"
@ -930,7 +920,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""
self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.mock_resolver.resolve_service.return_value = []
self.reactor.lookups["testserv"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
@ -986,7 +976,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
# the config left to the default, which will not trust it (since the
# presented cert is signed by a test CA)
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.mock_resolver.resolve_service.return_value = []
self.reactor.lookups["testserv"] = "1.2.3.4"
config = default_config("test", parse=True)
@ -1037,9 +1027,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""
self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
[Server(host=b"srvtarget", port=8443)]
)
self.mock_resolver.resolve_service.return_value = [
Server(host=b"srvtarget", port=8443)
]
self.reactor.lookups["srvtarget"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")
@ -1094,9 +1084,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertEqual(host, "1.2.3.4")
self.assertEqual(port, 443)
self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
[Server(host=b"srvtarget", port=8443)]
)
self.mock_resolver.resolve_service.return_value = [
Server(host=b"srvtarget", port=8443)
]
self._handle_well_known_connection(
client_factory,
@ -1137,7 +1127,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""test the behaviour when the server name has idna chars in"""
self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.mock_resolver.resolve_service.return_value = []
# the resolver is always called with the IDNA hostname as a native string.
self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4"
@ -1201,9 +1191,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""test the behaviour when the target of a SRV record has idna chars"""
self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
[Server(host=b"xn--trget-3qa.com", port=8443)] # târget.com
)
self.mock_resolver.resolve_service.return_value = [
Server(host=b"xn--trget-3qa.com", port=8443)
] # târget.com
self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4"
test_d = self._make_get_request(
@ -1407,12 +1397,10 @@ class MatrixFederationAgentTests(unittest.TestCase):
"""Test that other SRV results are tried if the first one fails."""
self.agent = self._make_agent()
self.mock_resolver.resolve_service.side_effect = generate_resolve_service(
[
self.mock_resolver.resolve_service.return_value = [
Server(host=b"target.com", port=8443),
Server(host=b"target.com", port=8444),
]
)
self.reactor.lookups["target.com"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar")

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock
from netaddr import IPSet
@ -26,7 +26,6 @@ from synapse.types import UserID, create_requester
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import get_clock
from tests.test_utils import make_awaitable
logger = logging.getLogger(__name__)
@ -62,7 +61,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
new event.
"""
mock_client = Mock(spec=["put_json"])
mock_client.put_json.return_value = make_awaitable({})
mock_client.put_json = AsyncMock(return_value={})
mock_client.agent = self.matrix_federation_agent
self.make_worker_hs(
"synapse.app.generic_worker",
@ -93,7 +92,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
new events.
"""
mock_client1 = Mock(spec=["put_json"])
mock_client1.put_json.return_value = make_awaitable({})
mock_client1.put_json = AsyncMock(return_value={})
mock_client1.agent = self.matrix_federation_agent
self.make_worker_hs(
"synapse.app.generic_worker",
@ -108,7 +107,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
)
mock_client2 = Mock(spec=["put_json"])
mock_client2.put_json.return_value = make_awaitable({})
mock_client2.put_json = AsyncMock(return_value={})
mock_client2.agent = self.matrix_federation_agent
self.make_worker_hs(
"synapse.app.generic_worker",
@ -162,7 +161,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
new typing EDUs.
"""
mock_client1 = Mock(spec=["put_json"])
mock_client1.put_json.return_value = make_awaitable({})
mock_client1.put_json = AsyncMock(return_value={})
mock_client1.agent = self.matrix_federation_agent
self.make_worker_hs(
"synapse.app.generic_worker",
@ -177,7 +176,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
)
mock_client2 = Mock(spec=["put_json"])
mock_client2.put_json.return_value = make_awaitable({})
mock_client2.put_json = AsyncMock(return_value={})
mock_client2.agent = self.matrix_federation_agent
self.make_worker_hs(
"synapse.app.generic_worker",

View File

@ -18,7 +18,7 @@ import os
import urllib.parse
from binascii import unhexlify
from typing import List, Optional
from unittest.mock import Mock, patch
from unittest.mock import AsyncMock, Mock, patch
from parameterized import parameterized, parameterized_class
@ -45,7 +45,7 @@ from synapse.util import Clock
from tests import unittest
from tests.server import FakeSite, make_request
from tests.test_utils import SMALL_PNG, make_awaitable
from tests.test_utils import SMALL_PNG
from tests.unittest import override_config
@ -419,8 +419,8 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
store = self.hs.get_datastores().main
# Set monthly active users to the limit
store.get_monthly_active_count = Mock(
return_value=make_awaitable(self.hs.config.server.max_mau_value)
store.get_monthly_active_count = AsyncMock(
return_value=self.hs.config.server.max_mau_value
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
@ -1834,8 +1834,8 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
# Set monthly active users to the limit
self.store.get_monthly_active_count = Mock(
return_value=make_awaitable(self.hs.config.server.max_mau_value)
self.store.get_monthly_active_count = AsyncMock(
return_value=self.hs.config.server.max_mau_value
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
@ -1871,8 +1871,8 @@ class UserRestTestCase(unittest.HomeserverTestCase):
handler = self.hs.get_registration_handler()
# Set monthly active users to the limit
self.store.get_monthly_active_count = Mock(
return_value=make_awaitable(self.hs.config.server.max_mau_value)
self.store.get_monthly_active_count = AsyncMock(
return_value=self.hs.config.server.max_mau_value
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit

View File

@ -11,13 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock
from unittest.mock import AsyncMock
from synapse.rest import admin
from synapse.rest.client import account_data, login, room
from tests import unittest
from tests.test_utils import make_awaitable
class AccountDataTestCase(unittest.HomeserverTestCase):
@ -32,7 +31,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
"""Tests that the on_account_data_updated module callback is called correctly when
a user's account data changes.
"""
mocked_callback = Mock(return_value=make_awaitable(None))
mocked_callback = AsyncMock(return_value=None)
self.hs.get_account_data_handler()._on_account_data_updated_callbacks.append(
mocked_callback
)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from http import HTTPStatus
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock
from twisted.test.proto_helpers import MemoryReactor
@ -23,7 +23,6 @@ from synapse.types import UserID
from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
class PresenceTestCase(unittest.HomeserverTestCase):
@ -36,7 +35,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.presence_handler = Mock(spec=PresenceHandler)
self.presence_handler.set_state.return_value = make_awaitable(None)
self.presence_handler.set_state = AsyncMock(return_value=None)
hs = self.setup_test_homeserver(
"red",

View File

@ -15,7 +15,7 @@
import urllib.parse
from typing import Any, Callable, Dict, List, Optional, Tuple
from unittest.mock import patch
from unittest.mock import AsyncMock, patch
from twisted.test.proto_helpers import MemoryReactor
@ -28,7 +28,6 @@ from synapse.util import Clock
from tests import unittest
from tests.server import FakeChannel
from tests.test_utils import make_awaitable
from tests.test_utils.event_injection import inject_event
from tests.unittest import override_config
@ -264,7 +263,8 @@ class RelationsTestCase(BaseRelationsTestCase):
# Disable the validation to pretend this came over federation.
with patch(
"synapse.handlers.message.EventCreationHandler._validate_event_relation",
new=lambda self, event: make_awaitable(None),
new_callable=AsyncMock,
return_value=None,
):
# Generate a various relations from a different room.
self.get_success(
@ -1300,7 +1300,8 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
# not an event the Client-Server API will allow..
with patch(
"synapse.handlers.message.EventCreationHandler._validate_event_relation",
new=lambda self, event: make_awaitable(None),
new_callable=AsyncMock,
return_value=None,
):
# Create a sub-thread off the thread, which is not allowed.
self._send_relation(

View File

@ -20,7 +20,7 @@
import json
from http import HTTPStatus
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from unittest.mock import Mock, call, patch
from unittest.mock import AsyncMock, Mock, call, patch
from urllib import parse as urlparse
from parameterized import param, parameterized
@ -52,7 +52,6 @@ from synapse.util.stringutils import random_string
from tests import unittest
from tests.http.server._base import make_request_with_cancellation_test
from tests.storage.test_stream import PaginationTestCase
from tests.test_utils import make_awaitable
from tests.test_utils.event_injection import create_event
from tests.unittest import override_config
@ -70,8 +69,8 @@ class RoomBase(unittest.HomeserverTestCase):
)
self.hs.get_federation_handler = Mock() # type: ignore[assignment]
self.hs.get_federation_handler.return_value.maybe_backfill = Mock(
return_value=make_awaitable(None)
self.hs.get_federation_handler.return_value.maybe_backfill = AsyncMock(
return_value=None
)
async def _insert_client_ip(*args: Any, **kwargs: Any) -> None:
@ -2375,7 +2374,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(federation_client=Mock())
return self.setup_test_homeserver(federation_client=AsyncMock())
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.register_user("user", "pass")
@ -2385,7 +2384,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
def test_simple(self) -> None:
"Simple test for searching rooms over federation"
self.federation_client.get_public_rooms.return_value = make_awaitable({}) # type: ignore[attr-defined]
self.federation_client.get_public_rooms.return_value = {} # type: ignore[attr-defined]
search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"}
@ -2413,7 +2412,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
# with a 404, when using search filters.
self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined]
HttpResponseException(HTTPStatus.NOT_FOUND, "Not Found", b""),
make_awaitable({}),
{},
)
search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"}
@ -3413,17 +3412,17 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# Mock a few functions to prevent the test from failing due to failing to talk to
# a remote IS. We keep the mock for make_and_store_3pid_invite around so we
# can check its call_count later on during the test.
make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
make_invite_mock = AsyncMock(return_value=(Mock(event_id="abc"), 0))
self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment]
self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment]
return_value=make_awaitable(None),
self.hs.get_identity_handler().lookup_3pid = AsyncMock( # type: ignore[assignment]
return_value=None,
)
# Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it
# allow everything for now.
# `spec` argument is needed for this function mock to have `__qualname__`, which
# is needed for `Measure` metrics buried in SpamChecker.
mock = Mock(return_value=make_awaitable(True), spec=lambda *x: None)
mock = AsyncMock(return_value=True, spec=lambda *x: None)
self.hs.get_module_api_callbacks().spam_checker._user_may_send_3pid_invite_callbacks.append(
mock
)
@ -3451,7 +3450,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# Now change the return value of the callback to deny any invite and test that
# we can't send the invite.
mock.return_value = make_awaitable(False)
mock.return_value = False
channel = self.make_request(
method="POST",
path="/rooms/" + self.room_id + "/invite",
@ -3477,18 +3476,18 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# Mock a few functions to prevent the test from failing due to failing to talk to
# a remote IS. We keep the mock for make_and_store_3pid_invite around so we
# can check its call_count later on during the test.
make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
make_invite_mock = AsyncMock(return_value=(Mock(event_id="abc"), 0))
self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment]
self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment]
return_value=make_awaitable(None),
self.hs.get_identity_handler().lookup_3pid = AsyncMock( # type: ignore[assignment]
return_value=None,
)
# Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it
# allow everything for now.
# `spec` argument is needed for this function mock to have `__qualname__`, which
# is needed for `Measure` metrics buried in SpamChecker.
mock = Mock(
return_value=make_awaitable(synapse.module_api.NOT_SPAM),
mock = AsyncMock(
return_value=synapse.module_api.NOT_SPAM,
spec=lambda *x: None,
)
self.hs.get_module_api_callbacks().spam_checker._user_may_send_3pid_invite_callbacks.append(
@ -3519,7 +3518,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# Now change the return value of the callback to deny any invite and test that
# we can't send the invite. We pick an arbitrary error code to be able to check
# that the same code has been returned
mock.return_value = make_awaitable(Codes.CONSENT_NOT_GIVEN)
mock.return_value = Codes.CONSENT_NOT_GIVEN
channel = self.make_request(
method="POST",
path="/rooms/" + self.room_id + "/invite",
@ -3538,7 +3537,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
make_invite_mock.assert_called_once()
# Run variant with `Tuple[Codes, dict]`.
mock.return_value = make_awaitable((Codes.EXPIRED_ACCOUNT, {"field": "value"}))
mock.return_value = (Codes.EXPIRED_ACCOUNT, {"field": "value"})
channel = self.make_request(
method="POST",
path="/rooms/" + self.room_id + "/invite",

View File

@ -13,7 +13,7 @@
# limitations under the License.
import threading
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock
from twisted.test.proto_helpers import MemoryReactor
@ -33,7 +33,6 @@ from synapse.util import Clock
from synapse.util.frozenutils import unfreeze
from tests import unittest
from tests.test_utils import make_awaitable
if TYPE_CHECKING:
from synapse.module_api import ModuleApi
@ -477,7 +476,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
def test_on_new_event(self) -> None:
"""Test that the on_new_event callback is called on new events"""
on_new_event = Mock(make_awaitable(None))
on_new_event = AsyncMock(return_value=None)
self.hs.get_module_api_callbacks().third_party_event_rules._on_new_event_callbacks.append(
on_new_event
)
@ -580,7 +579,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo"
# Register a mock callback.
m = Mock(return_value=make_awaitable(None))
m = AsyncMock(return_value=None)
self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append(
m
)
@ -641,7 +640,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo"
# Register a mock callback.
m = Mock(return_value=make_awaitable(None))
m = AsyncMock(return_value=None)
self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append(
m
)
@ -682,7 +681,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
correctly when processing a user's deactivation.
"""
# Register a mocked callback.
deactivation_mock = Mock(return_value=make_awaitable(None))
deactivation_mock = AsyncMock(return_value=None)
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
third_party_rules._on_user_deactivation_status_changed_callbacks.append(
deactivation_mock,
@ -690,7 +689,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
# Also register a mocked callback for profile updates, to check that the
# deactivation code calls it in a way that let modules know the user is being
# deactivated.
profile_mock = Mock(return_value=make_awaitable(None))
profile_mock = AsyncMock(return_value=None)
self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append(
profile_mock,
)
@ -740,7 +739,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
well as a reactivation.
"""
# Register a mock callback.
m = Mock(return_value=make_awaitable(None))
m = AsyncMock(return_value=None)
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
third_party_rules._on_user_deactivation_status_changed_callbacks.append(m)
@ -794,7 +793,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
correctly when processing a user's deactivation.
"""
# Register a mocked callback.
deactivation_mock = Mock(return_value=make_awaitable(False))
deactivation_mock = AsyncMock(return_value=False)
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
third_party_rules._check_can_deactivate_user_callbacks.append(
deactivation_mock,
@ -840,7 +839,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
correctly when processing a user's deactivation triggered by a server admin.
"""
# Register a mocked callback.
deactivation_mock = Mock(return_value=make_awaitable(False))
deactivation_mock = AsyncMock(return_value=False)
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
third_party_rules._check_can_deactivate_user_callbacks.append(
deactivation_mock,
@ -879,7 +878,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
correctly when processing an admin's shutdown room request.
"""
# Register a mocked callback.
shutdown_mock = Mock(return_value=make_awaitable(False))
shutdown_mock = AsyncMock(return_value=False)
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
third_party_rules._check_can_shutdown_room_callbacks.append(
shutdown_mock,
@ -915,7 +914,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
associating a 3PID to an account.
"""
# Register a mocked callback.
threepid_bind_mock = Mock(return_value=make_awaitable(None))
threepid_bind_mock = AsyncMock(return_value=None)
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
third_party_rules._on_threepid_bind_callbacks.append(threepid_bind_mock)
@ -957,11 +956,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
just before associating and removing a 3PID to/from an account.
"""
# Pretend to be a Synapse module and register both callbacks as mocks.
on_add_user_third_party_identifier_callback_mock = Mock(
return_value=make_awaitable(None)
)
on_remove_user_third_party_identifier_callback_mock = Mock(
return_value=make_awaitable(None)
on_add_user_third_party_identifier_callback_mock = AsyncMock(return_value=None)
on_remove_user_third_party_identifier_callback_mock = AsyncMock(
return_value=None
)
self.hs.get_module_api().register_third_party_rules_callbacks(
on_add_user_third_party_identifier=on_add_user_third_party_identifier_callback_mock,
@ -1021,8 +1018,8 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
when a user is deactivated and their third-party ID associations are deleted.
"""
# Pretend to be a Synapse module and register both callbacks as mocks.
on_remove_user_third_party_identifier_callback_mock = Mock(
return_value=make_awaitable(None)
on_remove_user_third_party_identifier_callback_mock = AsyncMock(
return_value=None
)
self.hs.get_module_api().register_third_party_rules_callbacks(
on_remove_user_third_party_identifier=on_remove_user_third_party_identifier_callback_mock,

View File

@ -14,7 +14,7 @@
from http import HTTPStatus
from typing import Any, Generator, Tuple, cast
from unittest.mock import Mock, call
from unittest.mock import AsyncMock, Mock, call
from twisted.internet import defer, reactor as _reactor
@ -24,7 +24,6 @@ from synapse.types import ISynapseReactor, JsonDict
from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
from tests.utils import MockClock
reactor = cast(ISynapseReactor, _reactor)
@ -53,7 +52,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
def test_executes_given_function(
self,
) -> Generator["defer.Deferred[Any]", object, None]:
cb = Mock(return_value=make_awaitable(self.mock_http_response))
cb = AsyncMock(return_value=self.mock_http_response)
res = yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb, "some_arg", keyword="arg"
)
@ -64,7 +63,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
def test_deduplicates_based_on_key(
self,
) -> Generator["defer.Deferred[Any]", object, None]:
cb = Mock(return_value=make_awaitable(self.mock_http_response))
cb = AsyncMock(return_value=self.mock_http_response)
for i in range(3): # invoke multiple times
res = yield self.cache.fetch_or_execute_request(
self.mock_request,
@ -168,7 +167,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_cleans_up(self) -> Generator["defer.Deferred[Any]", object, None]:
cb = Mock(return_value=make_awaitable(self.mock_http_response))
cb = AsyncMock(return_value=self.mock_http_response)
yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb, "an arg"
)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock
from twisted.test.proto_helpers import MemoryReactor
@ -29,7 +29,6 @@ from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
from tests.unittest import override_config
from tests.utils import default_config
@ -69,24 +68,22 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
assert isinstance(rlsn, ResourceLimitsServerNotices)
self._rlsn = rlsn
self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=make_awaitable(1000)
)
self._rlsn._server_notices_manager.send_notice = Mock( # type: ignore[assignment]
return_value=make_awaitable(Mock())
self._rlsn._store.user_last_seen_monthly_active = AsyncMock(return_value=1000)
self._rlsn._server_notices_manager.send_notice = AsyncMock( # type: ignore[assignment]
return_value=Mock()
)
self._send_notice = self._rlsn._server_notices_manager.send_notice
self.user_id = "@user_id:test"
self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock(
return_value=make_awaitable("!something:localhost")
self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = (
AsyncMock(return_value="!something:localhost")
)
self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = Mock(
return_value=make_awaitable("!something:localhost")
self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = AsyncMock(
return_value="!something:localhost"
)
self._rlsn._store.add_tag_to_room = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({})) # type: ignore[assignment]
self._rlsn._store.add_tag_to_room = AsyncMock(return_value=None) # type: ignore[assignment]
self._rlsn._store.get_tags_for_room = AsyncMock(return_value={}) # type: ignore[assignment]
@override_config({"hs_disabled": True})
def test_maybe_send_server_notice_disabled_hs(self) -> None:
@ -103,14 +100,14 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self) -> None:
"""Test when user has blocked notice, but should have it removed"""
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None)
self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment]
return_value=None
)
mock_event = Mock(
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock( # type: ignore[assignment]
return_value=make_awaitable({"123": mock_event})
self._rlsn._store.get_events = AsyncMock( # type: ignore[assignment]
return_value={"123": mock_event}
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Would be better to check the content, but once == remove blocking event
@ -125,16 +122,16 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
"""
Test when user has blocked notice, but notice ought to be there (NOOP)
"""
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None),
self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment]
return_value=None,
side_effect=ResourceLimitError(403, "foo"),
)
mock_event = Mock(
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock( # type: ignore[assignment]
return_value=make_awaitable({"123": mock_event})
self._rlsn._store.get_events = AsyncMock( # type: ignore[assignment]
return_value={"123": mock_event}
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -145,8 +142,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
"""
Test when user does not have blocked notice, but should have one
"""
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None),
self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment]
return_value=None,
side_effect=ResourceLimitError(403, "foo"),
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -158,8 +155,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
"""
Test when user does not have blocked notice, nor should they (NOOP)
"""
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None)
self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment]
return_value=None
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -171,12 +168,10 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test when user is not part of the MAU cohort - this should not ever
happen - but ...
"""
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None)
)
self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=make_awaitable(None)
self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment]
return_value=None
)
self._rlsn._store.user_last_seen_monthly_active = AsyncMock(return_value=None)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called()
@ -189,8 +184,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test that when server is over MAU limit and alerting is suppressed, then
an alert message is not sent into the room
"""
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None),
self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment]
return_value=None,
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
),
@ -204,8 +199,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
"""
Test that when a server is disabled, that MAU limit alerting is ignored.
"""
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None),
self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment]
return_value=None,
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
),
@ -223,22 +218,22 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
When the room is already in a blocked state, test that when alerting
is suppressed that the room is returned to an unblocked state.
"""
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
return_value=make_awaitable(None),
self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[assignment]
return_value=None,
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
),
)
self._rlsn._is_room_currently_blocked = Mock( # type: ignore[assignment]
return_value=make_awaitable((True, []))
self._rlsn._is_room_currently_blocked = AsyncMock( # type: ignore[assignment]
return_value=(True, [])
)
mock_event = Mock(
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock( # type: ignore[assignment]
return_value=make_awaitable({"123": mock_event})
self._rlsn._store.get_events = AsyncMock( # type: ignore[assignment]
return_value={"123": mock_event}
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -284,11 +279,9 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
self.user_id = "@user_id:test"
def test_server_notice_only_sent_once(self) -> None:
self.store.get_monthly_active_count = Mock(return_value=make_awaitable(1000))
self.store.get_monthly_active_count = AsyncMock(return_value=1000)
self.store.user_last_seen_monthly_active = Mock(
return_value=make_awaitable(1000)
)
self.store.user_last_seen_monthly_active = AsyncMock(return_value=1000)
# Call the function multiple times to ensure we only send the notice once
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -327,7 +320,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
hasn't been reached (since it's the only user and the limit is 5), so users
shouldn't receive a server notice.
"""
m = Mock(return_value=make_awaitable(None))
m = AsyncMock(return_value=None)
self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = m
user_id = self.register_user("user", "password")

View File

@ -15,7 +15,7 @@ import json
import os
import tempfile
from typing import List, cast
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock
import yaml
@ -35,7 +35,6 @@ from synapse.types import DeviceListUpdates
from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
@ -339,7 +338,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
# we aren't testing store._base stuff here, so mock this out
# (ignore needed because Mypy won't allow us to assign to a method otherwise)
self.store.get_events_as_list = Mock(return_value=make_awaitable(events)) # type: ignore[assignment]
self.store.get_events_as_list = AsyncMock(return_value=events) # type: ignore[assignment]
self.get_success(self._insert_txn(self.as_list[1]["id"], 9, other_events))
self.get_success(self._insert_txn(service.id, 10, events))

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock
import yaml
@ -32,7 +32,7 @@ from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable, simple_async_mock
from tests.test_utils import simple_async_mock
from tests.unittest import override_config
@ -363,9 +363,9 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
# Register the callbacks with more mocks
self.hs.get_module_api().register_background_update_controller_callbacks(
on_update=self._on_update,
min_batch_size=Mock(return_value=make_awaitable(self._default_batch_size)),
default_batch_size=Mock(
return_value=make_awaitable(self._default_batch_size),
min_batch_size=AsyncMock(return_value=self._default_batch_size),
default_batch_size=AsyncMock(
return_value=self._default_batch_size,
),
)

View File

@ -14,7 +14,7 @@
# limitations under the License.
from typing import Any, Dict
from unittest.mock import Mock
from unittest.mock import AsyncMock
from parameterized import parameterized
@ -30,7 +30,6 @@ from synapse.util import Clock
from tests import unittest
from tests.server import make_request
from tests.test_utils import make_awaitable
from tests.unittest import override_config
@ -443,9 +442,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
lots_of_users = 100
user_id = "@user:server"
self.store.get_monthly_active_count = Mock(
return_value=make_awaitable(lots_of_users)
)
self.store.get_monthly_active_count = AsyncMock(return_value=lots_of_users)
self.get_success(
self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id"

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List
from unittest.mock import Mock
from unittest.mock import AsyncMock
from twisted.test.proto_helpers import MemoryReactor
@ -21,7 +21,6 @@ from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
from tests.unittest import default_config, override_config
FORTY_DAYS = 40 * 24 * 60 * 60
@ -253,7 +252,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
)
self.get_success(d)
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[assignment]
d = self.store.populate_monthly_active_users(user_id)
self.get_success(d)
@ -261,24 +260,22 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_not_called()
def test_populate_monthly_users_should_update(self) -> None:
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[assignment]
self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment]
self.store.is_trial_user = AsyncMock(return_value=False) # type: ignore[assignment]
self.store.user_last_seen_monthly_active = Mock(
return_value=make_awaitable(None)
)
self.store.user_last_seen_monthly_active = AsyncMock(return_value=None)
d = self.store.populate_monthly_active_users("user_id")
self.get_success(d)
self.store.upsert_monthly_active_user.assert_called_once()
def test_populate_monthly_users_should_not_update(self) -> None:
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[assignment]
self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment]
self.store.user_last_seen_monthly_active = Mock(
return_value=make_awaitable(self.hs.get_clock().time_msec())
self.store.is_trial_user = AsyncMock(return_value=False) # type: ignore[assignment]
self.store.user_last_seen_monthly_active = AsyncMock(
return_value=self.hs.get_clock().time_msec()
)
d = self.store.populate_monthly_active_users("user_id")
@ -359,7 +356,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
def test_no_users_when_not_tracking(self) -> None:
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[assignment]
self.get_success(self.store.populate_monthly_active_users("@user:sever"))

View File

@ -22,7 +22,6 @@ from synapse.storage.util.partial_state_events_tracker import (
PartialStateEventsTracker,
)
from tests.test_utils import make_awaitable
from tests.unittest import TestCase
@ -124,16 +123,17 @@ class PartialStateEventsTrackerTestCase(TestCase):
class PartialCurrentStateTrackerTestCase(TestCase):
def setUp(self) -> None:
self.mock_store = mock.Mock(spec_set=["is_partial_state_room"])
self.mock_store.is_partial_state_room = mock.AsyncMock()
self.tracker = PartialCurrentStateTracker(self.mock_store)
def test_does_not_block_for_full_state_rooms(self) -> None:
self.mock_store.is_partial_state_room.return_value = make_awaitable(False)
self.mock_store.is_partial_state_room.return_value = False
self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
def test_blocks_for_partial_room_state(self) -> None:
self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
self.mock_store.is_partial_state_room.return_value = True
d = ensureDeferred(self.tracker.await_full_state("room_id"))
@ -156,7 +156,7 @@ class PartialCurrentStateTrackerTestCase(TestCase):
self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
def test_cancellation(self) -> None:
self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
self.mock_store.is_partial_state_room.return_value = True
d1 = ensureDeferred(self.tracker.await_full_state("room_id"))
self.assertNoResult(d1)

View File

@ -13,7 +13,7 @@
# limitations under the License.
from typing import Collection, List, Optional, Union
from unittest.mock import Mock
from unittest.mock import AsyncMock, Mock
from twisted.test.proto_helpers import MemoryReactor
@ -31,7 +31,6 @@ from synapse.util import Clock
from synapse.util.retryutils import NotRetryingDestination
from tests import unittest
from tests.test_utils import make_awaitable
class MessageAcceptTests(unittest.HomeserverTestCase):
@ -196,7 +195,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Register a mock on the store so that the incoming update doesn't fail because
# we don't share a room with the user.
store = self.hs.get_datastores().main
store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"]))
store.get_rooms_for_user = AsyncMock(return_value=["!someroom:test"])
# Manually inject a fake device list update. We need this update to include at
# least one prev_id so that the user's device list will need to be retried.
@ -241,9 +240,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Register mock device list retrieval on the federation client.
federation_client = self.hs.get_federation_client()
federation_client.query_user_devices = Mock( # type: ignore[assignment]
return_value=make_awaitable(
{
federation_client.query_user_devices = AsyncMock( # type: ignore[assignment]
return_value={
"user_id": remote_user_id,
"stream_id": 1,
"devices": [],
@ -256,13 +254,11 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
"user_id": remote_user_id,
"usage": ["self_signing"],
"keys": {
"ed25519:"
+ remote_self_signing_key: remote_self_signing_key
"ed25519:" + remote_self_signing_key: remote_self_signing_key
},
},
}
)
)
# Resync the device list.
device_handler = self.hs.get_device_handler()

View File

@ -18,7 +18,6 @@ Utilities for running the unit tests
import json
import sys
import warnings
from asyncio import Future
from binascii import unhexlify
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, TypeVar
from unittest.mock import Mock
@ -57,17 +56,6 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
raise Exception("awaitable has not yet completed")
def make_awaitable(result: TV) -> Awaitable[TV]:
"""
Makes an awaitable, suitable for mocking an `async` function.
This uses Futures as they can be awaited multiple times so can be returned
to multiple callers.
"""
future: Future[TV] = Future()
future.set_result(result)
return future
def setup_awaitable_errors() -> Callable[[], None]:
"""
Convert warnings from a non-awaited coroutines into errors.