mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Add login spam checker API (#15838)
This commit is contained in:
parent
52d8131e87
commit
25c55a9d22
1
changelog.d/15838.feature
Normal file
1
changelog.d/15838.feature
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add spam checker module API for logins.
|
@ -348,6 +348,42 @@ callback returns `False`, Synapse falls through to the next one. The value of th
|
|||||||
callback that does not return `False` will be used. If this happens, Synapse will not call
|
callback that does not return `False` will be used. If this happens, Synapse will not call
|
||||||
any of the subsequent implementations of this callback.
|
any of the subsequent implementations of this callback.
|
||||||
|
|
||||||
|
|
||||||
|
### `check_login_for_spam`
|
||||||
|
|
||||||
|
_First introduced in Synapse v1.87.0_
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def check_login_for_spam(
|
||||||
|
user_id: str,
|
||||||
|
device_id: Optional[str],
|
||||||
|
initial_display_name: Optional[str],
|
||||||
|
request_info: Collection[Tuple[Optional[str], str]],
|
||||||
|
auth_provider_id: Optional[str] = None,
|
||||||
|
) -> Union["synapse.module_api.NOT_SPAM", "synapse.module_api.errors.Codes"]
|
||||||
|
```
|
||||||
|
|
||||||
|
Called when a user logs in.
|
||||||
|
|
||||||
|
The arguments passed to this callback are:
|
||||||
|
|
||||||
|
* `user_id`: The user ID the user is logging in with
|
||||||
|
* `device_id`: The device ID the user is re-logging into.
|
||||||
|
* `initial_display_name`: The device display name, if any.
|
||||||
|
* `request_info`: A collection of tuples, which first item is a user agent, and which
|
||||||
|
second item is an IP address. These user agents and IP addresses are the ones that were
|
||||||
|
used during the login process.
|
||||||
|
* `auth_provider_id`: The identifier of the SSO authentication provider, if any.
|
||||||
|
|
||||||
|
If multiple modules implement this callback, they will be considered in order. If a
|
||||||
|
callback returns `synapse.module_api.NOT_SPAM`, Synapse falls through to the next one.
|
||||||
|
The value of the first callback that does not return `synapse.module_api.NOT_SPAM` will
|
||||||
|
be used. If this happens, Synapse will not call any of the subsequent implementations of
|
||||||
|
this callback.
|
||||||
|
|
||||||
|
*Note:* This will not be called when a user registers.
|
||||||
|
|
||||||
|
|
||||||
## Example
|
## Example
|
||||||
|
|
||||||
The example below is a module that implements the spam checker callback
|
The example below is a module that implements the spam checker callback
|
||||||
|
@ -521,6 +521,11 @@ class SynapseRequest(Request):
|
|||||||
else:
|
else:
|
||||||
return self.getClientAddress().host
|
return self.getClientAddress().host
|
||||||
|
|
||||||
|
def request_info(self) -> "RequestInfo":
|
||||||
|
h = self.getHeader(b"User-Agent")
|
||||||
|
user_agent = h.decode("ascii", "replace") if h else None
|
||||||
|
return RequestInfo(user_agent=user_agent, ip=self.get_client_ip_if_available())
|
||||||
|
|
||||||
|
|
||||||
class XForwardedForRequest(SynapseRequest):
|
class XForwardedForRequest(SynapseRequest):
|
||||||
"""Request object which honours proxy headers
|
"""Request object which honours proxy headers
|
||||||
@ -661,3 +666,9 @@ class SynapseSite(Site):
|
|||||||
|
|
||||||
def log(self, request: SynapseRequest) -> None:
|
def log(self, request: SynapseRequest) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
||||||
|
class RequestInfo:
|
||||||
|
user_agent: Optional[str]
|
||||||
|
ip: str
|
||||||
|
@ -80,6 +80,7 @@ from synapse.module_api.callbacks.account_validity_callbacks import (
|
|||||||
)
|
)
|
||||||
from synapse.module_api.callbacks.spamchecker_callbacks import (
|
from synapse.module_api.callbacks.spamchecker_callbacks import (
|
||||||
CHECK_EVENT_FOR_SPAM_CALLBACK,
|
CHECK_EVENT_FOR_SPAM_CALLBACK,
|
||||||
|
CHECK_LOGIN_FOR_SPAM_CALLBACK,
|
||||||
CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK,
|
CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK,
|
||||||
CHECK_REGISTRATION_FOR_SPAM_CALLBACK,
|
CHECK_REGISTRATION_FOR_SPAM_CALLBACK,
|
||||||
CHECK_USERNAME_FOR_SPAM_CALLBACK,
|
CHECK_USERNAME_FOR_SPAM_CALLBACK,
|
||||||
@ -302,6 +303,7 @@ class ModuleApi:
|
|||||||
CHECK_REGISTRATION_FOR_SPAM_CALLBACK
|
CHECK_REGISTRATION_FOR_SPAM_CALLBACK
|
||||||
] = None,
|
] = None,
|
||||||
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
|
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
|
||||||
|
check_login_for_spam: Optional[CHECK_LOGIN_FOR_SPAM_CALLBACK] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Registers callbacks for spam checking capabilities.
|
"""Registers callbacks for spam checking capabilities.
|
||||||
|
|
||||||
@ -319,6 +321,7 @@ class ModuleApi:
|
|||||||
check_username_for_spam=check_username_for_spam,
|
check_username_for_spam=check_username_for_spam,
|
||||||
check_registration_for_spam=check_registration_for_spam,
|
check_registration_for_spam=check_registration_for_spam,
|
||||||
check_media_file_for_spam=check_media_file_for_spam,
|
check_media_file_for_spam=check_media_file_for_spam,
|
||||||
|
check_login_for_spam=check_login_for_spam,
|
||||||
)
|
)
|
||||||
|
|
||||||
def register_account_validity_callbacks(
|
def register_account_validity_callbacks(
|
||||||
|
@ -196,6 +196,26 @@ CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK = Callable[
|
|||||||
]
|
]
|
||||||
],
|
],
|
||||||
]
|
]
|
||||||
|
CHECK_LOGIN_FOR_SPAM_CALLBACK = Callable[
|
||||||
|
[
|
||||||
|
str,
|
||||||
|
Optional[str],
|
||||||
|
Optional[str],
|
||||||
|
Collection[Tuple[Optional[str], str]],
|
||||||
|
Optional[str],
|
||||||
|
],
|
||||||
|
Awaitable[
|
||||||
|
Union[
|
||||||
|
Literal["NOT_SPAM"],
|
||||||
|
Codes,
|
||||||
|
# Highly experimental, not officially part of the spamchecker API, may
|
||||||
|
# disappear without warning depending on the results of ongoing
|
||||||
|
# experiments.
|
||||||
|
# Use this to return additional information as part of an error.
|
||||||
|
Tuple[Codes, JsonDict],
|
||||||
|
]
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None:
|
def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None:
|
||||||
@ -315,6 +335,7 @@ class SpamCheckerModuleApiCallbacks:
|
|||||||
self._check_media_file_for_spam_callbacks: List[
|
self._check_media_file_for_spam_callbacks: List[
|
||||||
CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK
|
CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK
|
||||||
] = []
|
] = []
|
||||||
|
self._check_login_for_spam_callbacks: List[CHECK_LOGIN_FOR_SPAM_CALLBACK] = []
|
||||||
|
|
||||||
def register_callbacks(
|
def register_callbacks(
|
||||||
self,
|
self,
|
||||||
@ -335,6 +356,7 @@ class SpamCheckerModuleApiCallbacks:
|
|||||||
CHECK_REGISTRATION_FOR_SPAM_CALLBACK
|
CHECK_REGISTRATION_FOR_SPAM_CALLBACK
|
||||||
] = None,
|
] = None,
|
||||||
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
|
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
|
||||||
|
check_login_for_spam: Optional[CHECK_LOGIN_FOR_SPAM_CALLBACK] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Register callbacks from module for each hook."""
|
"""Register callbacks from module for each hook."""
|
||||||
if check_event_for_spam is not None:
|
if check_event_for_spam is not None:
|
||||||
@ -378,6 +400,9 @@ class SpamCheckerModuleApiCallbacks:
|
|||||||
if check_media_file_for_spam is not None:
|
if check_media_file_for_spam is not None:
|
||||||
self._check_media_file_for_spam_callbacks.append(check_media_file_for_spam)
|
self._check_media_file_for_spam_callbacks.append(check_media_file_for_spam)
|
||||||
|
|
||||||
|
if check_login_for_spam is not None:
|
||||||
|
self._check_login_for_spam_callbacks.append(check_login_for_spam)
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
async def check_event_for_spam(
|
async def check_event_for_spam(
|
||||||
self, event: "synapse.events.EventBase"
|
self, event: "synapse.events.EventBase"
|
||||||
@ -819,3 +844,58 @@ class SpamCheckerModuleApiCallbacks:
|
|||||||
return synapse.api.errors.Codes.FORBIDDEN, {}
|
return synapse.api.errors.Codes.FORBIDDEN, {}
|
||||||
|
|
||||||
return self.NOT_SPAM
|
return self.NOT_SPAM
|
||||||
|
|
||||||
|
async def check_login_for_spam(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
device_id: Optional[str],
|
||||||
|
initial_display_name: Optional[str],
|
||||||
|
request_info: Collection[Tuple[Optional[str], str]],
|
||||||
|
auth_provider_id: Optional[str] = None,
|
||||||
|
) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]:
|
||||||
|
"""Checks if we should allow the given registration request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The request user ID
|
||||||
|
request_info: List of tuples of user agent and IP that
|
||||||
|
were used during the registration process.
|
||||||
|
auth_provider_id: The SSO IdP the user used, e.g "oidc", "saml",
|
||||||
|
"cas". If any. Note this does not include users registered
|
||||||
|
via a password provider.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Enum for how the request should be handled
|
||||||
|
"""
|
||||||
|
|
||||||
|
for callback in self._check_login_for_spam_callbacks:
|
||||||
|
with Measure(
|
||||||
|
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
|
||||||
|
):
|
||||||
|
res = await delay_cancellation(
|
||||||
|
callback(
|
||||||
|
user_id,
|
||||||
|
device_id,
|
||||||
|
initial_display_name,
|
||||||
|
request_info,
|
||||||
|
auth_provider_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Normalize return values to `Codes` or `"NOT_SPAM"`.
|
||||||
|
if res is self.NOT_SPAM:
|
||||||
|
continue
|
||||||
|
elif isinstance(res, synapse.api.errors.Codes):
|
||||||
|
return res, {}
|
||||||
|
elif (
|
||||||
|
isinstance(res, tuple)
|
||||||
|
and len(res) == 2
|
||||||
|
and isinstance(res[0], synapse.api.errors.Codes)
|
||||||
|
and isinstance(res[1], dict)
|
||||||
|
):
|
||||||
|
return res
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Module returned invalid value, rejecting login as spam"
|
||||||
|
)
|
||||||
|
return synapse.api.errors.Codes.FORBIDDEN, {}
|
||||||
|
|
||||||
|
return self.NOT_SPAM
|
||||||
|
@ -50,7 +50,7 @@ from synapse.http.servlet import (
|
|||||||
parse_json_object_from_request,
|
parse_json_object_from_request,
|
||||||
parse_string,
|
parse_string,
|
||||||
)
|
)
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import RequestInfo, SynapseRequest
|
||||||
from synapse.rest.client._base import client_patterns
|
from synapse.rest.client._base import client_patterns
|
||||||
from synapse.rest.well_known import WellKnownBuilder
|
from synapse.rest.well_known import WellKnownBuilder
|
||||||
from synapse.types import JsonDict, UserID
|
from synapse.types import JsonDict, UserID
|
||||||
@ -114,6 +114,7 @@ class LoginRestServlet(RestServlet):
|
|||||||
self.auth_handler = self.hs.get_auth_handler()
|
self.auth_handler = self.hs.get_auth_handler()
|
||||||
self.registration_handler = hs.get_registration_handler()
|
self.registration_handler = hs.get_registration_handler()
|
||||||
self._sso_handler = hs.get_sso_handler()
|
self._sso_handler = hs.get_sso_handler()
|
||||||
|
self._spam_checker = hs.get_module_api_callbacks().spam_checker
|
||||||
|
|
||||||
self._well_known_builder = WellKnownBuilder(hs)
|
self._well_known_builder = WellKnownBuilder(hs)
|
||||||
self._address_ratelimiter = Ratelimiter(
|
self._address_ratelimiter = Ratelimiter(
|
||||||
@ -197,6 +198,8 @@ class LoginRestServlet(RestServlet):
|
|||||||
self._refresh_tokens_enabled and client_requested_refresh_token
|
self._refresh_tokens_enabled and client_requested_refresh_token
|
||||||
)
|
)
|
||||||
|
|
||||||
|
request_info = request.request_info()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
|
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
@ -216,6 +219,7 @@ class LoginRestServlet(RestServlet):
|
|||||||
login_submission,
|
login_submission,
|
||||||
appservice,
|
appservice,
|
||||||
should_issue_refresh_token=should_issue_refresh_token,
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
|
request_info=request_info,
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
self.jwt_enabled
|
self.jwt_enabled
|
||||||
@ -227,6 +231,7 @@ class LoginRestServlet(RestServlet):
|
|||||||
result = await self._do_jwt_login(
|
result = await self._do_jwt_login(
|
||||||
login_submission,
|
login_submission,
|
||||||
should_issue_refresh_token=should_issue_refresh_token,
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
|
request_info=request_info,
|
||||||
)
|
)
|
||||||
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
||||||
await self._address_ratelimiter.ratelimit(
|
await self._address_ratelimiter.ratelimit(
|
||||||
@ -235,6 +240,7 @@ class LoginRestServlet(RestServlet):
|
|||||||
result = await self._do_token_login(
|
result = await self._do_token_login(
|
||||||
login_submission,
|
login_submission,
|
||||||
should_issue_refresh_token=should_issue_refresh_token,
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
|
request_info=request_info,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await self._address_ratelimiter.ratelimit(
|
await self._address_ratelimiter.ratelimit(
|
||||||
@ -243,6 +249,7 @@ class LoginRestServlet(RestServlet):
|
|||||||
result = await self._do_other_login(
|
result = await self._do_other_login(
|
||||||
login_submission,
|
login_submission,
|
||||||
should_issue_refresh_token=should_issue_refresh_token,
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
|
request_info=request_info,
|
||||||
)
|
)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise SynapseError(400, "Missing JSON keys.")
|
raise SynapseError(400, "Missing JSON keys.")
|
||||||
@ -265,6 +272,8 @@ class LoginRestServlet(RestServlet):
|
|||||||
login_submission: JsonDict,
|
login_submission: JsonDict,
|
||||||
appservice: ApplicationService,
|
appservice: ApplicationService,
|
||||||
should_issue_refresh_token: bool = False,
|
should_issue_refresh_token: bool = False,
|
||||||
|
*,
|
||||||
|
request_info: RequestInfo,
|
||||||
) -> LoginResponse:
|
) -> LoginResponse:
|
||||||
identifier = login_submission.get("identifier")
|
identifier = login_submission.get("identifier")
|
||||||
logger.info("Got appservice login request with identifier: %r", identifier)
|
logger.info("Got appservice login request with identifier: %r", identifier)
|
||||||
@ -300,10 +309,15 @@ class LoginRestServlet(RestServlet):
|
|||||||
# The user represented by an appservice's configured sender_localpart
|
# The user represented by an appservice's configured sender_localpart
|
||||||
# is not actually created in Synapse.
|
# is not actually created in Synapse.
|
||||||
should_check_deactivated=qualified_user_id != appservice.sender,
|
should_check_deactivated=qualified_user_id != appservice.sender,
|
||||||
|
request_info=request_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _do_other_login(
|
async def _do_other_login(
|
||||||
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
|
self,
|
||||||
|
login_submission: JsonDict,
|
||||||
|
should_issue_refresh_token: bool = False,
|
||||||
|
*,
|
||||||
|
request_info: RequestInfo,
|
||||||
) -> LoginResponse:
|
) -> LoginResponse:
|
||||||
"""Handle non-token/saml/jwt logins
|
"""Handle non-token/saml/jwt logins
|
||||||
|
|
||||||
@ -333,6 +347,7 @@ class LoginRestServlet(RestServlet):
|
|||||||
login_submission,
|
login_submission,
|
||||||
callback,
|
callback,
|
||||||
should_issue_refresh_token=should_issue_refresh_token,
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
|
request_info=request_info,
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -347,6 +362,8 @@ class LoginRestServlet(RestServlet):
|
|||||||
should_issue_refresh_token: bool = False,
|
should_issue_refresh_token: bool = False,
|
||||||
auth_provider_session_id: Optional[str] = None,
|
auth_provider_session_id: Optional[str] = None,
|
||||||
should_check_deactivated: bool = True,
|
should_check_deactivated: bool = True,
|
||||||
|
*,
|
||||||
|
request_info: RequestInfo,
|
||||||
) -> LoginResponse:
|
) -> LoginResponse:
|
||||||
"""Called when we've successfully authed the user and now need to
|
"""Called when we've successfully authed the user and now need to
|
||||||
actually login them in (e.g. create devices). This gets called on
|
actually login them in (e.g. create devices). This gets called on
|
||||||
@ -371,6 +388,7 @@ class LoginRestServlet(RestServlet):
|
|||||||
|
|
||||||
This exists purely for appservice's configured sender_localpart
|
This exists purely for appservice's configured sender_localpart
|
||||||
which doesn't have an associated user in the database.
|
which doesn't have an associated user in the database.
|
||||||
|
request_info: The user agent/IP address of the user.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary of account information after successful login.
|
Dictionary of account information after successful login.
|
||||||
@ -417,6 +435,22 @@ class LoginRestServlet(RestServlet):
|
|||||||
)
|
)
|
||||||
|
|
||||||
initial_display_name = login_submission.get("initial_device_display_name")
|
initial_display_name = login_submission.get("initial_device_display_name")
|
||||||
|
spam_check = await self._spam_checker.check_login_for_spam(
|
||||||
|
user_id,
|
||||||
|
device_id=device_id,
|
||||||
|
initial_display_name=initial_display_name,
|
||||||
|
request_info=[(request_info.user_agent, request_info.ip)],
|
||||||
|
auth_provider_id=auth_provider_id,
|
||||||
|
)
|
||||||
|
if spam_check != self._spam_checker.NOT_SPAM:
|
||||||
|
logger.info("Blocking login due to spam checker")
|
||||||
|
raise SynapseError(
|
||||||
|
403,
|
||||||
|
msg="Login was blocked by the server",
|
||||||
|
errcode=spam_check[0],
|
||||||
|
additional_fields=spam_check[1],
|
||||||
|
)
|
||||||
|
|
||||||
(
|
(
|
||||||
device_id,
|
device_id,
|
||||||
access_token,
|
access_token,
|
||||||
@ -451,7 +485,11 @@ class LoginRestServlet(RestServlet):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
async def _do_token_login(
|
async def _do_token_login(
|
||||||
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
|
self,
|
||||||
|
login_submission: JsonDict,
|
||||||
|
should_issue_refresh_token: bool = False,
|
||||||
|
*,
|
||||||
|
request_info: RequestInfo,
|
||||||
) -> LoginResponse:
|
) -> LoginResponse:
|
||||||
"""
|
"""
|
||||||
Handle token login.
|
Handle token login.
|
||||||
@ -474,10 +512,15 @@ class LoginRestServlet(RestServlet):
|
|||||||
auth_provider_id=res.auth_provider_id,
|
auth_provider_id=res.auth_provider_id,
|
||||||
should_issue_refresh_token=should_issue_refresh_token,
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
auth_provider_session_id=res.auth_provider_session_id,
|
auth_provider_session_id=res.auth_provider_session_id,
|
||||||
|
request_info=request_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _do_jwt_login(
|
async def _do_jwt_login(
|
||||||
self, login_submission: JsonDict, should_issue_refresh_token: bool = False
|
self,
|
||||||
|
login_submission: JsonDict,
|
||||||
|
should_issue_refresh_token: bool = False,
|
||||||
|
*,
|
||||||
|
request_info: RequestInfo,
|
||||||
) -> LoginResponse:
|
) -> LoginResponse:
|
||||||
"""
|
"""
|
||||||
Handle the custom JWT login.
|
Handle the custom JWT login.
|
||||||
@ -496,6 +539,7 @@ class LoginRestServlet(RestServlet):
|
|||||||
login_submission,
|
login_submission,
|
||||||
create_non_existent_users=True,
|
create_non_existent_users=True,
|
||||||
should_issue_refresh_token=should_issue_refresh_token,
|
should_issue_refresh_token=should_issue_refresh_token,
|
||||||
|
request_info=request_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,11 +13,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import time
|
import time
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Collection, Dict, List, Optional, Tuple, Union
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
import pymacaroons
|
import pymacaroons
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
@ -26,11 +27,12 @@ import synapse.rest.admin
|
|||||||
from synapse.api.constants import ApprovalNoticeMedium, LoginType
|
from synapse.api.constants import ApprovalNoticeMedium, LoginType
|
||||||
from synapse.api.errors import Codes
|
from synapse.api.errors import Codes
|
||||||
from synapse.appservice import ApplicationService
|
from synapse.appservice import ApplicationService
|
||||||
|
from synapse.module_api import ModuleApi
|
||||||
from synapse.rest.client import devices, login, logout, register
|
from synapse.rest.client import devices, login, logout, register
|
||||||
from synapse.rest.client.account import WhoamiRestServlet
|
from synapse.rest.client.account import WhoamiRestServlet
|
||||||
from synapse.rest.synapse.client import build_synapse_client_resource_tree
|
from synapse.rest.synapse.client import build_synapse_client_resource_tree
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.types import create_requester
|
from synapse.types import JsonDict, create_requester
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
@ -88,6 +90,56 @@ ADDITIONAL_LOGIN_FLOWS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestSpamChecker:
|
||||||
|
def __init__(self, config: None, api: ModuleApi):
|
||||||
|
api.register_spam_checker_callbacks(
|
||||||
|
check_login_for_spam=self.check_login_for_spam,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_config(config: JsonDict) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def check_login_for_spam(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
device_id: Optional[str],
|
||||||
|
initial_display_name: Optional[str],
|
||||||
|
request_info: Collection[Tuple[Optional[str], str]],
|
||||||
|
auth_provider_id: Optional[str] = None,
|
||||||
|
) -> Union[
|
||||||
|
Literal["NOT_SPAM"],
|
||||||
|
Tuple["synapse.module_api.errors.Codes", JsonDict],
|
||||||
|
]:
|
||||||
|
return "NOT_SPAM"
|
||||||
|
|
||||||
|
|
||||||
|
class DenyAllSpamChecker:
|
||||||
|
def __init__(self, config: None, api: ModuleApi):
|
||||||
|
api.register_spam_checker_callbacks(
|
||||||
|
check_login_for_spam=self.check_login_for_spam,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_config(config: JsonDict) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def check_login_for_spam(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
device_id: Optional[str],
|
||||||
|
initial_display_name: Optional[str],
|
||||||
|
request_info: Collection[Tuple[Optional[str], str]],
|
||||||
|
auth_provider_id: Optional[str] = None,
|
||||||
|
) -> Union[
|
||||||
|
Literal["NOT_SPAM"],
|
||||||
|
Tuple["synapse.module_api.errors.Codes", JsonDict],
|
||||||
|
]:
|
||||||
|
# Return an odd set of values to ensure that they get correctly passed
|
||||||
|
# to the client.
|
||||||
|
return Codes.LIMIT_EXCEEDED, {"extra": "value"}
|
||||||
|
|
||||||
|
|
||||||
class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
servlets = [
|
servlets = [
|
||||||
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||||
@ -469,6 +521,58 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"modules": [
|
||||||
|
{
|
||||||
|
"module": TestSpamChecker.__module__
|
||||||
|
+ "."
|
||||||
|
+ TestSpamChecker.__qualname__
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_spam_checker_allow(self) -> None:
|
||||||
|
"""Check that that adding a spam checker doesn't break login."""
|
||||||
|
self.register_user("kermit", "monkey")
|
||||||
|
|
||||||
|
body = {"type": "m.login.password", "user": "kermit", "password": "monkey"}
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"/_matrix/client/r0/login",
|
||||||
|
body,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, 200, channel.result)
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"modules": [
|
||||||
|
{
|
||||||
|
"module": DenyAllSpamChecker.__module__
|
||||||
|
+ "."
|
||||||
|
+ DenyAllSpamChecker.__qualname__
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_spam_checker_deny(self) -> None:
|
||||||
|
"""Check that login"""
|
||||||
|
|
||||||
|
self.register_user("kermit", "monkey")
|
||||||
|
|
||||||
|
body = {"type": "m.login.password", "user": "kermit", "password": "monkey"}
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
"POST",
|
||||||
|
"/_matrix/client/r0/login",
|
||||||
|
body,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, 403, channel.result)
|
||||||
|
self.assertDictContainsSubset(
|
||||||
|
{"errcode": Codes.LIMIT_EXCEEDED, "extra": "value"}, channel.json_body
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
|
@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
|
||||||
class MultiSSOTestCase(unittest.HomeserverTestCase):
|
class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||||
|
Loading…
Reference in New Issue
Block a user