Merge branch 'release-v1.13.0' into rav/fix_dropped_messages

This commit is contained in:
Richard van der Hoff 2020-05-05 22:38:44 +01:00
commit 1242267316
66 changed files with 847 additions and 690 deletions

View File

@ -30,23 +30,24 @@ recursive-include synapse/static *.gif
recursive-include synapse/static *.html recursive-include synapse/static *.html
recursive-include synapse/static *.js recursive-include synapse/static *.js
exclude Dockerfile exclude .codecov.yml
exclude .coveragerc
exclude .dockerignore exclude .dockerignore
exclude test_postgresql.sh
exclude .editorconfig exclude .editorconfig
exclude Dockerfile
exclude mypy.ini
exclude sytest-blacklist exclude sytest-blacklist
exclude test_postgresql.sh
include pyproject.toml include pyproject.toml
recursive-include changelog.d * recursive-include changelog.d *
prune .buildkite prune .buildkite
prune .circleci prune .circleci
prune .codecov.yml
prune .coveragerc
prune .github prune .github
prune contrib
prune debian prune debian
prune demo/etc prune demo/etc
prune docker prune docker
prune mypy.ini
prune snap prune snap
prune stubs prune stubs

View File

@ -75,6 +75,37 @@ for example:
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
Upgrading to v1.13.0
====================
Incorrect database migration in old synapse versions
----------------------------------------------------
A bug was introduced in Synapse 1.4.0 which could cause the room directory to
be incomplete or empty if Synapse was upgraded directly from v1.2.1 or earlier,
to versions between v1.4.0 and v1.12.x.
This will *not* be a problem for Synapse installations which were:
* created at v1.4.0 or later,
* upgraded via v1.3.x, or
* upgraded straight from v1.2.1 or earlier to v1.13.0 or later.
If completeness of the room directory is a concern, installations which are
affected can be repaired as follows:
1. Run the following sql from a `psql` or `sqlite3` console:
.. code:: sql
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
('populate_stats_process_rooms', '{}', 'current_state_events_membership');
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
('populate_stats_process_users', '{}', 'populate_stats_process_rooms');
2. Restart synapse.
Upgrading to v1.12.0 Upgrading to v1.12.0
==================== ====================

1
changelog.d/7172.misc Normal file
View File

@ -0,0 +1 @@
Use `stream.current_token()` and remove `stream_positions()`.

1
changelog.d/7363.misc Normal file
View File

@ -0,0 +1 @@
Convert RegistrationWorkerStore.is_server_admin and dependent code to async/await.

1
changelog.d/7368.bugfix Normal file
View File

@ -0,0 +1 @@
Improve error responses when accessing remote public room lists.

1
changelog.d/7369.misc Normal file
View File

@ -0,0 +1 @@
Thread through instance name to replication client.

