Add a callback to react to 3PID associations (#12302)

This commit is contained in:
Brendan Abolivier 2022-03-31 18:27:21 +02:00 committed by GitHub
parent 34a8370d7b
commit 5e88143dff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 92 additions and 0 deletions

View File

@ -0,0 +1 @@
Add a module callback to react to new 3PID (email address, phone number) associations.

View File

@ -247,6 +247,24 @@ admin API.
If multiple modules implement this callback, Synapse runs them all in order. If multiple modules implement this callback, Synapse runs them all in order.
### `on_threepid_bind`
_First introduced in Synapse v1.56.0_
```python
async def on_threepid_bind(user_id: str, medium: str, address: str) -> None:
```
Called after creating an association between a local user and a third-party identifier
(email address, phone number). The module is given the Matrix ID of the user the
association is for, as well as the medium (`email` or `msisdn`) and address of the
third-party identifier.
Note that this callback is _not_ called after a successful association on an _identity
server_.
If multiple modules implement this callback, Synapse runs them all in order.
## Example ## Example
The example below is a module that implements the third-party rules callback The example below is a module that implements the third-party rules callback

View File

@ -42,6 +42,7 @@ CHECK_CAN_SHUTDOWN_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]]
CHECK_CAN_DEACTIVATE_USER_CALLBACK = Callable[[str, bool], Awaitable[bool]] CHECK_CAN_DEACTIVATE_USER_CALLBACK = Callable[[str, bool], Awaitable[bool]]
ON_PROFILE_UPDATE_CALLBACK = Callable[[str, ProfileInfo, bool, bool], Awaitable] ON_PROFILE_UPDATE_CALLBACK = Callable[[str, ProfileInfo, bool, bool], Awaitable]
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK = Callable[[str, bool, bool], Awaitable] ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK = Callable[[str, bool, bool], Awaitable]
ON_THREEPID_BIND_CALLBACK = Callable[[str, str, str], Awaitable]
def load_legacy_third_party_event_rules(hs: "HomeServer") -> None: def load_legacy_third_party_event_rules(hs: "HomeServer") -> None:
@ -169,6 +170,7 @@ class ThirdPartyEventRules:
self._on_user_deactivation_status_changed_callbacks: List[ self._on_user_deactivation_status_changed_callbacks: List[
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK
] = [] ] = []
self._on_threepid_bind_callbacks: List[ON_THREEPID_BIND_CALLBACK] = []
def register_third_party_rules_callbacks( def register_third_party_rules_callbacks(
self, self,
@ -187,6 +189,7 @@ class ThirdPartyEventRules:
on_user_deactivation_status_changed: Optional[ on_user_deactivation_status_changed: Optional[
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK
] = None, ] = None,
on_threepid_bind: Optional[ON_THREEPID_BIND_CALLBACK] = None,
) -> None: ) -> None:
"""Register callbacks from modules for each hook.""" """Register callbacks from modules for each hook."""
if check_event_allowed is not None: if check_event_allowed is not None:
@ -221,6 +224,9 @@ class ThirdPartyEventRules:
on_user_deactivation_status_changed, on_user_deactivation_status_changed,
) )
if on_threepid_bind is not None:
self._on_threepid_bind_callbacks.append(on_threepid_bind)
async def check_event_allowed( async def check_event_allowed(
self, event: EventBase, context: EventContext self, event: EventBase, context: EventContext
) -> Tuple[bool, Optional[dict]]: ) -> Tuple[bool, Optional[dict]]:
@ -479,3 +485,23 @@ class ThirdPartyEventRules:
logger.exception( logger.exception(
"Failed to run module API callback %s: %s", callback, e "Failed to run module API callback %s: %s", callback, e
) )
async def on_threepid_bind(self, user_id: str, medium: str, address: str) -> None:
"""Called after a threepid association has been verified and stored.
Note that this callback is called when an association is created on the
local homeserver, not when it's created on an identity server (and then kept track
of so that it can be unbound on the same IS later on).
Args:
user_id: the user being associated with the threepid.
medium: the threepid's medium.
address: the threepid's address.
"""
for callback in self._on_threepid_bind_callbacks:
try:
await callback(user_id, medium, address)
except Exception as e:
logger.exception(
"Failed to run module API callback %s: %s", callback, e
)

