Port the ThirdPartyEventRules module interface to the new generic interface (#10386)

Port the third-party event rules interface to the generic module interface introduced in v1.37.0
This commit is contained in:
Brendan Abolivier 2021-07-20 12:39:46 +02:00 committed by GitHub
parent f3ac9c6750
commit a743bf4694
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 402 additions and 107 deletions

View File

@ -0,0 +1 @@
The third-party event rules module interface is deprecated in favour of the generic module interface introduced in Synapse v1.37.0. See the [upgrade notes](https://matrix-org.github.io/synapse/latest/upgrade.html#upgrading-to-v1390) for more information.

View File

@ -186,7 +186,7 @@ The arguments passed to this callback are:
```python ```python
async def check_media_file_for_spam( async def check_media_file_for_spam(
file_wrapper: "synapse.rest.media.v1.media_storage.ReadableFileWrapper", file_wrapper: "synapse.rest.media.v1.media_storage.ReadableFileWrapper",
file_info: "synapse.rest.media.v1._base.FileInfo" file_info: "synapse.rest.media.v1._base.FileInfo",
) -> bool ) -> bool
``` ```
@ -223,6 +223,66 @@ Called after successfully registering a user, in case the module needs to perfor
operations to keep track of them. (e.g. add them to a database table). The user is operations to keep track of them. (e.g. add them to a database table). The user is
represented by their Matrix user ID. represented by their Matrix user ID.
#### Third party rules callbacks
Third party rules callbacks allow module developers to add extra checks to verify the
validity of incoming events. Third party event rules callbacks can be registered using
the module API's `register_third_party_rules_callbacks` method.
The available third party rules callbacks are:
```python
async def check_event_allowed(
event: "synapse.events.EventBase",
state_events: "synapse.types.StateMap",
) -> Tuple[bool, Optional[dict]]
```
**<span style="color:red">
This callback is very experimental and can and will break without notice. Module developers
are encouraged to implement `check_event_for_spam` from the spam checker category instead.
</span>**
Called when processing any incoming event, with the event and a `StateMap`
representing the current state of the room the event is being sent into. A `StateMap` is
a dictionary that maps tuples containing an event type and a state key to the
corresponding state event. For example retrieving the room's `m.room.create` event from
the `state_events` argument would look like this: `state_events.get(("m.room.create", ""))`.
The module must return a boolean indicating whether the event can be allowed.
Note that this callback function processes incoming events coming via federation
traffic (on top of client traffic). This means denying an event might cause the local
copy of the room's history to diverge from that of remote servers. This may cause
federation issues in the room. It is strongly recommended to only deny events using this
callback function if the sender is a local user, or in a private federation in which all
servers are using the same module, with the same configuration.
If the boolean returned by the module is `True`, it may also tell Synapse to replace the
event with new data by returning the new event's data as a dictionary. In order to do
that, it is recommended the module calls `event.get_dict()` to get the current event as a
dictionary, and modify the returned dictionary accordingly.
Note that replacing the event only works for events sent by local users, not for events
received over federation.
```python
async def on_create_room(
requester: "synapse.types.Requester",
request_content: dict,
is_requester_admin: bool,
) -> None
```
Called when processing a room creation request, with the `Requester` object for the user
performing the request, a dictionary representing the room creation request's JSON body
(see [the spec](https://matrix.org/docs/spec/client_server/latest#post-matrix-client-r0-createroom)
for a list of possible parameters), and a boolean indicating whether the user performing
the request is a server admin.
Modules can modify the `request_content` (by e.g. adding events to its `initial_state`),
or deny the room's creation by raising a `module_api.errors.SynapseError`.
### Porting an existing module that uses the old interface ### Porting an existing module that uses the old interface
In order to port a module that uses Synapse's old module interface, its author needs to: In order to port a module that uses Synapse's old module interface, its author needs to:

View File

@ -2654,19 +2654,6 @@ stats:
# action: allow # action: allow
# Server admins can define a Python module that implements extra rules for
# allowing or denying incoming events. In order to work, this module needs to
# override the methods defined in synapse/events/third_party_rules.py.
#
# This feature is designed to be used in closed federations only, where each
# participating server enforces the same rules.
#
#third_party_event_rules:
# module: "my_custom_project.SuperRulesSet"
# config:
# example_option: 'things'
## Opentracing ## ## Opentracing ##
# These settings enable opentracing, which implements distributed tracing. # These settings enable opentracing, which implements distributed tracing.

View File

@ -86,6 +86,19 @@ process, for example:
``` ```
# Upgrading to v1.39.0
## Deprecation of the current third-party rules module interface
The current third-party rules module interface is deprecated in favour of the new generic
modules system introduced in Synapse v1.37.0. Authors of third-party rules modules can refer
to [this documentation](modules.md#porting-an-existing-module-that-uses-the-old-interface)
to update their modules. Synapse administrators can refer to [this documentation](modules.md#using-modules)
to update their configuration once the modules they are using have been updated.
We plan to remove support for the current third-party rules interface in September 2021.
# Upgrading to v1.38.0 # Upgrading to v1.38.0
## Re-indexing of `events` table on Postgres databases ## Re-indexing of `events` table on Postgres databases

View File

@ -38,6 +38,7 @@ from synapse.app.phone_stats_home import start_phone_stats_home
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory from synapse.crypto import context_factory
from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.logging.context import PreserveLoggingContext from synapse.logging.context import PreserveLoggingContext
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.metrics.jemalloc import setup_jemalloc_stats from synapse.metrics.jemalloc import setup_jemalloc_stats
@ -368,6 +369,7 @@ async def start(hs: "HomeServer"):
module(config=config, api=module_api) module(config=config, api=module_api)
load_legacy_spam_checkers(hs) load_legacy_spam_checkers(hs)
load_legacy_third_party_event_rules(hs)
# If we've configured an expiry time for caches, start the background job now. # If we've configured an expiry time for caches, start the background job now.
setup_expire_lru_cache_entries(hs) setup_expire_lru_cache_entries(hs)

View File

@ -28,18 +28,3 @@ class ThirdPartyRulesConfig(Config):
self.third_party_event_rules = load_module( self.third_party_event_rules = load_module(
provider, ("third_party_event_rules",) provider, ("third_party_event_rules",)
) )
def generate_config_section(self, **kwargs):
return """\
# Server admins can define a Python module that implements extra rules for
# allowing or denying incoming events. In order to work, this module needs to
# override the methods defined in synapse/events/third_party_rules.py.
#
# This feature is designed to be used in closed federations only, where each
# participating server enforces the same rules.
#
#third_party_event_rules:
# module: "my_custom_project.SuperRulesSet"
# config:
# example_option: 'things'
"""

View File

@ -11,16 +11,124 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple
from typing import TYPE_CHECKING, Union from synapse.api.errors import SynapseError
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.types import Requester, StateMap from synapse.types import Requester, StateMap
from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
logger = logging.getLogger(__name__)
CHECK_EVENT_ALLOWED_CALLBACK = Callable[
[EventBase, StateMap[EventBase]], Awaitable[Tuple[bool, Optional[dict]]]
]
ON_CREATE_ROOM_CALLBACK = Callable[[Requester, dict, bool], Awaitable]
CHECK_THREEPID_CAN_BE_INVITED_CALLBACK = Callable[
[str, str, StateMap[EventBase]], Awaitable[bool]
]
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[
[str, StateMap[EventBase], str], Awaitable[bool]
]
def load_legacy_third_party_event_rules(hs: "HomeServer"):
"""Wrapper that loads a third party event rules module configured using the old
configuration, and registers the hooks they implement.
"""
if hs.config.third_party_event_rules is None:
return
module, config = hs.config.third_party_event_rules
api = hs.get_module_api()
third_party_rules = module(config=config, module_api=api)
# The known hooks. If a module implements a method which name appears in this set,
# we'll want to register it.
third_party_event_rules_methods = {
"check_event_allowed",
"on_create_room",
"check_threepid_can_be_invited",
"check_visibility_can_be_modified",
}
def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
# f might be None if the callback isn't implemented by the module. In this
# case we don't want to register a callback at all so we return None.
if f is None:
return None
# We return a separate wrapper for these methods because, in order to wrap them
# correctly, we need to await its result. Therefore it doesn't make a lot of
# sense to make it go through the run() wrapper.
if f.__name__ == "check_event_allowed":
# We need to wrap check_event_allowed because its old form would return either
# a boolean or a dict, but now we want to return the dict separately from the
# boolean.
async def wrap_check_event_allowed(
event: EventBase,
state_events: StateMap[EventBase],
) -> Tuple[bool, Optional[dict]]:
# We've already made sure f is not None above, but mypy doesn't do well
# across function boundaries so we need to tell it f is definitely not
# None.
assert f is not None
res = await f(event, state_events)
if isinstance(res, dict):
return True, res
else:
return res, None
return wrap_check_event_allowed
if f.__name__ == "on_create_room":
# We need to wrap on_create_room because its old form would return a boolean
# if the room creation is denied, but now we just want it to raise an
# exception.
async def wrap_on_create_room(
requester: Requester, config: dict, is_requester_admin: bool
) -> None:
# We've already made sure f is not None above, but mypy doesn't do well
# across function boundaries so we need to tell it f is definitely not
# None.
assert f is not None
res = await f(requester, config, is_requester_admin)
if res is False:
raise SynapseError(
403,
"Room creation forbidden with these parameters",
)
return wrap_on_create_room
def run(*args, **kwargs):
# mypy doesn't do well across function boundaries so we need to tell it
# f is definitely not None.
assert f is not None
return maybe_awaitable(f(*args, **kwargs))
return run
# Register the hooks through the module API.
hooks = {
hook: async_wrapper(getattr(third_party_rules, hook, None))
for hook in third_party_event_rules_methods
}
api.register_third_party_rules_callbacks(**hooks)
class ThirdPartyEventRules: class ThirdPartyEventRules:
"""Allows server admins to provide a Python module implementing an extra """Allows server admins to provide a Python module implementing an extra
@ -35,36 +143,65 @@ class ThirdPartyEventRules:
self.store = hs.get_datastore() self.store = hs.get_datastore()
module = None self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = []
config = None self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = []
if hs.config.third_party_event_rules: self._check_threepid_can_be_invited_callbacks: List[
module, config = hs.config.third_party_event_rules CHECK_THREEPID_CAN_BE_INVITED_CALLBACK
] = []
self._check_visibility_can_be_modified_callbacks: List[
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
] = []
if module is not None: def register_third_party_rules_callbacks(
self.third_party_rules = module( self,
config=config, check_event_allowed: Optional[CHECK_EVENT_ALLOWED_CALLBACK] = None,
module_api=hs.get_module_api(), on_create_room: Optional[ON_CREATE_ROOM_CALLBACK] = None,
check_threepid_can_be_invited: Optional[
CHECK_THREEPID_CAN_BE_INVITED_CALLBACK
] = None,
check_visibility_can_be_modified: Optional[
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
] = None,
):
"""Register callbacks from modules for each hook."""
if check_event_allowed is not None:
self._check_event_allowed_callbacks.append(check_event_allowed)
if on_create_room is not None:
self._on_create_room_callbacks.append(on_create_room)
if check_threepid_can_be_invited is not None:
self._check_threepid_can_be_invited_callbacks.append(
check_threepid_can_be_invited,
)
if check_visibility_can_be_modified is not None:
self._check_visibility_can_be_modified_callbacks.append(
check_visibility_can_be_modified,
) )
async def check_event_allowed( async def check_event_allowed(
self, event: EventBase, context: EventContext self, event: EventBase, context: EventContext
) -> Union[bool, dict]: ) -> Tuple[bool, Optional[dict]]:
"""Check if a provided event should be allowed in the given context. """Check if a provided event should be allowed in the given context.
The module can return: The module can return:
* True: the event is allowed. * True: the event is allowed.
* False: the event is not allowed, and should be rejected with M_FORBIDDEN. * False: the event is not allowed, and should be rejected with M_FORBIDDEN.
* a dict: replacement event data.
If the event is allowed, the module can also return a dictionary to use as a
replacement for the event.
Args: Args:
event: The event to be checked. event: The event to be checked.
context: The context of the event. context: The context of the event.
Returns: Returns:
The result from the ThirdPartyRules module, as above The result from the ThirdPartyRules module, as above.
""" """
if self.third_party_rules is None: # Bail out early without hitting the store if we don't have any callbacks to run.
return True if len(self._check_event_allowed_callbacks) == 0:
return True, None
prev_state_ids = await context.get_prev_state_ids() prev_state_ids = await context.get_prev_state_ids()
@ -77,29 +214,46 @@ class ThirdPartyEventRules:
# the hashes and signatures. # the hashes and signatures.
event.freeze() event.freeze()
return await self.third_party_rules.check_event_allowed(event, state_events) for callback in self._check_event_allowed_callbacks:
try:
res, replacement_data = await callback(event, state_events)
except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)
continue
# Return if the event shouldn't be allowed or if the module came up with a
# replacement dict for the event.
if res is False:
return res, None
elif isinstance(replacement_data, dict):
return True, replacement_data
return True, None
async def on_create_room( async def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool self, requester: Requester, config: dict, is_requester_admin: bool
) -> bool: ) -> None:
"""Intercept requests to create room to allow, deny or update the """Intercept requests to create room to maybe deny it (via an exception) or
request config. update the request config.
Args: Args:
requester requester
config: The creation config from the client. config: The creation config from the client.
is_requester_admin: If the requester is an admin is_requester_admin: If the requester is an admin
Returns:
Whether room creation is allowed or denied.
""" """
for callback in self._on_create_room_callbacks:
try:
await callback(requester, config, is_requester_admin)
except Exception as e:
# Don't silence the errors raised by this callback since we expect it to
# raise an exception to deny the creation of the room; instead make sure
# it's a SynapseError we can send to clients.
if not isinstance(e, SynapseError):
e = SynapseError(
403, "Room creation forbidden with these parameters"
)
if self.third_party_rules is None: raise e
return True
return await self.third_party_rules.on_create_room(
requester, config, is_requester_admin
)
async def check_threepid_can_be_invited( async def check_threepid_can_be_invited(
self, medium: str, address: str, room_id: str self, medium: str, address: str, room_id: str
@ -114,15 +268,20 @@ class ThirdPartyEventRules:
Returns: Returns:
True if the 3PID can be invited, False if not. True if the 3PID can be invited, False if not.
""" """
# Bail out early without hitting the store if we don't have any callbacks to run.
if self.third_party_rules is None: if len(self._check_threepid_can_be_invited_callbacks) == 0:
return True return True
state_events = await self._get_state_map_for_room(room_id) state_events = await self._get_state_map_for_room(room_id)
return await self.third_party_rules.check_threepid_can_be_invited( for callback in self._check_threepid_can_be_invited_callbacks:
medium, address, state_events try:
) if await callback(medium, address, state_events) is False:
return False
except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)
return True
async def check_visibility_can_be_modified( async def check_visibility_can_be_modified(
self, room_id: str, new_visibility: str self, room_id: str, new_visibility: str
@ -137,18 +296,20 @@ class ThirdPartyEventRules:
Returns: Returns:
True if the room's visibility can be modified, False if not. True if the room's visibility can be modified, False if not.
""" """
if self.third_party_rules is None: # Bail out early without hitting the store if we don't have any callback
return True if len(self._check_visibility_can_be_modified_callbacks) == 0:
check_func = getattr(
self.third_party_rules, "check_visibility_can_be_modified", None
)
if not check_func or not callable(check_func):
return True return True
state_events = await self._get_state_map_for_room(room_id) state_events = await self._get_state_map_for_room(room_id)
return await check_func(room_id, state_events, new_visibility) for callback in self._check_visibility_can_be_modified_callbacks:
try:
if await callback(room_id, state_events, new_visibility) is False:
return False
except Exception as e:
logger.warning("Failed to run module API callback %s: %s", callback, e)
return True
async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]: async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]:
"""Given a room ID, return the state events of that room. """Given a room ID, return the state events of that room.

View File

@ -1934,7 +1934,7 @@ class FederationHandler(BaseHandler):
builder=builder builder=builder
) )
event_allowed = await self.third_party_event_rules.check_event_allowed( event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
event, context event, context
) )
if not event_allowed: if not event_allowed:
@ -2026,7 +2026,7 @@ class FederationHandler(BaseHandler):
# for knock events, we run the third-party event rules. It's not entirely clear # for knock events, we run the third-party event rules. It's not entirely clear
# why we don't do this for other sorts of membership events. # why we don't do this for other sorts of membership events.
if event.membership == Membership.KNOCK: if event.membership == Membership.KNOCK:
event_allowed = await self.third_party_event_rules.check_event_allowed( event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
event, context event, context
) )
if not event_allowed: if not event_allowed:

View File

@ -949,10 +949,10 @@ class EventCreationHandler:
if requester: if requester:
context.app_service = requester.app_service context.app_service = requester.app_service
third_party_result = await self.third_party_event_rules.check_event_allowed( res, new_content = await self.third_party_event_rules.check_event_allowed(
event, context event, context
) )
if not third_party_result: if res is False:
logger.info( logger.info(
"Event %s forbidden by third-party rules", "Event %s forbidden by third-party rules",
event, event,
@ -960,11 +960,11 @@ class EventCreationHandler:
raise SynapseError( raise SynapseError(
403, "This event is not allowed in this context", Codes.FORBIDDEN 403, "This event is not allowed in this context", Codes.FORBIDDEN
) )
elif isinstance(third_party_result, dict): elif new_content is not None:
# the third-party rules want to replace the event. We'll need to build a new # the third-party rules want to replace the event. We'll need to build a new
# event. # event.
event, context = await self._rebuild_event_after_third_party_rules( event, context = await self._rebuild_event_after_third_party_rules(
third_party_result, event new_content, event
) )
self.validator.validate_new(event, self.config) self.validator.validate_new(event, self.config)

View File

@ -618,15 +618,11 @@ class RoomCreationHandler(BaseHandler):
else: else:
is_requester_admin = await self.auth.is_server_admin(requester.user) is_requester_admin = await self.auth.is_server_admin(requester.user)
# Check whether the third party rules allows/changes the room create # Let the third party rules modify the room creation config if needed, or abort
# request. # the room creation entirely with an exception.
event_allowed = await self.third_party_event_rules.on_create_room( await self.third_party_event_rules.on_create_room(
requester, config, is_requester_admin=is_requester_admin requester, config, is_requester_admin=is_requester_admin
) )
if not event_allowed:
raise SynapseError(
403, "You are not permitted to create rooms", Codes.FORBIDDEN
)
if not is_requester_admin and not await self.spam_checker.user_may_create_room( if not is_requester_admin and not await self.spam_checker.user_may_create_room(
user_id user_id

View File

@ -110,6 +110,7 @@ class ModuleApi:
self._spam_checker = hs.get_spam_checker() self._spam_checker = hs.get_spam_checker()
self._account_validity_handler = hs.get_account_validity_handler() self._account_validity_handler = hs.get_account_validity_handler()
self._third_party_event_rules = hs.get_third_party_event_rules()
################################################################################# #################################################################################
# The following methods should only be called during the module's initialisation. # The following methods should only be called during the module's initialisation.
@ -124,6 +125,11 @@ class ModuleApi:
"""Registers callbacks for account validity capabilities.""" """Registers callbacks for account validity capabilities."""
return self._account_validity_handler.register_account_validity_callbacks return self._account_validity_handler.register_account_validity_callbacks
@property
def register_third_party_rules_callbacks(self):
"""Registers callbacks for third party event rules capabilities."""
return self._third_party_event_rules.register_third_party_rules_callbacks
def register_web_resource(self, path: str, resource: IResource): def register_web_resource(self, path: str, resource: IResource):
"""Registers a web resource to be served at the given path. """Registers a web resource to be served at the given path.

View File

@ -16,17 +16,19 @@ from typing import Dict
from unittest.mock import Mock from unittest.mock import Mock
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.module_api import ModuleApi from synapse.module_api import ModuleApi
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client.v1 import login, room from synapse.rest.client.v1 import login, room
from synapse.types import Requester, StateMap from synapse.types import Requester, StateMap
from synapse.util.frozenutils import unfreeze
from tests import unittest from tests import unittest
thread_local = threading.local() thread_local = threading.local()
class ThirdPartyRulesTestModule: class LegacyThirdPartyRulesTestModule:
def __init__(self, config: Dict, module_api: ModuleApi): def __init__(self, config: Dict, module_api: ModuleApi):
# keep a record of the "current" rules module, so that the test can patch # keep a record of the "current" rules module, so that the test can patch
# it if desired. # it if desired.
@ -46,8 +48,26 @@ class ThirdPartyRulesTestModule:
return config return config
def current_rules_module() -> ThirdPartyRulesTestModule: class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule):
return thread_local.rules_module def __init__(self, config: Dict, module_api: ModuleApi):
super().__init__(config, module_api)
def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool
):
return False
class LegacyChangeEvents(LegacyThirdPartyRulesTestModule):
def __init__(self, config: Dict, module_api: ModuleApi):
super().__init__(config, module_api)
async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
d = event.get_dict()
content = unfreeze(event.content)
content["foo"] = "bar"
d["content"] = content
return d
class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
@ -57,20 +77,23 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
room.register_servlets, room.register_servlets,
] ]
def default_config(self): def make_homeserver(self, reactor, clock):
config = super().default_config() hs = self.setup_test_homeserver()
config["third_party_event_rules"] = {
"module": __name__ + ".ThirdPartyRulesTestModule", load_legacy_third_party_event_rules(hs)
"config": {},
} return hs
return config
def prepare(self, reactor, clock, homeserver): def prepare(self, reactor, clock, homeserver):
# Create a user and room to play with during the tests # Create a user and room to play with during the tests
self.user_id = self.register_user("kermit", "monkey") self.user_id = self.register_user("kermit", "monkey")
self.tok = self.login("kermit", "monkey") self.tok = self.login("kermit", "monkey")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) # Some tests might prevent room creation on purpose.
try:
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
except Exception:
pass
def test_third_party_rules(self): def test_third_party_rules(self):
"""Tests that a forbidden event is forbidden from being sent, but an allowed one """Tests that a forbidden event is forbidden from being sent, but an allowed one
@ -79,10 +102,12 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
# patch the rules module with a Mock which will return False for some event # patch the rules module with a Mock which will return False for some event
# types # types
async def check(ev, state): async def check(ev, state):
return ev.type != "foo.bar.forbidden" return ev.type != "foo.bar.forbidden", None
callback = Mock(spec=[], side_effect=check) callback = Mock(spec=[], side_effect=check)
current_rules_module().check_event_allowed = callback self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [
callback
]
channel = self.make_request( channel = self.make_request(
"PUT", "PUT",
@ -116,9 +141,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
# first patch the event checker so that it will try to modify the event # first patch the event checker so that it will try to modify the event
async def check(ev: EventBase, state): async def check(ev: EventBase, state):
ev.content = {"x": "y"} ev.content = {"x": "y"}
return True return True, None
current_rules_module().check_event_allowed = check self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
# now send the event # now send the event
channel = self.make_request( channel = self.make_request(
@ -127,7 +152,19 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
{"x": "x"}, {"x": "x"},
access_token=self.tok, access_token=self.tok,
) )
self.assertEqual(channel.result["code"], b"500", channel.result) # check_event_allowed has some error handling, so it shouldn't 500 just because a
# module did something bad.
self.assertEqual(channel.code, 200, channel.result)
event_id = channel.json_body["event_id"]
channel = self.make_request(
"GET",
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
access_token=self.tok,
)
self.assertEqual(channel.code, 200, channel.result)
ev = channel.json_body
self.assertEqual(ev["content"]["x"], "x")
def test_modify_event(self): def test_modify_event(self):
"""The module can return a modified version of the event""" """The module can return a modified version of the event"""
@ -135,9 +172,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
async def check(ev: EventBase, state): async def check(ev: EventBase, state):
d = ev.get_dict() d = ev.get_dict()
d["content"] = {"x": "y"} d["content"] = {"x": "y"}
return d return True, d
current_rules_module().check_event_allowed = check self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
# now send the event # now send the event
channel = self.make_request( channel = self.make_request(
@ -168,9 +205,9 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
"msgtype": "m.text", "msgtype": "m.text",
"body": d["content"]["body"].upper(), "body": d["content"]["body"].upper(),
} }
return d return True, d
current_rules_module().check_event_allowed = check self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
# Send an event, then edit it. # Send an event, then edit it.
channel = self.make_request( channel = self.make_request(
@ -222,7 +259,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
self.assertEqual(ev["content"]["body"], "EDITED BODY") self.assertEqual(ev["content"]["body"], "EDITED BODY")
def test_send_event(self): def test_send_event(self):
"""Tests that the module can send an event into a room via the module api""" """Tests that a module can send an event into a room via the module api"""
content = { content = {
"msgtype": "m.text", "msgtype": "m.text",
"body": "Hello!", "body": "Hello!",
@ -234,12 +271,59 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
"sender": self.user_id, "sender": self.user_id,
} }
event: EventBase = self.get_success( event: EventBase = self.get_success(
current_rules_module().module_api.create_and_send_event_into_room( self.hs.get_module_api().create_and_send_event_into_room(event_dict)
event_dict
)
) )
self.assertEquals(event.sender, self.user_id) self.assertEquals(event.sender, self.user_id)
self.assertEquals(event.room_id, self.room_id) self.assertEquals(event.room_id, self.room_id)
self.assertEquals(event.type, "m.room.message") self.assertEquals(event.type, "m.room.message")
self.assertEquals(event.content, content) self.assertEquals(event.content, content)
@unittest.override_config(
{
"third_party_event_rules": {
"module": __name__ + ".LegacyChangeEvents",
"config": {},
}
}
)
def test_legacy_check_event_allowed(self):
"""Tests that the wrapper for legacy check_event_allowed callbacks works
correctly.
"""
channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/send/m.room.message/1" % self.room_id,
{
"msgtype": "m.text",
"body": "Original body",
},
access_token=self.tok,
)
self.assertEqual(channel.result["code"], b"200", channel.result)
event_id = channel.json_body["event_id"]
channel = self.make_request(
"GET",
"/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
access_token=self.tok,
)
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertIn("foo", channel.json_body["content"].keys())
self.assertEqual(channel.json_body["content"]["foo"], "bar")
@unittest.override_config(
{
"third_party_event_rules": {
"module": __name__ + ".LegacyDenyNewRooms",
"config": {},
}
}
)
def test_legacy_on_create_room(self):
"""Tests that the wrapper for legacy on_create_room callbacks works
correctly.
"""
self.helper.create_room_as(self.user_id, tok=self.tok, expect_code=403)