1
changelog.d/7387.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug which would cause the room durectory to be incorrectly populated if Synapse was upgraded directly from v1.2.1 or earlier to v1.4.0 or later. Note that this fix does not apply retrospectively; see the [upgrade notes](UPGRADE.rst#upgrading-to-v1130) for more information.

1
changelog.d/7393.bugfix Normal file
View File

@ -0,0 +1 @@
Fix bug in `EventContext.deserialize`.

1
changelog.d/7394.misc Normal file
View File

@ -0,0 +1 @@
Convert synapse.server_notices to async/await.

1
changelog.d/7395.misc Normal file
View File

@ -0,0 +1 @@
Convert synapse.notifier to async/await.

1
changelog.d/7401.feature Normal file
View File

@ -0,0 +1 @@
Add support for running replication over Redis when using workers.

1
changelog.d/7404.misc Normal file
View File

@ -0,0 +1 @@
Fix issues with the Python package manifest.

1
changelog.d/7408.misc Normal file
View File

@ -0,0 +1 @@
Clean up some LoggingContext code.

View File

@ -22,7 +22,10 @@ class RedisProtocol:
def publish(self, channel: str, message: bytes): ... def publish(self, channel: str, message: bytes): ...
class SubscriberProtocol: class SubscriberProtocol:
password: Optional[str]
def subscribe(self, channels: Union[str, List[str]]): ... def subscribe(self, channels: Union[str, List[str]]): ...
def connectionMade(self): ...
def connectionLost(self, reason): ...
def lazyConnection( def lazyConnection(
host: str = ..., host: str = ...,

View File

@ -537,8 +537,7 @@ class Auth(object):
return defer.succeed(auth_ids) return defer.succeed(auth_ids)
@defer.inlineCallbacks async def check_can_change_room_list(self, room_id: str, user: UserID):
def check_can_change_room_list(self, room_id: str, user: UserID):
"""Determine whether the user is allowed to edit the room's entry in the """Determine whether the user is allowed to edit the room's entry in the
published room list. published room list.
@ -547,17 +546,17 @@ class Auth(object):
user user
""" """
is_admin = yield self.is_server_admin(user) is_admin = await self.is_server_admin(user)
if is_admin: if is_admin:
return True return True
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_in_room(room_id, user_id) await self.check_user_in_room(room_id, user_id)
# We currently require the user is a "moderator" in the room. We do this # We currently require the user is a "moderator" in the room. We do this
# by checking if they would (theoretically) be able to change the # by checking if they would (theoretically) be able to change the
# m.room.canonical_alias events # m.room.canonical_alias events
power_level_event = yield self.state.get_current_state( power_level_event = await self.state.get_current_state(
room_id, EventTypes.PowerLevels, "" room_id, EventTypes.PowerLevels, ""
) )

View File

@ -413,12 +413,6 @@ class GenericWorkerTyping(object):
# map room IDs to sets of users currently typing # map room IDs to sets of users currently typing
self._room_typing = {} self._room_typing = {}
def stream_positions(self):
# We must update this typing token from the response of the previous
# sync. In particular, the stream id may "reset" back to zero/a low
# value which we *must* use for the next replication request.
return {"typing": self._latest_room_serial}
def process_replication_rows(self, token, rows): def process_replication_rows(self, token, rows):
if self._latest_room_serial > token: if self._latest_room_serial > token:
# The master has gone backwards. To prevent inconsistent data, just # The master has gone backwards. To prevent inconsistent data, just
@ -652,20 +646,11 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
else: else:
self.send_handler = None self.send_handler = None
async def on_rdata(self, stream_name, token, rows): async def on_rdata(self, stream_name, instance_name, token, rows):
await super(GenericWorkerReplicationHandler, self).on_rdata( await super().on_rdata(stream_name, instance_name, token, rows)
stream_name, token, rows await self._process_and_notify(stream_name, instance_name, token, rows)
)
await self.process_and_notify(stream_name, token, rows)
def get_streams_to_replicate(self): async def _process_and_notify(self, stream_name, instance_name, token, rows):
args = super(GenericWorkerReplicationHandler, self).get_streams_to_replicate()
args.update(self.typing_handler.stream_positions())
if self.send_handler:
args.update(self.send_handler.stream_positions())
return args
async def process_and_notify(self, stream_name, token, rows):
try: try:
if self.send_handler: if self.send_handler:
await self.send_handler.process_replication_rows( await self.send_handler.process_replication_rows(
@ -799,9 +784,6 @@ class FederationSenderHandler(object):
def wake_destination(self, server: str): def wake_destination(self, server: str):
self.federation_sender.wake_destination(server) self.federation_sender.wake_destination(server)
def stream_positions(self):
return {"federation": self.federation_position}
async def process_replication_rows(self, stream_name, token, rows): async def process_replication_rows(self, stream_name, token, rows):
# The federation stream contains things that we want to send out, e.g. # The federation stream contains things that we want to send out, e.g.
# presence, typing, etc. # presence, typing, etc.

View File

@ -322,11 +322,14 @@ class _AsyncEventContextImpl(EventContext):
self._current_state_ids = yield self._storage.state.get_state_ids_for_group( self._current_state_ids = yield self._storage.state.get_state_ids_for_group(
self.state_group self.state_group
) )
if self._prev_state_id and self._event_state_key is not None: if self._event_state_key is not None:
self._prev_state_ids = dict(self._current_state_ids) self._prev_state_ids = dict(self._current_state_ids)
key = (self._event_type, self._event_state_key) key = (self._event_type, self._event_state_key)
self._prev_state_ids[key] = self._prev_state_id if self._prev_state_id:
self._prev_state_ids[key] = self._prev_state_id
else:
self._prev_state_ids.pop(key, None)
else: else:
self._prev_state_ids = self._current_state_ids self._prev_state_ids = self._current_state_ids

View File

@ -883,18 +883,37 @@ class FederationClient(FederationBase):
def get_public_rooms( def get_public_rooms(
self, self,
destination, remote_server: str,
limit=None, limit: Optional[int] = None,
since_token=None, since_token: Optional[str] = None,
search_filter=None, search_filter: Optional[Dict] = None,
include_all_networks=False, include_all_networks: bool = False,
third_party_instance_id=None, third_party_instance_id: Optional[str] = None,
): ):
if destination == self.server_name: """Get the list of public rooms from a remote homeserver
return
Args:
remote_server: The name of the remote server
limit: Maximum amount of rooms to return
since_token: Used for result pagination
search_filter: A filter dictionary to send the remote homeserver
and filter the result set
include_all_networks: Whether to include results from all third party instances
third_party_instance_id: Whether to only include results from a specific third
party instance
Returns:
Deferred[Dict[str, Any]]: The response from the remote server, or None if
`remote_server` is the same as the local server_name
Raises:
HttpResponseException: There was an exception returned from the remote server
SynapseException: M_FORBIDDEN when the remote server has disallowed publicRoom
requests over federation
"""
return self.transport_layer.get_public_rooms( return self.transport_layer.get_public_rooms(
destination, remote_server,
limit, limit,
since_token, since_token,
search_filter, search_filter,
@ -957,14 +976,13 @@ class FederationClient(FederationBase):
return signed_events return signed_events
@defer.inlineCallbacks async def forward_third_party_invite(self, destinations, room_id, event_dict):
def forward_third_party_invite(self, destinations, room_id, event_dict):
for destination in destinations: for destination in destinations:
if destination == self.server_name: if destination == self.server_name:
continue continue
try: try:
yield self.transport_layer.exchange_third_party_invite( await self.transport_layer.exchange_third_party_invite(
destination=destination, room_id=room_id, event_dict=event_dict destination=destination, room_id=room_id, event_dict=event_dict
) )
return None return None

View File

@ -15,13 +15,14 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Dict from typing import Any, Dict, Optional
from six.moves import urllib from six.moves import urllib
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.api.urls import ( from synapse.api.urls import (
FEDERATION_UNSTABLE_PREFIX, FEDERATION_UNSTABLE_PREFIX,
FEDERATION_V1_PREFIX, FEDERATION_V1_PREFIX,
@ -326,18 +327,25 @@ class TransportLayerClient(object):
@log_function @log_function
def get_public_rooms( def get_public_rooms(
self, self,
remote_server, remote_server: str,
limit, limit: Optional[int] = None,
since_token, since_token: Optional[str] = None,
search_filter=None, search_filter: Optional[Dict] = None,
include_all_networks=False, include_all_networks: bool = False,
third_party_instance_id=None, third_party_instance_id: Optional[str] = None,
): ):
"""Get the list of public rooms from a remote homeserver
See synapse.federation.federation_client.FederationClient.get_public_rooms for
more information.
"""
if search_filter: if search_filter:
# this uses MSC2197 (Search Filtering over Federation) # this uses MSC2197 (Search Filtering over Federation)
path = _create_v1_path("/publicRooms") path = _create_v1_path("/publicRooms")
data = {"include_all_networks": "true" if include_all_networks else "false"} data = {
"include_all_networks": "true" if include_all_networks else "false"
} # type: Dict[str, Any]
if third_party_instance_id: if third_party_instance_id:
data["third_party_instance_id"] = third_party_instance_id data["third_party_instance_id"] = third_party_instance_id
if limit: if limit:
@ -347,9 +355,19 @@ class TransportLayerClient(object):
data["filter"] = search_filter data["filter"] = search_filter
response = yield self.client.post_json( try:
destination=remote_server, path=path, data=data, ignore_backoff=True response = yield self.client.post_json(
) destination=remote_server, path=path, data=data, ignore_backoff=True
)
except HttpResponseException as e:
if e.code == 403:
raise SynapseError(
403,
"You are not allowed to view the public rooms list of %s"
% (remote_server,),
errcode=Codes.FORBIDDEN,
)
raise
else: else:
path = _create_v1_path("/publicRooms") path = _create_v1_path("/publicRooms")
@ -363,9 +381,19 @@ class TransportLayerClient(object):
if since_token: if since_token:
args["since"] = [since_token] args["since"] = [since_token]
response = yield self.client.get_json( try:
destination=remote_server, path=path, args=args, ignore_backoff=True response = yield self.client.get_json(
) destination=remote_server, path=path, args=args, ignore_backoff=True
)
except HttpResponseException as e:
if e.code == 403:
raise SynapseError(
403,
"You are not allowed to view the public rooms list of %s"
% (remote_server,),
errcode=Codes.FORBIDDEN,
)
raise
return response return response

View File

@ -748,17 +748,18 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
raise NotImplementedError() raise NotImplementedError()
@defer.inlineCallbacks async def remove_user_from_group(
def remove_user_from_group(self, group_id, user_id, requester_user_id, content): self, group_id, user_id, requester_user_id, content
):
"""Remove a user from the group; either a user is leaving or an admin """Remove a user from the group; either a user is leaving or an admin
kicked them. kicked them.
""" """
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
is_kick = False is_kick = False
if requester_user_id != user_id: if requester_user_id != user_id:
is_admin = yield self.store.is_user_admin_in_group( is_admin = await self.store.is_user_admin_in_group(
group_id, requester_user_id group_id, requester_user_id
) )
if not is_admin: if not is_admin:
@ -766,30 +767,29 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
is_kick = True is_kick = True
yield self.store.remove_user_from_group(group_id, user_id) await self.store.remove_user_from_group(group_id, user_id)
if is_kick: if is_kick:
if self.hs.is_mine_id(user_id): if self.hs.is_mine_id(user_id):
groups_local = self.hs.get_groups_local_handler() groups_local = self.hs.get_groups_local_handler()
yield groups_local.user_removed_from_group(group_id, user_id, {}) await groups_local.user_removed_from_group(group_id, user_id, {})
else: else:
yield self.transport_client.remove_user_from_group_notification( await self.transport_client.remove_user_from_group_notification(
get_domain_from_id(user_id), group_id, user_id, {} get_domain_from_id(user_id), group_id, user_id, {}
) )
if not self.hs.is_mine_id(user_id): if not self.hs.is_mine_id(user_id):
yield self.store.maybe_delete_remote_profile_cache(user_id) await self.store.maybe_delete_remote_profile_cache(user_id)
# Delete group if the last user has left # Delete group if the last user has left
users = yield self.store.get_users_in_group(group_id, include_private=True) users = await self.store.get_users_in_group(group_id, include_private=True)
if not users: if not users:
yield self.store.delete_group(group_id) await self.store.delete_group(group_id)
return {} return {}
@defer.inlineCallbacks async def create_group(self, group_id, requester_user_id, content):
def create_group(self, group_id, requester_user_id, content): group = await self.check_group_is_ours(group_id, requester_user_id)
group = yield self.check_group_is_ours(group_id, requester_user_id)
logger.info("Attempting to create group with ID: %r", group_id) logger.info("Attempting to create group with ID: %r", group_id)
@ -799,7 +799,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
if group: if group:
raise SynapseError(400, "Group already exists") raise SynapseError(400, "Group already exists")
is_admin = yield self.auth.is_server_admin( is_admin = await self.auth.is_server_admin(
UserID.from_string(requester_user_id) UserID.from_string(requester_user_id)
) )
if not is_admin: if not is_admin:
@ -822,7 +822,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
long_description = profile.get("long_description") long_description = profile.get("long_description")
user_profile = content.get("user_profile", {}) user_profile = content.get("user_profile", {})
yield self.store.create_group( await self.store.create_group(
group_id, group_id,
requester_user_id, requester_user_id,
name=name, name=name,
@ -834,7 +834,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
if not self.hs.is_mine_id(requester_user_id): if not self.hs.is_mine_id(requester_user_id):
remote_attestation = content["attestation"] remote_attestation = content["attestation"]
yield self.attestations.verify_attestation( await self.attestations.verify_attestation(
remote_attestation, user_id=requester_user_id, group_id=group_id remote_attestation, user_id=requester_user_id, group_id=group_id
) )
@ -845,7 +845,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
local_attestation = None local_attestation = None
remote_attestation = None remote_attestation = None
yield self.store.add_user_to_group( await self.store.add_user_to_group(
group_id, group_id,
requester_user_id, requester_user_id,
is_admin=True, is_admin=True,
@ -855,7 +855,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
) )
if not self.hs.is_mine_id(requester_user_id): if not self.hs.is_mine_id(requester_user_id):
yield self.store.add_remote_profile_cache( await self.store.add_remote_profile_cache(
requester_user_id, requester_user_id,
displayname=user_profile.get("displayname"), displayname=user_profile.get("displayname"),
avatar_url=user_profile.get("avatar_url"), avatar_url=user_profile.get("avatar_url"),
@ -863,8 +863,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {"group_id": group_id} return {"group_id": group_id}
@defer.inlineCallbacks async def delete_group(self, group_id, requester_user_id):
def delete_group(self, group_id, requester_user_id):
"""Deletes a group, kicking out all current members. """Deletes a group, kicking out all current members.
Only group admins or server admins can call this request Only group admins or server admins can call this request
@ -877,14 +876,14 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
Deferred Deferred
""" """
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True) await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
# Only server admins or group admins can delete groups. # Only server admins or group admins can delete groups.
is_admin = yield self.store.is_user_admin_in_group(group_id, requester_user_id) is_admin = await self.store.is_user_admin_in_group(group_id, requester_user_id)
if not is_admin: if not is_admin:
is_admin = yield self.auth.is_server_admin( is_admin = await self.auth.is_server_admin(
UserID.from_string(requester_user_id) UserID.from_string(requester_user_id)
) )
@ -892,18 +891,17 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
raise SynapseError(403, "User is not an admin") raise SynapseError(403, "User is not an admin")
# Before deleting the group lets kick everyone out of it # Before deleting the group lets kick everyone out of it
users = yield self.store.get_users_in_group(group_id, include_private=True) users = await self.store.get_users_in_group(group_id, include_private=True)
@defer.inlineCallbacks async def _kick_user_from_group(user_id):
def _kick_user_from_group(user_id):
if self.hs.is_mine_id(user_id): if self.hs.is_mine_id(user_id):
groups_local = self.hs.get_groups_local_handler() groups_local = self.hs.get_groups_local_handler()
yield groups_local.user_removed_from_group(group_id, user_id, {}) await groups_local.user_removed_from_group(group_id, user_id, {})
else: else:
yield self.transport_client.remove_user_from_group_notification( await self.transport_client.remove_user_from_group_notification(
get_domain_from_id(user_id), group_id, user_id, {} get_domain_from_id(user_id), group_id, user_id, {}
) )
yield self.store.maybe_delete_remote_profile_cache(user_id) await self.store.maybe_delete_remote_profile_cache(user_id)
# We kick users out in the order of: # We kick users out in the order of:
# 1. Non-admins # 1. Non-admins
@ -922,11 +920,11 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
else: else:
non_admins.append(u["user_id"]) non_admins.append(u["user_id"])
yield concurrently_execute(_kick_user_from_group, non_admins, 10) await concurrently_execute(_kick_user_from_group, non_admins, 10)
yield concurrently_execute(_kick_user_from_group, admins, 10) await concurrently_execute(_kick_user_from_group, admins, 10)
yield _kick_user_from_group(requester_user_id) await _kick_user_from_group(requester_user_id)
yield self.store.delete_group(group_id) await self.store.delete_group(group_id)
def _parse_join_policy_from_contents(content): def _parse_join_policy_from_contents(content):

View File

@ -126,30 +126,28 @@ class BaseHandler(object):
retry_after_ms=int(1000 * (time_allowed - time_now)) retry_after_ms=int(1000 * (time_allowed - time_now))
) )
@defer.inlineCallbacks async def maybe_kick_guest_users(self, event, context=None):
def maybe_kick_guest_users(self, event, context=None):
# Technically this function invalidates current_state by changing it. # Technically this function invalidates current_state by changing it.
# Hopefully this isn't that important to the caller. # Hopefully this isn't that important to the caller.
if event.type == EventTypes.GuestAccess: if event.type == EventTypes.GuestAccess:
guest_access = event.content.get("guest_access", "forbidden") guest_access = event.content.get("guest_access", "forbidden")
if guest_access != "can_join": if guest_access != "can_join":
if context: if context:
current_state_ids = yield context.get_current_state_ids() current_state_ids = await context.get_current_state_ids()
current_state = yield self.store.get_events( current_state = await self.store.get_events(
list(current_state_ids.values()) list(current_state_ids.values())
) )
else: else:
current_state = yield self.state_handler.get_current_state( current_state = await self.state_handler.get_current_state(
event.room_id event.room_id
) )
current_state = list(current_state.values()) current_state = list(current_state.values())
logger.info("maybe_kick_guest_users %r", current_state) logger.info("maybe_kick_guest_users %r", current_state)
yield self.kick_guest_users(current_state) await self.kick_guest_users(current_state)
@defer.inlineCallbacks async def kick_guest_users(self, current_state):
def kick_guest_users(self, current_state):
for member_event in current_state: for member_event in current_state:
try: try:
if member_event.type != EventTypes.Member: if member_event.type != EventTypes.Member:
@ -180,7 +178,7 @@ class BaseHandler(object):
# homeserver. # homeserver.
requester = synapse.types.create_requester(target_user, is_guest=True) requester = synapse.types.create_requester(target_user, is_guest=True)
handler = self.hs.get_room_member_handler() handler = self.hs.get_room_member_handler()
yield handler.update_membership( await handler.update_membership(
requester, requester,
target_user, target_user,
member_event.room_id, member_event.room_id,

View File

@ -86,8 +86,7 @@ class DirectoryHandler(BaseHandler):
room_alias, room_id, servers, creator=creator room_alias, room_id, servers, creator=creator
) )
@defer.inlineCallbacks async def create_association(
def create_association(
self, self,
requester: Requester, requester: Requester,
room_alias: RoomAlias, room_alias: RoomAlias,
@ -129,10 +128,10 @@ class DirectoryHandler(BaseHandler):
else: else:
# Server admins are not subject to the same constraints as normal # Server admins are not subject to the same constraints as normal
# users when creating an alias (e.g. being in the room). # users when creating an alias (e.g. being in the room).
is_admin = yield self.auth.is_server_admin(requester.user) is_admin = await self.auth.is_server_admin(requester.user)
if (self.require_membership and check_membership) and not is_admin: if (self.require_membership and check_membership) and not is_admin:
rooms_for_user = yield self.store.get_rooms_for_user(user_id) rooms_for_user = await self.store.get_rooms_for_user(user_id)
if room_id not in rooms_for_user: if room_id not in rooms_for_user:
raise AuthError( raise AuthError(
403, "You must be in the room to create an alias for it" 403, "You must be in the room to create an alias for it"
@ -149,7 +148,7 @@ class DirectoryHandler(BaseHandler):
# per alias creation rule? # per alias creation rule?
raise SynapseError(403, "Not allowed to create alias") raise SynapseError(403, "Not allowed to create alias")
can_create = yield self.can_modify_alias(room_alias, user_id=user_id) can_create = await self.can_modify_alias(room_alias, user_id=user_id)
if not can_create: if not can_create:
raise AuthError( raise AuthError(
400, 400,
@ -157,10 +156,9 @@ class DirectoryHandler(BaseHandler):
errcode=Codes.EXCLUSIVE, errcode=Codes.EXCLUSIVE,
) )
yield self._create_association(room_alias, room_id, servers, creator=user_id) await self._create_association(room_alias, room_id, servers, creator=user_id)
@defer.inlineCallbacks async def delete_association(self, requester: Requester, room_alias: RoomAlias):
def delete_association(self, requester: Requester, room_alias: RoomAlias):
"""Remove an alias from the directory """Remove an alias from the directory
(this is only meant for human users; AS users should call (this is only meant for human users; AS users should call
@ -184,7 +182,7 @@ class DirectoryHandler(BaseHandler):
user_id = requester.user.to_string() user_id = requester.user.to_string()
try: try:
can_delete = yield self._user_can_delete_alias(room_alias, user_id) can_delete = await self._user_can_delete_alias(room_alias, user_id)
except StoreError as e: except StoreError as e:
if e.code == 404: if e.code == 404:
raise NotFoundError("Unknown room alias") raise NotFoundError("Unknown room alias")
@ -193,7 +191,7 @@ class DirectoryHandler(BaseHandler):
if not can_delete: if not can_delete:
raise AuthError(403, "You don't have permission to delete the alias.") raise AuthError(403, "You don't have permission to delete the alias.")
can_delete = yield self.can_modify_alias(room_alias, user_id=user_id) can_delete = await self.can_modify_alias(room_alias, user_id=user_id)
if not can_delete: if not can_delete:
raise SynapseError( raise SynapseError(
400, 400,
@ -201,10 +199,10 @@ class DirectoryHandler(BaseHandler):
errcode=Codes.EXCLUSIVE, errcode=Codes.EXCLUSIVE,
) )
room_id = yield self._delete_association(room_alias) room_id = await self._delete_association(room_alias)
try: try:
yield self._update_canonical_alias(requester, user_id, room_id, room_alias) await self._update_canonical_alias(requester, user_id, room_id, room_alias)
except AuthError as e: except AuthError as e:
logger.info("Failed to update alias events: %s", e) logger.info("Failed to update alias events: %s", e)
@ -296,15 +294,14 @@ class DirectoryHandler(BaseHandler):
Codes.NOT_FOUND, Codes.NOT_FOUND,
) )
@defer.inlineCallbacks async def _update_canonical_alias(
def _update_canonical_alias(
self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias
): ):
""" """
Send an updated canonical alias event if the removed alias was set as Send an updated canonical alias event if the removed alias was set as
the canonical alias or listed in the alt_aliases field. the canonical alias or listed in the alt_aliases field.
""" """
alias_event = yield self.state.get_current_state( alias_event = await self.state.get_current_state(
room_id, EventTypes.CanonicalAlias, "" room_id, EventTypes.CanonicalAlias, ""
) )
@ -335,7 +332,7 @@ class DirectoryHandler(BaseHandler):
del content["alt_aliases"] del content["alt_aliases"]
if send_update: if send_update:
yield self.event_creation_handler.create_and_send_nonmember_event( await self.event_creation_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.CanonicalAlias, "type": EventTypes.CanonicalAlias,
@ -376,8 +373,7 @@ class DirectoryHandler(BaseHandler):
# either no interested services, or no service with an exclusive lock # either no interested services, or no service with an exclusive lock
return defer.succeed(True) return defer.succeed(True)
@defer.inlineCallbacks async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str):
def _user_can_delete_alias(self, alias: RoomAlias, user_id: str):
"""Determine whether a user can delete an alias. """Determine whether a user can delete an alias.
One of the following must be true: One of the following must be true:
@ -388,24 +384,23 @@ class DirectoryHandler(BaseHandler):
for the current room. for the current room.
""" """
creator = yield self.store.get_room_alias_creator(alias.to_string()) creator = await self.store.get_room_alias_creator(alias.to_string())
if creator is not None and creator == user_id: if creator is not None and creator == user_id:
return True return True
# Resolve the alias to the corresponding room. # Resolve the alias to the corresponding room.
room_mapping = yield self.get_association(alias) room_mapping = await self.get_association(alias)
room_id = room_mapping["room_id"] room_id = room_mapping["room_id"]
if not room_id: if not room_id:
return False return False
res = yield self.auth.check_can_change_room_list( res = await self.auth.check_can_change_room_list(
room_id, UserID.from_string(user_id) room_id, UserID.from_string(user_id)
) )
return res return res
@defer.inlineCallbacks async def edit_published_room_list(
def edit_published_room_list(
self, requester: Requester, room_id: str, visibility: str self, requester: Requester, room_id: str, visibility: str
): ):
"""Edit the entry of the room in the published room list. """Edit the entry of the room in the published room list.
@ -433,11 +428,11 @@ class DirectoryHandler(BaseHandler):
403, "This user is not permitted to publish rooms to the room list" 403, "This user is not permitted to publish rooms to the room list"
) )
room = yield self.store.get_room(room_id) room = await self.store.get_room(room_id)
if room is None: if room is None:
raise SynapseError(400, "Unknown room") raise SynapseError(400, "Unknown room")
can_change_room_list = yield self.auth.check_can_change_room_list( can_change_room_list = await self.auth.check_can_change_room_list(
room_id, requester.user room_id, requester.user
) )
if not can_change_room_list: if not can_change_room_list:
@ -449,8 +444,8 @@ class DirectoryHandler(BaseHandler):
making_public = visibility == "public" making_public = visibility == "public"
if making_public: if making_public:
room_aliases = yield self.store.get_aliases_for_room(room_id) room_aliases = await self.store.get_aliases_for_room(room_id)
canonical_alias = yield self.store.get_canonical_alias_for_room(room_id) canonical_alias = await self.store.get_canonical_alias_for_room(room_id)
if canonical_alias: if canonical_alias:
room_aliases.append(canonical_alias) room_aliases.append(canonical_alias)
@ -462,7 +457,7 @@ class DirectoryHandler(BaseHandler):
# per alias creation rule? # per alias creation rule?
raise SynapseError(403, "Not allowed to publish room") raise SynapseError(403, "Not allowed to publish room")
yield self.store.set_room_is_public(room_id, making_public) await self.store.set_room_is_public(room_id, making_public)
@defer.inlineCallbacks @defer.inlineCallbacks
def edit_published_appservice_room_list( def edit_published_appservice_room_list(

View File

@ -2562,9 +2562,8 @@ class FederationHandler(BaseHandler):
"missing": [e.event_id for e in missing_locals], "missing": [e.event_id for e in missing_locals],
} }
@defer.inlineCallbacks
@log_function @log_function
def exchange_third_party_invite( async def exchange_third_party_invite(
self, sender_user_id, target_user_id, room_id, signed self, sender_user_id, target_user_id, room_id, signed
): ):
third_party_invite = {"signed": signed} third_party_invite = {"signed": signed}
@ -2580,16 +2579,16 @@ class FederationHandler(BaseHandler):
"state_key": target_user_id, "state_key": target_user_id,
} }
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)): if await self.auth.check_host_in_room(room_id, self.hs.hostname):
room_version = yield self.store.get_room_version_id(room_id) room_version = await self.store.get_room_version_id(room_id)
builder = self.event_builder_factory.new(room_version, event_dict) builder = self.event_builder_factory.new(room_version, event_dict)
EventValidator().validate_builder(builder) EventValidator().validate_builder(builder)
event, context = yield self.event_creation_handler.create_new_client_event( event, context = await self.event_creation_handler.create_new_client_event(
builder=builder builder=builder
) )
event_allowed = yield self.third_party_event_rules.check_event_allowed( event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context event, context
) )
if not event_allowed: if not event_allowed:
@ -2601,7 +2600,7 @@ class FederationHandler(BaseHandler):
403, "This event is not allowed in this context", Codes.FORBIDDEN 403, "This event is not allowed in this context", Codes.FORBIDDEN
) )
event, context = yield self.add_display_name_to_third_party_invite( event, context = await self.add_display_name_to_third_party_invite(
room_version, event_dict, event, context room_version, event_dict, event, context
) )
@ -2612,19 +2611,19 @@ class FederationHandler(BaseHandler):
event.internal_metadata.send_on_behalf_of = self.hs.hostname event.internal_metadata.send_on_behalf_of = self.hs.hostname
try: try:
yield self.auth.check_from_context(room_version, event, context) await self.auth.check_from_context(room_version, event, context)
except AuthError as e: except AuthError as e:
logger.warning("Denying new third party invite %r because %s", event, e) logger.warning("Denying new third party invite %r because %s", event, e)
raise e raise e
yield self._check_signature(event, context) await self._check_signature(event, context)
# We retrieve the room member handler here as to not cause a cyclic dependency # We retrieve the room member handler here as to not cause a cyclic dependency
member_handler = self.hs.get_room_member_handler() member_handler = self.hs.get_room_member_handler()
yield member_handler.send_membership_event(None, event, context) await member_handler.send_membership_event(None, event, context)
else: else:
destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)} destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)}
yield self.federation_client.forward_third_party_invite( await self.federation_client.forward_third_party_invite(
destinations, room_id, event_dict destinations, room_id, event_dict
) )

View File

@ -284,15 +284,14 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
set_group_join_policy = _create_rerouter("set_group_join_policy") set_group_join_policy = _create_rerouter("set_group_join_policy")
@defer.inlineCallbacks async def create_group(self, group_id, user_id, content):
def create_group(self, group_id, user_id, content):
"""Create a group """Create a group
""" """
logger.info("Asking to create group with ID: %r", group_id) logger.info("Asking to create group with ID: %r", group_id)
if self.is_mine_id(group_id): if self.is_mine_id(group_id):
res = yield self.groups_server_handler.create_group( res = await self.groups_server_handler.create_group(
group_id, user_id, content group_id, user_id, content
) )
local_attestation = None local_attestation = None
@ -301,10 +300,10 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
local_attestation = self.attestations.create_attestation(group_id, user_id) local_attestation = self.attestations.create_attestation(group_id, user_id)
content["attestation"] = local_attestation content["attestation"] = local_attestation
content["user_profile"] = yield self.profile_handler.get_profile(user_id) content["user_profile"] = await self.profile_handler.get_profile(user_id)
try: try:
res = yield self.transport_client.create_group( res = await self.transport_client.create_group(
get_domain_from_id(group_id), group_id, user_id, content get_domain_from_id(group_id), group_id, user_id, content
) )
except HttpResponseException as e: except HttpResponseException as e:
@ -313,7 +312,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
raise SynapseError(502, "Failed to contact group server") raise SynapseError(502, "Failed to contact group server")
remote_attestation = res["attestation"] remote_attestation = res["attestation"]
yield self.attestations.verify_attestation( await self.attestations.verify_attestation(
remote_attestation, remote_attestation,
group_id=group_id, group_id=group_id,
user_id=user_id, user_id=user_id,
@ -321,7 +320,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
) )
is_publicised = content.get("publicise", False) is_publicised = content.get("publicise", False)
token = yield self.store.register_user_group_membership( token = await self.store.register_user_group_membership(
group_id, group_id,
user_id, user_id,
membership="join", membership="join",
@ -482,12 +481,13 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {"state": "invite", "user_profile": user_profile} return {"state": "invite", "user_profile": user_profile}
@defer.inlineCallbacks async def remove_user_from_group(
def remove_user_from_group(self, group_id, user_id, requester_user_id, content): self, group_id, user_id, requester_user_id, content
):
"""Remove a user from a group """Remove a user from a group
""" """
if user_id == requester_user_id: if user_id == requester_user_id:
token = yield self.store.register_user_group_membership( token = await self.store.register_user_group_membership(
group_id, user_id, membership="leave" group_id, user_id, membership="leave"
) )
self.notifier.on_new_event("groups_key", token, users=[user_id]) self.notifier.on_new_event("groups_key", token, users=[user_id])
@ -496,13 +496,13 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
# retry if the group server is currently down. # retry if the group server is currently down.
if self.is_mine_id(group_id): if self.is_mine_id(group_id):
res = yield self.groups_server_handler.remove_user_from_group( res = await self.groups_server_handler.remove_user_from_group(
group_id, user_id, requester_user_id, content group_id, user_id, requester_user_id, content
) )
else: else:
content["requester_user_id"] = requester_user_id content["requester_user_id"] = requester_user_id
try: try:
res = yield self.transport_client.remove_user_from_group( res = await self.transport_client.remove_user_from_group(
get_domain_from_id(group_id), get_domain_from_id(group_id),
group_id, group_id,
requester_user_id, requester_user_id,

View File

@ -626,8 +626,7 @@ class EventCreationHandler(object):
msg = self._block_events_without_consent_error % {"consent_uri": consent_uri} msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri) raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
@defer.inlineCallbacks async def send_nonmember_event(self, requester, event, context, ratelimit=True):
def send_nonmember_event(self, requester, event, context, ratelimit=True):
""" """
Persists and notifies local clients and federation of an event. Persists and notifies local clients and federation of an event.
@ -647,7 +646,7 @@ class EventCreationHandler(object):
assert self.hs.is_mine(user), "User must be our own: %s" % (user,) assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if event.is_state(): if event.is_state():
prev_state = yield self.deduplicate_state_event(event, context) prev_state = await self.deduplicate_state_event(event, context)
if prev_state is not None: if prev_state is not None:
logger.info( logger.info(
"Not bothering to persist state event %s duplicated by %s", "Not bothering to persist state event %s duplicated by %s",
@ -656,7 +655,7 @@ class EventCreationHandler(object):
) )
return prev_state return prev_state
yield self.handle_new_client_event( await self.handle_new_client_event(
requester=requester, event=event, context=context, ratelimit=ratelimit requester=requester, event=event, context=context, ratelimit=ratelimit
) )
@ -683,8 +682,7 @@ class EventCreationHandler(object):
return prev_event return prev_event
return return
@defer.inlineCallbacks async def create_and_send_nonmember_event(
def create_and_send_nonmember_event(
self, requester, event_dict, ratelimit=True, txn_id=None self, requester, event_dict, ratelimit=True, txn_id=None
): ):
""" """
@ -698,8 +696,8 @@ class EventCreationHandler(object):
# a situation where event persistence can't keep up, causing # a situation where event persistence can't keep up, causing
# extremities to pile up, which in turn leads to state resolution # extremities to pile up, which in turn leads to state resolution
# taking longer. # taking longer.
with (yield self.limiter.queue(event_dict["room_id"])): with (await self.limiter.queue(event_dict["room_id"])):
event, context = yield self.create_event( event, context = await self.create_event(
requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id
) )
@ -709,7 +707,7 @@ class EventCreationHandler(object):
spam_error = "Spam is not permitted here" spam_error = "Spam is not permitted here"
raise SynapseError(403, spam_error, Codes.FORBIDDEN) raise SynapseError(403, spam_error, Codes.FORBIDDEN)
yield self.send_nonmember_event( await self.send_nonmember_event(
requester, event, context, ratelimit=ratelimit requester, event, context, ratelimit=ratelimit
) )
return event return event
@ -770,8 +768,7 @@ class EventCreationHandler(object):
return (event, context) return (event, context)
@measure_func("handle_new_client_event") @measure_func("handle_new_client_event")
@defer.inlineCallbacks async def handle_new_client_event(
def handle_new_client_event(
self, requester, event, context, ratelimit=True, extra_users=[] self, requester, event, context, ratelimit=True, extra_users=[]
): ):
"""Processes a new event. This includes checking auth, persisting it, """Processes a new event. This includes checking auth, persisting it,
@ -794,9 +791,9 @@ class EventCreationHandler(object):
): ):
room_version = event.content.get("room_version", RoomVersions.V1.identifier) room_version = event.content.get("room_version", RoomVersions.V1.identifier)
else: else:
room_version = yield self.store.get_room_version_id(event.room_id) room_version = await self.store.get_room_version_id(event.room_id)
event_allowed = yield self.third_party_event_rules.check_event_allowed( event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context event, context
) )
if not event_allowed: if not event_allowed:
@ -805,7 +802,7 @@ class EventCreationHandler(object):
) )
try: try:
yield self.auth.check_from_context(room_version, event, context) await self.auth.check_from_context(room_version, event, context)
except AuthError as err: except AuthError as err:
logger.warning("Denying new event %r because %s", event, err) logger.warning("Denying new event %r because %s", event, err)
raise err raise err
@ -818,7 +815,7 @@ class EventCreationHandler(object):
logger.exception("Failed to encode content: %r", event.content) logger.exception("Failed to encode content: %r", event.content)
raise raise
yield self.action_generator.handle_push_actions_for_event(event, context) await self.action_generator.handle_push_actions_for_event(event, context)
# reraise does not allow inlineCallbacks to preserve the stacktrace, so we # reraise does not allow inlineCallbacks to preserve the stacktrace, so we
# hack around with a try/finally instead. # hack around with a try/finally instead.
@ -826,7 +823,7 @@ class EventCreationHandler(object):
try: try:
# If we're a worker we need to hit out to the master. # If we're a worker we need to hit out to the master.
if self.config.worker_app: if self.config.worker_app:
yield self.send_event_to_master( await self.send_event_to_master(
event_id=event.event_id, event_id=event.event_id,
store=self.store, store=self.store,
requester=requester, requester=requester,
@ -838,7 +835,7 @@ class EventCreationHandler(object):
success = True success = True
return return
yield self.persist_and_notify_client_event( await self.persist_and_notify_client_event(
requester, event, context, ratelimit=ratelimit, extra_users=extra_users requester, event, context, ratelimit=ratelimit, extra_users=extra_users
) )
@ -883,8 +880,7 @@ class EventCreationHandler(object):
Codes.BAD_ALIAS, Codes.BAD_ALIAS,
) )
@defer.inlineCallbacks async def persist_and_notify_client_event(
def persist_and_notify_client_event(
self, requester, event, context, ratelimit=True, extra_users=[] self, requester, event, context, ratelimit=True, extra_users=[]
): ):
"""Called when we have fully built the event, have already """Called when we have fully built the event, have already
@ -901,7 +897,7 @@ class EventCreationHandler(object):
# user is actually admin or not). # user is actually admin or not).
is_admin_redaction = False is_admin_redaction = False
if event.type == EventTypes.Redaction: if event.type == EventTypes.Redaction:
original_event = yield self.store.get_event( original_event = await self.store.get_event(
event.redacts, event.redacts,
redact_behaviour=EventRedactBehaviour.AS_IS, redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False, get_prev_content=False,
@ -913,11 +909,11 @@ class EventCreationHandler(object):
original_event and event.sender != original_event.sender original_event and event.sender != original_event.sender
) )
yield self.base_handler.ratelimit( await self.base_handler.ratelimit(
requester, is_admin_redaction=is_admin_redaction requester, is_admin_redaction=is_admin_redaction
) )
yield self.base_handler.maybe_kick_guest_users(event, context) await self.base_handler.maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias: if event.type == EventTypes.CanonicalAlias:
# Validate a newly added alias or newly added alt_aliases. # Validate a newly added alias or newly added alt_aliases.
@ -927,7 +923,7 @@ class EventCreationHandler(object):
original_event_id = event.unsigned.get("replaces_state") original_event_id = event.unsigned.get("replaces_state")
if original_event_id: if original_event_id:
original_event = yield self.store.get_event(original_event_id) original_event = await self.store.get_event(original_event_id)
if original_event: if original_event:
original_alias = original_event.content.get("alias", None) original_alias = original_event.content.get("alias", None)
@ -937,7 +933,7 @@ class EventCreationHandler(object):
room_alias_str = event.content.get("alias", None) room_alias_str = event.content.get("alias", None)
directory_handler = self.hs.get_handlers().directory_handler directory_handler = self.hs.get_handlers().directory_handler
if room_alias_str and room_alias_str != original_alias: if room_alias_str and room_alias_str != original_alias:
yield self._validate_canonical_alias( await self._validate_canonical_alias(
directory_handler, room_alias_str, event.room_id directory_handler, room_alias_str, event.room_id
) )
@ -957,7 +953,7 @@ class EventCreationHandler(object):
new_alt_aliases = set(alt_aliases) - set(original_alt_aliases) new_alt_aliases = set(alt_aliases) - set(original_alt_aliases)
if new_alt_aliases: if new_alt_aliases:
for alias_str in new_alt_aliases: for alias_str in new_alt_aliases:
yield self._validate_canonical_alias( await self._validate_canonical_alias(
directory_handler, alias_str, event.room_id directory_handler, alias_str, event.room_id
) )
@ -969,7 +965,7 @@ class EventCreationHandler(object):
def is_inviter_member_event(e): def is_inviter_member_event(e):
return e.type == EventTypes.Member and e.sender == event.sender return e.type == EventTypes.Member and e.sender == event.sender
current_state_ids = yield context.get_current_state_ids() current_state_ids = await context.get_current_state_ids()
state_to_include_ids = [ state_to_include_ids = [
e_id e_id
@ -978,7 +974,7 @@ class EventCreationHandler(object):
or k == (EventTypes.Member, event.sender) or k == (EventTypes.Member, event.sender)
] ]
state_to_include = yield self.store.get_events(state_to_include_ids) state_to_include = await self.store.get_events(state_to_include_ids)
event.unsigned["invite_room_state"] = [ event.unsigned["invite_room_state"] = [
{ {
@ -996,8 +992,8 @@ class EventCreationHandler(object):
# way? If we have been invited by a remote server, we need # way? If we have been invited by a remote server, we need
# to get them to sign the event. # to get them to sign the event.
returned_invite = yield defer.ensureDeferred( returned_invite = await federation_handler.send_invite(
federation_handler.send_invite(invitee.domain, event) invitee.domain, event
) )
event.unsigned.pop("room_state", None) event.unsigned.pop("room_state", None)
@ -1005,7 +1001,7 @@ class EventCreationHandler(object):
event.signatures.update(returned_invite.signatures) event.signatures.update(returned_invite.signatures)
if event.type == EventTypes.Redaction: if event.type == EventTypes.Redaction:
original_event = yield self.store.get_event( original_event = await self.store.get_event(
event.redacts, event.redacts,
redact_behaviour=EventRedactBehaviour.AS_IS, redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False, get_prev_content=False,
@ -1021,14 +1017,14 @@ class EventCreationHandler(object):
if original_event.room_id != event.room_id: if original_event.room_id != event.room_id:
raise SynapseError(400, "Cannot redact event from a different room") raise SynapseError(400, "Cannot redact event from a different room")
prev_state_ids = yield context.get_prev_state_ids() prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = yield self.auth.compute_auth_events( auth_events_ids = await self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True event, prev_state_ids, for_verification=True
) )
auth_events = yield self.store.get_events(auth_events_ids) auth_events = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()} auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
room_version = yield self.store.get_room_version_id(event.room_id) room_version = await self.store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version] room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
if event_auth.check_redaction( if event_auth.check_redaction(
@ -1047,11 +1043,11 @@ class EventCreationHandler(object):
event.internal_metadata.recheck_redaction = False event.internal_metadata.recheck_redaction = False
if event.type == EventTypes.Create: if event.type == EventTypes.Create:
prev_state_ids = yield context.get_prev_state_ids() prev_state_ids = await context.get_prev_state_ids()
if prev_state_ids: if prev_state_ids:
raise AuthError(403, "Changing the room create event is forbidden") raise AuthError(403, "Changing the room create event is forbidden")
event_stream_id, max_stream_id = yield self.storage.persistence.persist_event( event_stream_id, max_stream_id = await self.storage.persistence.persist_event(
event, context=context event, context=context
) )
@ -1059,7 +1055,7 @@ class EventCreationHandler(object):
# If there's an expiry timestamp on the event, schedule its expiry. # If there's an expiry timestamp on the event, schedule its expiry.
self._message_handler.maybe_schedule_expiry(event) self._message_handler.maybe_schedule_expiry(event)
yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
def _notify(): def _notify():
try: try:
@ -1083,13 +1079,12 @@ class EventCreationHandler(object):
except Exception: except Exception:
logger.exception("Error bumping presence active time") logger.exception("Error bumping presence active time")
@defer.inlineCallbacks async def _send_dummy_events_to_fill_extremities(self):
def _send_dummy_events_to_fill_extremities(self):
"""Background task to send dummy events into rooms that have a large """Background task to send dummy events into rooms that have a large
number of extremities number of extremities
""" """
self._expire_rooms_to_exclude_from_dummy_event_insertion() self._expire_rooms_to_exclude_from_dummy_event_insertion()
room_ids = yield self.store.get_rooms_with_many_extremities( room_ids = await self.store.get_rooms_with_many_extremities(
min_count=10, min_count=10,
limit=5, limit=5,
room_id_filter=self._rooms_to_exclude_from_dummy_event_insertion.keys(), room_id_filter=self._rooms_to_exclude_from_dummy_event_insertion.keys(),
@ -1099,9 +1094,9 @@ class EventCreationHandler(object):
# For each room we need to find a joined member we can use to send # For each room we need to find a joined member we can use to send
# the dummy event with. # the dummy event with.
latest_event_ids = yield self.store.get_prev_events_for_room(room_id) latest_event_ids = await self.store.get_prev_events_for_room(room_id)
members = yield self.state.get_current_users_in_room( members = await self.state.get_current_users_in_room(
room_id, latest_event_ids=latest_event_ids room_id, latest_event_ids=latest_event_ids
) )
dummy_event_sent = False dummy_event_sent = False
@ -1110,7 +1105,7 @@ class EventCreationHandler(object):
continue continue
requester = create_requester(user_id) requester = create_requester(user_id)
try: try:
event, context = yield self.create_event( event, context = await self.create_event(
requester, requester,
{ {
"type": "org.matrix.dummy_event", "type": "org.matrix.dummy_event",
@ -1123,7 +1118,7 @@ class EventCreationHandler(object):
event.internal_metadata.proactively_send = False event.internal_metadata.proactively_send = False
yield self.send_nonmember_event( await self.send_nonmember_event(
requester, event, context, ratelimit=False requester, event, context, ratelimit=False
) )
dummy_event_sent = True dummy_event_sent = True

View File

@ -141,8 +141,9 @@ class BaseProfileHandler(BaseHandler):
return result["displayname"] return result["displayname"]
@defer.inlineCallbacks async def set_displayname(
def set_displayname(self, target_user, requester, new_displayname, by_admin=False): self, target_user, requester, new_displayname, by_admin=False
):
"""Set the displayname of a user """Set the displayname of a user
Args: Args:
@ -158,7 +159,7 @@ class BaseProfileHandler(BaseHandler):
raise AuthError(400, "Cannot set another user's displayname") raise AuthError(400, "Cannot set another user's displayname")
if not by_admin and not self.hs.config.enable_set_displayname: if not by_admin and not self.hs.config.enable_set_displayname:
profile = yield self.store.get_profileinfo(target_user.localpart) profile = await self.store.get_profileinfo(target_user.localpart)
if profile.display_name: if profile.display_name:
raise SynapseError( raise SynapseError(
400, 400,
@ -180,15 +181,15 @@ class BaseProfileHandler(BaseHandler):
if by_admin: if by_admin:
requester = create_requester(target_user) requester = create_requester(target_user)
yield self.store.set_profile_displayname(target_user.localpart, new_displayname) await self.store.set_profile_displayname(target_user.localpart, new_displayname)
if self.hs.config.user_directory_search_all_users: if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(target_user.localpart) profile = await self.store.get_profileinfo(target_user.localpart)
yield self.user_directory_handler.handle_local_profile_change( await self.user_directory_handler.handle_local_profile_change(
target_user.to_string(), profile target_user.to_string(), profile
) )
yield self._update_join_states(requester, target_user) await self._update_join_states(requester, target_user)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_avatar_url(self, target_user): def get_avatar_url(self, target_user):
@ -217,8 +218,9 @@ class BaseProfileHandler(BaseHandler):
return result["avatar_url"] return result["avatar_url"]
@defer.inlineCallbacks async def set_avatar_url(
def set_avatar_url(self, target_user, requester, new_avatar_url, by_admin=False): self, target_user, requester, new_avatar_url, by_admin=False
):
"""target_user is the user whose avatar_url is to be changed; """target_user is the user whose avatar_url is to be changed;
auth_user is the user attempting to make this change.""" auth_user is the user attempting to make this change."""
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
@ -228,7 +230,7 @@ class BaseProfileHandler(BaseHandler):
raise AuthError(400, "Cannot set another user's avatar_url") raise AuthError(400, "Cannot set another user's avatar_url")
if not by_admin and not self.hs.config.enable_set_avatar_url: if not by_admin and not self.hs.config.enable_set_avatar_url:
profile = yield self.store.get_profileinfo(target_user.localpart) profile = await self.store.get_profileinfo(target_user.localpart)
if profile.avatar_url: if profile.avatar_url:
raise SynapseError( raise SynapseError(
400, "Changing avatar is disabled on this server", Codes.FORBIDDEN 400, "Changing avatar is disabled on this server", Codes.FORBIDDEN
@ -243,15 +245,15 @@ class BaseProfileHandler(BaseHandler):
if by_admin: if by_admin:
requester = create_requester(target_user) requester = create_requester(target_user)
yield self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url) await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
if self.hs.config.user_directory_search_all_users: if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(target_user.localpart) profile = await self.store.get_profileinfo(target_user.localpart)
yield self.user_directory_handler.handle_local_profile_change( await self.user_directory_handler.handle_local_profile_change(
target_user.to_string(), profile target_user.to_string(), profile
) )
yield self._update_join_states(requester, target_user) await self._update_join_states(requester, target_user)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_profile_query(self, args): def on_profile_query(self, args):
@ -279,21 +281,20 @@ class BaseProfileHandler(BaseHandler):
return response return response
@defer.inlineCallbacks async def _update_join_states(self, requester, target_user):
def _update_join_states(self, requester, target_user):
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
return return
yield self.ratelimit(requester) await self.ratelimit(requester)
room_ids = yield self.store.get_rooms_for_user(target_user.to_string()) room_ids = await self.store.get_rooms_for_user(target_user.to_string())
for room_id in room_ids: for room_id in room_ids:
handler = self.hs.get_room_member_handler() handler = self.hs.get_room_member_handler()
try: try:
# Assume the target_user isn't a guest, # Assume the target_user isn't a guest,
# because we don't let guests set profile or avatar data. # because we don't let guests set profile or avatar data.
yield handler.update_membership( await handler.update_membership(
requester, requester,
target_user, target_user,
room_id, room_id,

View File

@ -145,9 +145,9 @@ class RegistrationHandler(BaseHandler):
"""Registers a new client on the server. """Registers a new client on the server.
Args: Args:
localpart : The local part of the user ID to register. If None, localpart: The local part of the user ID to register. If None,
one will be generated. one will be generated.
password (unicode) : The password to assign to this user so they can password (unicode): The password to assign to this user so they can
login again. This can be None which means they cannot login again login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user). via a password (e.g. the user is an application service user).
user_type (str|None): type of user. One of the values from user_type (str|None): type of user. One of the values from
@ -244,7 +244,7 @@ class RegistrationHandler(BaseHandler):
fail_count += 1 fail_count += 1
if not self.hs.config.user_consent_at_registration: if not self.hs.config.user_consent_at_registration:
yield self._auto_join_rooms(user_id) yield defer.ensureDeferred(self._auto_join_rooms(user_id))
else: else:
logger.info( logger.info(
"Skipping auto-join for %s because consent is required at registration", "Skipping auto-join for %s because consent is required at registration",
@ -266,8 +266,7 @@ class RegistrationHandler(BaseHandler):
return user_id return user_id
@defer.inlineCallbacks async def _auto_join_rooms(self, user_id):
def _auto_join_rooms(self, user_id):
"""Automatically joins users to auto join rooms - creating the room in the first place """Automatically joins users to auto join rooms - creating the room in the first place
if the user is the first to be created. if the user is the first to be created.
@ -281,9 +280,9 @@ class RegistrationHandler(BaseHandler):
# that an auto-generated support or bot user is not a real user and will never be # that an auto-generated support or bot user is not a real user and will never be
# the user to create the room # the user to create the room
should_auto_create_rooms = False should_auto_create_rooms = False
is_real_user = yield self.store.is_real_user(user_id) is_real_user = await self.store.is_real_user(user_id)
if self.hs.config.autocreate_auto_join_rooms and is_real_user: if self.hs.config.autocreate_auto_join_rooms and is_real_user:
count = yield self.store.count_real_users() count = await self.store.count_real_users()
should_auto_create_rooms = count == 1 should_auto_create_rooms = count == 1
for r in self.hs.config.auto_join_rooms: for r in self.hs.config.auto_join_rooms:
logger.info("Auto-joining %s to %s", user_id, r) logger.info("Auto-joining %s to %s", user_id, r)
@ -302,7 +301,7 @@ class RegistrationHandler(BaseHandler):
# getting the RoomCreationHandler during init gives a dependency # getting the RoomCreationHandler during init gives a dependency
# loop # loop
yield self.hs.get_room_creation_handler().create_room( await self.hs.get_room_creation_handler().create_room(
fake_requester, fake_requester,
config={ config={
"preset": "public_chat", "preset": "public_chat",
@ -311,7 +310,7 @@ class RegistrationHandler(BaseHandler):
ratelimit=False, ratelimit=False,
) )
else: else:
yield self._join_user_to_room(fake_requester, r) await self._join_user_to_room(fake_requester, r)
except ConsentNotGivenError as e: except ConsentNotGivenError as e:
# Technically not necessary to pull out this error though # Technically not necessary to pull out this error though
# moving away from bare excepts is a good thing to do. # moving away from bare excepts is a good thing to do.
@ -319,15 +318,14 @@ class RegistrationHandler(BaseHandler):
except Exception as e: except Exception as e:
logger.error("Failed to join new user to %r: %r", r, e) logger.error("Failed to join new user to %r: %r", r, e)
@defer.inlineCallbacks async def post_consent_actions(self, user_id):
def post_consent_actions(self, user_id):
"""A series of registration actions that can only be carried out once consent """A series of registration actions that can only be carried out once consent
has been granted has been granted
Args: Args:
user_id (str): The user to join user_id (str): The user to join
""" """
yield self._auto_join_rooms(user_id) await self._auto_join_rooms(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def appservice_register(self, user_localpart, as_token): def appservice_register(self, user_localpart, as_token):
@ -394,14 +392,13 @@ class RegistrationHandler(BaseHandler):
self._next_generated_user_id += 1 self._next_generated_user_id += 1
return str(id) return str(id)
@defer.inlineCallbacks async def _join_user_to_room(self, requester, room_identifier):
def _join_user_to_room(self, requester, room_identifier):
room_member_handler = self.hs.get_room_member_handler() room_member_handler = self.hs.get_room_member_handler()
if RoomID.is_valid(room_identifier): if RoomID.is_valid(room_identifier):
room_id = room_identifier room_id = room_identifier
elif RoomAlias.is_valid(room_identifier): elif RoomAlias.is_valid(room_identifier):
room_alias = RoomAlias.from_string(room_identifier) room_alias = RoomAlias.from_string(room_identifier)
room_id, remote_room_hosts = yield room_member_handler.lookup_room_alias( room_id, remote_room_hosts = await room_member_handler.lookup_room_alias(
room_alias room_alias
) )
room_id = room_id.to_string() room_id = room_id.to_string()
@ -410,7 +407,7 @@ class RegistrationHandler(BaseHandler):
400, "%s was not legal room ID or room alias" % (room_identifier,) 400, "%s was not legal room ID or room alias" % (room_identifier,)
) )
yield room_member_handler.update_membership( await room_member_handler.update_membership(
requester=requester, requester=requester,
target=requester.user, target=requester.user,
room_id=room_id, room_id=room_id,
@ -550,8 +547,7 @@ class RegistrationHandler(BaseHandler):
return (device_id, access_token) return (device_id, access_token)
@defer.inlineCallbacks async def post_registration_actions(self, user_id, auth_result, access_token):
def post_registration_actions(self, user_id, auth_result, access_token):
"""A user has completed registration """A user has completed registration
Args: Args:
@ -562,7 +558,7 @@ class RegistrationHandler(BaseHandler):
device, or None if `inhibit_login` enabled. device, or None if `inhibit_login` enabled.
""" """
if self.hs.config.worker_app: if self.hs.config.worker_app:
yield self._post_registration_client( await self._post_registration_client(
user_id=user_id, auth_result=auth_result, access_token=access_token user_id=user_id, auth_result=auth_result, access_token=access_token
) )
return return
@ -574,19 +570,18 @@ class RegistrationHandler(BaseHandler):
if is_threepid_reserved( if is_threepid_reserved(
self.hs.config.mau_limits_reserved_threepids, threepid self.hs.config.mau_limits_reserved_threepids, threepid
): ):
yield self.store.upsert_monthly_active_user(user_id) await self.store.upsert_monthly_active_user(user_id)
yield self._register_email_threepid(user_id, threepid, access_token) await self._register_email_threepid(user_id, threepid, access_token)
if auth_result and LoginType.MSISDN in auth_result: if auth_result and LoginType.MSISDN in auth_result:
threepid = auth_result[LoginType.MSISDN] threepid = auth_result[LoginType.MSISDN]
yield self._register_msisdn_threepid(user_id, threepid) await self._register_msisdn_threepid(user_id, threepid)
if auth_result and LoginType.TERMS in auth_result: if auth_result and LoginType.TERMS in auth_result:
yield self._on_user_consented(user_id, self.hs.config.user_consent_version) await self._on_user_consented(user_id, self.hs.config.user_consent_version)
@defer.inlineCallbacks async def _on_user_consented(self, user_id, consent_version):
def _on_user_consented(self, user_id, consent_version):
"""A user consented to the terms on registration """A user consented to the terms on registration
Args: Args:
@ -595,8 +590,8 @@ class RegistrationHandler(BaseHandler):
consented to. consented to.
""" """
logger.info("%s has consented to the privacy policy", user_id) logger.info("%s has consented to the privacy policy", user_id)
yield self.store.user_set_consent_version(user_id, consent_version) await self.store.user_set_consent_version(user_id, consent_version)
yield self.post_consent_actions(user_id) await self.post_consent_actions(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def _register_email_threepid(self, user_id, threepid, token): def _register_email_threepid(self, user_id, threepid, token):

View File

@ -148,17 +148,16 @@ class RoomCreationHandler(BaseHandler):
return ret return ret
@defer.inlineCallbacks async def _upgrade_room(
def _upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion self, requester: Requester, old_room_id: str, new_version: RoomVersion
): ):
user_id = requester.user.to_string() user_id = requester.user.to_string()
# start by allocating a new room id # start by allocating a new room id
r = yield self.store.get_room(old_room_id) r = await self.store.get_room(old_room_id)
if r is None: if r is None:
raise NotFoundError("Unknown room id %s" % (old_room_id,)) raise NotFoundError("Unknown room id %s" % (old_room_id,))
new_room_id = yield self._generate_room_id( new_room_id = await self._generate_room_id(
creator_id=user_id, is_public=r["is_public"], room_version=new_version, creator_id=user_id, is_public=r["is_public"], room_version=new_version,
) )
@ -169,7 +168,7 @@ class RoomCreationHandler(BaseHandler):
( (
tombstone_event, tombstone_event,
tombstone_context, tombstone_context,
) = yield self.event_creation_handler.create_event( ) = await self.event_creation_handler.create_event(
requester, requester,
{ {
"type": EventTypes.Tombstone, "type": EventTypes.Tombstone,
@ -183,12 +182,12 @@ class RoomCreationHandler(BaseHandler):
}, },
token_id=requester.access_token_id, token_id=requester.access_token_id,
) )
old_room_version = yield self.store.get_room_version_id(old_room_id) old_room_version = await self.store.get_room_version_id(old_room_id)
yield self.auth.check_from_context( await self.auth.check_from_context(
old_room_version, tombstone_event, tombstone_context old_room_version, tombstone_event, tombstone_context
) )
yield self.clone_existing_room( await self.clone_existing_room(
requester, requester,
old_room_id=old_room_id, old_room_id=old_room_id,
new_room_id=new_room_id, new_room_id=new_room_id,
@ -197,32 +196,31 @@ class RoomCreationHandler(BaseHandler):
) )
# now send the tombstone # now send the tombstone
yield self.event_creation_handler.send_nonmember_event( await self.event_creation_handler.send_nonmember_event(
requester, tombstone_event, tombstone_context requester, tombstone_event, tombstone_context
) )
old_room_state = yield tombstone_context.get_current_state_ids() old_room_state = await tombstone_context.get_current_state_ids()
# update any aliases # update any aliases
yield self._move_aliases_to_new_room( await self._move_aliases_to_new_room(
requester, old_room_id, new_room_id, old_room_state requester, old_room_id, new_room_id, old_room_state
) )
# Copy over user push rules, tags and migrate room directory state # Copy over user push rules, tags and migrate room directory state
yield self.room_member_handler.transfer_room_state_on_room_upgrade( await self.room_member_handler.transfer_room_state_on_room_upgrade(
old_room_id, new_room_id old_room_id, new_room_id
) )
# finally, shut down the PLs in the old room, and update them in the new # finally, shut down the PLs in the old room, and update them in the new
# room. # room.
yield self._update_upgraded_room_pls( await self._update_upgraded_room_pls(
requester, old_room_id, new_room_id, old_room_state, requester, old_room_id, new_room_id, old_room_state,
) )
return new_room_id return new_room_id
@defer.inlineCallbacks async def _update_upgraded_room_pls(
def _update_upgraded_room_pls(
self, self,
requester: Requester, requester: Requester,
old_room_id: str, old_room_id: str,
@ -249,7 +247,7 @@ class RoomCreationHandler(BaseHandler):
) )
return return
old_room_pl_state = yield self.store.get_event(old_room_pl_event_id) old_room_pl_state = await self.store.get_event(old_room_pl_event_id)
# we try to stop regular users from speaking by setting the PL required # we try to stop regular users from speaking by setting the PL required
# to send regular events and invites to 'Moderator' level. That's normally # to send regular events and invites to 'Moderator' level. That's normally
@ -278,7 +276,7 @@ class RoomCreationHandler(BaseHandler):
if updated: if updated:
try: try:
yield self.event_creation_handler.create_and_send_nonmember_event( await self.event_creation_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.PowerLevels, "type": EventTypes.PowerLevels,
@ -292,7 +290,7 @@ class RoomCreationHandler(BaseHandler):
except AuthError as e: except AuthError as e:
logger.warning("Unable to update PLs in old room: %s", e) logger.warning("Unable to update PLs in old room: %s", e)
yield self.event_creation_handler.create_and_send_nonmember_event( await self.event_creation_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.PowerLevels, "type": EventTypes.PowerLevels,
@ -304,8 +302,7 @@ class RoomCreationHandler(BaseHandler):
ratelimit=False, ratelimit=False,
) )
@defer.inlineCallbacks async def clone_existing_room(
def clone_existing_room(
self, self,
requester: Requester, requester: Requester,
old_room_id: str, old_room_id: str,
@ -338,7 +335,7 @@ class RoomCreationHandler(BaseHandler):
# Check if old room was non-federatable # Check if old room was non-federatable
# Get old room's create event # Get old room's create event
old_room_create_event = yield self.store.get_create_event_for_room(old_room_id) old_room_create_event = await self.store.get_create_event_for_room(old_room_id)
# Check if the create event specified a non-federatable room # Check if the create event specified a non-federatable room
if not old_room_create_event.content.get("m.federate", True): if not old_room_create_event.content.get("m.federate", True):
@ -361,11 +358,11 @@ class RoomCreationHandler(BaseHandler):
(EventTypes.PowerLevels, ""), (EventTypes.PowerLevels, ""),
) )
old_room_state_ids = yield self.store.get_filtered_current_state_ids( old_room_state_ids = await self.store.get_filtered_current_state_ids(
old_room_id, StateFilter.from_types(types_to_copy) old_room_id, StateFilter.from_types(types_to_copy)
) )
# map from event_id to BaseEvent # map from event_id to BaseEvent
old_room_state_events = yield self.store.get_events(old_room_state_ids.values()) old_room_state_events = await self.store.get_events(old_room_state_ids.values())
for k, old_event_id in iteritems(old_room_state_ids): for k, old_event_id in iteritems(old_room_state_ids):
old_event = old_room_state_events.get(old_event_id) old_event = old_room_state_events.get(old_event_id)
@ -400,7 +397,7 @@ class RoomCreationHandler(BaseHandler):
if current_power_level < needed_power_level: if current_power_level < needed_power_level:
power_levels["users"][user_id] = needed_power_level power_levels["users"][user_id] = needed_power_level
yield self._send_events_for_new_room( await self._send_events_for_new_room(
requester, requester,
new_room_id, new_room_id,
# we expect to override all the presets with initial_state, so this is # we expect to override all the presets with initial_state, so this is
@ -412,12 +409,12 @@ class RoomCreationHandler(BaseHandler):
) )
# Transfer membership events # Transfer membership events
old_room_member_state_ids = yield self.store.get_filtered_current_state_ids( old_room_member_state_ids = await self.store.get_filtered_current_state_ids(
old_room_id, StateFilter.from_types([(EventTypes.Member, None)]) old_room_id, StateFilter.from_types([(EventTypes.Member, None)])
) )
# map from event_id to BaseEvent # map from event_id to BaseEvent
old_room_member_state_events = yield self.store.get_events( old_room_member_state_events = await self.store.get_events(
old_room_member_state_ids.values() old_room_member_state_ids.values()
) )
for k, old_event in iteritems(old_room_member_state_events): for k, old_event in iteritems(old_room_member_state_events):
@ -426,7 +423,7 @@ class RoomCreationHandler(BaseHandler):
"membership" in old_event.content "membership" in old_event.content
and old_event.content["membership"] == "ban" and old_event.content["membership"] == "ban"
): ):
yield self.room_member_handler.update_membership( await self.room_member_handler.update_membership(
requester, requester,
UserID.from_string(old_event["state_key"]), UserID.from_string(old_event["state_key"]),
new_room_id, new_room_id,
@ -438,8 +435,7 @@ class RoomCreationHandler(BaseHandler):
# XXX invites/joins # XXX invites/joins
# XXX 3pid invites # XXX 3pid invites
@defer.inlineCallbacks async def _move_aliases_to_new_room(
def _move_aliases_to_new_room(
self, self,
requester: Requester, requester: Requester,
old_room_id: str, old_room_id: str,
@ -448,13 +444,13 @@ class RoomCreationHandler(BaseHandler):
): ):
directory_handler = self.hs.get_handlers().directory_handler directory_handler = self.hs.get_handlers().directory_handler
aliases = yield self.store.get_aliases_for_room(old_room_id) aliases = await self.store.get_aliases_for_room(old_room_id)
# check to see if we have a canonical alias. # check to see if we have a canonical alias.
canonical_alias_event = None canonical_alias_event = None
canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, "")) canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, ""))
if canonical_alias_event_id: if canonical_alias_event_id:
canonical_alias_event = yield self.store.get_event(canonical_alias_event_id) canonical_alias_event = await self.store.get_event(canonical_alias_event_id)
# first we try to remove the aliases from the old room (we suppress sending # first we try to remove the aliases from the old room (we suppress sending
# the room_aliases event until the end). # the room_aliases event until the end).
@ -472,7 +468,7 @@ class RoomCreationHandler(BaseHandler):
for alias_str in aliases: for alias_str in aliases:
alias = RoomAlias.from_string(alias_str) alias = RoomAlias.from_string(alias_str)
try: try:
yield directory_handler.delete_association(requester, alias) await directory_handler.delete_association(requester, alias)
removed_aliases.append(alias_str) removed_aliases.append(alias_str)
except SynapseError as e: except SynapseError as e:
logger.warning("Unable to remove alias %s from old room: %s", alias, e) logger.warning("Unable to remove alias %s from old room: %s", alias, e)
@ -485,7 +481,7 @@ class RoomCreationHandler(BaseHandler):
# we can now add any aliases we successfully removed to the new room. # we can now add any aliases we successfully removed to the new room.
for alias in removed_aliases: for alias in removed_aliases:
try: try:
yield directory_handler.create_association( await directory_handler.create_association(
requester, requester,
RoomAlias.from_string(alias), RoomAlias.from_string(alias),
new_room_id, new_room_id,
@ -502,7 +498,7 @@ class RoomCreationHandler(BaseHandler):
# alias event for the new room with a copy of the information. # alias event for the new room with a copy of the information.
try: try:
if canonical_alias_event: if canonical_alias_event:
yield self.event_creation_handler.create_and_send_nonmember_event( await self.event_creation_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.CanonicalAlias, "type": EventTypes.CanonicalAlias,
@ -518,8 +514,9 @@ class RoomCreationHandler(BaseHandler):
# we returned the new room to the client at this point. # we returned the new room to the client at this point.
logger.error("Unable to send updated alias events in new room: %s", e) logger.error("Unable to send updated alias events in new room: %s", e)
@defer.inlineCallbacks async def create_room(
def create_room(self, requester, config, ratelimit=True, creator_join_profile=None): self, requester, config, ratelimit=True, creator_join_profile=None
):
""" Creates a new room. """ Creates a new room.
Args: Args:
@ -547,7 +544,7 @@ class RoomCreationHandler(BaseHandler):
""" """
user_id = requester.user.to_string() user_id = requester.user.to_string()
yield self.auth.check_auth_blocking(user_id) await self.auth.check_auth_blocking(user_id)
if ( if (
self._server_notices_mxid is not None self._server_notices_mxid is not None
@ -556,11 +553,11 @@ class RoomCreationHandler(BaseHandler):
# allow the server notices mxid to create rooms # allow the server notices mxid to create rooms
is_requester_admin = True is_requester_admin = True
else: else:
is_requester_admin = yield self.auth.is_server_admin(requester.user) is_requester_admin = await self.auth.is_server_admin(requester.user)
# Check whether the third party rules allows/changes the room create # Check whether the third party rules allows/changes the room create
# request. # request.
event_allowed = yield self.third_party_event_rules.on_create_room( event_allowed = await self.third_party_event_rules.on_create_room(
requester, config, is_requester_admin=is_requester_admin requester, config, is_requester_admin=is_requester_admin
) )
if not event_allowed: if not event_allowed:
@ -574,7 +571,7 @@ class RoomCreationHandler(BaseHandler):
raise SynapseError(403, "You are not permitted to create rooms") raise SynapseError(403, "You are not permitted to create rooms")
if ratelimit: if ratelimit:
yield self.ratelimit(requester) await self.ratelimit(requester)
room_version_id = config.get( room_version_id = config.get(
"room_version", self.config.default_room_version.identifier "room_version", self.config.default_room_version.identifier
@ -597,7 +594,7 @@ class RoomCreationHandler(BaseHandler):
raise SynapseError(400, "Invalid characters in room alias") raise SynapseError(400, "Invalid characters in room alias")
room_alias = RoomAlias(config["room_alias_name"], self.hs.hostname) room_alias = RoomAlias(config["room_alias_name"], self.hs.hostname)
mapping = yield self.store.get_association_from_room_alias(room_alias) mapping = await self.store.get_association_from_room_alias(room_alias)
if mapping: if mapping:
raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE) raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE)
@ -612,7 +609,7 @@ class RoomCreationHandler(BaseHandler):
except Exception: except Exception:
raise SynapseError(400, "Invalid user_id: %s" % (i,)) raise SynapseError(400, "Invalid user_id: %s" % (i,))
yield self.event_creation_handler.assert_accepted_privacy_policy(requester) await self.event_creation_handler.assert_accepted_privacy_policy(requester)
power_level_content_override = config.get("power_level_content_override") power_level_content_override = config.get("power_level_content_override")
if ( if (
@ -631,13 +628,13 @@ class RoomCreationHandler(BaseHandler):
visibility = config.get("visibility", None) visibility = config.get("visibility", None)
is_public = visibility == "public" is_public = visibility == "public"
room_id = yield self._generate_room_id( room_id = await self._generate_room_id(
creator_id=user_id, is_public=is_public, room_version=room_version, creator_id=user_id, is_public=is_public, room_version=room_version,
) )
directory_handler = self.hs.get_handlers().directory_handler directory_handler = self.hs.get_handlers().directory_handler
if room_alias: if room_alias:
yield directory_handler.create_association( await directory_handler.create_association(
requester=requester, requester=requester,
room_id=room_id, room_id=room_id,
room_alias=room_alias, room_alias=room_alias,
@ -670,7 +667,7 @@ class RoomCreationHandler(BaseHandler):
# override any attempt to set room versions via the creation_content # override any attempt to set room versions via the creation_content
creation_content["room_version"] = room_version.identifier creation_content["room_version"] = room_version.identifier
yield self._send_events_for_new_room( await self._send_events_for_new_room(
requester, requester,
room_id, room_id,
preset_config=preset_config, preset_config=preset_config,
@ -684,7 +681,7 @@ class RoomCreationHandler(BaseHandler):
if "name" in config: if "name" in config:
name = config["name"] name = config["name"]
yield self.event_creation_handler.create_and_send_nonmember_event( await self.event_creation_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.Name, "type": EventTypes.Name,
@ -698,7 +695,7 @@ class RoomCreationHandler(BaseHandler):
if "topic" in config: if "topic" in config:
topic = config["topic"] topic = config["topic"]
yield self.event_creation_handler.create_and_send_nonmember_event( await self.event_creation_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.Topic, "type": EventTypes.Topic,
@ -716,7 +713,7 @@ class RoomCreationHandler(BaseHandler):
if is_direct: if is_direct:
content["is_direct"] = is_direct content["is_direct"] = is_direct
yield self.room_member_handler.update_membership( await self.room_member_handler.update_membership(
requester, requester,
UserID.from_string(invitee), UserID.from_string(invitee),
room_id, room_id,
@ -730,7 +727,7 @@ class RoomCreationHandler(BaseHandler):
id_access_token = invite_3pid.get("id_access_token") # optional id_access_token = invite_3pid.get("id_access_token") # optional
address = invite_3pid["address"] address = invite_3pid["address"]
medium = invite_3pid["medium"] medium = invite_3pid["medium"]
yield self.hs.get_room_member_handler().do_3pid_invite( await self.hs.get_room_member_handler().do_3pid_invite(
room_id, room_id,
requester.user, requester.user,
medium, medium,
@ -748,8 +745,7 @@ class RoomCreationHandler(BaseHandler):
return result return result
@defer.inlineCallbacks async def _send_events_for_new_room(
def _send_events_for_new_room(
self, self,
creator, # A Requester object. creator, # A Requester object.
room_id, room_id,
@ -769,11 +765,10 @@ class RoomCreationHandler(BaseHandler):
return e return e
@defer.inlineCallbacks async def send(etype, content, **kwargs):
def send(etype, content, **kwargs):
event = create(etype, content, **kwargs) event = create(etype, content, **kwargs)
logger.debug("Sending %s in new room", etype) logger.debug("Sending %s in new room", etype)
yield self.event_creation_handler.create_and_send_nonmember_event( await self.event_creation_handler.create_and_send_nonmember_event(
creator, event, ratelimit=False creator, event, ratelimit=False
) )
@ -784,10 +779,10 @@ class RoomCreationHandler(BaseHandler):
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
creation_content.update({"creator": creator_id}) creation_content.update({"creator": creator_id})
yield send(etype=EventTypes.Create, content=creation_content) await send(etype=EventTypes.Create, content=creation_content)
logger.debug("Sending %s in new room", EventTypes.Member) logger.debug("Sending %s in new room", EventTypes.Member)
yield self.room_member_handler.update_membership( await self.room_member_handler.update_membership(
creator, creator,
creator.user, creator.user,
room_id, room_id,
@ -800,7 +795,7 @@ class RoomCreationHandler(BaseHandler):
# of the first events that get sent into a room. # of the first events that get sent into a room.
pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None)
if pl_content is not None: if pl_content is not None:
yield send(etype=EventTypes.PowerLevels, content=pl_content) await send(etype=EventTypes.PowerLevels, content=pl_content)
else: else:
power_level_content = { power_level_content = {
"users": {creator_id: 100}, "users": {creator_id: 100},
@ -833,33 +828,33 @@ class RoomCreationHandler(BaseHandler):
if power_level_content_override: if power_level_content_override:
power_level_content.update(power_level_content_override) power_level_content.update(power_level_content_override)
yield send(etype=EventTypes.PowerLevels, content=power_level_content) await send(etype=EventTypes.PowerLevels, content=power_level_content)
if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
yield send( await send(
etype=EventTypes.CanonicalAlias, etype=EventTypes.CanonicalAlias,
content={"alias": room_alias.to_string()}, content={"alias": room_alias.to_string()},
) )
if (EventTypes.JoinRules, "") not in initial_state: if (EventTypes.JoinRules, "") not in initial_state:
yield send( await send(
etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]} etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]}
) )
if (EventTypes.RoomHistoryVisibility, "") not in initial_state: if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
yield send( await send(
etype=EventTypes.RoomHistoryVisibility, etype=EventTypes.RoomHistoryVisibility,
content={"history_visibility": config["history_visibility"]}, content={"history_visibility": config["history_visibility"]},
) )
if config["guest_can_join"]: if config["guest_can_join"]:
if (EventTypes.GuestAccess, "") not in initial_state: if (EventTypes.GuestAccess, "") not in initial_state:
yield send( await send(
etype=EventTypes.GuestAccess, content={"guest_access": "can_join"} etype=EventTypes.GuestAccess, content={"guest_access": "can_join"}
) )
for (etype, state_key), content in initial_state.items(): for (etype, state_key), content in initial_state.items():
yield send(etype=etype, state_key=state_key, content=content) await send(etype=etype, state_key=state_key, content=content)
@defer.inlineCallbacks @defer.inlineCallbacks
def _generate_room_id( def _generate_room_id(

View File

@ -142,8 +142,7 @@ class RoomMemberHandler(object):
""" """
raise NotImplementedError() raise NotImplementedError()
@defer.inlineCallbacks async def _local_membership_update(
def _local_membership_update(
self, self,
requester, requester,
target, target,
@ -164,7 +163,7 @@ class RoomMemberHandler(object):
if requester.is_guest: if requester.is_guest:
content["kind"] = "guest" content["kind"] = "guest"
event, context = yield self.event_creation_handler.create_event( event, context = await self.event_creation_handler.create_event(
requester, requester,
{ {
"type": EventTypes.Member, "type": EventTypes.Member,
@ -182,18 +181,18 @@ class RoomMemberHandler(object):
) )
# Check if this event matches the previous membership event for the user. # Check if this event matches the previous membership event for the user.
duplicate = yield self.event_creation_handler.deduplicate_state_event( duplicate = await self.event_creation_handler.deduplicate_state_event(
event, context event, context
) )
if duplicate is not None: if duplicate is not None:
# Discard the new event since this membership change is a no-op. # Discard the new event since this membership change is a no-op.
return duplicate return duplicate
yield self.event_creation_handler.handle_new_client_event( await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target], ratelimit=ratelimit requester, event, context, extra_users=[target], ratelimit=ratelimit
) )
prev_state_ids = yield context.get_prev_state_ids() prev_state_ids = await context.get_prev_state_ids()
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
@ -203,15 +202,15 @@ class RoomMemberHandler(object):
# info. # info.
newly_joined = True newly_joined = True
if prev_member_event_id: if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id) prev_member_event = await self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined: if newly_joined:
yield self._user_joined_room(target, room_id) await self._user_joined_room(target, room_id)
elif event.membership == Membership.LEAVE: elif event.membership == Membership.LEAVE:
if prev_member_event_id: if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id) prev_member_event = await self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN: if prev_member_event.membership == Membership.JOIN:
yield self._user_left_room(target, room_id) await self._user_left_room(target, room_id)
return event return event
@ -253,8 +252,7 @@ class RoomMemberHandler(object):
for tag, tag_content in room_tags.items(): for tag, tag_content in room_tags.items():
yield self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content) yield self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content)
@defer.inlineCallbacks async def update_membership(
def update_membership(
self, self,
requester, requester,
target, target,
@ -269,8 +267,8 @@ class RoomMemberHandler(object):
): ):
key = (room_id,) key = (room_id,)
with (yield self.member_linearizer.queue(key)): with (await self.member_linearizer.queue(key)):
result = yield self._update_membership( result = await self._update_membership(
requester, requester,
target, target,
room_id, room_id,
@ -285,8 +283,7 @@ class RoomMemberHandler(object):
return result return result
@defer.inlineCallbacks async def _update_membership(
def _update_membership(
self, self,
requester, requester,
target, target,
@ -321,7 +318,7 @@ class RoomMemberHandler(object):
# if this is a join with a 3pid signature, we may need to turn a 3pid # if this is a join with a 3pid signature, we may need to turn a 3pid
# invite into a normal invite before we can handle the join. # invite into a normal invite before we can handle the join.
if third_party_signed is not None: if third_party_signed is not None:
yield self.federation_handler.exchange_third_party_invite( await self.federation_handler.exchange_third_party_invite(
third_party_signed["sender"], third_party_signed["sender"],
target.to_string(), target.to_string(),
room_id, room_id,
@ -332,7 +329,7 @@ class RoomMemberHandler(object):
remote_room_hosts = [] remote_room_hosts = []
if effective_membership_state not in ("leave", "ban"): if effective_membership_state not in ("leave", "ban"):
is_blocked = yield self.store.is_room_blocked(room_id) is_blocked = await self.store.is_room_blocked(room_id)
if is_blocked: if is_blocked:
raise SynapseError(403, "This room has been blocked on this server") raise SynapseError(403, "This room has been blocked on this server")
@ -351,7 +348,7 @@ class RoomMemberHandler(object):
is_requester_admin = True is_requester_admin = True
else: else:
is_requester_admin = yield self.auth.is_server_admin(requester.user) is_requester_admin = await self.auth.is_server_admin(requester.user)
if not is_requester_admin: if not is_requester_admin:
if self.config.block_non_admin_invites: if self.config.block_non_admin_invites:
@ -370,9 +367,9 @@ class RoomMemberHandler(object):
if block_invite: if block_invite:
raise SynapseError(403, "Invites have been disabled on this server") raise SynapseError(403, "Invites have been disabled on this server")
latest_event_ids = yield self.store.get_prev_events_for_room(room_id) latest_event_ids = await self.store.get_prev_events_for_room(room_id)
current_state_ids = yield self.state_handler.get_current_state_ids( current_state_ids = await self.state_handler.get_current_state_ids(
room_id, latest_event_ids=latest_event_ids room_id, latest_event_ids=latest_event_ids
) )
@ -381,7 +378,7 @@ class RoomMemberHandler(object):
# transitions and generic otherwise # transitions and generic otherwise
old_state_id = current_state_ids.get((EventTypes.Member, target.to_string())) old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
if old_state_id: if old_state_id:
old_state = yield self.store.get_event(old_state_id, allow_none=True) old_state = await self.store.get_event(old_state_id, allow_none=True)
old_membership = old_state.content.get("membership") if old_state else None old_membership = old_state.content.get("membership") if old_state else None
if action == "unban" and old_membership != "ban": if action == "unban" and old_membership != "ban":
raise SynapseError( raise SynapseError(
@ -413,7 +410,7 @@ class RoomMemberHandler(object):
old_membership == Membership.INVITE old_membership == Membership.INVITE
and effective_membership_state == Membership.LEAVE and effective_membership_state == Membership.LEAVE
): ):
is_blocked = yield self._is_server_notice_room(room_id) is_blocked = await self._is_server_notice_room(room_id)
if is_blocked: if is_blocked:
raise SynapseError( raise SynapseError(
http_client.FORBIDDEN, http_client.FORBIDDEN,
@ -424,18 +421,18 @@ class RoomMemberHandler(object):
if action == "kick": if action == "kick":
raise AuthError(403, "The target user is not in the room") raise AuthError(403, "The target user is not in the room")
is_host_in_room = yield self._is_host_in_room(current_state_ids) is_host_in_room = await self._is_host_in_room(current_state_ids)
if effective_membership_state == Membership.JOIN: if effective_membership_state == Membership.JOIN:
if requester.is_guest: if requester.is_guest:
guest_can_join = yield self._can_guest_join(current_state_ids) guest_can_join = await self._can_guest_join(current_state_ids)
if not guest_can_join: if not guest_can_join:
# This should be an auth check, but guests are a local concept, # This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process. # so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed") raise AuthError(403, "Guest access not allowed")
if not is_host_in_room: if not is_host_in_room:
inviter = yield self._get_inviter(target.to_string(), room_id) inviter = await self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter): if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain) remote_room_hosts.append(inviter.domain)
@ -443,13 +440,13 @@ class RoomMemberHandler(object):
profile = self.profile_handler profile = self.profile_handler
if not content_specified: if not content_specified:
content["displayname"] = yield profile.get_displayname(target) content["displayname"] = await profile.get_displayname(target)
content["avatar_url"] = yield profile.get_avatar_url(target) content["avatar_url"] = await profile.get_avatar_url(target)
if requester.is_guest: if requester.is_guest:
content["kind"] = "guest" content["kind"] = "guest"
remote_join_response = yield self._remote_join( remote_join_response = await self._remote_join(
requester, remote_room_hosts, room_id, target, content requester, remote_room_hosts, room_id, target, content
) )
@ -458,7 +455,7 @@ class RoomMemberHandler(object):
elif effective_membership_state == Membership.LEAVE: elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room: if not is_host_in_room:
# perhaps we've been invited # perhaps we've been invited
inviter = yield self._get_inviter(target.to_string(), room_id) inviter = await self._get_inviter(target.to_string(), room_id)
if not inviter: if not inviter:
raise SynapseError(404, "Not a known room") raise SynapseError(404, "Not a known room")
@ -472,12 +469,12 @@ class RoomMemberHandler(object):
else: else:
# send the rejection to the inviter's HS. # send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain] remote_room_hosts = remote_room_hosts + [inviter.domain]
res = yield self._remote_reject_invite( res = await self._remote_reject_invite(
requester, remote_room_hosts, room_id, target, content, requester, remote_room_hosts, room_id, target, content,
) )
return res return res
res = yield self._local_membership_update( res = await self._local_membership_update(
requester=requester, requester=requester,
target=target, target=target,
room_id=room_id, room_id=room_id,
@ -572,8 +569,7 @@ class RoomMemberHandler(object):
) )
continue continue
@defer.inlineCallbacks async def send_membership_event(self, requester, event, context, ratelimit=True):
def send_membership_event(self, requester, event, context, ratelimit=True):
""" """
Change the membership status of a user in a room. Change the membership status of a user in a room.
@ -599,27 +595,27 @@ class RoomMemberHandler(object):
else: else:
requester = types.create_requester(target_user) requester = types.create_requester(target_user)
prev_event = yield self.event_creation_handler.deduplicate_state_event( prev_event = await self.event_creation_handler.deduplicate_state_event(
event, context event, context
) )
if prev_event is not None: if prev_event is not None:
return return
prev_state_ids = yield context.get_prev_state_ids() prev_state_ids = await context.get_prev_state_ids()
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
if requester.is_guest: if requester.is_guest:
guest_can_join = yield self._can_guest_join(prev_state_ids) guest_can_join = await self._can_guest_join(prev_state_ids)
if not guest_can_join: if not guest_can_join:
# This should be an auth check, but guests are a local concept, # This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process. # so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed") raise AuthError(403, "Guest access not allowed")
if event.membership not in (Membership.LEAVE, Membership.BAN): if event.membership not in (Membership.LEAVE, Membership.BAN):
is_blocked = yield self.store.is_room_blocked(room_id) is_blocked = await self.store.is_room_blocked(room_id)
if is_blocked: if is_blocked:
raise SynapseError(403, "This room has been blocked on this server") raise SynapseError(403, "This room has been blocked on this server")
yield self.event_creation_handler.handle_new_client_event( await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target_user], ratelimit=ratelimit requester, event, context, extra_users=[target_user], ratelimit=ratelimit
) )
@ -633,15 +629,15 @@ class RoomMemberHandler(object):
# info. # info.
newly_joined = True newly_joined = True
if prev_member_event_id: if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id) prev_member_event = await self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined: if newly_joined:
yield self._user_joined_room(target_user, room_id) await self._user_joined_room(target_user, room_id)
elif event.membership == Membership.LEAVE: elif event.membership == Membership.LEAVE:
if prev_member_event_id: if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id) prev_member_event = await self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN: if prev_member_event.membership == Membership.JOIN:
yield self._user_left_room(target_user, room_id) await self._user_left_room(target_user, room_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def _can_guest_join(self, current_state_ids): def _can_guest_join(self, current_state_ids):
@ -699,8 +695,7 @@ class RoomMemberHandler(object):
if invite: if invite:
return UserID.from_string(invite.sender) return UserID.from_string(invite.sender)
@defer.inlineCallbacks async def do_3pid_invite(
def do_3pid_invite(
self, self,
room_id, room_id,
inviter, inviter,
@ -712,7 +707,7 @@ class RoomMemberHandler(object):
id_access_token=None, id_access_token=None,
): ):
if self.config.block_non_admin_invites: if self.config.block_non_admin_invites:
is_requester_admin = yield self.auth.is_server_admin(requester.user) is_requester_admin = await self.auth.is_server_admin(requester.user)
if not is_requester_admin: if not is_requester_admin:
raise SynapseError( raise SynapseError(
403, "Invites have been disabled on this server", Codes.FORBIDDEN 403, "Invites have been disabled on this server", Codes.FORBIDDEN
@ -720,9 +715,9 @@ class RoomMemberHandler(object):
# We need to rate limit *before* we send out any 3PID invites, so we # We need to rate limit *before* we send out any 3PID invites, so we
# can't just rely on the standard ratelimiting of events. # can't just rely on the standard ratelimiting of events.
yield self.base_handler.ratelimit(requester) await self.base_handler.ratelimit(requester)
can_invite = yield self.third_party_event_rules.check_threepid_can_be_invited( can_invite = await self.third_party_event_rules.check_threepid_can_be_invited(
medium, address, room_id medium, address, room_id
) )
if not can_invite: if not can_invite:
@ -737,16 +732,16 @@ class RoomMemberHandler(object):
403, "Looking up third-party identifiers is denied from this server" 403, "Looking up third-party identifiers is denied from this server"
) )
invitee = yield self.identity_handler.lookup_3pid( invitee = await self.identity_handler.lookup_3pid(
id_server, medium, address, id_access_token id_server, medium, address, id_access_token
) )
if invitee: if invitee:
yield self.update_membership( await self.update_membership(
requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
) )
else: else:
yield self._make_and_store_3pid_invite( await self._make_and_store_3pid_invite(
requester, requester,
id_server, id_server,
medium, medium,
@ -757,8 +752,7 @@ class RoomMemberHandler(object):
id_access_token=id_access_token, id_access_token=id_access_token,
) )
@defer.inlineCallbacks async def _make_and_store_3pid_invite(
def _make_and_store_3pid_invite(
self, self,
requester, requester,
id_server, id_server,
@ -769,7 +763,7 @@ class RoomMemberHandler(object):
txn_id, txn_id,
id_access_token=None, id_access_token=None,
): ):
room_state = yield self.state_handler.get_current_state(room_id) room_state = await self.state_handler.get_current_state(room_id)
inviter_display_name = "" inviter_display_name = ""
inviter_avatar_url = "" inviter_avatar_url = ""
@ -807,7 +801,7 @@ class RoomMemberHandler(object):
public_keys, public_keys,
fallback_public_key, fallback_public_key,
display_name, display_name,
) = yield self.identity_handler.ask_id_server_for_third_party_invite( ) = await self.identity_handler.ask_id_server_for_third_party_invite(
requester=requester, requester=requester,
id_server=id_server, id_server=id_server,
medium=medium, medium=medium,
@ -823,7 +817,7 @@ class RoomMemberHandler(object):
id_access_token=id_access_token, id_access_token=id_access_token,
) )
yield self.event_creation_handler.create_and_send_nonmember_event( await self.event_creation_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.ThirdPartyInvite, "type": EventTypes.ThirdPartyInvite,
@ -917,8 +911,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return complexity["v1"] > max_complexity return complexity["v1"] > max_complexity
@defer.inlineCallbacks async def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
"""Implements RoomMemberHandler._remote_join """Implements RoomMemberHandler._remote_join
""" """
# filter ourselves out of remote_room_hosts: do_invite_join ignores it # filter ourselves out of remote_room_hosts: do_invite_join ignores it
@ -933,7 +926,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
if self.hs.config.limit_remote_rooms.enabled: if self.hs.config.limit_remote_rooms.enabled:
# Fetch the room complexity # Fetch the room complexity
too_complex = yield self._is_remote_room_too_complex( too_complex = await self._is_remote_room_too_complex(
room_id, remote_room_hosts room_id, remote_room_hosts
) )
if too_complex is True: if too_complex is True:
@ -947,12 +940,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# join dance for now, since we're kinda implicitly checking # join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we # that we are allowed to join when we decide whether or not we
# need to do the invite/join dance. # need to do the invite/join dance.
yield defer.ensureDeferred( await self.federation_handler.do_invite_join(
self.federation_handler.do_invite_join( remote_room_hosts, room_id, user.to_string(), content
remote_room_hosts, room_id, user.to_string(), content
)
) )
yield self._user_joined_room(user, room_id) await self._user_joined_room(user, room_id)
# Check the room we just joined wasn't too large, if we didn't fetch the # Check the room we just joined wasn't too large, if we didn't fetch the
# complexity of it before. # complexity of it before.
@ -962,7 +953,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return return
# Check again, but with the local state events # Check again, but with the local state events
too_complex = yield self._is_local_room_too_complex(room_id) too_complex = await self._is_local_room_too_complex(room_id)
if too_complex is False: if too_complex is False:
# We're under the limit. # We're under the limit.
@ -970,7 +961,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# The room is too large. Leave. # The room is too large. Leave.
requester = types.create_requester(user, None, False, None) requester = types.create_requester(user, None, False, None)
yield self.update_membership( await self.update_membership(
requester=requester, target=user, room_id=room_id, action="leave" requester=requester, target=user, room_id=room_id, action="leave"
) )
raise SynapseError( raise SynapseError(
@ -1008,12 +999,12 @@ class RoomMemberMasterHandler(RoomMemberHandler):
def _user_joined_room(self, target, room_id): def _user_joined_room(self, target, room_id):
"""Implements RoomMemberHandler._user_joined_room """Implements RoomMemberHandler._user_joined_room
""" """
return user_joined_room(self.distributor, target, room_id) return defer.succeed(user_joined_room(self.distributor, target, room_id))
def _user_left_room(self, target, room_id): def _user_left_room(self, target, room_id):
"""Implements RoomMemberHandler._user_left_room """Implements RoomMemberHandler._user_left_room
""" """
return user_left_room(self.distributor, target, room_id) return defer.succeed(user_left_room(self.distributor, target, room_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def forget(self, user, room_id): def forget(self, user, room_id):

View File

@ -27,6 +27,7 @@ import inspect
import logging import logging
import threading import threading
import types import types
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
from typing_extensions import Literal from typing_extensions import Literal
@ -287,6 +288,46 @@ class LoggingContext(object):
return str(self.request) return str(self.request)
return "%s@%x" % (self.name, id(self)) return "%s@%x" % (self.name, id(self))
@classmethod
def current_context(cls) -> LoggingContextOrSentinel:
"""Get the current logging context from thread local storage
This exists for backwards compatibility. ``current_context()`` should be
called directly.
Returns:
LoggingContext: the current logging context
"""
warnings.warn(
"synapse.logging.context.LoggingContext.current_context() is deprecated "
"in favor of synapse.logging.context.current_context().",
DeprecationWarning,
stacklevel=2,
)
return current_context()
@classmethod
def set_current_context(
cls, context: LoggingContextOrSentinel
) -> LoggingContextOrSentinel:
"""Set the current logging context in thread local storage
This exists for backwards compatibility. ``set_current_context()`` should be
called directly.
Args:
context(LoggingContext): The context to activate.
Returns:
The context that was previously active
"""
warnings.warn(
"synapse.logging.context.LoggingContext.set_current_context() is deprecated "
"in favor of synapse.logging.context.set_current_context().",
DeprecationWarning,
stacklevel=2,
)
return set_current_context(context)
def __enter__(self) -> "LoggingContext": def __enter__(self) -> "LoggingContext":
"""Enters this logging context into thread local storage""" """Enters this logging context into thread local storage"""
old_context = set_current_context(self) old_context = set_current_context(self)

View File

@ -273,10 +273,9 @@ class Notifier(object):
"room_key", room_stream_id, users=extra_users, rooms=[event.room_id] "room_key", room_stream_id, users=extra_users, rooms=[event.room_id]
) )
@defer.inlineCallbacks async def _notify_app_services(self, room_stream_id):
def _notify_app_services(self, room_stream_id):
try: try:
yield self.appservice_handler.notify_interested_services(room_stream_id) await self.appservice_handler.notify_interested_services(room_stream_id)
except Exception: except Exception:
logger.exception("Error notifying application services of event") logger.exception("Error notifying application services of event")
@ -475,20 +474,18 @@ class Notifier(object):
return result return result
@defer.inlineCallbacks async def _get_room_ids(self, user, explicit_room_id):
def _get_room_ids(self, user, explicit_room_id): joined_room_ids = await self.store.get_rooms_for_user(user.to_string())
joined_room_ids = yield self.store.get_rooms_for_user(user.to_string())
if explicit_room_id: if explicit_room_id:
if explicit_room_id in joined_room_ids: if explicit_room_id in joined_room_ids:
return [explicit_room_id], True return [explicit_room_id], True
if (yield self._is_world_readable(explicit_room_id)): if await self._is_world_readable(explicit_room_id):
return [explicit_room_id], False return [explicit_room_id], False
raise AuthError(403, "Non-joined access not allowed") raise AuthError(403, "Non-joined access not allowed")
return joined_room_ids, True return joined_room_ids, True
@defer.inlineCallbacks async def _is_world_readable(self, room_id):
def _is_world_readable(self, room_id): state = await self.state_handler.get_current_state(
state = yield self.state_handler.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, "" room_id, EventTypes.RoomHistoryVisibility, ""
) )
if state and "history_visibility" in state.content: if state and "history_visibility" in state.content:

View File

@ -16,6 +16,7 @@
import abc import abc
import logging import logging
import re import re
from inspect import signature
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
from six import raise_from from six import raise_from
@ -60,6 +61,8 @@ class ReplicationEndpoint(object):
must call `register` to register the path with the HTTP server. must call `register` to register the path with the HTTP server.
Requests can be sent by calling the client returned by `make_client`. Requests can be sent by calling the client returned by `make_client`.
Requests are sent to master process by default, but can be sent to other
named processes by specifying an `instance_name` keyword argument.
Attributes: Attributes:
NAME (str): A name for the endpoint, added to the path as well as used NAME (str): A name for the endpoint, added to the path as well as used
@ -91,6 +94,16 @@ class ReplicationEndpoint(object):
hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000 hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
) )
# We reserve `instance_name` as a parameter to sending requests, so we
# assert here that sub classes don't try and use the name.
assert (
"instance_name" not in self.PATH_ARGS
), "`instance_name` is a reserved paramater name"
assert (
"instance_name"
not in signature(self.__class__._serialize_payload).parameters
), "`instance_name` is a reserved paramater name"
assert self.METHOD in ("PUT", "POST", "GET") assert self.METHOD in ("PUT", "POST", "GET")
@abc.abstractmethod @abc.abstractmethod
@ -135,7 +148,11 @@ class ReplicationEndpoint(object):
@trace(opname="outgoing_replication_request") @trace(opname="outgoing_replication_request")
@defer.inlineCallbacks @defer.inlineCallbacks
def send_request(**kwargs): def send_request(instance_name="master", **kwargs):
# Currently we only support sending requests to master process.
if instance_name != "master":
raise Exception("Unknown instance")
data = yield cls._serialize_payload(**kwargs) data = yield cls._serialize_payload(**kwargs)
url_args = [ url_args = [

View File

@ -50,6 +50,8 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
def __init__(self, hs): def __init__(self, hs):
super().__init__(hs) super().__init__(hs)
self._instance_name = hs.get_instance_name()
# We pull the streams from the replication steamer (if we try and make # We pull the streams from the replication steamer (if we try and make
# them ourselves we end up in an import loop). # them ourselves we end up in an import loop).
self.streams = hs.get_replication_streamer().get_streams() self.streams = hs.get_replication_streamer().get_streams()
@ -67,7 +69,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
upto_token = parse_integer(request, "upto_token", required=True) upto_token = parse_integer(request, "upto_token", required=True)
updates, upto_token, limited = await stream.get_updates_since( updates, upto_token, limited = await stream.get_updates_since(
from_token, upto_token self._instance_name, from_token, upto_token
) )
return ( return (

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict, Optional from typing import Optional
import six import six
@ -49,19 +49,6 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
self.hs = hs self.hs = hs
def stream_positions(self) -> Dict[str, int]:
"""
Get the current positions of all the streams this store wants to subscribe to
Returns:
map from stream name to the most recent update we have for
that stream (ie, the point we want to start replicating from)
"""
pos = {}
if self._cache_id_gen:
pos["caches"] = self._cache_id_gen.get_current_token()
return pos
def get_cache_stream_token(self): def get_cache_stream_token(self):
if self._cache_id_gen: if self._cache_id_gen:
return self._cache_id_gen.get_current_token() return self._cache_id_gen.get_current_token()

View File

@ -32,14 +32,6 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
def get_max_account_data_stream_id(self): def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token() return self._account_data_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedAccountDataStore, self).stream_positions()
position = self._account_data_id_gen.get_current_token()
result["user_account_data"] = position
result["room_account_data"] = position
result["tag_account_data"] = position
return result
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, token, rows):
if stream_name == "tag_account_data": if stream_name == "tag_account_data":
self._account_data_id_gen.advance(token) self._account_data_id_gen.advance(token)

View File

@ -43,11 +43,6 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
expiry_ms=30 * 60 * 1000, expiry_ms=30 * 60 * 1000,
) )
def stream_positions(self):
result = super(SlavedDeviceInboxStore, self).stream_positions()
result["to_device"] = self._device_inbox_id_gen.get_current_token()
return result
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, token, rows):
if stream_name == "to_device": if stream_name == "to_device":
self._device_inbox_id_gen.advance(token) self._device_inbox_id_gen.advance(token)

View File

@ -48,16 +48,6 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
"DeviceListFederationStreamChangeCache", device_list_max "DeviceListFederationStreamChangeCache", device_list_max
) )
def stream_positions(self):
result = super(SlavedDeviceStore, self).stream_positions()
# The user signature stream uses the same stream ID generator as the
# device list stream, so set them both to the device list ID
# generator's current token.
current_token = self._device_list_id_gen.get_current_token()
result[DeviceListsStream.NAME] = current_token
result[UserSignatureStream.NAME] = current_token
return result
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, token, rows):
if stream_name == DeviceListsStream.NAME: if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(token) self._device_list_id_gen.advance(token)

View File

@ -93,12 +93,6 @@ class SlavedEventStore(
def get_room_min_stream_ordering(self): def get_room_min_stream_ordering(self):
return self._backfill_id_gen.get_current_token() return self._backfill_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions()
result["events"] = self._stream_id_gen.get_current_token()
result["backfill"] = -self._backfill_id_gen.get_current_token()
return result
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, token, rows):
if stream_name == "events": if stream_name == "events":
self._stream_id_gen.advance(token) self._stream_id_gen.advance(token)

View File

@ -37,11 +37,6 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
def get_group_stream_token(self): def get_group_stream_token(self):
return self._group_updates_id_gen.get_current_token() return self._group_updates_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedGroupServerStore, self).stream_positions()
result["groups"] = self._group_updates_id_gen.get_current_token()
return result
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, token, rows):
if stream_name == "groups": if stream_name == "groups":
self._group_updates_id_gen.advance(token) self._group_updates_id_gen.advance(token)

View File

@ -41,15 +41,6 @@ class SlavedPresenceStore(BaseSlavedStore):
def get_current_presence_token(self): def get_current_presence_token(self):
return self._presence_id_gen.get_current_token() return self._presence_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedPresenceStore, self).stream_positions()
if self.hs.config.use_presence:
position = self._presence_id_gen.get_current_token()
result["presence"] = position
return result
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, token, rows):
if stream_name == "presence": if stream_name == "presence":
self._presence_id_gen.advance(token) self._presence_id_gen.advance(token)

View File

@ -37,11 +37,6 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
def get_max_push_rules_stream_id(self): def get_max_push_rules_stream_id(self):
return self._push_rules_stream_id_gen.get_current_token() return self._push_rules_stream_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedPushRuleStore, self).stream_positions()
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
return result
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, token, rows):
if stream_name == "push_rules": if stream_name == "push_rules":
self._push_rules_stream_id_gen.advance(token) self._push_rules_stream_id_gen.advance(token)

View File

@ -28,11 +28,6 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
) )
def stream_positions(self):
result = super(SlavedPusherStore, self).stream_positions()
result["pushers"] = self._pushers_id_gen.get_current_token()
return result
def get_pushers_stream_token(self): def get_pushers_stream_token(self):
return self._pushers_id_gen.get_current_token() return self._pushers_id_gen.get_current_token()

View File

@ -42,11 +42,6 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def get_max_receipt_stream_id(self): def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token() return self._receipts_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedReceiptsStore, self).stream_positions()
result["receipts"] = self._receipts_id_gen.get_current_token()
return result
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type)) self.get_receipts_for_user.invalidate((user_id, receipt_type))
self._get_linearized_receipts_for_room.invalidate_many((room_id,)) self._get_linearized_receipts_for_room.invalidate_many((room_id,))

View File

@ -30,11 +30,6 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
def get_current_public_room_stream_id(self): def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token() return self._public_room_id_gen.get_current_token()
def stream_positions(self):
result = super(RoomStore, self).stream_positions()
result["public_rooms"] = self._public_room_id_gen.get_current_token()
return result
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, token, rows):
if stream_name == "public_rooms": if stream_name == "public_rooms":
self._public_room_id_gen.advance(token) self._public_room_id_gen.advance(token)

View File

@ -16,7 +16,7 @@
""" """
import logging import logging
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING
from twisted.internet.protocol import ReconnectingClientFactory from twisted.internet.protocol import ReconnectingClientFactory
@ -86,37 +86,22 @@ class ReplicationDataHandler:
def __init__(self, store: BaseSlavedStore): def __init__(self, store: BaseSlavedStore):
self.store = store self.store = store
async def on_rdata(self, stream_name: str, token: int, rows: list): async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
):
"""Called to handle a batch of replication data with a given stream token. """Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to By default this just pokes the slave store. Can be overridden in subclasses to
handle more. handle more.
Args: Args:
stream_name (str): name of the replication stream for this batch of rows stream_name: name of the replication stream for this batch of rows
token (int): stream token for this batch of rows instance_name: the instance that wrote the rows.
rows (list): a list of Stream.ROW_TYPE objects as returned by token: stream token for this batch of rows
Stream.parse_row. rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
""" """
self.store.process_replication_rows(stream_name, token, rows) self.store.process_replication_rows(stream_name, token, rows)
def get_streams_to_replicate(self) -> Dict[str, int]:
"""Called when a new connection has been established and we need to
subscribe to streams.
Returns:
map from stream name to the most recent update we have for
that stream (ie, the point we want to start replicating from)
"""
args = self.store.stream_positions()
user_account_data = args.pop("user_account_data", None)
room_account_data = args.pop("room_account_data", None)
if user_account_data:
args["account_data"] = user_account_data
elif room_account_data:
args["account_data"] = room_account_data
return args
async def on_position(self, stream_name: str, token: int): async def on_position(self, stream_name: str, token: int):
self.store.process_replication_rows(stream_name, token, []) self.store.process_replication_rows(stream_name, token, [])

View File

@ -281,19 +281,24 @@ class ReplicationCommandHandler:
# Check if this is the last of a batch of updates # Check if this is the last of a batch of updates
rows = self._pending_batches.pop(stream_name, []) rows = self._pending_batches.pop(stream_name, [])
rows.append(row) rows.append(row)
await self.on_rdata(stream_name, cmd.token, rows) await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
async def on_rdata(self, stream_name: str, token: int, rows: list): async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
):
"""Called to handle a batch of replication data with a given stream token. """Called to handle a batch of replication data with a given stream token.
Args: Args:
stream_name: name of the replication stream for this batch of rows stream_name: name of the replication stream for this batch of rows
instance_name: the instance that wrote the rows.
token: stream token for this batch of rows token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by rows: a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row. Stream.parse_row.
""" """
logger.debug("Received rdata %s -> %s", stream_name, token) logger.debug("Received rdata %s -> %s", stream_name, token)
await self._replication_data_handler.on_rdata(stream_name, token, rows) await self._replication_data_handler.on_rdata(
stream_name, instance_name, token, rows
)
async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand): async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
if cmd.instance_name == self._instance_name: if cmd.instance_name == self._instance_name:
@ -321,14 +326,7 @@ class ReplicationCommandHandler:
self._pending_batches.pop(stream_name, []) self._pending_batches.pop(stream_name, [])
# Find where we previously streamed up to. # Find where we previously streamed up to.
current_token = self._replication_data_handler.get_streams_to_replicate().get( current_token = stream.current_token()
stream_name
)
if current_token is None:
logger.warning(
"Got POSITION for stream we're not subscribed to: %s", stream_name,
)
return
# If the position token matches our current token then we're up to # If the position token matches our current token then we're up to
# date and there's nothing to do. Otherwise, fetch all updates # date and there's nothing to do. Otherwise, fetch all updates
@ -345,7 +343,9 @@ class ReplicationCommandHandler:
updates, updates,
current_token, current_token,
missing_updates, missing_updates,
) = await stream.get_updates_since(current_token, cmd.token) ) = await stream.get_updates_since(
cmd.instance_name, current_token, cmd.token
)
# TODO: add some tests for this # TODO: add some tests for this
@ -354,7 +354,10 @@ class ReplicationCommandHandler:
for token, rows in _batch_updates(updates): for token, rows in _batch_updates(updates):
await self.on_rdata( await self.on_rdata(
stream_name, token, [stream.parse_row(row) for row in rows], stream_name,
cmd.instance_name,
token,
[stream.parse_row(row) for row in rows],
) )
logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token) logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)

View File

@ -68,6 +68,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
def connectionMade(self): def connectionMade(self):
logger.info("Connected to redis") logger.info("Connected to redis")
super().connectionMade()
run_as_background_process("subscribe-replication", self._send_subscribe) run_as_background_process("subscribe-replication", self._send_subscribe)
self.handler.new_connection(self) self.handler.new_connection(self)
@ -136,6 +137,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
def connectionLost(self, reason): def connectionLost(self, reason):
logger.info("Lost connection to redis") logger.info("Lost connection to redis")
super().connectionLost(reason)
self.handler.lost_connection(self) self.handler.lost_connection(self)
def send_command(self, cmd: Command): def send_command(self, cmd: Command):
@ -203,5 +205,6 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
p.handler = self.handler p.handler = self.handler
p.outbound_redis_connection = self.outbound_redis_connection p.outbound_redis_connection = self.outbound_redis_connection
p.stream_name = self.stream_name p.stream_name = self.stream_name
p.password = self.password
return p return p

View File

@ -16,7 +16,7 @@
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import Any, Awaitable, Callable, Iterable, List, Optional, Tuple from typing import Any, Awaitable, Callable, List, Optional, Tuple
import attr import attr
@ -53,6 +53,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
# #
# The arguments are: # The arguments are:
# #
# * instance_name: the writer of the stream
# * from_token: the previous stream token: the starting point for fetching the # * from_token: the previous stream token: the starting point for fetching the
# updates # updates
# * to_token: the new stream token: the point to get updates up to # * to_token: the new stream token: the point to get updates up to
@ -62,7 +63,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
# If there are more updates available, it should set `limited` in the result, and # If there are more updates available, it should set `limited` in the result, and
# it will be called again to get the next batch. # it will be called again to get the next batch.
# #
UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]] UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]]
class Stream(object): class Stream(object):
@ -93,6 +94,7 @@ class Stream(object):
def __init__( def __init__(
self, self,
local_instance_name: str,
current_token_function: Callable[[], Token], current_token_function: Callable[[], Token],
update_function: UpdateFunction, update_function: UpdateFunction,
): ):
@ -108,9 +110,11 @@ class Stream(object):
stream tokens. See the UpdateFunction type definition for more info. stream tokens. See the UpdateFunction type definition for more info.
Args: Args:
local_instance_name: The instance name of the current process
current_token_function: callback to get the current token, as above current_token_function: callback to get the current token, as above
update_function: callback go get stream updates, as above update_function: callback go get stream updates, as above
""" """
self.local_instance_name = local_instance_name
self.current_token = current_token_function self.current_token = current_token_function
self.update_function = update_function self.update_function = update_function
@ -135,14 +139,14 @@ class Stream(object):
""" """
current_token = self.current_token() current_token = self.current_token()
updates, current_token, limited = await self.get_updates_since( updates, current_token, limited = await self.get_updates_since(
self.last_token, current_token self.local_instance_name, self.last_token, current_token
) )
self.last_token = current_token self.last_token = current_token
return updates, current_token, limited return updates, current_token, limited
async def get_updates_since( async def get_updates_since(
self, from_token: Token, upto_token: Token self, instance_name: str, from_token: Token, upto_token: Token
) -> StreamUpdateResult: ) -> StreamUpdateResult:
"""Like get_updates except allows specifying from when we should """Like get_updates except allows specifying from when we should
stream updates stream updates
@ -160,19 +164,19 @@ class Stream(object):
return [], upto_token, False return [], upto_token, False
updates, upto_token, limited = await self.update_function( updates, upto_token, limited = await self.update_function(
from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT, instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
) )
return updates, upto_token, limited return updates, upto_token, limited
def db_query_to_update_function( def db_query_to_update_function(
query_function: Callable[[Token, Token, int], Awaitable[Iterable[tuple]]] query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
) -> UpdateFunction: ) -> UpdateFunction:
"""Wraps a db query function which returns a list of rows to make it """Wraps a db query function which returns a list of rows to make it
suitable for use as an `update_function` for the Stream class suitable for use as an `update_function` for the Stream class
""" """
async def update_function(from_token, upto_token, limit): async def update_function(instance_name, from_token, upto_token, limit):
rows = await query_function(from_token, upto_token, limit) rows = await query_function(from_token, upto_token, limit)
updates = [(row[0], row[1:]) for row in rows] updates = [(row[0], row[1:]) for row in rows]
limited = False limited = False
@ -193,10 +197,13 @@ def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
client = ReplicationGetStreamUpdates.make_client(hs) client = ReplicationGetStreamUpdates.make_client(hs)
async def update_function( async def update_function(
from_token: int, upto_token: int, limit: int instance_name: str, from_token: int, upto_token: int, limit: int
) -> StreamUpdateResult: ) -> StreamUpdateResult:
result = await client( result = await client(
stream_name=stream_name, from_token=from_token, upto_token=upto_token, instance_name=instance_name,
stream_name=stream_name,
from_token=from_token,
upto_token=upto_token,
) )
return result["updates"], result["upto_token"], result["limited"] return result["updates"], result["upto_token"], result["limited"]
@ -226,6 +233,7 @@ class BackfillStream(Stream):
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(),
store.get_current_backfill_token, store.get_current_backfill_token,
db_query_to_update_function(store.get_all_new_backfill_event_rows), db_query_to_update_function(store.get_all_new_backfill_event_rows),
) )
@ -261,7 +269,9 @@ class PresenceStream(Stream):
# Query master process # Query master process
update_function = make_http_update_function(hs, self.NAME) update_function = make_http_update_function(hs, self.NAME)
super().__init__(store.get_current_presence_token, update_function) super().__init__(
hs.get_instance_name(), store.get_current_presence_token, update_function
)
class TypingStream(Stream): class TypingStream(Stream):
@ -284,7 +294,9 @@ class TypingStream(Stream):
# Query master process # Query master process
update_function = make_http_update_function(hs, self.NAME) update_function = make_http_update_function(hs, self.NAME)
super().__init__(typing_handler.get_current_token, update_function) super().__init__(
hs.get_instance_name(), typing_handler.get_current_token, update_function
)
class ReceiptsStream(Stream): class ReceiptsStream(Stream):
@ -305,6 +317,7 @@ class ReceiptsStream(Stream):
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(),
store.get_max_receipt_stream_id, store.get_max_receipt_stream_id,
db_query_to_update_function(store.get_all_updated_receipts), db_query_to_update_function(store.get_all_updated_receipts),
) )
@ -322,14 +335,16 @@ class PushRulesStream(Stream):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
super(PushRulesStream, self).__init__( super(PushRulesStream, self).__init__(
self._current_token, self._update_function hs.get_instance_name(), self._current_token, self._update_function
) )
def _current_token(self) -> int: def _current_token(self) -> int:
push_rules_token, _ = self.store.get_push_rules_stream_token() push_rules_token, _ = self.store.get_push_rules_stream_token()
return push_rules_token return push_rules_token
async def _update_function(self, from_token: Token, to_token: Token, limit: int): async def _update_function(
self, instance_name: str, from_token: Token, to_token: Token, limit: int
):
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
limited = False limited = False
@ -356,6 +371,7 @@ class PushersStream(Stream):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(),
store.get_pushers_stream_token, store.get_pushers_stream_token,
db_query_to_update_function(store.get_all_updated_pushers_rows), db_query_to_update_function(store.get_all_updated_pushers_rows),
) )
@ -387,6 +403,7 @@ class CachesStream(Stream):
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(),
store.get_cache_stream_token, store.get_cache_stream_token,
db_query_to_update_function(store.get_all_updated_caches), db_query_to_update_function(store.get_all_updated_caches),
) )
@ -412,6 +429,7 @@ class PublicRoomsStream(Stream):
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(),
store.get_current_public_room_stream_id, store.get_current_public_room_stream_id,
db_query_to_update_function(store.get_all_new_public_rooms), db_query_to_update_function(store.get_all_new_public_rooms),
) )
@ -432,6 +450,7 @@ class DeviceListsStream(Stream):
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(),
store.get_device_stream_token, store.get_device_stream_token,
db_query_to_update_function(store.get_all_device_list_changes_for_remotes), db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
) )
@ -449,6 +468,7 @@ class ToDeviceStream(Stream):
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(),
store.get_to_device_stream_token, store.get_to_device_stream_token,
db_query_to_update_function(store.get_all_new_device_messages), db_query_to_update_function(store.get_all_new_device_messages),
) )
@ -468,6 +488,7 @@ class TagAccountDataStream(Stream):
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(),
store.get_max_account_data_stream_id, store.get_max_account_data_stream_id,
db_query_to_update_function(store.get_all_updated_tags), db_query_to_update_function(store.get_all_updated_tags),
) )
@ -487,6 +508,7 @@ class AccountDataStream(Stream):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(),
self.store.get_max_account_data_stream_id, self.store.get_max_account_data_stream_id,
db_query_to_update_function(self._update_function), db_query_to_update_function(self._update_function),
) )
@ -517,6 +539,7 @@ class GroupServerStream(Stream):
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(),
store.get_group_stream_token, store.get_group_stream_token,
db_query_to_update_function(store.get_all_groups_changes), db_query_to_update_function(store.get_all_groups_changes),
) )
@ -534,6 +557,7 @@ class UserSignatureStream(Stream):
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(),
store.get_device_stream_token, store.get_device_stream_token,
db_query_to_update_function( db_query_to_update_function(
store.get_all_user_signature_changes_for_remotes store.get_all_user_signature_changes_for_remotes

View File

@ -118,11 +118,17 @@ class EventsStream(Stream):
def __init__(self, hs): def __init__(self, hs):
self._store = hs.get_datastore() self._store = hs.get_datastore()
super().__init__( super().__init__(
self._store.get_current_events_token, self._update_function, hs.get_instance_name(),
self._store.get_current_events_token,
self._update_function,
) )
async def _update_function( async def _update_function(
self, from_token: Token, current_token: Token, target_row_count: int self,
instance_name: str,
from_token: Token,
current_token: Token,
target_row_count: int,
) -> StreamUpdateResult: ) -> StreamUpdateResult:
# the events stream merges together three separate sources: # the events stream merges together three separate sources:

View File

@ -48,8 +48,8 @@ class FederationStream(Stream):
current_token = lambda: 0 current_token = lambda: 0
update_function = self._stub_update_function update_function = self._stub_update_function
super().__init__(current_token, update_function) super().__init__(hs.get_instance_name(), current_token, update_function)
@staticmethod @staticmethod
async def _stub_update_function(from_token, upto_token, limit): async def _stub_update_function(instance_name, from_token, upto_token, limit):
return [], upto_token, False return [], upto_token, False

View File

@ -16,8 +16,6 @@ import logging
from six import iteritems, string_types from six import iteritems, string_types
from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.urls import ConsentURIBuilder from synapse.api.urls import ConsentURIBuilder
from synapse.config import ConfigError from synapse.config import ConfigError
@ -59,8 +57,7 @@ class ConsentServerNotices(object):
self._consent_uri_builder = ConsentURIBuilder(hs.config) self._consent_uri_builder = ConsentURIBuilder(hs.config)
@defer.inlineCallbacks async def maybe_send_server_notice_to_user(self, user_id):
def maybe_send_server_notice_to_user(self, user_id):
"""Check if we need to send a notice to this user, and does so if so """Check if we need to send a notice to this user, and does so if so
Args: Args:
@ -78,7 +75,7 @@ class ConsentServerNotices(object):
return return
self._users_in_progress.add(user_id) self._users_in_progress.add(user_id)
try: try:
u = yield self._store.get_user_by_id(user_id) u = await self._store.get_user_by_id(user_id)
if u["is_guest"] and not self._send_to_guests: if u["is_guest"] and not self._send_to_guests:
# don't send to guests # don't send to guests
@ -100,8 +97,8 @@ class ConsentServerNotices(object):
content = copy_with_str_subst( content = copy_with_str_subst(
self._server_notice_content, {"consent_uri": consent_uri} self._server_notice_content, {"consent_uri": consent_uri}
) )
yield self._server_notices_manager.send_notice(user_id, content) await self._server_notices_manager.send_notice(user_id, content)
yield self._store.user_set_consent_server_notice_sent( await self._store.user_set_consent_server_notice_sent(
user_id, self._current_consent_version user_id, self._current_consent_version
) )
except SynapseError as e: except SynapseError as e:

View File

@ -16,8 +16,6 @@ import logging
from six import iteritems from six import iteritems
from twisted.internet import defer
from synapse.api.constants import ( from synapse.api.constants import (
EventTypes, EventTypes,
LimitBlockingTypes, LimitBlockingTypes,
@ -50,8 +48,7 @@ class ResourceLimitsServerNotices(object):
self._notifier = hs.get_notifier() self._notifier = hs.get_notifier()
@defer.inlineCallbacks async def maybe_send_server_notice_to_user(self, user_id):
def maybe_send_server_notice_to_user(self, user_id):
"""Check if we need to send a notice to this user, this will be true in """Check if we need to send a notice to this user, this will be true in
two cases. two cases.
1. The server has reached its limit does not reflect this 1. The server has reached its limit does not reflect this
@ -74,13 +71,13 @@ class ResourceLimitsServerNotices(object):
# Don't try and send server notices unless they've been enabled # Don't try and send server notices unless they've been enabled
return return
timestamp = yield self._store.user_last_seen_monthly_active(user_id) timestamp = await self._store.user_last_seen_monthly_active(user_id)
if timestamp is None: if timestamp is None:
# This user will be blocked from receiving the notice anyway. # This user will be blocked from receiving the notice anyway.
# In practice, not sure we can ever get here # In practice, not sure we can ever get here
return return
room_id = yield self._server_notices_manager.get_or_create_notice_room_for_user( room_id = await self._server_notices_manager.get_or_create_notice_room_for_user(
user_id user_id
) )
@ -88,10 +85,10 @@ class ResourceLimitsServerNotices(object):
logger.warning("Failed to get server notices room") logger.warning("Failed to get server notices room")
return return
yield self._check_and_set_tags(user_id, room_id) await self._check_and_set_tags(user_id, room_id)
# Determine current state of room # Determine current state of room
currently_blocked, ref_events = yield self._is_room_currently_blocked(room_id) currently_blocked, ref_events = await self._is_room_currently_blocked(room_id)
limit_msg = None limit_msg = None
limit_type = None limit_type = None
@ -99,7 +96,7 @@ class ResourceLimitsServerNotices(object):
# Normally should always pass in user_id to check_auth_blocking # Normally should always pass in user_id to check_auth_blocking
# if you have it, but in this case are checking what would happen # if you have it, but in this case are checking what would happen
# to other users if they were to arrive. # to other users if they were to arrive.
yield self._auth.check_auth_blocking() await self._auth.check_auth_blocking()
except ResourceLimitError as e: except ResourceLimitError as e:
limit_msg = e.msg limit_msg = e.msg
limit_type = e.limit_type limit_type = e.limit_type
@ -112,22 +109,21 @@ class ResourceLimitsServerNotices(object):
# We have hit the MAU limit, but MAU alerting is disabled: # We have hit the MAU limit, but MAU alerting is disabled:
# reset room if necessary and return # reset room if necessary and return
if currently_blocked: if currently_blocked:
self._remove_limit_block_notification(user_id, ref_events) await self._remove_limit_block_notification(user_id, ref_events)
return return
if currently_blocked and not limit_msg: if currently_blocked and not limit_msg:
# Room is notifying of a block, when it ought not to be. # Room is notifying of a block, when it ought not to be.
yield self._remove_limit_block_notification(user_id, ref_events) await self._remove_limit_block_notification(user_id, ref_events)
elif not currently_blocked and limit_msg: elif not currently_blocked and limit_msg:
# Room is not notifying of a block, when it ought to be. # Room is not notifying of a block, when it ought to be.
yield self._apply_limit_block_notification( await self._apply_limit_block_notification(
user_id, limit_msg, limit_type user_id, limit_msg, limit_type
) )
except SynapseError as e: except SynapseError as e:
logger.error("Error sending resource limits server notice: %s", e) logger.error("Error sending resource limits server notice: %s", e)
@defer.inlineCallbacks async def _remove_limit_block_notification(self, user_id, ref_events):
def _remove_limit_block_notification(self, user_id, ref_events):
"""Utility method to remove limit block notifications from the server """Utility method to remove limit block notifications from the server
notices room. notices room.
@ -137,12 +133,13 @@ class ResourceLimitsServerNotices(object):
limit blocking and need to be preserved. limit blocking and need to be preserved.
""" """
content = {"pinned": ref_events} content = {"pinned": ref_events}
yield self._server_notices_manager.send_notice( await self._server_notices_manager.send_notice(
user_id, content, EventTypes.Pinned, "" user_id, content, EventTypes.Pinned, ""
) )
@defer.inlineCallbacks async def _apply_limit_block_notification(
def _apply_limit_block_notification(self, user_id, event_body, event_limit_type): self, user_id, event_body, event_limit_type
):
"""Utility method to apply limit block notifications in the server """Utility method to apply limit block notifications in the server
notices room. notices room.
@ -159,17 +156,16 @@ class ResourceLimitsServerNotices(object):
"admin_contact": self._config.admin_contact, "admin_contact": self._config.admin_contact,
"limit_type": event_limit_type, "limit_type": event_limit_type,
} }
event = yield self._server_notices_manager.send_notice( event = await self._server_notices_manager.send_notice(
user_id, content, EventTypes.Message user_id, content, EventTypes.Message
) )
content = {"pinned": [event.event_id]} content = {"pinned": [event.event_id]}
yield self._server_notices_manager.send_notice( await self._server_notices_manager.send_notice(
user_id, content, EventTypes.Pinned, "" user_id, content, EventTypes.Pinned, ""
) )
@defer.inlineCallbacks async def _check_and_set_tags(self, user_id, room_id):
def _check_and_set_tags(self, user_id, room_id):
""" """
Since server notices rooms were originally not with tags, Since server notices rooms were originally not with tags,
important to check that tags have been set correctly important to check that tags have been set correctly
@ -177,20 +173,19 @@ class ResourceLimitsServerNotices(object):
user_id(str): the user in question user_id(str): the user in question
room_id(str): the server notices room for that user room_id(str): the server notices room for that user
""" """
tags = yield self._store.get_tags_for_room(user_id, room_id) tags = await self._store.get_tags_for_room(user_id, room_id)
need_to_set_tag = True need_to_set_tag = True
if tags: if tags:
if SERVER_NOTICE_ROOM_TAG in tags: if SERVER_NOTICE_ROOM_TAG in tags:
# tag already present, nothing to do here # tag already present, nothing to do here
need_to_set_tag = False need_to_set_tag = False
if need_to_set_tag: if need_to_set_tag:
max_id = yield self._store.add_tag_to_room( max_id = await self._store.add_tag_to_room(
user_id, room_id, SERVER_NOTICE_ROOM_TAG, {} user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
) )
self._notifier.on_new_event("account_data_key", max_id, users=[user_id]) self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
@defer.inlineCallbacks async def _is_room_currently_blocked(self, room_id):
def _is_room_currently_blocked(self, room_id):
""" """
Determines if the room is currently blocked Determines if the room is currently blocked
@ -198,7 +193,7 @@ class ResourceLimitsServerNotices(object):
room_id(str): The room id of the server notices room room_id(str): The room id of the server notices room
Returns: Returns:
Deferred[Tuple[bool, List]]:
bool: Is the room currently blocked bool: Is the room currently blocked
list: The list of pinned events that are unrelated to limit blocking list: The list of pinned events that are unrelated to limit blocking
This list can be used as a convenience in the case where the block This list can be used as a convenience in the case where the block
@ -208,7 +203,7 @@ class ResourceLimitsServerNotices(object):
currently_blocked = False currently_blocked = False
pinned_state_event = None pinned_state_event = None
try: try:
pinned_state_event = yield self._state.get_current_state( pinned_state_event = await self._state.get_current_state(
room_id, event_type=EventTypes.Pinned room_id, event_type=EventTypes.Pinned
) )
except AuthError: except AuthError:
@ -219,7 +214,7 @@ class ResourceLimitsServerNotices(object):
if pinned_state_event is not None: if pinned_state_event is not None:
referenced_events = list(pinned_state_event.content.get("pinned", [])) referenced_events = list(pinned_state_event.content.get("pinned", []))
events = yield self._store.get_events(referenced_events) events = await self._store.get_events(referenced_events)
for event_id, event in iteritems(events): for event_id, event in iteritems(events):
if event.type != EventTypes.Message: if event.type != EventTypes.Message:
continue continue

View File

@ -14,11 +14,9 @@
# limitations under the License. # limitations under the License.
import logging import logging
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, RoomCreationPreset from synapse.api.constants import EventTypes, Membership, RoomCreationPreset
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -51,8 +49,7 @@ class ServerNoticesManager(object):
""" """
return self._config.server_notices_mxid is not None return self._config.server_notices_mxid is not None
@defer.inlineCallbacks async def send_notice(
def send_notice(
self, user_id, event_content, type=EventTypes.Message, state_key=None self, user_id, event_content, type=EventTypes.Message, state_key=None
): ):
"""Send a notice to the given user """Send a notice to the given user
@ -68,8 +65,8 @@ class ServerNoticesManager(object):
Returns: Returns:
Deferred[FrozenEvent] Deferred[FrozenEvent]
""" """
room_id = yield self.get_or_create_notice_room_for_user(user_id) room_id = await self.get_or_create_notice_room_for_user(user_id)
yield self.maybe_invite_user_to_room(user_id, room_id) await self.maybe_invite_user_to_room(user_id, room_id)
system_mxid = self._config.server_notices_mxid system_mxid = self._config.server_notices_mxid
requester = create_requester(system_mxid) requester = create_requester(system_mxid)
@ -86,13 +83,13 @@ class ServerNoticesManager(object):
if state_key is not None: if state_key is not None:
event_dict["state_key"] = state_key event_dict["state_key"] = state_key
res = yield self._event_creation_handler.create_and_send_nonmember_event( res = await self._event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, ratelimit=False requester, event_dict, ratelimit=False
) )
return res return res
@cachedInlineCallbacks() @cached()
def get_or_create_notice_room_for_user(self, user_id): async def get_or_create_notice_room_for_user(self, user_id):
"""Get the room for notices for a given user """Get the room for notices for a given user
If we have not yet created a notice room for this user, create it, but don't If we have not yet created a notice room for this user, create it, but don't
@ -109,7 +106,7 @@ class ServerNoticesManager(object):
assert self._is_mine_id(user_id), "Cannot send server notices to remote users" assert self._is_mine_id(user_id), "Cannot send server notices to remote users"
rooms = yield self._store.get_rooms_for_local_user_where_membership_is( rooms = await self._store.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE, Membership.JOIN] user_id, [Membership.INVITE, Membership.JOIN]
) )
for room in rooms: for room in rooms:
@ -118,7 +115,7 @@ class ServerNoticesManager(object):
# be joined. This is kinda deliberate, in that if somebody somehow # be joined. This is kinda deliberate, in that if somebody somehow
# manages to invite the system user to a room, that doesn't make it # manages to invite the system user to a room, that doesn't make it
# the server notices room. # the server notices room.
user_ids = yield self._store.get_users_in_room(room.room_id) user_ids = await self._store.get_users_in_room(room.room_id)
if self.server_notices_mxid in user_ids: if self.server_notices_mxid in user_ids:
# we found a room which our user shares with the system notice # we found a room which our user shares with the system notice
# user # user
@ -146,7 +143,7 @@ class ServerNoticesManager(object):
} }
requester = create_requester(self.server_notices_mxid) requester = create_requester(self.server_notices_mxid)
info = yield self._room_creation_handler.create_room( info = await self._room_creation_handler.create_room(
requester, requester,
config={ config={
"preset": RoomCreationPreset.PRIVATE_CHAT, "preset": RoomCreationPreset.PRIVATE_CHAT,
@ -158,7 +155,7 @@ class ServerNoticesManager(object):
) )
room_id = info["room_id"] room_id = info["room_id"]
max_id = yield self._store.add_tag_to_room( max_id = await self._store.add_tag_to_room(
user_id, room_id, SERVER_NOTICE_ROOM_TAG, {} user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
) )
self._notifier.on_new_event("account_data_key", max_id, users=[user_id]) self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
@ -166,8 +163,7 @@ class ServerNoticesManager(object):
logger.info("Created server notices room %s for %s", room_id, user_id) logger.info("Created server notices room %s for %s", room_id, user_id)
return room_id return room_id
@defer.inlineCallbacks async def maybe_invite_user_to_room(self, user_id: str, room_id: str):
def maybe_invite_user_to_room(self, user_id: str, room_id: str):
"""Invite the given user to the given server room, unless the user has already """Invite the given user to the given server room, unless the user has already
joined or been invited to it. joined or been invited to it.
@ -179,14 +175,14 @@ class ServerNoticesManager(object):
# Check whether the user has already joined or been invited to this room. If # Check whether the user has already joined or been invited to this room. If
# that's the case, there is no need to re-invite them. # that's the case, there is no need to re-invite them.
joined_rooms = yield self._store.get_rooms_for_local_user_where_membership_is( joined_rooms = await self._store.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE, Membership.JOIN] user_id, [Membership.INVITE, Membership.JOIN]
) )
for room in joined_rooms: for room in joined_rooms:
if room.room_id == room_id: if room.room_id == room_id:
return return
yield self._room_member_handler.update_membership( await self._room_member_handler.update_membership(
requester=requester, requester=requester,
target=UserID.from_string(user_id), target=UserID.from_string(user_id),
room_id=room_id, room_id=room_id,

View File

@ -12,8 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.server_notices.consent_server_notices import ConsentServerNotices from synapse.server_notices.consent_server_notices import ConsentServerNotices
from synapse.server_notices.resource_limits_server_notices import ( from synapse.server_notices.resource_limits_server_notices import (
ResourceLimitsServerNotices, ResourceLimitsServerNotices,
@ -36,18 +34,16 @@ class ServerNoticesSender(object):
ResourceLimitsServerNotices(hs), ResourceLimitsServerNotices(hs),
) )
@defer.inlineCallbacks async def on_user_syncing(self, user_id):
def on_user_syncing(self, user_id):
"""Called when the user performs a sync operation. """Called when the user performs a sync operation.
Args: Args:
user_id (str): mxid of user who synced user_id (str): mxid of user who synced
""" """
for sn in self._server_notices: for sn in self._server_notices:
yield sn.maybe_send_server_notice_to_user(user_id) await sn.maybe_send_server_notice_to_user(user_id)
@defer.inlineCallbacks async def on_user_ip(self, user_id):
def on_user_ip(self, user_id):
"""Called on the master when a worker process saw a client request. """Called on the master when a worker process saw a client request.
Args: Args:
@ -57,4 +53,4 @@ class ServerNoticesSender(object):
# we check for notices to send to the user in on_user_ip as well as # we check for notices to send to the user in on_user_ip as well as
# in on_user_syncing # in on_user_syncing
for sn in self._server_notices: for sn in self._server_notices:
yield sn.maybe_send_server_notice_to_user(user_id) await sn.maybe_send_server_notice_to_user(user_id)

View File

@ -273,8 +273,7 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="delete_account_validity_for_user", desc="delete_account_validity_for_user",
) )
@defer.inlineCallbacks async def is_server_admin(self, user):
def is_server_admin(self, user):
"""Determines if a user is an admin of this homeserver. """Determines if a user is an admin of this homeserver.
Args: Args:
@ -283,7 +282,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns (bool): Returns (bool):
true iff the user is a server admin, false otherwise. true iff the user is a server admin, false otherwise.
""" """
res = yield self.db.simple_select_one_onecol( res = await self.db.simple_select_one_onecol(
table="users", table="users",
keyvalues={"name": user.to_string()}, keyvalues={"name": user.to_string()},
retcol="admin", retcol="admin",

View File

@ -35,9 +35,13 @@ DELETE FROM background_updates WHERE update_name IN (
'populate_stats_cleanup' 'populate_stats_cleanup'
); );
-- this relies on current_state_events.membership having been populated, so add
-- a dependency on current_state_events_membership.
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
('populate_stats_process_rooms', '{}', ''); ('populate_stats_process_rooms', '{}', 'current_state_events_membership');
-- this also relies on current_state_events.membership having been populated, but
-- we get that as a side-effect of depending on populate_stats_process_rooms.
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
('populate_stats_process_users', '{}', 'populate_stats_process_rooms'); ('populate_stats_process_users', '{}', 'populate_stats_process_rooms');

View File

@ -0,0 +1,100 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.events.snapshot import EventContext
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from tests import unittest
from tests.test_utils.event_injection import create_event
class TestEventContext(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
room.register_servlets,
]
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.user_id = self.register_user("u1", "pass")
self.user_tok = self.login("u1", "pass")
self.room_id = self.helper.create_room_as(tok=self.user_tok)
def test_serialize_deserialize_msg(self):
"""Test that an EventContext for a message event is the same after
serialize/deserialize.
"""
event, context = create_event(
self.hs, room_id=self.room_id, type="m.test", sender=self.user_id,
)
self._check_serialize_deserialize(event, context)
def test_serialize_deserialize_state_no_prev(self):
"""Test that an EventContext for a state event (with not previous entry)
is the same after serialize/deserialize.
"""
event, context = create_event(
self.hs,
room_id=self.room_id,
type="m.test",
sender=self.user_id,
state_key="",
)
self._check_serialize_deserialize(event, context)
def test_serialize_deserialize_state_prev(self):
"""Test that an EventContext for a state event (which replaces a
previous entry) is the same after serialize/deserialize.
"""
event, context = create_event(
self.hs,
room_id=self.room_id,
type="m.room.member",
sender=self.user_id,
state_key=self.user_id,
content={"membership": "leave"},
)
self._check_serialize_deserialize(event, context)
def _check_serialize_deserialize(self, event, context):
serialized = self.get_success(context.serialize(event, self.store))
d_context = EventContext.deserialize(self.storage, serialized)
self.assertEqual(context.state_group, d_context.state_group)
self.assertEqual(context.rejected, d_context.rejected)
self.assertEqual(
context.state_group_before_event, d_context.state_group_before_event
)
self.assertEqual(context.prev_group, d_context.prev_group)
self.assertEqual(context.delta_ids, d_context.delta_ids)
self.assertEqual(context.app_service, d_context.app_service)
self.assertEqual(
self.get_success(context.get_current_state_ids()),
self.get_success(d_context.get_current_state_ids()),
)
self.assertEqual(
self.get_success(context.get_prev_state_ids()),
self.get_success(d_context.get_prev_state_ids()),
)

View File

@ -82,18 +82,26 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_my_name(self): def test_set_my_name(self):
yield self.handler.set_displayname( yield defer.ensureDeferred(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr." self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
)
) )
self.assertEquals( self.assertEquals(
(yield self.store.get_profile_displayname(self.frank.localpart)), (
yield defer.ensureDeferred(
self.store.get_profile_displayname(self.frank.localpart)
)
),
"Frank Jr.", "Frank Jr.",
) )
# Set displayname again # Set displayname again
yield self.handler.set_displayname( yield defer.ensureDeferred(
self.frank, synapse.types.create_requester(self.frank), "Frank" self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank"
)
) )
self.assertEquals( self.assertEquals(
@ -112,16 +120,20 @@ class ProfileTestCase(unittest.TestCase):
) )
# Setting displayname a second time is forbidden # Setting displayname a second time is forbidden
d = self.handler.set_displayname( d = defer.ensureDeferred(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr." self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
)
) )
yield self.assertFailure(d, SynapseError) yield self.assertFailure(d, SynapseError)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_my_name_noauth(self): def test_set_my_name_noauth(self):
d = self.handler.set_displayname( d = defer.ensureDeferred(
self.frank, synapse.types.create_requester(self.bob), "Frank Jr." self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
)
) )
yield self.assertFailure(d, AuthError) yield self.assertFailure(d, AuthError)
@ -165,10 +177,12 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_my_avatar(self): def test_set_my_avatar(self):
yield self.handler.set_avatar_url( yield defer.ensureDeferred(
self.frank, self.handler.set_avatar_url(
synapse.types.create_requester(self.frank), self.frank,
"http://my.server/pic.gif", synapse.types.create_requester(self.frank),
"http://my.server/pic.gif",
)
) )
self.assertEquals( self.assertEquals(
@ -177,10 +191,12 @@ class ProfileTestCase(unittest.TestCase):
) )
# Set avatar again # Set avatar again
yield self.handler.set_avatar_url( yield defer.ensureDeferred(
self.frank, self.handler.set_avatar_url(
synapse.types.create_requester(self.frank), self.frank,
"http://my.server/me.png", synapse.types.create_requester(self.frank),
"http://my.server/me.png",
)
) )
self.assertEquals( self.assertEquals(
@ -203,10 +219,12 @@ class ProfileTestCase(unittest.TestCase):
) )
# Set avatar a second time is forbidden # Set avatar a second time is forbidden
d = self.handler.set_avatar_url( d = defer.ensureDeferred(
self.frank, self.handler.set_avatar_url(
synapse.types.create_requester(self.frank), self.frank,
"http://my.server/pic.gif", synapse.types.create_requester(self.frank),
"http://my.server/pic.gif",
)
) )
yield self.assertFailure(d, SynapseError) yield self.assertFailure(d, SynapseError)