View File

@ -211,6 +211,7 @@ class AuthHandler:
self.macaroon_gen = hs.get_macaroon_generator() self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.auth.password_enabled self._password_enabled = hs.config.auth.password_enabled
self._password_localdb_enabled = hs.config.auth.password_localdb_enabled self._password_localdb_enabled = hs.config.auth.password_localdb_enabled
self._third_party_rules = hs.get_third_party_event_rules()
# Ratelimiter for failed auth during UIA. Uses same ratelimit config # Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`. # as per `rc_login.failed_attempts`.
@ -1505,6 +1506,8 @@ class AuthHandler:
user_id, medium, address, validated_at, self.hs.get_clock().time_msec() user_id, medium, address, validated_at, self.hs.get_clock().time_msec()
) )
await self._third_party_rules.on_threepid_bind(user_id, medium, address)
async def delete_threepid( async def delete_threepid(
self, user_id: str, medium: str, address: str, id_server: Optional[str] = None self, user_id: str, medium: str, address: str, id_server: Optional[str] = None
) -> bool: ) -> bool:

View File

@ -62,6 +62,7 @@ from synapse.events.third_party_rules import (
ON_CREATE_ROOM_CALLBACK, ON_CREATE_ROOM_CALLBACK,
ON_NEW_EVENT_CALLBACK, ON_NEW_EVENT_CALLBACK,
ON_PROFILE_UPDATE_CALLBACK, ON_PROFILE_UPDATE_CALLBACK,
ON_THREEPID_BIND_CALLBACK,
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK, ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK,
) )
from synapse.handlers.account_validity import ( from synapse.handlers.account_validity import (
@ -293,6 +294,7 @@ class ModuleApi:
on_user_deactivation_status_changed: Optional[ on_user_deactivation_status_changed: Optional[
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK
] = None, ] = None,
on_threepid_bind: Optional[ON_THREEPID_BIND_CALLBACK] = None,
) -> None: ) -> None:
"""Registers callbacks for third party event rules capabilities. """Registers callbacks for third party event rules capabilities.
@ -308,6 +310,7 @@ class ModuleApi:
check_can_deactivate_user=check_can_deactivate_user, check_can_deactivate_user=check_can_deactivate_user,
on_profile_update=on_profile_update, on_profile_update=on_profile_update,
on_user_deactivation_status_changed=on_user_deactivation_status_changed, on_user_deactivation_status_changed=on_user_deactivation_status_changed,
on_threepid_bind=on_threepid_bind,
) )
def register_presence_router_callbacks( def register_presence_router_callbacks(

View File

@ -896,3 +896,44 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
# Check that the mock was called with the right room ID # Check that the mock was called with the right room ID
self.assertEqual(args[1], self.room_id) self.assertEqual(args[1], self.room_id)
def test_on_threepid_bind(self) -> None:
"""Tests that the on_threepid_bind module callback is called correctly after
associating a 3PID to an account.
"""
# Register a mocked callback.
threepid_bind_mock = Mock(return_value=make_awaitable(None))
third_party_rules = self.hs.get_third_party_event_rules()
third_party_rules._on_threepid_bind_callbacks.append(threepid_bind_mock)
# Register an admin user.
self.register_user("admin", "password", admin=True)
admin_tok = self.login("admin", "password")
# Also register a normal user we can modify.
user_id = self.register_user("user", "password")
# Add a 3PID to the user.
channel = self.make_request(
"PUT",
"/_synapse/admin/v2/users/%s" % user_id,
{
"threepids": [
{
"medium": "email",
"address": "foo@example.com",
},
],
},
access_token=admin_tok,
)
# Check that the shutdown was blocked
self.assertEqual(channel.code, 200, channel.json_body)
# Check that the mock was called once.
threepid_bind_mock.assert_called_once()
args = threepid_bind_mock.call_args[0]
# Check that the mock was called with the right parameters
self.assertEqual(args, (user_id, "email", "foo@example.com"))