Deduplicate is_server_notices_room. (#13780)

This commit is contained in:
reivilibre 2022-09-14 15:53:18 +00:00 committed by GitHub
parent cf65433de2
commit 6302753012
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 20 additions and 18 deletions

1
changelog.d/13780.misc Normal file
View File

@ -0,0 +1 @@
Deduplicate `is_server_notices_room`.

View File

@ -752,20 +752,12 @@ class EventCreationHandler:
if builder.type == EventTypes.Member: if builder.type == EventTypes.Member:
membership = builder.content.get("membership", None) membership = builder.content.get("membership", None)
if membership == Membership.JOIN: if membership == Membership.JOIN:
return await self._is_server_notices_room(builder.room_id) return await self.store.is_server_notice_room(builder.room_id)
elif membership == Membership.LEAVE: elif membership == Membership.LEAVE:
# the user is always allowed to leave (but not kick people) # the user is always allowed to leave (but not kick people)
return builder.state_key == requester.user.to_string() return builder.state_key == requester.user.to_string()
return False return False
async def _is_server_notices_room(self, room_id: str) -> bool:
if self.config.servernotices.server_notices_mxid is None:
return False
is_server_notices_room = await self.store.check_local_user_in_room(
user_id=self.config.servernotices.server_notices_mxid, room_id=room_id
)
return is_server_notices_room
async def assert_accepted_privacy_policy(self, requester: Requester) -> None: async def assert_accepted_privacy_policy(self, requester: Requester) -> None:
"""Check if a user has accepted the privacy policy """Check if a user has accepted the privacy policy

View File

@ -837,7 +837,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
old_membership == Membership.INVITE old_membership == Membership.INVITE
and effective_membership_state == Membership.LEAVE and effective_membership_state == Membership.LEAVE
): ):
is_blocked = await self._is_server_notice_room(room_id) is_blocked = await self.store.is_server_notice_room(room_id)
if is_blocked: if is_blocked:
raise SynapseError( raise SynapseError(
HTTPStatus.FORBIDDEN, HTTPStatus.FORBIDDEN,
@ -1617,14 +1617,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
return False return False
async def _is_server_notice_room(self, room_id: str) -> bool:
if self._server_notices_mxid is None:
return False
is_server_notices_room = await self.store.check_local_user_in_room(
user_id=self._server_notices_mxid, room_id=room_id
)
return is_server_notices_room
class RoomMemberMasterHandler(RoomMemberHandler): class RoomMemberMasterHandler(RoomMemberHandler):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):

View File

@ -88,6 +88,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# at a time. Keyed by room_id. # at a time. Keyed by room_id.
self._joined_host_linearizer = Linearizer("_JoinedHostsCache") self._joined_host_linearizer = Linearizer("_JoinedHostsCache")
self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
if ( if (
self.hs.config.worker.run_background_tasks self.hs.config.worker.run_background_tasks
and self.hs.config.metrics.metrics_flags.known_servers and self.hs.config.metrics.metrics_flags.known_servers
@ -504,6 +506,21 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return membership == Membership.JOIN return membership == Membership.JOIN
async def is_server_notice_room(self, room_id: str) -> bool:
"""
Determines whether the given room is a 'Server Notices' room, used for
sending server notices to a user.
This is determined by seeing whether the server notices user is present
in the room.
"""
if self._server_notices_mxid is None:
return False
is_server_notices_room = await self.check_local_user_in_room(
user_id=self._server_notices_mxid, room_id=room_id
)
return is_server_notices_room
async def get_local_current_membership_for_user_in_room( async def get_local_current_membership_for_user_in_room(
self, user_id: str, room_id: str self, user_id: str, room_id: str
) -> Tuple[Optional[str], Optional[str]]: ) -> Tuple[Optional[str], Optional[str]]: