Uniformize spam-checker API, part 4: port other spam-checker callbacks to return Union[Allow, Codes]. (#12857)

Co-authored-by: Brendan Abolivier <babolivier@matrix.org>
This commit is contained in:
David Teller 2022-06-13 20:16:16 +02:00 committed by GitHub
parent 53b77b203a
commit a164a46038
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 604 additions and 182 deletions

View file

@ -18,10 +18,13 @@
"""Tests REST events for /rooms paths."""
import json
from typing import Any, Dict, Iterable, List, Optional
from typing import Any, Dict, Iterable, List, Optional, Union
from unittest.mock import Mock, call
from urllib import parse as urlparse
# `Literal` appears with Python 3.8.
from typing_extensions import Literal
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@ -777,9 +780,11 @@ class RoomsCreateTestCase(RoomBase):
channel = self.make_request("POST", "/createRoom", content)
self.assertEqual(200, channel.code)
def test_spam_checker_may_join_room(self) -> None:
def test_spam_checker_may_join_room_deprecated(self) -> None:
"""Tests that the user_may_join_room spam checker callback is correctly bypassed
when creating a new room.
In this test, we use the deprecated API in which callbacks return a bool.
"""
async def user_may_join_room(
@ -801,6 +806,32 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(join_mock.call_count, 0)
def test_spam_checker_may_join_room(self) -> None:
"""Tests that the user_may_join_room spam checker callback is correctly bypassed
when creating a new room.
In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`.
"""
async def user_may_join_room(
mxid: str,
room_id: str,
is_invite: bool,
) -> Codes:
return Codes.CONSENT_NOT_GIVEN
join_mock = Mock(side_effect=user_may_join_room)
self.hs.get_spam_checker()._user_may_join_room_callbacks.append(join_mock)
channel = self.make_request(
"POST",
"/createRoom",
{},
)
self.assertEqual(channel.code, 200, channel.json_body)
self.assertEqual(join_mock.call_count, 0)
class RoomTopicTestCase(RoomBase):
"""Tests /rooms/$room_id/topic REST events."""
@ -1011,9 +1042,11 @@ class RoomJoinTestCase(RoomBase):
self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
def test_spam_checker_may_join_room(self) -> None:
def test_spam_checker_may_join_room_deprecated(self) -> None:
"""Tests that the user_may_join_room spam checker callback is correctly called
and blocks room joins when needed.
This test uses the deprecated API, in which callbacks return booleans.
"""
# Register a dummy callback. Make it allow all room joins for now.
@ -1026,6 +1059,8 @@ class RoomJoinTestCase(RoomBase):
) -> bool:
return return_value
# `spec` argument is needed for this function mock to have `__qualname__`, which
# is needed for `Measure` metrics buried in SpamChecker.
callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None)
self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock)
@ -1068,6 +1103,67 @@ class RoomJoinTestCase(RoomBase):
return_value = False
self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2)
def test_spam_checker_may_join_room(self) -> None:
"""Tests that the user_may_join_room spam checker callback is correctly called
and blocks room joins when needed.
This test uses the latest API to this day, in which callbacks return `NOT_SPAM` or `Codes`.
"""
# Register a dummy callback. Make it allow all room joins for now.
return_value: Union[Literal["NOT_SPAM"], Codes] = synapse.module_api.NOT_SPAM
async def user_may_join_room(
userid: str,
room_id: str,
is_invited: bool,
) -> Union[Literal["NOT_SPAM"], Codes]:
return return_value
# `spec` argument is needed for this function mock to have `__qualname__`, which
# is needed for `Measure` metrics buried in SpamChecker.
callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None)
self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock)
# Join a first room, without being invited to it.
self.helper.join(self.room1, self.user2, tok=self.tok2)
# Check that the callback was called with the right arguments.
expected_call_args = (
(
self.user2,
self.room1,
False,
),
)
self.assertEqual(
callback_mock.call_args,
expected_call_args,
callback_mock.call_args,
)
# Join a second room, this time with an invite for it.
self.helper.invite(self.room2, self.user1, self.user2, tok=self.tok1)
self.helper.join(self.room2, self.user2, tok=self.tok2)
# Check that the callback was called with the right arguments.
expected_call_args = (
(
self.user2,
self.room2,
True,
),
)
self.assertEqual(
callback_mock.call_args,
expected_call_args,
callback_mock.call_args,
)
# Now make the callback deny all room joins, and check that a join actually fails.
return_value = Codes.CONSENT_NOT_GIVEN
self.helper.join(self.room3, self.user2, expect_code=403, tok=self.tok2)
class RoomJoinRatelimitTestCase(RoomBase):
user_id = "@sid1:red"
@ -2945,9 +3041,14 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
def test_threepid_invite_spamcheck(self) -> None:
def test_threepid_invite_spamcheck_deprecated(self) -> None:
"""
Test allowing/blocking threepid invites with a spam-check module.
In this test, we use the deprecated API in which callbacks return a bool.
"""
# Mock a few functions to prevent the test from failing due to failing to talk to
# a remote IS. We keep the mock for _mock_make_and_store_3pid_invite around so we
# a remote IS. We keep the mock for make_and_store_3pid_invite around so we
# can check its call_count later on during the test.
make_invite_mock = Mock(return_value=make_awaitable(0))
self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock
@ -3001,3 +3102,67 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# Also check that it stopped before calling _make_and_store_3pid_invite.
make_invite_mock.assert_called_once()
def test_threepid_invite_spamcheck(self) -> None:
"""
Test allowing/blocking threepid invites with a spam-check module.
In this test, we use the more recent API in which callbacks return a `Union[Codes, Literal["NOT_SPAM"]]`."""
# Mock a few functions to prevent the test from failing due to failing to talk to
# a remote IS. We keep the mock for make_and_store_3pid_invite around so we
# can check its call_count later on during the test.
make_invite_mock = Mock(return_value=make_awaitable(0))
self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock
self.hs.get_identity_handler().lookup_3pid = Mock(
return_value=make_awaitable(None),
)
# Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it
# allow everything for now.
# `spec` argument is needed for this function mock to have `__qualname__`, which
# is needed for `Measure` metrics buried in SpamChecker.
mock = Mock(
return_value=make_awaitable(synapse.module_api.NOT_SPAM),
spec=lambda *x: None,
)
self.hs.get_spam_checker()._user_may_send_3pid_invite_callbacks.append(mock)
# Send a 3PID invite into the room and check that it succeeded.
email_to_invite = "teresa@example.com"
channel = self.make_request(
method="POST",
path="/rooms/" + self.room_id + "/invite",
content={
"id_server": "example.com",
"id_access_token": "sometoken",
"medium": "email",
"address": email_to_invite,
},
access_token=self.tok,
)
self.assertEqual(channel.code, 200)
# Check that the callback was called with the right params.
mock.assert_called_with(self.user_id, "email", email_to_invite, self.room_id)
# Check that the call to send the invite was made.
make_invite_mock.assert_called_once()
# Now change the return value of the callback to deny any invite and test that
# we can't send the invite.
mock.return_value = make_awaitable(Codes.CONSENT_NOT_GIVEN)
channel = self.make_request(
method="POST",
path="/rooms/" + self.room_id + "/invite",
content={
"id_server": "example.com",
"id_access_token": "sometoken",
"medium": "email",
"address": email_to_invite,
},
access_token=self.tok,
)
self.assertEqual(channel.code, 403)
# Also check that it stopped before calling _make_and_store_3pid_invite.
make_invite_mock.assert_called_once()