View File

@ -175,7 +175,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test" room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str] self.hs.config.auto_join_rooms = [room_alias_str]
self.store.is_real_user = Mock(return_value=False) self.store.is_real_user = Mock(return_value=defer.succeed(False))
user_id = self.get_success(self.handler.register_user(localpart="support")) user_id = self.get_success(self.handler.register_user(localpart="support"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id)) rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0) self.assertEqual(len(rooms), 0)
@ -187,8 +187,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test" room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str] self.hs.config.auto_join_rooms = [room_alias_str]
self.store.count_real_users = Mock(return_value=1) self.store.count_real_users = Mock(return_value=defer.succeed(1))
self.store.is_real_user = Mock(return_value=True) self.store.is_real_user = Mock(return_value=defer.succeed(True))
user_id = self.get_success(self.handler.register_user(localpart="real")) user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id)) rooms = self.get_success(self.store.get_rooms_for_user(user_id))
directory_handler = self.hs.get_handlers().directory_handler directory_handler = self.hs.get_handlers().directory_handler
@ -202,8 +202,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test" room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str] self.hs.config.auto_join_rooms = [room_alias_str]
self.store.count_real_users = Mock(return_value=2) self.store.count_real_users = Mock(return_value=defer.succeed(2))
self.store.is_real_user = Mock(return_value=True) self.store.is_real_user = Mock(return_value=defer.succeed(True))
user_id = self.get_success(self.handler.register_user(localpart="real")) user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id)) rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0) self.assertEqual(len(rooms), 0)
@ -256,8 +256,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart=invalid_user_id), SynapseError self.handler.register_user(localpart=invalid_user_id), SynapseError
) )
@defer.inlineCallbacks async def get_or_create_user(
def get_or_create_user(self, requester, localpart, displayname, password_hash=None): self, requester, localpart, displayname, password_hash=None
):
"""Creates a new user if the user does not exist, """Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one. else revokes all previous access tokens and generates a new one.
@ -272,11 +273,11 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
""" """
if localpart is None: if localpart is None:
raise SynapseError(400, "Request must include user id") raise SynapseError(400, "Request must include user id")
yield self.hs.get_auth().check_auth_blocking() await self.hs.get_auth().check_auth_blocking()
need_register = True need_register = True
try: try:
yield self.handler.check_username(localpart) await self.handler.check_username(localpart)
except SynapseError as e: except SynapseError as e:
if e.errcode == Codes.USER_IN_USE: if e.errcode == Codes.USER_IN_USE:
need_register = False need_register = False
@ -288,23 +289,21 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
token = self.macaroon_generator.generate_access_token(user_id) token = self.macaroon_generator.generate_access_token(user_id)
if need_register: if need_register:
yield self.handler.register_with_store( await self.handler.register_with_store(
user_id=user_id, user_id=user_id,
password_hash=password_hash, password_hash=password_hash,
create_profile_with_displayname=user.localpart, create_profile_with_displayname=user.localpart,
) )
else: else:
yield defer.ensureDeferred( await self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
)
yield self.store.add_access_token_to_user( await self.store.add_access_token_to_user(
user_id=user_id, token=token, device_id=None, valid_until_ms=None user_id=user_id, token=token, device_id=None, valid_until_ms=None
) )
if displayname is not None: if displayname is not None:
# logger.info("setting user display name: %s -> %s", user_id, displayname) # logger.info("setting user display name: %s -> %s", user_id, displayname)
yield self.hs.get_profile_handler().set_displayname( await self.hs.get_profile_handler().set_displayname(
user, requester, displayname, by_admin=True user, requester, displayname, by_admin=True
) )

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Dict, List, Optional, Tuple from typing import Any, List, Optional, Tuple
import attr import attr
@ -22,13 +22,15 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.task import LoopingCall from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel from twisted.web.http import HTTPChannel
from synapse.app.generic_worker import GenericWorkerServer from synapse.app.generic_worker import (
GenericWorkerReplicationHandler,
GenericWorkerServer,
)
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.server import HomeServer
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
@ -77,7 +79,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._server_transport = None self._server_transport = None
def _build_replication_data_handler(self): def _build_replication_data_handler(self):
return TestReplicationDataHandler(self.worker_hs.get_datastore()) return TestReplicationDataHandler(self.worker_hs)
def reconnect(self): def reconnect(self):
if self._client_transport: if self._client_transport:
@ -172,32 +174,20 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.assertEqual(request.method, b"GET") self.assertEqual(request.method, b"GET")
class TestReplicationDataHandler(ReplicationDataHandler): class TestReplicationDataHandler(GenericWorkerReplicationHandler):
"""Drop-in for ReplicationDataHandler which just collects RDATA rows""" """Drop-in for ReplicationDataHandler which just collects RDATA rows"""
def __init__(self, store: BaseSlavedStore): def __init__(self, hs: HomeServer):
super().__init__(store) super().__init__(hs)
# streams to subscribe to: map from stream id to position
self.stream_positions = {} # type: Dict[str, int]
# list of received (stream_name, token, row) tuples # list of received (stream_name, token, row) tuples
self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]] self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]
def get_streams_to_replicate(self): async def on_rdata(self, stream_name, instance_name, token, rows):
return self.stream_positions await super().on_rdata(stream_name, instance_name, token, rows)
async def on_rdata(self, stream_name, token, rows):
await super().on_rdata(stream_name, token, rows)
for r in rows: for r in rows:
self.received_rdata_rows.append((stream_name, token, r)) self.received_rdata_rows.append((stream_name, token, r))
if (
stream_name in self.stream_positions
and token > self.stream_positions[stream_name]
):
self.stream_positions[stream_name] = token
@attr.s() @attr.s()
class OneShotRequestFactory: class OneShotRequestFactory:

View File

@ -43,7 +43,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.user_tok = self.login("u1", "pass") self.user_tok = self.login("u1", "pass")
self.reconnect() self.reconnect()
self.test_handler.stream_positions["events"] = 0
self.room_id = self.helper.create_room_as(tok=self.user_tok) self.room_id = self.helper.create_room_as(tok=self.user_tok)
self.test_handler.received_rdata_rows.clear() self.test_handler.received_rdata_rows.clear()
@ -80,8 +79,12 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.reconnect() self.reconnect()
self.replicate() self.replicate()
# we should have received all the expected rows in the right order # we should have received all the expected rows in the right order (as
received_rows = self.test_handler.received_rdata_rows # well as various cache invalidation updates which we ignore)
received_rows = [
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
]
for event in events: for event in events:
stream_name, token, row = received_rows.pop(0) stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name) self.assertEqual("events", stream_name)
@ -184,7 +187,8 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.reconnect() self.reconnect()
self.replicate() self.replicate()
# now we should have received all the expected rows in the right order. # we should have received all the expected rows in the right order (as
# well as various cache invalidation updates which we ignore)
# #
# we expect: # we expect:
# #
@ -193,7 +197,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
# of the states that got reverted. # of the states that got reverted.
# - two rows for state2 # - two rows for state2
received_rows = self.test_handler.received_rdata_rows received_rows = [
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
]
# first check the first two rows, which should be state1 # first check the first two rows, which should be state1
@ -334,9 +340,11 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.reconnect() self.reconnect()
self.replicate() self.replicate()
# we should have received all the expected rows in the right order # we should have received all the expected rows in the right order (as
# well as various cache invalidation updates which we ignore)
received_rows = self.test_handler.received_rdata_rows received_rows = [
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
]
self.assertGreaterEqual(len(received_rows), len(events)) self.assertGreaterEqual(len(received_rows), len(events))
for i in range(NUM_USERS): for i in range(NUM_USERS):
# for each user, we expect the PL event row, followed by state rows for # for each user, we expect the PL event row, followed by state rows for

View File

@ -31,9 +31,6 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
def test_receipt(self): def test_receipt(self):
self.reconnect() self.reconnect()
# make the client subscribe to the receipts stream
self.test_handler.stream_positions.update({"receipts": 0})
# tell the master to send a new receipt # tell the master to send a new receipt
self.get_success( self.get_success(
self.hs.get_datastore().insert_receipt( self.hs.get_datastore().insert_receipt(
@ -44,7 +41,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
# there should be one RDATA command # there should be one RDATA command
self.test_handler.on_rdata.assert_called_once() self.test_handler.on_rdata.assert_called_once()
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "receipts") self.assertEqual(stream_name, "receipts")
self.assertEqual(1, len(rdata_rows)) self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
@ -74,7 +71,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
# We should now have caught up and get the missing data # We should now have caught up and get the missing data
self.test_handler.on_rdata.assert_called_once() self.test_handler.on_rdata.assert_called_once()
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "receipts") self.assertEqual(stream_name, "receipts")
self.assertEqual(token, 3) self.assertEqual(token, 3)
self.assertEqual(1, len(rdata_rows)) self.assertEqual(1, len(rdata_rows))

View File

@ -38,9 +38,6 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.reconnect() self.reconnect()
# make the client subscribe to the typing stream
self.test_handler.stream_positions.update({"typing": 0})
typing._push_update(member=RoomMember(room_id, USER_ID), typing=True) typing._push_update(member=RoomMember(room_id, USER_ID), typing=True)
self.reactor.advance(0) self.reactor.advance(0)
@ -50,7 +47,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.assert_request_is_get_repl_stream_updates(request, "typing") self.assert_request_is_get_repl_stream_updates(request, "typing")
self.test_handler.on_rdata.assert_called_once() self.test_handler.on_rdata.assert_called_once()
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing") self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows)) self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] # type: TypingStream.TypingStreamRow row = rdata_rows[0] # type: TypingStream.TypingStreamRow
@ -77,7 +74,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.assertEqual(int(request.args[b"from_token"][0]), token) self.assertEqual(int(request.args[b"from_token"][0]), token)
self.test_handler.on_rdata.assert_called_once() self.test_handler.on_rdata.assert_called_once()
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing") self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows)) self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] row = rdata_rows[0]

View File

@ -55,26 +55,19 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._rlsn._store.user_last_seen_monthly_active = Mock( self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(1000) return_value=defer.succeed(1000)
) )
self._send_notice = self._rlsn._server_notices_manager.send_notice self._rlsn._server_notices_manager.send_notice = Mock(
self._rlsn._server_notices_manager.send_notice = Mock() return_value=defer.succeed(Mock())
self._rlsn._state.get_current_state = Mock(return_value=defer.succeed(None)) )
self._rlsn._store.get_events = Mock(return_value=defer.succeed({}))
self._send_notice = self._rlsn._server_notices_manager.send_notice self._send_notice = self._rlsn._server_notices_manager.send_notice
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.user_id = "@user_id:test" self.user_id = "@user_id:test"
# self.server_notices_mxid = "@server:test"
# self.server_notices_mxid_display_name = None
# self.server_notices_mxid_avatar_url = None
# self.server_notices_room_name = "Server Notices"
self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock( self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock(
returnValue="" return_value=defer.succeed("!something:localhost")
) )
self._rlsn._store.add_tag_to_room = Mock() self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
self._rlsn._store.get_tags_for_room = Mock(return_value={}) self._rlsn._store.get_tags_for_room = Mock(return_value=defer.succeed({}))
self.hs.config.admin_contact = "mailto:user@test.com" self.hs.config.admin_contact = "mailto:user@test.com"
def test_maybe_send_server_notice_to_user_flag_off(self): def test_maybe_send_server_notice_to_user_flag_off(self):
@ -95,14 +88,13 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self): def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
"""Test when user has blocked notice, but should have it removed""" """Test when user has blocked notice, but should have it removed"""
self._rlsn._auth.check_auth_blocking = Mock() self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
mock_event = Mock( mock_event = Mock(
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
) )
self._rlsn._store.get_events = Mock( self._rlsn._store.get_events = Mock(
return_value=defer.succeed({"123": mock_event}) return_value=defer.succeed({"123": mock_event})
) )
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Would be better to check the content, but once == remove blocking event # Would be better to check the content, but once == remove blocking event
self._send_notice.assert_called_once() self._send_notice.assert_called_once()
@ -112,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test when user has blocked notice, but notice ought to be there (NOOP) Test when user has blocked notice, but notice ought to be there (NOOP)
""" """
self._rlsn._auth.check_auth_blocking = Mock( self._rlsn._auth.check_auth_blocking = Mock(
side_effect=ResourceLimitError(403, "foo") return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
) )
mock_event = Mock( mock_event = Mock(
@ -121,6 +113,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._rlsn._store.get_events = Mock( self._rlsn._store.get_events = Mock(
return_value=defer.succeed({"123": mock_event}) return_value=defer.succeed({"123": mock_event})
) )
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called() self._send_notice.assert_not_called()
@ -129,9 +122,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
""" """
Test when user does not have blocked notice, but should have one Test when user does not have blocked notice, but should have one
""" """
self._rlsn._auth.check_auth_blocking = Mock( self._rlsn._auth.check_auth_blocking = Mock(
side_effect=ResourceLimitError(403, "foo") return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
) )
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -142,7 +134,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
""" """
Test when user does not have blocked notice, nor should they (NOOP) Test when user does not have blocked notice, nor should they (NOOP)
""" """
self._rlsn._auth.check_auth_blocking = Mock() self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -153,7 +145,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test when user is not part of the MAU cohort - this should not ever Test when user is not part of the MAU cohort - this should not ever
happen - but ... happen - but ...
""" """
self._rlsn._auth.check_auth_blocking = Mock() self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
self._rlsn._store.user_last_seen_monthly_active = Mock( self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(None) return_value=defer.succeed(None)
) )
@ -167,24 +159,28 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
an alert message is not sent into the room an alert message is not sent into the room
""" """
self.hs.config.mau_limit_alerting = False self.hs.config.mau_limit_alerting = False
self._rlsn._auth.check_auth_blocking = Mock( self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None),
side_effect=ResourceLimitError( side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
) ),
) )
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self.assertTrue(self._send_notice.call_count == 0) self.assertEqual(self._send_notice.call_count, 0)
def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self): def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self):
""" """
Test that when a server is disabled, that MAU limit alerting is ignored. Test that when a server is disabled, that MAU limit alerting is ignored.
""" """
self.hs.config.mau_limit_alerting = False self.hs.config.mau_limit_alerting = False
self._rlsn._auth.check_auth_blocking = Mock( self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None),
side_effect=ResourceLimitError( side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
) ),
) )
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -198,10 +194,12 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
""" """
self.hs.config.mau_limit_alerting = False self.hs.config.mau_limit_alerting = False
self._rlsn._auth.check_auth_blocking = Mock( self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None),
side_effect=ResourceLimitError( side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
) ),
) )
self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock( self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock(
return_value=defer.succeed((True, [])) return_value=defer.succeed((True, []))
) )
@ -256,7 +254,9 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
def test_server_notice_only_sent_once(self): def test_server_notice_only_sent_once(self):
self.store.get_monthly_active_count = Mock(return_value=1000) self.store.get_monthly_active_count = Mock(return_value=1000)
self.store.user_last_seen_monthly_active = Mock(return_value=1000) self.store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(1000)
)
# Call the function multiple times to ensure we only send the notice once # Call the function multiple times to ensure we only send the notice once
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))

View File

@ -27,8 +27,10 @@ class MessageAcceptTests(unittest.TestCase):
user_id = UserID("us", "test") user_id = UserID("us", "test")
our_user = Requester(user_id, None, False, None, None) our_user = Requester(user_id, None, False, None, None)
room_creator = self.homeserver.get_room_creation_handler() room_creator = self.homeserver.get_room_creation_handler()
room = room_creator.create_room( room = ensureDeferred(
our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False room_creator.create_room(
our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
)
) )
self.reactor.advance(0.1) self.reactor.advance(0.1)
self.room_id = self.successResultOf(room)["room_id"] self.room_id = self.successResultOf(room)["room_id"]

View File

@ -14,12 +14,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional from typing import Optional, Tuple
import synapse.server import synapse.server
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.types import Collection from synapse.types import Collection
from tests.test_utils import get_awaitable_result from tests.test_utils import get_awaitable_result
@ -75,6 +76,23 @@ def inject_event(
""" """
test_reactor = hs.get_reactor() test_reactor = hs.get_reactor()
event, context = create_event(hs, room_version, prev_event_ids, **kwargs)
d = hs.get_storage().persistence.persist_event(event, context)
test_reactor.advance(0)
get_awaitable_result(d)
return event
def create_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
prev_event_ids: Optional[Collection[str]] = None,
**kwargs
) -> Tuple[EventBase, EventContext]:
test_reactor = hs.get_reactor()
if room_version is None: if room_version is None:
d = hs.get_datastore().get_room_version_id(kwargs["room_id"]) d = hs.get_datastore().get_room_version_id(kwargs["room_id"])
test_reactor.advance(0) test_reactor.advance(0)
@ -89,8 +107,4 @@ def inject_event(
test_reactor.advance(0) test_reactor.advance(0)
event, context = get_awaitable_result(d) event, context = get_awaitable_result(d)
d = hs.get_storage().persistence.persist_event(event, context) return event, context
test_reactor.advance(0)
get_awaitable_result(d)
return event