mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2024-10-01 11:49:51 -04:00
Merge branch 'release-v1.13.0' into erikj/faster_device_lists_fetch
This commit is contained in:
commit
13dd458b8d
11
MANIFEST.in
11
MANIFEST.in
@ -30,23 +30,24 @@ recursive-include synapse/static *.gif
|
|||||||
recursive-include synapse/static *.html
|
recursive-include synapse/static *.html
|
||||||
recursive-include synapse/static *.js
|
recursive-include synapse/static *.js
|
||||||
|
|
||||||
exclude Dockerfile
|
exclude .codecov.yml
|
||||||
|
exclude .coveragerc
|
||||||
exclude .dockerignore
|
exclude .dockerignore
|
||||||
exclude test_postgresql.sh
|
|
||||||
exclude .editorconfig
|
exclude .editorconfig
|
||||||
|
exclude Dockerfile
|
||||||
|
exclude mypy.ini
|
||||||
exclude sytest-blacklist
|
exclude sytest-blacklist
|
||||||
|
exclude test_postgresql.sh
|
||||||
|
|
||||||
include pyproject.toml
|
include pyproject.toml
|
||||||
recursive-include changelog.d *
|
recursive-include changelog.d *
|
||||||
|
|
||||||
prune .buildkite
|
prune .buildkite
|
||||||
prune .circleci
|
prune .circleci
|
||||||
prune .codecov.yml
|
|
||||||
prune .coveragerc
|
|
||||||
prune .github
|
prune .github
|
||||||
|
prune contrib
|
||||||
prune debian
|
prune debian
|
||||||
prune demo/etc
|
prune demo/etc
|
||||||
prune docker
|
prune docker
|
||||||
prune mypy.ini
|
|
||||||
prune snap
|
prune snap
|
||||||
prune stubs
|
prune stubs
|
||||||
|
31
UPGRADE.rst
31
UPGRADE.rst
@ -75,6 +75,37 @@ for example:
|
|||||||
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
|
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
|
||||||
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
|
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
|
||||||
|
|
||||||
|
Upgrading to v1.13.0
|
||||||
|
====================
|
||||||
|
|
||||||
|
Incorrect database migration in old synapse versions
|
||||||
|
----------------------------------------------------
|
||||||
|
|
||||||
|
A bug was introduced in Synapse 1.4.0 which could cause the room directory to
|
||||||
|
be incomplete or empty if Synapse was upgraded directly from v1.2.1 or earlier,
|
||||||
|
to versions between v1.4.0 and v1.12.x.
|
||||||
|
|
||||||
|
This will *not* be a problem for Synapse installations which were:
|
||||||
|
* created at v1.4.0 or later,
|
||||||
|
* upgraded via v1.3.x, or
|
||||||
|
* upgraded straight from v1.2.1 or earlier to v1.13.0 or later.
|
||||||
|
|
||||||
|
If completeness of the room directory is a concern, installations which are
|
||||||
|
affected can be repaired as follows:
|
||||||
|
|
||||||
|
1. Run the following sql from a `psql` or `sqlite3` console:
|
||||||
|
|
||||||
|
.. code:: sql
|
||||||
|
|
||||||
|
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
|
||||||
|
('populate_stats_process_rooms', '{}', 'current_state_events_membership');
|
||||||
|
|
||||||
|
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
|
||||||
|
('populate_stats_process_users', '{}', 'populate_stats_process_rooms');
|
||||||
|
|
||||||
|
2. Restart synapse.
|
||||||
|
|
||||||
|
|
||||||
Upgrading to v1.12.0
|
Upgrading to v1.12.0
|
||||||
====================
|
====================
|
||||||
|
|
||||||
|
1
changelog.d/7172.misc
Normal file
1
changelog.d/7172.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Use `stream.current_token()` and remove `stream_positions()`.
|
1
changelog.d/7363.misc
Normal file
1
changelog.d/7363.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Convert RegistrationWorkerStore.is_server_admin and dependent code to async/await.
|
1
changelog.d/7368.bugfix
Normal file
1
changelog.d/7368.bugfix
Normal file
@ -0,0 +1 @@
|
|||||||
|
Improve error responses when accessing remote public room lists.
|
1
changelog.d/7369.misc
Normal file
1
changelog.d/7369.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Thread through instance name to replication client.
|
1
changelog.d/7387.bugfix
Normal file
1
changelog.d/7387.bugfix
Normal file
@ -0,0 +1 @@
|
|||||||
|
Fix a bug which would cause the room durectory to be incorrectly populated if Synapse was upgraded directly from v1.2.1 or earlier to v1.4.0 or later. Note that this fix does not apply retrospectively; see the [upgrade notes](UPGRADE.rst#upgrading-to-v1130) for more information.
|
1
changelog.d/7393.bugfix
Normal file
1
changelog.d/7393.bugfix
Normal file
@ -0,0 +1 @@
|
|||||||
|
Fix bug in `EventContext.deserialize`.
|
1
changelog.d/7394.misc
Normal file
1
changelog.d/7394.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Convert synapse.server_notices to async/await.
|
1
changelog.d/7395.misc
Normal file
1
changelog.d/7395.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Convert synapse.notifier to async/await.
|
1
changelog.d/7401.feature
Normal file
1
changelog.d/7401.feature
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add support for running replication over Redis when using workers.
|
1
changelog.d/7404.misc
Normal file
1
changelog.d/7404.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Fix issues with the Python package manifest.
|
1
changelog.d/7408.misc
Normal file
1
changelog.d/7408.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Clean up some LoggingContext code.
|
@ -22,7 +22,10 @@ class RedisProtocol:
|
|||||||
def publish(self, channel: str, message: bytes): ...
|
def publish(self, channel: str, message: bytes): ...
|
||||||
|
|
||||||
class SubscriberProtocol:
|
class SubscriberProtocol:
|
||||||
|
password: Optional[str]
|
||||||
def subscribe(self, channels: Union[str, List[str]]): ...
|
def subscribe(self, channels: Union[str, List[str]]): ...
|
||||||
|
def connectionMade(self): ...
|
||||||
|
def connectionLost(self, reason): ...
|
||||||
|
|
||||||
def lazyConnection(
|
def lazyConnection(
|
||||||
host: str = ...,
|
host: str = ...,
|
||||||
|
@ -537,8 +537,7 @@ class Auth(object):
|
|||||||
|
|
||||||
return defer.succeed(auth_ids)
|
return defer.succeed(auth_ids)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def check_can_change_room_list(self, room_id: str, user: UserID):
|
||||||
def check_can_change_room_list(self, room_id: str, user: UserID):
|
|
||||||
"""Determine whether the user is allowed to edit the room's entry in the
|
"""Determine whether the user is allowed to edit the room's entry in the
|
||||||
published room list.
|
published room list.
|
||||||
|
|
||||||
@ -547,17 +546,17 @@ class Auth(object):
|
|||||||
user
|
user
|
||||||
"""
|
"""
|
||||||
|
|
||||||
is_admin = yield self.is_server_admin(user)
|
is_admin = await self.is_server_admin(user)
|
||||||
if is_admin:
|
if is_admin:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
yield self.check_user_in_room(room_id, user_id)
|
await self.check_user_in_room(room_id, user_id)
|
||||||
|
|
||||||
# We currently require the user is a "moderator" in the room. We do this
|
# We currently require the user is a "moderator" in the room. We do this
|
||||||
# by checking if they would (theoretically) be able to change the
|
# by checking if they would (theoretically) be able to change the
|
||||||
# m.room.canonical_alias events
|
# m.room.canonical_alias events
|
||||||
power_level_event = yield self.state.get_current_state(
|
power_level_event = await self.state.get_current_state(
|
||||||
room_id, EventTypes.PowerLevels, ""
|
room_id, EventTypes.PowerLevels, ""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -413,12 +413,6 @@ class GenericWorkerTyping(object):
|
|||||||
# map room IDs to sets of users currently typing
|
# map room IDs to sets of users currently typing
|
||||||
self._room_typing = {}
|
self._room_typing = {}
|
||||||
|
|
||||||
def stream_positions(self):
|
|
||||||
# We must update this typing token from the response of the previous
|
|
||||||
# sync. In particular, the stream id may "reset" back to zero/a low
|
|
||||||
# value which we *must* use for the next replication request.
|
|
||||||
return {"typing": self._latest_room_serial}
|
|
||||||
|
|
||||||
def process_replication_rows(self, token, rows):
|
def process_replication_rows(self, token, rows):
|
||||||
if self._latest_room_serial > token:
|
if self._latest_room_serial > token:
|
||||||
# The master has gone backwards. To prevent inconsistent data, just
|
# The master has gone backwards. To prevent inconsistent data, just
|
||||||
@ -652,20 +646,11 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
|
|||||||
else:
|
else:
|
||||||
self.send_handler = None
|
self.send_handler = None
|
||||||
|
|
||||||
async def on_rdata(self, stream_name, token, rows):
|
async def on_rdata(self, stream_name, instance_name, token, rows):
|
||||||
await super(GenericWorkerReplicationHandler, self).on_rdata(
|
await super().on_rdata(stream_name, instance_name, token, rows)
|
||||||
stream_name, token, rows
|
await self._process_and_notify(stream_name, instance_name, token, rows)
|
||||||
)
|
|
||||||
await self.process_and_notify(stream_name, token, rows)
|
|
||||||
|
|
||||||
def get_streams_to_replicate(self):
|
async def _process_and_notify(self, stream_name, instance_name, token, rows):
|
||||||
args = super(GenericWorkerReplicationHandler, self).get_streams_to_replicate()
|
|
||||||
args.update(self.typing_handler.stream_positions())
|
|
||||||
if self.send_handler:
|
|
||||||
args.update(self.send_handler.stream_positions())
|
|
||||||
return args
|
|
||||||
|
|
||||||
async def process_and_notify(self, stream_name, token, rows):
|
|
||||||
try:
|
try:
|
||||||
if self.send_handler:
|
if self.send_handler:
|
||||||
await self.send_handler.process_replication_rows(
|
await self.send_handler.process_replication_rows(
|
||||||
@ -799,9 +784,6 @@ class FederationSenderHandler(object):
|
|||||||
def wake_destination(self, server: str):
|
def wake_destination(self, server: str):
|
||||||
self.federation_sender.wake_destination(server)
|
self.federation_sender.wake_destination(server)
|
||||||
|
|
||||||
def stream_positions(self):
|
|
||||||
return {"federation": self.federation_position}
|
|
||||||
|
|
||||||
async def process_replication_rows(self, stream_name, token, rows):
|
async def process_replication_rows(self, stream_name, token, rows):
|
||||||
# The federation stream contains things that we want to send out, e.g.
|
# The federation stream contains things that we want to send out, e.g.
|
||||||
# presence, typing, etc.
|
# presence, typing, etc.
|
||||||
|
@ -322,11 +322,14 @@ class _AsyncEventContextImpl(EventContext):
|
|||||||
self._current_state_ids = yield self._storage.state.get_state_ids_for_group(
|
self._current_state_ids = yield self._storage.state.get_state_ids_for_group(
|
||||||
self.state_group
|
self.state_group
|
||||||
)
|
)
|
||||||
if self._prev_state_id and self._event_state_key is not None:
|
if self._event_state_key is not None:
|
||||||
self._prev_state_ids = dict(self._current_state_ids)
|
self._prev_state_ids = dict(self._current_state_ids)
|
||||||
|
|
||||||
key = (self._event_type, self._event_state_key)
|
key = (self._event_type, self._event_state_key)
|
||||||
self._prev_state_ids[key] = self._prev_state_id
|
if self._prev_state_id:
|
||||||
|
self._prev_state_ids[key] = self._prev_state_id
|
||||||
|
else:
|
||||||
|
self._prev_state_ids.pop(key, None)
|
||||||
else:
|
else:
|
||||||
self._prev_state_ids = self._current_state_ids
|
self._prev_state_ids = self._current_state_ids
|
||||||
|
|
||||||
|
@ -883,18 +883,37 @@ class FederationClient(FederationBase):
|
|||||||
|
|
||||||
def get_public_rooms(
|
def get_public_rooms(
|
||||||
self,
|
self,
|
||||||
destination,
|
remote_server: str,
|
||||||
limit=None,
|
limit: Optional[int] = None,
|
||||||
since_token=None,
|
since_token: Optional[str] = None,
|
||||||
search_filter=None,
|
search_filter: Optional[Dict] = None,
|
||||||
include_all_networks=False,
|
include_all_networks: bool = False,
|
||||||
third_party_instance_id=None,
|
third_party_instance_id: Optional[str] = None,
|
||||||
):
|
):
|
||||||
if destination == self.server_name:
|
"""Get the list of public rooms from a remote homeserver
|
||||||
return
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote_server: The name of the remote server
|
||||||
|
limit: Maximum amount of rooms to return
|
||||||
|
since_token: Used for result pagination
|
||||||
|
search_filter: A filter dictionary to send the remote homeserver
|
||||||
|
and filter the result set
|
||||||
|
include_all_networks: Whether to include results from all third party instances
|
||||||
|
third_party_instance_id: Whether to only include results from a specific third
|
||||||
|
party instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[Dict[str, Any]]: The response from the remote server, or None if
|
||||||
|
`remote_server` is the same as the local server_name
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HttpResponseException: There was an exception returned from the remote server
|
||||||
|
SynapseException: M_FORBIDDEN when the remote server has disallowed publicRoom
|
||||||
|
requests over federation
|
||||||
|
|
||||||
|
"""
|
||||||
return self.transport_layer.get_public_rooms(
|
return self.transport_layer.get_public_rooms(
|
||||||
destination,
|
remote_server,
|
||||||
limit,
|
limit,
|
||||||
since_token,
|
since_token,
|
||||||
search_filter,
|
search_filter,
|
||||||
@ -957,14 +976,13 @@ class FederationClient(FederationBase):
|
|||||||
|
|
||||||
return signed_events
|
return signed_events
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def forward_third_party_invite(self, destinations, room_id, event_dict):
|
||||||
def forward_third_party_invite(self, destinations, room_id, event_dict):
|
|
||||||
for destination in destinations:
|
for destination in destinations:
|
||||||
if destination == self.server_name:
|
if destination == self.server_name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield self.transport_layer.exchange_third_party_invite(
|
await self.transport_layer.exchange_third_party_invite(
|
||||||
destination=destination, room_id=room_id, event_dict=event_dict
|
destination=destination, room_id=room_id, event_dict=event_dict
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
@ -15,13 +15,14 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from six.moves import urllib
|
from six.moves import urllib
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.constants import Membership
|
from synapse.api.constants import Membership
|
||||||
|
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
||||||
from synapse.api.urls import (
|
from synapse.api.urls import (
|
||||||
FEDERATION_UNSTABLE_PREFIX,
|
FEDERATION_UNSTABLE_PREFIX,
|
||||||
FEDERATION_V1_PREFIX,
|
FEDERATION_V1_PREFIX,
|
||||||
@ -326,18 +327,25 @@ class TransportLayerClient(object):
|
|||||||
@log_function
|
@log_function
|
||||||
def get_public_rooms(
|
def get_public_rooms(
|
||||||
self,
|
self,
|
||||||
remote_server,
|
remote_server: str,
|
||||||
limit,
|
limit: Optional[int] = None,
|
||||||
since_token,
|
since_token: Optional[str] = None,
|
||||||
search_filter=None,
|
search_filter: Optional[Dict] = None,
|
||||||
include_all_networks=False,
|
include_all_networks: bool = False,
|
||||||
third_party_instance_id=None,
|
third_party_instance_id: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
"""Get the list of public rooms from a remote homeserver
|
||||||
|
|
||||||
|
See synapse.federation.federation_client.FederationClient.get_public_rooms for
|
||||||
|
more information.
|
||||||
|
"""
|
||||||
if search_filter:
|
if search_filter:
|
||||||
# this uses MSC2197 (Search Filtering over Federation)
|
# this uses MSC2197 (Search Filtering over Federation)
|
||||||
path = _create_v1_path("/publicRooms")
|
path = _create_v1_path("/publicRooms")
|
||||||
|
|
||||||
data = {"include_all_networks": "true" if include_all_networks else "false"}
|
data = {
|
||||||
|
"include_all_networks": "true" if include_all_networks else "false"
|
||||||
|
} # type: Dict[str, Any]
|
||||||
if third_party_instance_id:
|
if third_party_instance_id:
|
||||||
data["third_party_instance_id"] = third_party_instance_id
|
data["third_party_instance_id"] = third_party_instance_id
|
||||||
if limit:
|
if limit:
|
||||||
@ -347,9 +355,19 @@ class TransportLayerClient(object):
|
|||||||
|
|
||||||
data["filter"] = search_filter
|
data["filter"] = search_filter
|
||||||
|
|
||||||
response = yield self.client.post_json(
|
try:
|
||||||
destination=remote_server, path=path, data=data, ignore_backoff=True
|
response = yield self.client.post_json(
|
||||||
)
|
destination=remote_server, path=path, data=data, ignore_backoff=True
|
||||||
|
)
|
||||||
|
except HttpResponseException as e:
|
||||||
|
if e.code == 403:
|
||||||
|
raise SynapseError(
|
||||||
|
403,
|
||||||
|
"You are not allowed to view the public rooms list of %s"
|
||||||
|
% (remote_server,),
|
||||||
|
errcode=Codes.FORBIDDEN,
|
||||||
|
)
|
||||||
|
raise
|
||||||
else:
|
else:
|
||||||
path = _create_v1_path("/publicRooms")
|
path = _create_v1_path("/publicRooms")
|
||||||
|
|
||||||
@ -363,9 +381,19 @@ class TransportLayerClient(object):
|
|||||||
if since_token:
|
if since_token:
|
||||||
args["since"] = [since_token]
|
args["since"] = [since_token]
|
||||||
|
|
||||||
response = yield self.client.get_json(
|
try:
|
||||||
destination=remote_server, path=path, args=args, ignore_backoff=True
|
response = yield self.client.get_json(
|
||||||
)
|
destination=remote_server, path=path, args=args, ignore_backoff=True
|
||||||
|
)
|
||||||
|
except HttpResponseException as e:
|
||||||
|
if e.code == 403:
|
||||||
|
raise SynapseError(
|
||||||
|
403,
|
||||||
|
"You are not allowed to view the public rooms list of %s"
|
||||||
|
% (remote_server,),
|
||||||
|
errcode=Codes.FORBIDDEN,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@ -748,17 +748,18 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||||||
|
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def remove_user_from_group(
|
||||||
def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
|
self, group_id, user_id, requester_user_id, content
|
||||||
|
):
|
||||||
"""Remove a user from the group; either a user is leaving or an admin
|
"""Remove a user from the group; either a user is leaving or an admin
|
||||||
kicked them.
|
kicked them.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||||
|
|
||||||
is_kick = False
|
is_kick = False
|
||||||
if requester_user_id != user_id:
|
if requester_user_id != user_id:
|
||||||
is_admin = yield self.store.is_user_admin_in_group(
|
is_admin = await self.store.is_user_admin_in_group(
|
||||||
group_id, requester_user_id
|
group_id, requester_user_id
|
||||||
)
|
)
|
||||||
if not is_admin:
|
if not is_admin:
|
||||||
@ -766,30 +767,29 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||||||
|
|
||||||
is_kick = True
|
is_kick = True
|
||||||
|
|
||||||
yield self.store.remove_user_from_group(group_id, user_id)
|
await self.store.remove_user_from_group(group_id, user_id)
|
||||||
|
|
||||||
if is_kick:
|
if is_kick:
|
||||||
if self.hs.is_mine_id(user_id):
|
if self.hs.is_mine_id(user_id):
|
||||||
groups_local = self.hs.get_groups_local_handler()
|
groups_local = self.hs.get_groups_local_handler()
|
||||||
yield groups_local.user_removed_from_group(group_id, user_id, {})
|
await groups_local.user_removed_from_group(group_id, user_id, {})
|
||||||
else:
|
else:
|
||||||
yield self.transport_client.remove_user_from_group_notification(
|
await self.transport_client.remove_user_from_group_notification(
|
||||||
get_domain_from_id(user_id), group_id, user_id, {}
|
get_domain_from_id(user_id), group_id, user_id, {}
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.hs.is_mine_id(user_id):
|
if not self.hs.is_mine_id(user_id):
|
||||||
yield self.store.maybe_delete_remote_profile_cache(user_id)
|
await self.store.maybe_delete_remote_profile_cache(user_id)
|
||||||
|
|
||||||
# Delete group if the last user has left
|
# Delete group if the last user has left
|
||||||
users = yield self.store.get_users_in_group(group_id, include_private=True)
|
users = await self.store.get_users_in_group(group_id, include_private=True)
|
||||||
if not users:
|
if not users:
|
||||||
yield self.store.delete_group(group_id)
|
await self.store.delete_group(group_id)
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def create_group(self, group_id, requester_user_id, content):
|
||||||
def create_group(self, group_id, requester_user_id, content):
|
group = await self.check_group_is_ours(group_id, requester_user_id)
|
||||||
group = yield self.check_group_is_ours(group_id, requester_user_id)
|
|
||||||
|
|
||||||
logger.info("Attempting to create group with ID: %r", group_id)
|
logger.info("Attempting to create group with ID: %r", group_id)
|
||||||
|
|
||||||
@ -799,7 +799,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||||||
if group:
|
if group:
|
||||||
raise SynapseError(400, "Group already exists")
|
raise SynapseError(400, "Group already exists")
|
||||||
|
|
||||||
is_admin = yield self.auth.is_server_admin(
|
is_admin = await self.auth.is_server_admin(
|
||||||
UserID.from_string(requester_user_id)
|
UserID.from_string(requester_user_id)
|
||||||
)
|
)
|
||||||
if not is_admin:
|
if not is_admin:
|
||||||
@ -822,7 +822,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||||||
long_description = profile.get("long_description")
|
long_description = profile.get("long_description")
|
||||||
user_profile = content.get("user_profile", {})
|
user_profile = content.get("user_profile", {})
|
||||||
|
|
||||||
yield self.store.create_group(
|
await self.store.create_group(
|
||||||
group_id,
|
group_id,
|
||||||
requester_user_id,
|
requester_user_id,
|
||||||
name=name,
|
name=name,
|
||||||
@ -834,7 +834,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||||||
if not self.hs.is_mine_id(requester_user_id):
|
if not self.hs.is_mine_id(requester_user_id):
|
||||||
remote_attestation = content["attestation"]
|
remote_attestation = content["attestation"]
|
||||||
|
|
||||||
yield self.attestations.verify_attestation(
|
await self.attestations.verify_attestation(
|
||||||
remote_attestation, user_id=requester_user_id, group_id=group_id
|
remote_attestation, user_id=requester_user_id, group_id=group_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -845,7 +845,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||||||
local_attestation = None
|
local_attestation = None
|
||||||
remote_attestation = None
|
remote_attestation = None
|
||||||
|
|
||||||
yield self.store.add_user_to_group(
|
await self.store.add_user_to_group(
|
||||||
group_id,
|
group_id,
|
||||||
requester_user_id,
|
requester_user_id,
|
||||||
is_admin=True,
|
is_admin=True,
|
||||||
@ -855,7 +855,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not self.hs.is_mine_id(requester_user_id):
|
if not self.hs.is_mine_id(requester_user_id):
|
||||||
yield self.store.add_remote_profile_cache(
|
await self.store.add_remote_profile_cache(
|
||||||
requester_user_id,
|
requester_user_id,
|
||||||
displayname=user_profile.get("displayname"),
|
displayname=user_profile.get("displayname"),
|
||||||
avatar_url=user_profile.get("avatar_url"),
|
avatar_url=user_profile.get("avatar_url"),
|
||||||
@ -863,8 +863,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||||||
|
|
||||||
return {"group_id": group_id}
|
return {"group_id": group_id}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def delete_group(self, group_id, requester_user_id):
|
||||||
def delete_group(self, group_id, requester_user_id):
|
|
||||||
"""Deletes a group, kicking out all current members.
|
"""Deletes a group, kicking out all current members.
|
||||||
|
|
||||||
Only group admins or server admins can call this request
|
Only group admins or server admins can call this request
|
||||||
@ -877,14 +876,14 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||||||
Deferred
|
Deferred
|
||||||
"""
|
"""
|
||||||
|
|
||||||
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||||
|
|
||||||
# Only server admins or group admins can delete groups.
|
# Only server admins or group admins can delete groups.
|
||||||
|
|
||||||
is_admin = yield self.store.is_user_admin_in_group(group_id, requester_user_id)
|
is_admin = await self.store.is_user_admin_in_group(group_id, requester_user_id)
|
||||||
|
|
||||||
if not is_admin:
|
if not is_admin:
|
||||||
is_admin = yield self.auth.is_server_admin(
|
is_admin = await self.auth.is_server_admin(
|
||||||
UserID.from_string(requester_user_id)
|
UserID.from_string(requester_user_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -892,18 +891,17 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||||||
raise SynapseError(403, "User is not an admin")
|
raise SynapseError(403, "User is not an admin")
|
||||||
|
|
||||||
# Before deleting the group lets kick everyone out of it
|
# Before deleting the group lets kick everyone out of it
|
||||||
users = yield self.store.get_users_in_group(group_id, include_private=True)
|
users = await self.store.get_users_in_group(group_id, include_private=True)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _kick_user_from_group(user_id):
|
||||||
def _kick_user_from_group(user_id):
|
|
||||||
if self.hs.is_mine_id(user_id):
|
if self.hs.is_mine_id(user_id):
|
||||||
groups_local = self.hs.get_groups_local_handler()
|
groups_local = self.hs.get_groups_local_handler()
|
||||||
yield groups_local.user_removed_from_group(group_id, user_id, {})
|
await groups_local.user_removed_from_group(group_id, user_id, {})
|
||||||
else:
|
else:
|
||||||
yield self.transport_client.remove_user_from_group_notification(
|
await self.transport_client.remove_user_from_group_notification(
|
||||||
get_domain_from_id(user_id), group_id, user_id, {}
|
get_domain_from_id(user_id), group_id, user_id, {}
|
||||||
)
|
)
|
||||||
yield self.store.maybe_delete_remote_profile_cache(user_id)
|
await self.store.maybe_delete_remote_profile_cache(user_id)
|
||||||
|
|
||||||
# We kick users out in the order of:
|
# We kick users out in the order of:
|
||||||
# 1. Non-admins
|
# 1. Non-admins
|
||||||
@ -922,11 +920,11 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||||||
else:
|
else:
|
||||||
non_admins.append(u["user_id"])
|
non_admins.append(u["user_id"])
|
||||||
|
|
||||||
yield concurrently_execute(_kick_user_from_group, non_admins, 10)
|
await concurrently_execute(_kick_user_from_group, non_admins, 10)
|
||||||
yield concurrently_execute(_kick_user_from_group, admins, 10)
|
await concurrently_execute(_kick_user_from_group, admins, 10)
|
||||||
yield _kick_user_from_group(requester_user_id)
|
await _kick_user_from_group(requester_user_id)
|
||||||
|
|
||||||
yield self.store.delete_group(group_id)
|
await self.store.delete_group(group_id)
|
||||||
|
|
||||||
|
|
||||||
def _parse_join_policy_from_contents(content):
|
def _parse_join_policy_from_contents(content):
|
||||||
|
@ -126,30 +126,28 @@ class BaseHandler(object):
|
|||||||
retry_after_ms=int(1000 * (time_allowed - time_now))
|
retry_after_ms=int(1000 * (time_allowed - time_now))
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def maybe_kick_guest_users(self, event, context=None):
|
||||||
def maybe_kick_guest_users(self, event, context=None):
|
|
||||||
# Technically this function invalidates current_state by changing it.
|
# Technically this function invalidates current_state by changing it.
|
||||||
# Hopefully this isn't that important to the caller.
|
# Hopefully this isn't that important to the caller.
|
||||||
if event.type == EventTypes.GuestAccess:
|
if event.type == EventTypes.GuestAccess:
|
||||||
guest_access = event.content.get("guest_access", "forbidden")
|
guest_access = event.content.get("guest_access", "forbidden")
|
||||||
if guest_access != "can_join":
|
if guest_access != "can_join":
|
||||||
if context:
|
if context:
|
||||||
current_state_ids = yield context.get_current_state_ids()
|
current_state_ids = await context.get_current_state_ids()
|
||||||
current_state = yield self.store.get_events(
|
current_state = await self.store.get_events(
|
||||||
list(current_state_ids.values())
|
list(current_state_ids.values())
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
current_state = yield self.state_handler.get_current_state(
|
current_state = await self.state_handler.get_current_state(
|
||||||
event.room_id
|
event.room_id
|
||||||
)
|
)
|
||||||
|
|
||||||
current_state = list(current_state.values())
|
current_state = list(current_state.values())
|
||||||
|
|
||||||
logger.info("maybe_kick_guest_users %r", current_state)
|
logger.info("maybe_kick_guest_users %r", current_state)
|
||||||
yield self.kick_guest_users(current_state)
|
await self.kick_guest_users(current_state)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def kick_guest_users(self, current_state):
|
||||||
def kick_guest_users(self, current_state):
|
|
||||||
for member_event in current_state:
|
for member_event in current_state:
|
||||||
try:
|
try:
|
||||||
if member_event.type != EventTypes.Member:
|
if member_event.type != EventTypes.Member:
|
||||||
@ -180,7 +178,7 @@ class BaseHandler(object):
|
|||||||
# homeserver.
|
# homeserver.
|
||||||
requester = synapse.types.create_requester(target_user, is_guest=True)
|
requester = synapse.types.create_requester(target_user, is_guest=True)
|
||||||
handler = self.hs.get_room_member_handler()
|
handler = self.hs.get_room_member_handler()
|
||||||
yield handler.update_membership(
|
await handler.update_membership(
|
||||||
requester,
|
requester,
|
||||||
target_user,
|
target_user,
|
||||||
member_event.room_id,
|
member_event.room_id,
|
||||||
|
@ -86,8 +86,7 @@ class DirectoryHandler(BaseHandler):
|
|||||||
room_alias, room_id, servers, creator=creator
|
room_alias, room_id, servers, creator=creator
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def create_association(
|
||||||
def create_association(
|
|
||||||
self,
|
self,
|
||||||
requester: Requester,
|
requester: Requester,
|
||||||
room_alias: RoomAlias,
|
room_alias: RoomAlias,
|
||||||
@ -129,10 +128,10 @@ class DirectoryHandler(BaseHandler):
|
|||||||
else:
|
else:
|
||||||
# Server admins are not subject to the same constraints as normal
|
# Server admins are not subject to the same constraints as normal
|
||||||
# users when creating an alias (e.g. being in the room).
|
# users when creating an alias (e.g. being in the room).
|
||||||
is_admin = yield self.auth.is_server_admin(requester.user)
|
is_admin = await self.auth.is_server_admin(requester.user)
|
||||||
|
|
||||||
if (self.require_membership and check_membership) and not is_admin:
|
if (self.require_membership and check_membership) and not is_admin:
|
||||||
rooms_for_user = yield self.store.get_rooms_for_user(user_id)
|
rooms_for_user = await self.store.get_rooms_for_user(user_id)
|
||||||
if room_id not in rooms_for_user:
|
if room_id not in rooms_for_user:
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
403, "You must be in the room to create an alias for it"
|
403, "You must be in the room to create an alias for it"
|
||||||
@ -149,7 +148,7 @@ class DirectoryHandler(BaseHandler):
|
|||||||
# per alias creation rule?
|
# per alias creation rule?
|
||||||
raise SynapseError(403, "Not allowed to create alias")
|
raise SynapseError(403, "Not allowed to create alias")
|
||||||
|
|
||||||
can_create = yield self.can_modify_alias(room_alias, user_id=user_id)
|
can_create = await self.can_modify_alias(room_alias, user_id=user_id)
|
||||||
if not can_create:
|
if not can_create:
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
400,
|
400,
|
||||||
@ -157,10 +156,9 @@ class DirectoryHandler(BaseHandler):
|
|||||||
errcode=Codes.EXCLUSIVE,
|
errcode=Codes.EXCLUSIVE,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self._create_association(room_alias, room_id, servers, creator=user_id)
|
await self._create_association(room_alias, room_id, servers, creator=user_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def delete_association(self, requester: Requester, room_alias: RoomAlias):
|
||||||
def delete_association(self, requester: Requester, room_alias: RoomAlias):
|
|
||||||
"""Remove an alias from the directory
|
"""Remove an alias from the directory
|
||||||
|
|
||||||
(this is only meant for human users; AS users should call
|
(this is only meant for human users; AS users should call
|
||||||
@ -184,7 +182,7 @@ class DirectoryHandler(BaseHandler):
|
|||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
can_delete = yield self._user_can_delete_alias(room_alias, user_id)
|
can_delete = await self._user_can_delete_alias(room_alias, user_id)
|
||||||
except StoreError as e:
|
except StoreError as e:
|
||||||
if e.code == 404:
|
if e.code == 404:
|
||||||
raise NotFoundError("Unknown room alias")
|
raise NotFoundError("Unknown room alias")
|
||||||
@ -193,7 +191,7 @@ class DirectoryHandler(BaseHandler):
|
|||||||
if not can_delete:
|
if not can_delete:
|
||||||
raise AuthError(403, "You don't have permission to delete the alias.")
|
raise AuthError(403, "You don't have permission to delete the alias.")
|
||||||
|
|
||||||
can_delete = yield self.can_modify_alias(room_alias, user_id=user_id)
|
can_delete = await self.can_modify_alias(room_alias, user_id=user_id)
|
||||||
if not can_delete:
|
if not can_delete:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
400,
|
400,
|
||||||
@ -201,10 +199,10 @@ class DirectoryHandler(BaseHandler):
|
|||||||
errcode=Codes.EXCLUSIVE,
|
errcode=Codes.EXCLUSIVE,
|
||||||
)
|
)
|
||||||
|
|
||||||
room_id = yield self._delete_association(room_alias)
|
room_id = await self._delete_association(room_alias)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield self._update_canonical_alias(requester, user_id, room_id, room_alias)
|
await self._update_canonical_alias(requester, user_id, room_id, room_alias)
|
||||||
except AuthError as e:
|
except AuthError as e:
|
||||||
logger.info("Failed to update alias events: %s", e)
|
logger.info("Failed to update alias events: %s", e)
|
||||||
|
|
||||||
@ -296,15 +294,14 @@ class DirectoryHandler(BaseHandler):
|
|||||||
Codes.NOT_FOUND,
|
Codes.NOT_FOUND,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _update_canonical_alias(
|
||||||
def _update_canonical_alias(
|
|
||||||
self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias
|
self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Send an updated canonical alias event if the removed alias was set as
|
Send an updated canonical alias event if the removed alias was set as
|
||||||
the canonical alias or listed in the alt_aliases field.
|
the canonical alias or listed in the alt_aliases field.
|
||||||
"""
|
"""
|
||||||
alias_event = yield self.state.get_current_state(
|
alias_event = await self.state.get_current_state(
|
||||||
room_id, EventTypes.CanonicalAlias, ""
|
room_id, EventTypes.CanonicalAlias, ""
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -335,7 +332,7 @@ class DirectoryHandler(BaseHandler):
|
|||||||
del content["alt_aliases"]
|
del content["alt_aliases"]
|
||||||
|
|
||||||
if send_update:
|
if send_update:
|
||||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
await self.event_creation_handler.create_and_send_nonmember_event(
|
||||||
requester,
|
requester,
|
||||||
{
|
{
|
||||||
"type": EventTypes.CanonicalAlias,
|
"type": EventTypes.CanonicalAlias,
|
||||||
@ -376,8 +373,7 @@ class DirectoryHandler(BaseHandler):
|
|||||||
# either no interested services, or no service with an exclusive lock
|
# either no interested services, or no service with an exclusive lock
|
||||||
return defer.succeed(True)
|
return defer.succeed(True)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str):
|
||||||
def _user_can_delete_alias(self, alias: RoomAlias, user_id: str):
|
|
||||||
"""Determine whether a user can delete an alias.
|
"""Determine whether a user can delete an alias.
|
||||||
|
|
||||||
One of the following must be true:
|
One of the following must be true:
|
||||||
@ -388,24 +384,23 @@ class DirectoryHandler(BaseHandler):
|
|||||||
for the current room.
|
for the current room.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
creator = yield self.store.get_room_alias_creator(alias.to_string())
|
creator = await self.store.get_room_alias_creator(alias.to_string())
|
||||||
|
|
||||||
if creator is not None and creator == user_id:
|
if creator is not None and creator == user_id:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Resolve the alias to the corresponding room.
|
# Resolve the alias to the corresponding room.
|
||||||
room_mapping = yield self.get_association(alias)
|
room_mapping = await self.get_association(alias)
|
||||||
room_id = room_mapping["room_id"]
|
room_id = room_mapping["room_id"]
|
||||||
if not room_id:
|
if not room_id:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
res = yield self.auth.check_can_change_room_list(
|
res = await self.auth.check_can_change_room_list(
|
||||||
room_id, UserID.from_string(user_id)
|
room_id, UserID.from_string(user_id)
|
||||||
)
|
)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def edit_published_room_list(
|
||||||
def edit_published_room_list(
|
|
||||||
self, requester: Requester, room_id: str, visibility: str
|
self, requester: Requester, room_id: str, visibility: str
|
||||||
):
|
):
|
||||||
"""Edit the entry of the room in the published room list.
|
"""Edit the entry of the room in the published room list.
|
||||||
@ -433,11 +428,11 @@ class DirectoryHandler(BaseHandler):
|
|||||||
403, "This user is not permitted to publish rooms to the room list"
|
403, "This user is not permitted to publish rooms to the room list"
|
||||||
)
|
)
|
||||||
|
|
||||||
room = yield self.store.get_room(room_id)
|
room = await self.store.get_room(room_id)
|
||||||
if room is None:
|
if room is None:
|
||||||
raise SynapseError(400, "Unknown room")
|
raise SynapseError(400, "Unknown room")
|
||||||
|
|
||||||
can_change_room_list = yield self.auth.check_can_change_room_list(
|
can_change_room_list = await self.auth.check_can_change_room_list(
|
||||||
room_id, requester.user
|
room_id, requester.user
|
||||||
)
|
)
|
||||||
if not can_change_room_list:
|
if not can_change_room_list:
|
||||||
@ -449,8 +444,8 @@ class DirectoryHandler(BaseHandler):
|
|||||||
|
|
||||||
making_public = visibility == "public"
|
making_public = visibility == "public"
|
||||||
if making_public:
|
if making_public:
|
||||||
room_aliases = yield self.store.get_aliases_for_room(room_id)
|
room_aliases = await self.store.get_aliases_for_room(room_id)
|
||||||
canonical_alias = yield self.store.get_canonical_alias_for_room(room_id)
|
canonical_alias = await self.store.get_canonical_alias_for_room(room_id)
|
||||||
if canonical_alias:
|
if canonical_alias:
|
||||||
room_aliases.append(canonical_alias)
|
room_aliases.append(canonical_alias)
|
||||||
|
|
||||||
@ -462,7 +457,7 @@ class DirectoryHandler(BaseHandler):
|
|||||||
# per alias creation rule?
|
# per alias creation rule?
|
||||||
raise SynapseError(403, "Not allowed to publish room")
|
raise SynapseError(403, "Not allowed to publish room")
|
||||||
|
|
||||||
yield self.store.set_room_is_public(room_id, making_public)
|
await self.store.set_room_is_public(room_id, making_public)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def edit_published_appservice_room_list(
|
def edit_published_appservice_room_list(
|
||||||
|
@ -2562,9 +2562,8 @@ class FederationHandler(BaseHandler):
|
|||||||
"missing": [e.event_id for e in missing_locals],
|
"missing": [e.event_id for e in missing_locals],
|
||||||
}
|
}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
@log_function
|
@log_function
|
||||||
def exchange_third_party_invite(
|
async def exchange_third_party_invite(
|
||||||
self, sender_user_id, target_user_id, room_id, signed
|
self, sender_user_id, target_user_id, room_id, signed
|
||||||
):
|
):
|
||||||
third_party_invite = {"signed": signed}
|
third_party_invite = {"signed": signed}
|
||||||
@ -2580,16 +2579,16 @@ class FederationHandler(BaseHandler):
|
|||||||
"state_key": target_user_id,
|
"state_key": target_user_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
|
if await self.auth.check_host_in_room(room_id, self.hs.hostname):
|
||||||
room_version = yield self.store.get_room_version_id(room_id)
|
room_version = await self.store.get_room_version_id(room_id)
|
||||||
builder = self.event_builder_factory.new(room_version, event_dict)
|
builder = self.event_builder_factory.new(room_version, event_dict)
|
||||||
|
|
||||||
EventValidator().validate_builder(builder)
|
EventValidator().validate_builder(builder)
|
||||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
event, context = await self.event_creation_handler.create_new_client_event(
|
||||||
builder=builder
|
builder=builder
|
||||||
)
|
)
|
||||||
|
|
||||||
event_allowed = yield self.third_party_event_rules.check_event_allowed(
|
event_allowed = await self.third_party_event_rules.check_event_allowed(
|
||||||
event, context
|
event, context
|
||||||
)
|
)
|
||||||
if not event_allowed:
|
if not event_allowed:
|
||||||
@ -2601,7 +2600,7 @@ class FederationHandler(BaseHandler):
|
|||||||
403, "This event is not allowed in this context", Codes.FORBIDDEN
|
403, "This event is not allowed in this context", Codes.FORBIDDEN
|
||||||
)
|
)
|
||||||
|
|
||||||
event, context = yield self.add_display_name_to_third_party_invite(
|
event, context = await self.add_display_name_to_third_party_invite(
|
||||||
room_version, event_dict, event, context
|
room_version, event_dict, event, context
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -2612,19 +2611,19 @@ class FederationHandler(BaseHandler):
|
|||||||
event.internal_metadata.send_on_behalf_of = self.hs.hostname
|
event.internal_metadata.send_on_behalf_of = self.hs.hostname
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield self.auth.check_from_context(room_version, event, context)
|
await self.auth.check_from_context(room_version, event, context)
|
||||||
except AuthError as e:
|
except AuthError as e:
|
||||||
logger.warning("Denying new third party invite %r because %s", event, e)
|
logger.warning("Denying new third party invite %r because %s", event, e)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
yield self._check_signature(event, context)
|
await self._check_signature(event, context)
|
||||||
|
|
||||||
# We retrieve the room member handler here as to not cause a cyclic dependency
|
# We retrieve the room member handler here as to not cause a cyclic dependency
|
||||||
member_handler = self.hs.get_room_member_handler()
|
member_handler = self.hs.get_room_member_handler()
|
||||||
yield member_handler.send_membership_event(None, event, context)
|
await member_handler.send_membership_event(None, event, context)
|
||||||
else:
|
else:
|
||||||
destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)}
|
destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)}
|
||||||
yield self.federation_client.forward_third_party_invite(
|
await self.federation_client.forward_third_party_invite(
|
||||||
destinations, room_id, event_dict
|
destinations, room_id, event_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -284,15 +284,14 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||||||
|
|
||||||
set_group_join_policy = _create_rerouter("set_group_join_policy")
|
set_group_join_policy = _create_rerouter("set_group_join_policy")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def create_group(self, group_id, user_id, content):
|
||||||
def create_group(self, group_id, user_id, content):
|
|
||||||
"""Create a group
|
"""Create a group
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logger.info("Asking to create group with ID: %r", group_id)
|
logger.info("Asking to create group with ID: %r", group_id)
|
||||||
|
|
||||||
if self.is_mine_id(group_id):
|
if self.is_mine_id(group_id):
|
||||||
res = yield self.groups_server_handler.create_group(
|
res = await self.groups_server_handler.create_group(
|
||||||
group_id, user_id, content
|
group_id, user_id, content
|
||||||
)
|
)
|
||||||
local_attestation = None
|
local_attestation = None
|
||||||
@ -301,10 +300,10 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||||||
local_attestation = self.attestations.create_attestation(group_id, user_id)
|
local_attestation = self.attestations.create_attestation(group_id, user_id)
|
||||||
content["attestation"] = local_attestation
|
content["attestation"] = local_attestation
|
||||||
|
|
||||||
content["user_profile"] = yield self.profile_handler.get_profile(user_id)
|
content["user_profile"] = await self.profile_handler.get_profile(user_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
res = yield self.transport_client.create_group(
|
res = await self.transport_client.create_group(
|
||||||
get_domain_from_id(group_id), group_id, user_id, content
|
get_domain_from_id(group_id), group_id, user_id, content
|
||||||
)
|
)
|
||||||
except HttpResponseException as e:
|
except HttpResponseException as e:
|
||||||
@ -313,7 +312,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||||||
raise SynapseError(502, "Failed to contact group server")
|
raise SynapseError(502, "Failed to contact group server")
|
||||||
|
|
||||||
remote_attestation = res["attestation"]
|
remote_attestation = res["attestation"]
|
||||||
yield self.attestations.verify_attestation(
|
await self.attestations.verify_attestation(
|
||||||
remote_attestation,
|
remote_attestation,
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@ -321,7 +320,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
is_publicised = content.get("publicise", False)
|
is_publicised = content.get("publicise", False)
|
||||||
token = yield self.store.register_user_group_membership(
|
token = await self.store.register_user_group_membership(
|
||||||
group_id,
|
group_id,
|
||||||
user_id,
|
user_id,
|
||||||
membership="join",
|
membership="join",
|
||||||
@ -482,12 +481,13 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||||||
|
|
||||||
return {"state": "invite", "user_profile": user_profile}
|
return {"state": "invite", "user_profile": user_profile}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def remove_user_from_group(
|
||||||
def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
|
self, group_id, user_id, requester_user_id, content
|
||||||
|
):
|
||||||
"""Remove a user from a group
|
"""Remove a user from a group
|
||||||
"""
|
"""
|
||||||
if user_id == requester_user_id:
|
if user_id == requester_user_id:
|
||||||
token = yield self.store.register_user_group_membership(
|
token = await self.store.register_user_group_membership(
|
||||||
group_id, user_id, membership="leave"
|
group_id, user_id, membership="leave"
|
||||||
)
|
)
|
||||||
self.notifier.on_new_event("groups_key", token, users=[user_id])
|
self.notifier.on_new_event("groups_key", token, users=[user_id])
|
||||||
@ -496,13 +496,13 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||||||
# retry if the group server is currently down.
|
# retry if the group server is currently down.
|
||||||
|
|
||||||
if self.is_mine_id(group_id):
|
if self.is_mine_id(group_id):
|
||||||
res = yield self.groups_server_handler.remove_user_from_group(
|
res = await self.groups_server_handler.remove_user_from_group(
|
||||||
group_id, user_id, requester_user_id, content
|
group_id, user_id, requester_user_id, content
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
content["requester_user_id"] = requester_user_id
|
content["requester_user_id"] = requester_user_id
|
||||||
try:
|
try:
|
||||||
res = yield self.transport_client.remove_user_from_group(
|
res = await self.transport_client.remove_user_from_group(
|
||||||
get_domain_from_id(group_id),
|
get_domain_from_id(group_id),
|
||||||
group_id,
|
group_id,
|
||||||
requester_user_id,
|
requester_user_id,
|
||||||
|
@ -626,8 +626,7 @@ class EventCreationHandler(object):
|
|||||||
msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
|
msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
|
||||||
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
|
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def send_nonmember_event(self, requester, event, context, ratelimit=True):
|
||||||
def send_nonmember_event(self, requester, event, context, ratelimit=True):
|
|
||||||
"""
|
"""
|
||||||
Persists and notifies local clients and federation of an event.
|
Persists and notifies local clients and federation of an event.
|
||||||
|
|
||||||
@ -647,7 +646,7 @@ class EventCreationHandler(object):
|
|||||||
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
|
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
|
||||||
|
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
prev_state = yield self.deduplicate_state_event(event, context)
|
prev_state = await self.deduplicate_state_event(event, context)
|
||||||
if prev_state is not None:
|
if prev_state is not None:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Not bothering to persist state event %s duplicated by %s",
|
"Not bothering to persist state event %s duplicated by %s",
|
||||||
@ -656,7 +655,7 @@ class EventCreationHandler(object):
|
|||||||
)
|
)
|
||||||
return prev_state
|
return prev_state
|
||||||
|
|
||||||
yield self.handle_new_client_event(
|
await self.handle_new_client_event(
|
||||||
requester=requester, event=event, context=context, ratelimit=ratelimit
|
requester=requester, event=event, context=context, ratelimit=ratelimit
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -683,8 +682,7 @@ class EventCreationHandler(object):
|
|||||||
return prev_event
|
return prev_event
|
||||||
return
|
return
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def create_and_send_nonmember_event(
|
||||||
def create_and_send_nonmember_event(
|
|
||||||
self, requester, event_dict, ratelimit=True, txn_id=None
|
self, requester, event_dict, ratelimit=True, txn_id=None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -698,8 +696,8 @@ class EventCreationHandler(object):
|
|||||||
# a situation where event persistence can't keep up, causing
|
# a situation where event persistence can't keep up, causing
|
||||||
# extremities to pile up, which in turn leads to state resolution
|
# extremities to pile up, which in turn leads to state resolution
|
||||||
# taking longer.
|
# taking longer.
|
||||||
with (yield self.limiter.queue(event_dict["room_id"])):
|
with (await self.limiter.queue(event_dict["room_id"])):
|
||||||
event, context = yield self.create_event(
|
event, context = await self.create_event(
|
||||||
requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id
|
requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -709,7 +707,7 @@ class EventCreationHandler(object):
|
|||||||
spam_error = "Spam is not permitted here"
|
spam_error = "Spam is not permitted here"
|
||||||
raise SynapseError(403, spam_error, Codes.FORBIDDEN)
|
raise SynapseError(403, spam_error, Codes.FORBIDDEN)
|
||||||
|
|
||||||
yield self.send_nonmember_event(
|
await self.send_nonmember_event(
|
||||||
requester, event, context, ratelimit=ratelimit
|
requester, event, context, ratelimit=ratelimit
|
||||||
)
|
)
|
||||||
return event
|
return event
|
||||||
@ -770,8 +768,7 @@ class EventCreationHandler(object):
|
|||||||
return (event, context)
|
return (event, context)
|
||||||
|
|
||||||
@measure_func("handle_new_client_event")
|
@measure_func("handle_new_client_event")
|
||||||
@defer.inlineCallbacks
|
async def handle_new_client_event(
|
||||||
def handle_new_client_event(
|
|
||||||
self, requester, event, context, ratelimit=True, extra_users=[]
|
self, requester, event, context, ratelimit=True, extra_users=[]
|
||||||
):
|
):
|
||||||
"""Processes a new event. This includes checking auth, persisting it,
|
"""Processes a new event. This includes checking auth, persisting it,
|
||||||
@ -794,9 +791,9 @@ class EventCreationHandler(object):
|
|||||||
):
|
):
|
||||||
room_version = event.content.get("room_version", RoomVersions.V1.identifier)
|
room_version = event.content.get("room_version", RoomVersions.V1.identifier)
|
||||||
else:
|
else:
|
||||||
room_version = yield self.store.get_room_version_id(event.room_id)
|
room_version = await self.store.get_room_version_id(event.room_id)
|
||||||
|
|
||||||
event_allowed = yield self.third_party_event_rules.check_event_allowed(
|
event_allowed = await self.third_party_event_rules.check_event_allowed(
|
||||||
event, context
|
event, context
|
||||||
)
|
)
|
||||||
if not event_allowed:
|
if not event_allowed:
|
||||||
@ -805,7 +802,7 @@ class EventCreationHandler(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield self.auth.check_from_context(room_version, event, context)
|
await self.auth.check_from_context(room_version, event, context)
|
||||||
except AuthError as err:
|
except AuthError as err:
|
||||||
logger.warning("Denying new event %r because %s", event, err)
|
logger.warning("Denying new event %r because %s", event, err)
|
||||||
raise err
|
raise err
|
||||||
@ -818,7 +815,7 @@ class EventCreationHandler(object):
|
|||||||
logger.exception("Failed to encode content: %r", event.content)
|
logger.exception("Failed to encode content: %r", event.content)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
yield self.action_generator.handle_push_actions_for_event(event, context)
|
await self.action_generator.handle_push_actions_for_event(event, context)
|
||||||
|
|
||||||
# reraise does not allow inlineCallbacks to preserve the stacktrace, so we
|
# reraise does not allow inlineCallbacks to preserve the stacktrace, so we
|
||||||
# hack around with a try/finally instead.
|
# hack around with a try/finally instead.
|
||||||
@ -826,7 +823,7 @@ class EventCreationHandler(object):
|
|||||||
try:
|
try:
|
||||||
# If we're a worker we need to hit out to the master.
|
# If we're a worker we need to hit out to the master.
|
||||||
if self.config.worker_app:
|
if self.config.worker_app:
|
||||||
yield self.send_event_to_master(
|
await self.send_event_to_master(
|
||||||
event_id=event.event_id,
|
event_id=event.event_id,
|
||||||
store=self.store,
|
store=self.store,
|
||||||
requester=requester,
|
requester=requester,
|
||||||
@ -838,7 +835,7 @@ class EventCreationHandler(object):
|
|||||||
success = True
|
success = True
|
||||||
return
|
return
|
||||||
|
|
||||||
yield self.persist_and_notify_client_event(
|
await self.persist_and_notify_client_event(
|
||||||
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
|
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -883,8 +880,7 @@ class EventCreationHandler(object):
|
|||||||
Codes.BAD_ALIAS,
|
Codes.BAD_ALIAS,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def persist_and_notify_client_event(
|
||||||
def persist_and_notify_client_event(
|
|
||||||
self, requester, event, context, ratelimit=True, extra_users=[]
|
self, requester, event, context, ratelimit=True, extra_users=[]
|
||||||
):
|
):
|
||||||
"""Called when we have fully built the event, have already
|
"""Called when we have fully built the event, have already
|
||||||
@ -901,7 +897,7 @@ class EventCreationHandler(object):
|
|||||||
# user is actually admin or not).
|
# user is actually admin or not).
|
||||||
is_admin_redaction = False
|
is_admin_redaction = False
|
||||||
if event.type == EventTypes.Redaction:
|
if event.type == EventTypes.Redaction:
|
||||||
original_event = yield self.store.get_event(
|
original_event = await self.store.get_event(
|
||||||
event.redacts,
|
event.redacts,
|
||||||
redact_behaviour=EventRedactBehaviour.AS_IS,
|
redact_behaviour=EventRedactBehaviour.AS_IS,
|
||||||
get_prev_content=False,
|
get_prev_content=False,
|
||||||
@ -913,11 +909,11 @@ class EventCreationHandler(object):
|
|||||||
original_event and event.sender != original_event.sender
|
original_event and event.sender != original_event.sender
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.base_handler.ratelimit(
|
await self.base_handler.ratelimit(
|
||||||
requester, is_admin_redaction=is_admin_redaction
|
requester, is_admin_redaction=is_admin_redaction
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.base_handler.maybe_kick_guest_users(event, context)
|
await self.base_handler.maybe_kick_guest_users(event, context)
|
||||||
|
|
||||||
if event.type == EventTypes.CanonicalAlias:
|
if event.type == EventTypes.CanonicalAlias:
|
||||||
# Validate a newly added alias or newly added alt_aliases.
|
# Validate a newly added alias or newly added alt_aliases.
|
||||||
@ -927,7 +923,7 @@ class EventCreationHandler(object):
|
|||||||
|
|
||||||
original_event_id = event.unsigned.get("replaces_state")
|
original_event_id = event.unsigned.get("replaces_state")
|
||||||
if original_event_id:
|
if original_event_id:
|
||||||
original_event = yield self.store.get_event(original_event_id)
|
original_event = await self.store.get_event(original_event_id)
|
||||||
|
|
||||||
if original_event:
|
if original_event:
|
||||||
original_alias = original_event.content.get("alias", None)
|
original_alias = original_event.content.get("alias", None)
|
||||||
@ -937,7 +933,7 @@ class EventCreationHandler(object):
|
|||||||
room_alias_str = event.content.get("alias", None)
|
room_alias_str = event.content.get("alias", None)
|
||||||
directory_handler = self.hs.get_handlers().directory_handler
|
directory_handler = self.hs.get_handlers().directory_handler
|
||||||
if room_alias_str and room_alias_str != original_alias:
|
if room_alias_str and room_alias_str != original_alias:
|
||||||
yield self._validate_canonical_alias(
|
await self._validate_canonical_alias(
|
||||||
directory_handler, room_alias_str, event.room_id
|
directory_handler, room_alias_str, event.room_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -957,7 +953,7 @@ class EventCreationHandler(object):
|
|||||||
new_alt_aliases = set(alt_aliases) - set(original_alt_aliases)
|
new_alt_aliases = set(alt_aliases) - set(original_alt_aliases)
|
||||||
if new_alt_aliases:
|
if new_alt_aliases:
|
||||||
for alias_str in new_alt_aliases:
|
for alias_str in new_alt_aliases:
|
||||||
yield self._validate_canonical_alias(
|
await self._validate_canonical_alias(
|
||||||
directory_handler, alias_str, event.room_id
|
directory_handler, alias_str, event.room_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -969,7 +965,7 @@ class EventCreationHandler(object):
|
|||||||
def is_inviter_member_event(e):
|
def is_inviter_member_event(e):
|
||||||
return e.type == EventTypes.Member and e.sender == event.sender
|
return e.type == EventTypes.Member and e.sender == event.sender
|
||||||
|
|
||||||
current_state_ids = yield context.get_current_state_ids()
|
current_state_ids = await context.get_current_state_ids()
|
||||||
|
|
||||||
state_to_include_ids = [
|
state_to_include_ids = [
|
||||||
e_id
|
e_id
|
||||||
@ -978,7 +974,7 @@ class EventCreationHandler(object):
|
|||||||
or k == (EventTypes.Member, event.sender)
|
or k == (EventTypes.Member, event.sender)
|
||||||
]
|
]
|
||||||
|
|
||||||
state_to_include = yield self.store.get_events(state_to_include_ids)
|
state_to_include = await self.store.get_events(state_to_include_ids)
|
||||||
|
|
||||||
event.unsigned["invite_room_state"] = [
|
event.unsigned["invite_room_state"] = [
|
||||||
{
|
{
|
||||||
@ -996,8 +992,8 @@ class EventCreationHandler(object):
|
|||||||
# way? If we have been invited by a remote server, we need
|
# way? If we have been invited by a remote server, we need
|
||||||
# to get them to sign the event.
|
# to get them to sign the event.
|
||||||
|
|
||||||
returned_invite = yield defer.ensureDeferred(
|
returned_invite = await federation_handler.send_invite(
|
||||||
federation_handler.send_invite(invitee.domain, event)
|
invitee.domain, event
|
||||||
)
|
)
|
||||||
event.unsigned.pop("room_state", None)
|
event.unsigned.pop("room_state", None)
|
||||||
|
|
||||||
@ -1005,7 +1001,7 @@ class EventCreationHandler(object):
|
|||||||
event.signatures.update(returned_invite.signatures)
|
event.signatures.update(returned_invite.signatures)
|
||||||
|
|
||||||
if event.type == EventTypes.Redaction:
|
if event.type == EventTypes.Redaction:
|
||||||
original_event = yield self.store.get_event(
|
original_event = await self.store.get_event(
|
||||||
event.redacts,
|
event.redacts,
|
||||||
redact_behaviour=EventRedactBehaviour.AS_IS,
|
redact_behaviour=EventRedactBehaviour.AS_IS,
|
||||||
get_prev_content=False,
|
get_prev_content=False,
|
||||||
@ -1021,14 +1017,14 @@ class EventCreationHandler(object):
|
|||||||
if original_event.room_id != event.room_id:
|
if original_event.room_id != event.room_id:
|
||||||
raise SynapseError(400, "Cannot redact event from a different room")
|
raise SynapseError(400, "Cannot redact event from a different room")
|
||||||
|
|
||||||
prev_state_ids = yield context.get_prev_state_ids()
|
prev_state_ids = await context.get_prev_state_ids()
|
||||||
auth_events_ids = yield self.auth.compute_auth_events(
|
auth_events_ids = await self.auth.compute_auth_events(
|
||||||
event, prev_state_ids, for_verification=True
|
event, prev_state_ids, for_verification=True
|
||||||
)
|
)
|
||||||
auth_events = yield self.store.get_events(auth_events_ids)
|
auth_events = await self.store.get_events(auth_events_ids)
|
||||||
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
|
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
|
||||||
|
|
||||||
room_version = yield self.store.get_room_version_id(event.room_id)
|
room_version = await self.store.get_room_version_id(event.room_id)
|
||||||
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
|
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
|
||||||
|
|
||||||
if event_auth.check_redaction(
|
if event_auth.check_redaction(
|
||||||
@ -1047,11 +1043,11 @@ class EventCreationHandler(object):
|
|||||||
event.internal_metadata.recheck_redaction = False
|
event.internal_metadata.recheck_redaction = False
|
||||||
|
|
||||||
if event.type == EventTypes.Create:
|
if event.type == EventTypes.Create:
|
||||||
prev_state_ids = yield context.get_prev_state_ids()
|
prev_state_ids = await context.get_prev_state_ids()
|
||||||
if prev_state_ids:
|
if prev_state_ids:
|
||||||
raise AuthError(403, "Changing the room create event is forbidden")
|
raise AuthError(403, "Changing the room create event is forbidden")
|
||||||
|
|
||||||
event_stream_id, max_stream_id = yield self.storage.persistence.persist_event(
|
event_stream_id, max_stream_id = await self.storage.persistence.persist_event(
|
||||||
event, context=context
|
event, context=context
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1059,7 +1055,7 @@ class EventCreationHandler(object):
|
|||||||
# If there's an expiry timestamp on the event, schedule its expiry.
|
# If there's an expiry timestamp on the event, schedule its expiry.
|
||||||
self._message_handler.maybe_schedule_expiry(event)
|
self._message_handler.maybe_schedule_expiry(event)
|
||||||
|
|
||||||
yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
|
await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
|
||||||
|
|
||||||
def _notify():
|
def _notify():
|
||||||
try:
|
try:
|
||||||
@ -1083,13 +1079,12 @@ class EventCreationHandler(object):
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error bumping presence active time")
|
logger.exception("Error bumping presence active time")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _send_dummy_events_to_fill_extremities(self):
|
||||||
def _send_dummy_events_to_fill_extremities(self):
|
|
||||||
"""Background task to send dummy events into rooms that have a large
|
"""Background task to send dummy events into rooms that have a large
|
||||||
number of extremities
|
number of extremities
|
||||||
"""
|
"""
|
||||||
self._expire_rooms_to_exclude_from_dummy_event_insertion()
|
self._expire_rooms_to_exclude_from_dummy_event_insertion()
|
||||||
room_ids = yield self.store.get_rooms_with_many_extremities(
|
room_ids = await self.store.get_rooms_with_many_extremities(
|
||||||
min_count=10,
|
min_count=10,
|
||||||
limit=5,
|
limit=5,
|
||||||
room_id_filter=self._rooms_to_exclude_from_dummy_event_insertion.keys(),
|
room_id_filter=self._rooms_to_exclude_from_dummy_event_insertion.keys(),
|
||||||
@ -1099,9 +1094,9 @@ class EventCreationHandler(object):
|
|||||||
# For each room we need to find a joined member we can use to send
|
# For each room we need to find a joined member we can use to send
|
||||||
# the dummy event with.
|
# the dummy event with.
|
||||||
|
|
||||||
latest_event_ids = yield self.store.get_prev_events_for_room(room_id)
|
latest_event_ids = await self.store.get_prev_events_for_room(room_id)
|
||||||
|
|
||||||
members = yield self.state.get_current_users_in_room(
|
members = await self.state.get_current_users_in_room(
|
||||||
room_id, latest_event_ids=latest_event_ids
|
room_id, latest_event_ids=latest_event_ids
|
||||||
)
|
)
|
||||||
dummy_event_sent = False
|
dummy_event_sent = False
|
||||||
@ -1110,7 +1105,7 @@ class EventCreationHandler(object):
|
|||||||
continue
|
continue
|
||||||
requester = create_requester(user_id)
|
requester = create_requester(user_id)
|
||||||
try:
|
try:
|
||||||
event, context = yield self.create_event(
|
event, context = await self.create_event(
|
||||||
requester,
|
requester,
|
||||||
{
|
{
|
||||||
"type": "org.matrix.dummy_event",
|
"type": "org.matrix.dummy_event",
|
||||||
@ -1123,7 +1118,7 @@ class EventCreationHandler(object):
|
|||||||
|
|
||||||
event.internal_metadata.proactively_send = False
|
event.internal_metadata.proactively_send = False
|
||||||
|
|
||||||
yield self.send_nonmember_event(
|
await self.send_nonmember_event(
|
||||||
requester, event, context, ratelimit=False
|
requester, event, context, ratelimit=False
|
||||||
)
|
)
|
||||||
dummy_event_sent = True
|
dummy_event_sent = True
|
||||||
|
@ -141,8 +141,9 @@ class BaseProfileHandler(BaseHandler):
|
|||||||
|
|
||||||
return result["displayname"]
|
return result["displayname"]
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def set_displayname(
|
||||||
def set_displayname(self, target_user, requester, new_displayname, by_admin=False):
|
self, target_user, requester, new_displayname, by_admin=False
|
||||||
|
):
|
||||||
"""Set the displayname of a user
|
"""Set the displayname of a user
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -158,7 +159,7 @@ class BaseProfileHandler(BaseHandler):
|
|||||||
raise AuthError(400, "Cannot set another user's displayname")
|
raise AuthError(400, "Cannot set another user's displayname")
|
||||||
|
|
||||||
if not by_admin and not self.hs.config.enable_set_displayname:
|
if not by_admin and not self.hs.config.enable_set_displayname:
|
||||||
profile = yield self.store.get_profileinfo(target_user.localpart)
|
profile = await self.store.get_profileinfo(target_user.localpart)
|
||||||
if profile.display_name:
|
if profile.display_name:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
400,
|
400,
|
||||||
@ -180,15 +181,15 @@ class BaseProfileHandler(BaseHandler):
|
|||||||
if by_admin:
|
if by_admin:
|
||||||
requester = create_requester(target_user)
|
requester = create_requester(target_user)
|
||||||
|
|
||||||
yield self.store.set_profile_displayname(target_user.localpart, new_displayname)
|
await self.store.set_profile_displayname(target_user.localpart, new_displayname)
|
||||||
|
|
||||||
if self.hs.config.user_directory_search_all_users:
|
if self.hs.config.user_directory_search_all_users:
|
||||||
profile = yield self.store.get_profileinfo(target_user.localpart)
|
profile = await self.store.get_profileinfo(target_user.localpart)
|
||||||
yield self.user_directory_handler.handle_local_profile_change(
|
await self.user_directory_handler.handle_local_profile_change(
|
||||||
target_user.to_string(), profile
|
target_user.to_string(), profile
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self._update_join_states(requester, target_user)
|
await self._update_join_states(requester, target_user)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_avatar_url(self, target_user):
|
def get_avatar_url(self, target_user):
|
||||||
@ -217,8 +218,9 @@ class BaseProfileHandler(BaseHandler):
|
|||||||
|
|
||||||
return result["avatar_url"]
|
return result["avatar_url"]
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def set_avatar_url(
|
||||||
def set_avatar_url(self, target_user, requester, new_avatar_url, by_admin=False):
|
self, target_user, requester, new_avatar_url, by_admin=False
|
||||||
|
):
|
||||||
"""target_user is the user whose avatar_url is to be changed;
|
"""target_user is the user whose avatar_url is to be changed;
|
||||||
auth_user is the user attempting to make this change."""
|
auth_user is the user attempting to make this change."""
|
||||||
if not self.hs.is_mine(target_user):
|
if not self.hs.is_mine(target_user):
|
||||||
@ -228,7 +230,7 @@ class BaseProfileHandler(BaseHandler):
|
|||||||
raise AuthError(400, "Cannot set another user's avatar_url")
|
raise AuthError(400, "Cannot set another user's avatar_url")
|
||||||
|
|
||||||
if not by_admin and not self.hs.config.enable_set_avatar_url:
|
if not by_admin and not self.hs.config.enable_set_avatar_url:
|
||||||
profile = yield self.store.get_profileinfo(target_user.localpart)
|
profile = await self.store.get_profileinfo(target_user.localpart)
|
||||||
if profile.avatar_url:
|
if profile.avatar_url:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
400, "Changing avatar is disabled on this server", Codes.FORBIDDEN
|
400, "Changing avatar is disabled on this server", Codes.FORBIDDEN
|
||||||
@ -243,15 +245,15 @@ class BaseProfileHandler(BaseHandler):
|
|||||||
if by_admin:
|
if by_admin:
|
||||||
requester = create_requester(target_user)
|
requester = create_requester(target_user)
|
||||||
|
|
||||||
yield self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
|
await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
|
||||||
|
|
||||||
if self.hs.config.user_directory_search_all_users:
|
if self.hs.config.user_directory_search_all_users:
|
||||||
profile = yield self.store.get_profileinfo(target_user.localpart)
|
profile = await self.store.get_profileinfo(target_user.localpart)
|
||||||
yield self.user_directory_handler.handle_local_profile_change(
|
await self.user_directory_handler.handle_local_profile_change(
|
||||||
target_user.to_string(), profile
|
target_user.to_string(), profile
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self._update_join_states(requester, target_user)
|
await self._update_join_states(requester, target_user)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_profile_query(self, args):
|
def on_profile_query(self, args):
|
||||||
@ -279,21 +281,20 @@ class BaseProfileHandler(BaseHandler):
|
|||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _update_join_states(self, requester, target_user):
|
||||||
def _update_join_states(self, requester, target_user):
|
|
||||||
if not self.hs.is_mine(target_user):
|
if not self.hs.is_mine(target_user):
|
||||||
return
|
return
|
||||||
|
|
||||||
yield self.ratelimit(requester)
|
await self.ratelimit(requester)
|
||||||
|
|
||||||
room_ids = yield self.store.get_rooms_for_user(target_user.to_string())
|
room_ids = await self.store.get_rooms_for_user(target_user.to_string())
|
||||||
|
|
||||||
for room_id in room_ids:
|
for room_id in room_ids:
|
||||||
handler = self.hs.get_room_member_handler()
|
handler = self.hs.get_room_member_handler()
|
||||||
try:
|
try:
|
||||||
# Assume the target_user isn't a guest,
|
# Assume the target_user isn't a guest,
|
||||||
# because we don't let guests set profile or avatar data.
|
# because we don't let guests set profile or avatar data.
|
||||||
yield handler.update_membership(
|
await handler.update_membership(
|
||||||
requester,
|
requester,
|
||||||
target_user,
|
target_user,
|
||||||
room_id,
|
room_id,
|
||||||
|
@ -145,9 +145,9 @@ class RegistrationHandler(BaseHandler):
|
|||||||
"""Registers a new client on the server.
|
"""Registers a new client on the server.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
localpart : The local part of the user ID to register. If None,
|
localpart: The local part of the user ID to register. If None,
|
||||||
one will be generated.
|
one will be generated.
|
||||||
password (unicode) : The password to assign to this user so they can
|
password (unicode): The password to assign to this user so they can
|
||||||
login again. This can be None which means they cannot login again
|
login again. This can be None which means they cannot login again
|
||||||
via a password (e.g. the user is an application service user).
|
via a password (e.g. the user is an application service user).
|
||||||
user_type (str|None): type of user. One of the values from
|
user_type (str|None): type of user. One of the values from
|
||||||
@ -244,7 +244,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
fail_count += 1
|
fail_count += 1
|
||||||
|
|
||||||
if not self.hs.config.user_consent_at_registration:
|
if not self.hs.config.user_consent_at_registration:
|
||||||
yield self._auto_join_rooms(user_id)
|
yield defer.ensureDeferred(self._auto_join_rooms(user_id))
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Skipping auto-join for %s because consent is required at registration",
|
"Skipping auto-join for %s because consent is required at registration",
|
||||||
@ -266,8 +266,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
|
|
||||||
return user_id
|
return user_id
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _auto_join_rooms(self, user_id):
|
||||||
def _auto_join_rooms(self, user_id):
|
|
||||||
"""Automatically joins users to auto join rooms - creating the room in the first place
|
"""Automatically joins users to auto join rooms - creating the room in the first place
|
||||||
if the user is the first to be created.
|
if the user is the first to be created.
|
||||||
|
|
||||||
@ -281,9 +280,9 @@ class RegistrationHandler(BaseHandler):
|
|||||||
# that an auto-generated support or bot user is not a real user and will never be
|
# that an auto-generated support or bot user is not a real user and will never be
|
||||||
# the user to create the room
|
# the user to create the room
|
||||||
should_auto_create_rooms = False
|
should_auto_create_rooms = False
|
||||||
is_real_user = yield self.store.is_real_user(user_id)
|
is_real_user = await self.store.is_real_user(user_id)
|
||||||
if self.hs.config.autocreate_auto_join_rooms and is_real_user:
|
if self.hs.config.autocreate_auto_join_rooms and is_real_user:
|
||||||
count = yield self.store.count_real_users()
|
count = await self.store.count_real_users()
|
||||||
should_auto_create_rooms = count == 1
|
should_auto_create_rooms = count == 1
|
||||||
for r in self.hs.config.auto_join_rooms:
|
for r in self.hs.config.auto_join_rooms:
|
||||||
logger.info("Auto-joining %s to %s", user_id, r)
|
logger.info("Auto-joining %s to %s", user_id, r)
|
||||||
@ -302,7 +301,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
|
|
||||||
# getting the RoomCreationHandler during init gives a dependency
|
# getting the RoomCreationHandler during init gives a dependency
|
||||||
# loop
|
# loop
|
||||||
yield self.hs.get_room_creation_handler().create_room(
|
await self.hs.get_room_creation_handler().create_room(
|
||||||
fake_requester,
|
fake_requester,
|
||||||
config={
|
config={
|
||||||
"preset": "public_chat",
|
"preset": "public_chat",
|
||||||
@ -311,7 +310,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
ratelimit=False,
|
ratelimit=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield self._join_user_to_room(fake_requester, r)
|
await self._join_user_to_room(fake_requester, r)
|
||||||
except ConsentNotGivenError as e:
|
except ConsentNotGivenError as e:
|
||||||
# Technically not necessary to pull out this error though
|
# Technically not necessary to pull out this error though
|
||||||
# moving away from bare excepts is a good thing to do.
|
# moving away from bare excepts is a good thing to do.
|
||||||
@ -319,15 +318,14 @@ class RegistrationHandler(BaseHandler):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to join new user to %r: %r", r, e)
|
logger.error("Failed to join new user to %r: %r", r, e)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def post_consent_actions(self, user_id):
|
||||||
def post_consent_actions(self, user_id):
|
|
||||||
"""A series of registration actions that can only be carried out once consent
|
"""A series of registration actions that can only be carried out once consent
|
||||||
has been granted
|
has been granted
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): The user to join
|
user_id (str): The user to join
|
||||||
"""
|
"""
|
||||||
yield self._auto_join_rooms(user_id)
|
await self._auto_join_rooms(user_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def appservice_register(self, user_localpart, as_token):
|
def appservice_register(self, user_localpart, as_token):
|
||||||
@ -394,14 +392,13 @@ class RegistrationHandler(BaseHandler):
|
|||||||
self._next_generated_user_id += 1
|
self._next_generated_user_id += 1
|
||||||
return str(id)
|
return str(id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _join_user_to_room(self, requester, room_identifier):
|
||||||
def _join_user_to_room(self, requester, room_identifier):
|
|
||||||
room_member_handler = self.hs.get_room_member_handler()
|
room_member_handler = self.hs.get_room_member_handler()
|
||||||
if RoomID.is_valid(room_identifier):
|
if RoomID.is_valid(room_identifier):
|
||||||
room_id = room_identifier
|
room_id = room_identifier
|
||||||
elif RoomAlias.is_valid(room_identifier):
|
elif RoomAlias.is_valid(room_identifier):
|
||||||
room_alias = RoomAlias.from_string(room_identifier)
|
room_alias = RoomAlias.from_string(room_identifier)
|
||||||
room_id, remote_room_hosts = yield room_member_handler.lookup_room_alias(
|
room_id, remote_room_hosts = await room_member_handler.lookup_room_alias(
|
||||||
room_alias
|
room_alias
|
||||||
)
|
)
|
||||||
room_id = room_id.to_string()
|
room_id = room_id.to_string()
|
||||||
@ -410,7 +407,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
400, "%s was not legal room ID or room alias" % (room_identifier,)
|
400, "%s was not legal room ID or room alias" % (room_identifier,)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield room_member_handler.update_membership(
|
await room_member_handler.update_membership(
|
||||||
requester=requester,
|
requester=requester,
|
||||||
target=requester.user,
|
target=requester.user,
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
@ -550,8 +547,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
|
|
||||||
return (device_id, access_token)
|
return (device_id, access_token)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def post_registration_actions(self, user_id, auth_result, access_token):
|
||||||
def post_registration_actions(self, user_id, auth_result, access_token):
|
|
||||||
"""A user has completed registration
|
"""A user has completed registration
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -562,7 +558,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
device, or None if `inhibit_login` enabled.
|
device, or None if `inhibit_login` enabled.
|
||||||
"""
|
"""
|
||||||
if self.hs.config.worker_app:
|
if self.hs.config.worker_app:
|
||||||
yield self._post_registration_client(
|
await self._post_registration_client(
|
||||||
user_id=user_id, auth_result=auth_result, access_token=access_token
|
user_id=user_id, auth_result=auth_result, access_token=access_token
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
@ -574,19 +570,18 @@ class RegistrationHandler(BaseHandler):
|
|||||||
if is_threepid_reserved(
|
if is_threepid_reserved(
|
||||||
self.hs.config.mau_limits_reserved_threepids, threepid
|
self.hs.config.mau_limits_reserved_threepids, threepid
|
||||||
):
|
):
|
||||||
yield self.store.upsert_monthly_active_user(user_id)
|
await self.store.upsert_monthly_active_user(user_id)
|
||||||
|
|
||||||
yield self._register_email_threepid(user_id, threepid, access_token)
|
await self._register_email_threepid(user_id, threepid, access_token)
|
||||||
|
|
||||||
if auth_result and LoginType.MSISDN in auth_result:
|
if auth_result and LoginType.MSISDN in auth_result:
|
||||||
threepid = auth_result[LoginType.MSISDN]
|
threepid = auth_result[LoginType.MSISDN]
|
||||||
yield self._register_msisdn_threepid(user_id, threepid)
|
await self._register_msisdn_threepid(user_id, threepid)
|
||||||
|
|
||||||
if auth_result and LoginType.TERMS in auth_result:
|
if auth_result and LoginType.TERMS in auth_result:
|
||||||
yield self._on_user_consented(user_id, self.hs.config.user_consent_version)
|
await self._on_user_consented(user_id, self.hs.config.user_consent_version)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _on_user_consented(self, user_id, consent_version):
|
||||||
def _on_user_consented(self, user_id, consent_version):
|
|
||||||
"""A user consented to the terms on registration
|
"""A user consented to the terms on registration
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -595,8 +590,8 @@ class RegistrationHandler(BaseHandler):
|
|||||||
consented to.
|
consented to.
|
||||||
"""
|
"""
|
||||||
logger.info("%s has consented to the privacy policy", user_id)
|
logger.info("%s has consented to the privacy policy", user_id)
|
||||||
yield self.store.user_set_consent_version(user_id, consent_version)
|
await self.store.user_set_consent_version(user_id, consent_version)
|
||||||
yield self.post_consent_actions(user_id)
|
await self.post_consent_actions(user_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _register_email_threepid(self, user_id, threepid, token):
|
def _register_email_threepid(self, user_id, threepid, token):
|
||||||
|
@ -148,17 +148,16 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _upgrade_room(
|
||||||
def _upgrade_room(
|
|
||||||
self, requester: Requester, old_room_id: str, new_version: RoomVersion
|
self, requester: Requester, old_room_id: str, new_version: RoomVersion
|
||||||
):
|
):
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
# start by allocating a new room id
|
# start by allocating a new room id
|
||||||
r = yield self.store.get_room(old_room_id)
|
r = await self.store.get_room(old_room_id)
|
||||||
if r is None:
|
if r is None:
|
||||||
raise NotFoundError("Unknown room id %s" % (old_room_id,))
|
raise NotFoundError("Unknown room id %s" % (old_room_id,))
|
||||||
new_room_id = yield self._generate_room_id(
|
new_room_id = await self._generate_room_id(
|
||||||
creator_id=user_id, is_public=r["is_public"], room_version=new_version,
|
creator_id=user_id, is_public=r["is_public"], room_version=new_version,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -169,7 +168,7 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
(
|
(
|
||||||
tombstone_event,
|
tombstone_event,
|
||||||
tombstone_context,
|
tombstone_context,
|
||||||
) = yield self.event_creation_handler.create_event(
|
) = await self.event_creation_handler.create_event(
|
||||||
requester,
|
requester,
|
||||||
{
|
{
|
||||||
"type": EventTypes.Tombstone,
|
"type": EventTypes.Tombstone,
|
||||||
@ -183,12 +182,12 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
},
|
},
|
||||||
token_id=requester.access_token_id,
|
token_id=requester.access_token_id,
|
||||||
)
|
)
|
||||||
old_room_version = yield self.store.get_room_version_id(old_room_id)
|
old_room_version = await self.store.get_room_version_id(old_room_id)
|
||||||
yield self.auth.check_from_context(
|
await self.auth.check_from_context(
|
||||||
old_room_version, tombstone_event, tombstone_context
|
old_room_version, tombstone_event, tombstone_context
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.clone_existing_room(
|
await self.clone_existing_room(
|
||||||
requester,
|
requester,
|
||||||
old_room_id=old_room_id,
|
old_room_id=old_room_id,
|
||||||
new_room_id=new_room_id,
|
new_room_id=new_room_id,
|
||||||
@ -197,32 +196,31 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# now send the tombstone
|
# now send the tombstone
|
||||||
yield self.event_creation_handler.send_nonmember_event(
|
await self.event_creation_handler.send_nonmember_event(
|
||||||
requester, tombstone_event, tombstone_context
|
requester, tombstone_event, tombstone_context
|
||||||
)
|
)
|
||||||
|
|
||||||
old_room_state = yield tombstone_context.get_current_state_ids()
|
old_room_state = await tombstone_context.get_current_state_ids()
|
||||||
|
|
||||||
# update any aliases
|
# update any aliases
|
||||||
yield self._move_aliases_to_new_room(
|
await self._move_aliases_to_new_room(
|
||||||
requester, old_room_id, new_room_id, old_room_state
|
requester, old_room_id, new_room_id, old_room_state
|
||||||
)
|
)
|
||||||
|
|
||||||
# Copy over user push rules, tags and migrate room directory state
|
# Copy over user push rules, tags and migrate room directory state
|
||||||
yield self.room_member_handler.transfer_room_state_on_room_upgrade(
|
await self.room_member_handler.transfer_room_state_on_room_upgrade(
|
||||||
old_room_id, new_room_id
|
old_room_id, new_room_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# finally, shut down the PLs in the old room, and update them in the new
|
# finally, shut down the PLs in the old room, and update them in the new
|
||||||
# room.
|
# room.
|
||||||
yield self._update_upgraded_room_pls(
|
await self._update_upgraded_room_pls(
|
||||||
requester, old_room_id, new_room_id, old_room_state,
|
requester, old_room_id, new_room_id, old_room_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
return new_room_id
|
return new_room_id
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _update_upgraded_room_pls(
|
||||||
def _update_upgraded_room_pls(
|
|
||||||
self,
|
self,
|
||||||
requester: Requester,
|
requester: Requester,
|
||||||
old_room_id: str,
|
old_room_id: str,
|
||||||
@ -249,7 +247,7 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
old_room_pl_state = yield self.store.get_event(old_room_pl_event_id)
|
old_room_pl_state = await self.store.get_event(old_room_pl_event_id)
|
||||||
|
|
||||||
# we try to stop regular users from speaking by setting the PL required
|
# we try to stop regular users from speaking by setting the PL required
|
||||||
# to send regular events and invites to 'Moderator' level. That's normally
|
# to send regular events and invites to 'Moderator' level. That's normally
|
||||||
@ -278,7 +276,7 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
|
|
||||||
if updated:
|
if updated:
|
||||||
try:
|
try:
|
||||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
await self.event_creation_handler.create_and_send_nonmember_event(
|
||||||
requester,
|
requester,
|
||||||
{
|
{
|
||||||
"type": EventTypes.PowerLevels,
|
"type": EventTypes.PowerLevels,
|
||||||
@ -292,7 +290,7 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
except AuthError as e:
|
except AuthError as e:
|
||||||
logger.warning("Unable to update PLs in old room: %s", e)
|
logger.warning("Unable to update PLs in old room: %s", e)
|
||||||
|
|
||||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
await self.event_creation_handler.create_and_send_nonmember_event(
|
||||||
requester,
|
requester,
|
||||||
{
|
{
|
||||||
"type": EventTypes.PowerLevels,
|
"type": EventTypes.PowerLevels,
|
||||||
@ -304,8 +302,7 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
ratelimit=False,
|
ratelimit=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def clone_existing_room(
|
||||||
def clone_existing_room(
|
|
||||||
self,
|
self,
|
||||||
requester: Requester,
|
requester: Requester,
|
||||||
old_room_id: str,
|
old_room_id: str,
|
||||||
@ -338,7 +335,7 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
# Check if old room was non-federatable
|
# Check if old room was non-federatable
|
||||||
|
|
||||||
# Get old room's create event
|
# Get old room's create event
|
||||||
old_room_create_event = yield self.store.get_create_event_for_room(old_room_id)
|
old_room_create_event = await self.store.get_create_event_for_room(old_room_id)
|
||||||
|
|
||||||
# Check if the create event specified a non-federatable room
|
# Check if the create event specified a non-federatable room
|
||||||
if not old_room_create_event.content.get("m.federate", True):
|
if not old_room_create_event.content.get("m.federate", True):
|
||||||
@ -361,11 +358,11 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
(EventTypes.PowerLevels, ""),
|
(EventTypes.PowerLevels, ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
old_room_state_ids = yield self.store.get_filtered_current_state_ids(
|
old_room_state_ids = await self.store.get_filtered_current_state_ids(
|
||||||
old_room_id, StateFilter.from_types(types_to_copy)
|
old_room_id, StateFilter.from_types(types_to_copy)
|
||||||
)
|
)
|
||||||
# map from event_id to BaseEvent
|
# map from event_id to BaseEvent
|
||||||
old_room_state_events = yield self.store.get_events(old_room_state_ids.values())
|
old_room_state_events = await self.store.get_events(old_room_state_ids.values())
|
||||||
|
|
||||||
for k, old_event_id in iteritems(old_room_state_ids):
|
for k, old_event_id in iteritems(old_room_state_ids):
|
||||||
old_event = old_room_state_events.get(old_event_id)
|
old_event = old_room_state_events.get(old_event_id)
|
||||||
@ -400,7 +397,7 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
if current_power_level < needed_power_level:
|
if current_power_level < needed_power_level:
|
||||||
power_levels["users"][user_id] = needed_power_level
|
power_levels["users"][user_id] = needed_power_level
|
||||||
|
|
||||||
yield self._send_events_for_new_room(
|
await self._send_events_for_new_room(
|
||||||
requester,
|
requester,
|
||||||
new_room_id,
|
new_room_id,
|
||||||
# we expect to override all the presets with initial_state, so this is
|
# we expect to override all the presets with initial_state, so this is
|
||||||
@ -412,12 +409,12 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Transfer membership events
|
# Transfer membership events
|
||||||
old_room_member_state_ids = yield self.store.get_filtered_current_state_ids(
|
old_room_member_state_ids = await self.store.get_filtered_current_state_ids(
|
||||||
old_room_id, StateFilter.from_types([(EventTypes.Member, None)])
|
old_room_id, StateFilter.from_types([(EventTypes.Member, None)])
|
||||||
)
|
)
|
||||||
|
|
||||||
# map from event_id to BaseEvent
|
# map from event_id to BaseEvent
|
||||||
old_room_member_state_events = yield self.store.get_events(
|
old_room_member_state_events = await self.store.get_events(
|
||||||
old_room_member_state_ids.values()
|
old_room_member_state_ids.values()
|
||||||
)
|
)
|
||||||
for k, old_event in iteritems(old_room_member_state_events):
|
for k, old_event in iteritems(old_room_member_state_events):
|
||||||
@ -426,7 +423,7 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
"membership" in old_event.content
|
"membership" in old_event.content
|
||||||
and old_event.content["membership"] == "ban"
|
and old_event.content["membership"] == "ban"
|
||||||
):
|
):
|
||||||
yield self.room_member_handler.update_membership(
|
await self.room_member_handler.update_membership(
|
||||||
requester,
|
requester,
|
||||||
UserID.from_string(old_event["state_key"]),
|
UserID.from_string(old_event["state_key"]),
|
||||||
new_room_id,
|
new_room_id,
|
||||||
@ -438,8 +435,7 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
# XXX invites/joins
|
# XXX invites/joins
|
||||||
# XXX 3pid invites
|
# XXX 3pid invites
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _move_aliases_to_new_room(
|
||||||
def _move_aliases_to_new_room(
|
|
||||||
self,
|
self,
|
||||||
requester: Requester,
|
requester: Requester,
|
||||||
old_room_id: str,
|
old_room_id: str,
|
||||||
@ -448,13 +444,13 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
):
|
):
|
||||||
directory_handler = self.hs.get_handlers().directory_handler
|
directory_handler = self.hs.get_handlers().directory_handler
|
||||||
|
|
||||||
aliases = yield self.store.get_aliases_for_room(old_room_id)
|
aliases = await self.store.get_aliases_for_room(old_room_id)
|
||||||
|
|
||||||
# check to see if we have a canonical alias.
|
# check to see if we have a canonical alias.
|
||||||
canonical_alias_event = None
|
canonical_alias_event = None
|
||||||
canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, ""))
|
canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, ""))
|
||||||
if canonical_alias_event_id:
|
if canonical_alias_event_id:
|
||||||
canonical_alias_event = yield self.store.get_event(canonical_alias_event_id)
|
canonical_alias_event = await self.store.get_event(canonical_alias_event_id)
|
||||||
|
|
||||||
# first we try to remove the aliases from the old room (we suppress sending
|
# first we try to remove the aliases from the old room (we suppress sending
|
||||||
# the room_aliases event until the end).
|
# the room_aliases event until the end).
|
||||||
@ -472,7 +468,7 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
for alias_str in aliases:
|
for alias_str in aliases:
|
||||||
alias = RoomAlias.from_string(alias_str)
|
alias = RoomAlias.from_string(alias_str)
|
||||||
try:
|
try:
|
||||||
yield directory_handler.delete_association(requester, alias)
|
await directory_handler.delete_association(requester, alias)
|
||||||
removed_aliases.append(alias_str)
|
removed_aliases.append(alias_str)
|
||||||
except SynapseError as e:
|
except SynapseError as e:
|
||||||
logger.warning("Unable to remove alias %s from old room: %s", alias, e)
|
logger.warning("Unable to remove alias %s from old room: %s", alias, e)
|
||||||
@ -485,7 +481,7 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
# we can now add any aliases we successfully removed to the new room.
|
# we can now add any aliases we successfully removed to the new room.
|
||||||
for alias in removed_aliases:
|
for alias in removed_aliases:
|
||||||
try:
|
try:
|
||||||
yield directory_handler.create_association(
|
await directory_handler.create_association(
|
||||||
requester,
|
requester,
|
||||||
RoomAlias.from_string(alias),
|
RoomAlias.from_string(alias),
|
||||||
new_room_id,
|
new_room_id,
|
||||||
@ -502,7 +498,7 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
# alias event for the new room with a copy of the information.
|
# alias event for the new room with a copy of the information.
|
||||||
try:
|
try:
|
||||||
if canonical_alias_event:
|
if canonical_alias_event:
|
||||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
await self.event_creation_handler.create_and_send_nonmember_event(
|
||||||
requester,
|
requester,
|
||||||
{
|
{
|
||||||
"type": EventTypes.CanonicalAlias,
|
"type": EventTypes.CanonicalAlias,
|
||||||
@ -518,8 +514,9 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
# we returned the new room to the client at this point.
|
# we returned the new room to the client at this point.
|
||||||
logger.error("Unable to send updated alias events in new room: %s", e)
|
logger.error("Unable to send updated alias events in new room: %s", e)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def create_room(
|
||||||
def create_room(self, requester, config, ratelimit=True, creator_join_profile=None):
|
self, requester, config, ratelimit=True, creator_join_profile=None
|
||||||
|
):
|
||||||
""" Creates a new room.
|
""" Creates a new room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -547,7 +544,7 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
"""
|
"""
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
yield self.auth.check_auth_blocking(user_id)
|
await self.auth.check_auth_blocking(user_id)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self._server_notices_mxid is not None
|
self._server_notices_mxid is not None
|
||||||
@ -556,11 +553,11 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
# allow the server notices mxid to create rooms
|
# allow the server notices mxid to create rooms
|
||||||
is_requester_admin = True
|
is_requester_admin = True
|
||||||
else:
|
else:
|
||||||
is_requester_admin = yield self.auth.is_server_admin(requester.user)
|
is_requester_admin = await self.auth.is_server_admin(requester.user)
|
||||||
|
|
||||||
# Check whether the third party rules allows/changes the room create
|
# Check whether the third party rules allows/changes the room create
|
||||||
# request.
|
# request.
|
||||||
event_allowed = yield self.third_party_event_rules.on_create_room(
|
event_allowed = await self.third_party_event_rules.on_create_room(
|
||||||
requester, config, is_requester_admin=is_requester_admin
|
requester, config, is_requester_admin=is_requester_admin
|
||||||
)
|
)
|
||||||
if not event_allowed:
|
if not event_allowed:
|
||||||
@ -574,7 +571,7 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
raise SynapseError(403, "You are not permitted to create rooms")
|
raise SynapseError(403, "You are not permitted to create rooms")
|
||||||
|
|
||||||
if ratelimit:
|
if ratelimit:
|
||||||
yield self.ratelimit(requester)
|
await self.ratelimit(requester)
|
||||||
|
|
||||||
room_version_id = config.get(
|
room_version_id = config.get(
|
||||||
"room_version", self.config.default_room_version.identifier
|
"room_version", self.config.default_room_version.identifier
|
||||||
@ -597,7 +594,7 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
raise SynapseError(400, "Invalid characters in room alias")
|
raise SynapseError(400, "Invalid characters in room alias")
|
||||||
|
|
||||||
room_alias = RoomAlias(config["room_alias_name"], self.hs.hostname)
|
room_alias = RoomAlias(config["room_alias_name"], self.hs.hostname)
|
||||||
mapping = yield self.store.get_association_from_room_alias(room_alias)
|
mapping = await self.store.get_association_from_room_alias(room_alias)
|
||||||
|
|
||||||
if mapping:
|
if mapping:
|
||||||
raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE)
|
raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE)
|
||||||
@ -612,7 +609,7 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
except Exception:
|
except Exception:
|
||||||
raise SynapseError(400, "Invalid user_id: %s" % (i,))
|
raise SynapseError(400, "Invalid user_id: %s" % (i,))
|
||||||
|
|
||||||
yield self.event_creation_handler.assert_accepted_privacy_policy(requester)
|
await self.event_creation_handler.assert_accepted_privacy_policy(requester)
|
||||||
|
|
||||||
power_level_content_override = config.get("power_level_content_override")
|
power_level_content_override = config.get("power_level_content_override")
|
||||||
if (
|
if (
|
||||||
@ -631,13 +628,13 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
visibility = config.get("visibility", None)
|
visibility = config.get("visibility", None)
|
||||||
is_public = visibility == "public"
|
is_public = visibility == "public"
|
||||||
|
|
||||||
room_id = yield self._generate_room_id(
|
room_id = await self._generate_room_id(
|
||||||
creator_id=user_id, is_public=is_public, room_version=room_version,
|
creator_id=user_id, is_public=is_public, room_version=room_version,
|
||||||
)
|
)
|
||||||
|
|
||||||
directory_handler = self.hs.get_handlers().directory_handler
|
directory_handler = self.hs.get_handlers().directory_handler
|
||||||
if room_alias:
|
if room_alias:
|
||||||
yield directory_handler.create_association(
|
await directory_handler.create_association(
|
||||||
requester=requester,
|
requester=requester,
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
room_alias=room_alias,
|
room_alias=room_alias,
|
||||||
@ -670,7 +667,7 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
# override any attempt to set room versions via the creation_content
|
# override any attempt to set room versions via the creation_content
|
||||||
creation_content["room_version"] = room_version.identifier
|
creation_content["room_version"] = room_version.identifier
|
||||||
|
|
||||||
yield self._send_events_for_new_room(
|
await self._send_events_for_new_room(
|
||||||
requester,
|
requester,
|
||||||
room_id,
|
room_id,
|
||||||
preset_config=preset_config,
|
preset_config=preset_config,
|
||||||
@ -684,7 +681,7 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
|
|
||||||
if "name" in config:
|
if "name" in config:
|
||||||
name = config["name"]
|
name = config["name"]
|
||||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
await self.event_creation_handler.create_and_send_nonmember_event(
|
||||||
requester,
|
requester,
|
||||||
{
|
{
|
||||||
"type": EventTypes.Name,
|
"type": EventTypes.Name,
|
||||||
@ -698,7 +695,7 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
|
|
||||||
if "topic" in config:
|
if "topic" in config:
|
||||||
topic = config["topic"]
|
topic = config["topic"]
|
||||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
await self.event_creation_handler.create_and_send_nonmember_event(
|
||||||
requester,
|
requester,
|
||||||
{
|
{
|
||||||
"type": EventTypes.Topic,
|
"type": EventTypes.Topic,
|
||||||
@ -716,7 +713,7 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
if is_direct:
|
if is_direct:
|
||||||
content["is_direct"] = is_direct
|
content["is_direct"] = is_direct
|
||||||
|
|
||||||
yield self.room_member_handler.update_membership(
|
await self.room_member_handler.update_membership(
|
||||||
requester,
|
requester,
|
||||||
UserID.from_string(invitee),
|
UserID.from_string(invitee),
|
||||||
room_id,
|
room_id,
|
||||||
@ -730,7 +727,7 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
id_access_token = invite_3pid.get("id_access_token") # optional
|
id_access_token = invite_3pid.get("id_access_token") # optional
|
||||||
address = invite_3pid["address"]
|
address = invite_3pid["address"]
|
||||||
medium = invite_3pid["medium"]
|
medium = invite_3pid["medium"]
|
||||||
yield self.hs.get_room_member_handler().do_3pid_invite(
|
await self.hs.get_room_member_handler().do_3pid_invite(
|
||||||
room_id,
|
room_id,
|
||||||
requester.user,
|
requester.user,
|
||||||
medium,
|
medium,
|
||||||
@ -748,8 +745,7 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _send_events_for_new_room(
|
||||||
def _send_events_for_new_room(
|
|
||||||
self,
|
self,
|
||||||
creator, # A Requester object.
|
creator, # A Requester object.
|
||||||
room_id,
|
room_id,
|
||||||
@ -769,11 +765,10 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
|
|
||||||
return e
|
return e
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def send(etype, content, **kwargs):
|
||||||
def send(etype, content, **kwargs):
|
|
||||||
event = create(etype, content, **kwargs)
|
event = create(etype, content, **kwargs)
|
||||||
logger.debug("Sending %s in new room", etype)
|
logger.debug("Sending %s in new room", etype)
|
||||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
await self.event_creation_handler.create_and_send_nonmember_event(
|
||||||
creator, event, ratelimit=False
|
creator, event, ratelimit=False
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -784,10 +779,10 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
|
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
|
||||||
|
|
||||||
creation_content.update({"creator": creator_id})
|
creation_content.update({"creator": creator_id})
|
||||||
yield send(etype=EventTypes.Create, content=creation_content)
|
await send(etype=EventTypes.Create, content=creation_content)
|
||||||
|
|
||||||
logger.debug("Sending %s in new room", EventTypes.Member)
|
logger.debug("Sending %s in new room", EventTypes.Member)
|
||||||
yield self.room_member_handler.update_membership(
|
await self.room_member_handler.update_membership(
|
||||||
creator,
|
creator,
|
||||||
creator.user,
|
creator.user,
|
||||||
room_id,
|
room_id,
|
||||||
@ -800,7 +795,7 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
# of the first events that get sent into a room.
|
# of the first events that get sent into a room.
|
||||||
pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None)
|
pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None)
|
||||||
if pl_content is not None:
|
if pl_content is not None:
|
||||||
yield send(etype=EventTypes.PowerLevels, content=pl_content)
|
await send(etype=EventTypes.PowerLevels, content=pl_content)
|
||||||
else:
|
else:
|
||||||
power_level_content = {
|
power_level_content = {
|
||||||
"users": {creator_id: 100},
|
"users": {creator_id: 100},
|
||||||
@ -833,33 +828,33 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
if power_level_content_override:
|
if power_level_content_override:
|
||||||
power_level_content.update(power_level_content_override)
|
power_level_content.update(power_level_content_override)
|
||||||
|
|
||||||
yield send(etype=EventTypes.PowerLevels, content=power_level_content)
|
await send(etype=EventTypes.PowerLevels, content=power_level_content)
|
||||||
|
|
||||||
if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
|
if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
|
||||||
yield send(
|
await send(
|
||||||
etype=EventTypes.CanonicalAlias,
|
etype=EventTypes.CanonicalAlias,
|
||||||
content={"alias": room_alias.to_string()},
|
content={"alias": room_alias.to_string()},
|
||||||
)
|
)
|
||||||
|
|
||||||
if (EventTypes.JoinRules, "") not in initial_state:
|
if (EventTypes.JoinRules, "") not in initial_state:
|
||||||
yield send(
|
await send(
|
||||||
etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]}
|
etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]}
|
||||||
)
|
)
|
||||||
|
|
||||||
if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
|
if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
|
||||||
yield send(
|
await send(
|
||||||
etype=EventTypes.RoomHistoryVisibility,
|
etype=EventTypes.RoomHistoryVisibility,
|
||||||
content={"history_visibility": config["history_visibility"]},
|
content={"history_visibility": config["history_visibility"]},
|
||||||
)
|
)
|
||||||
|
|
||||||
if config["guest_can_join"]:
|
if config["guest_can_join"]:
|
||||||
if (EventTypes.GuestAccess, "") not in initial_state:
|
if (EventTypes.GuestAccess, "") not in initial_state:
|
||||||
yield send(
|
await send(
|
||||||
etype=EventTypes.GuestAccess, content={"guest_access": "can_join"}
|
etype=EventTypes.GuestAccess, content={"guest_access": "can_join"}
|
||||||
)
|
)
|
||||||
|
|
||||||
for (etype, state_key), content in initial_state.items():
|
for (etype, state_key), content in initial_state.items():
|
||||||
yield send(etype=etype, state_key=state_key, content=content)
|
await send(etype=etype, state_key=state_key, content=content)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _generate_room_id(
|
def _generate_room_id(
|
||||||
|
@ -142,8 +142,7 @@ class RoomMemberHandler(object):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _local_membership_update(
|
||||||
def _local_membership_update(
|
|
||||||
self,
|
self,
|
||||||
requester,
|
requester,
|
||||||
target,
|
target,
|
||||||
@ -164,7 +163,7 @@ class RoomMemberHandler(object):
|
|||||||
if requester.is_guest:
|
if requester.is_guest:
|
||||||
content["kind"] = "guest"
|
content["kind"] = "guest"
|
||||||
|
|
||||||
event, context = yield self.event_creation_handler.create_event(
|
event, context = await self.event_creation_handler.create_event(
|
||||||
requester,
|
requester,
|
||||||
{
|
{
|
||||||
"type": EventTypes.Member,
|
"type": EventTypes.Member,
|
||||||
@ -182,18 +181,18 @@ class RoomMemberHandler(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Check if this event matches the previous membership event for the user.
|
# Check if this event matches the previous membership event for the user.
|
||||||
duplicate = yield self.event_creation_handler.deduplicate_state_event(
|
duplicate = await self.event_creation_handler.deduplicate_state_event(
|
||||||
event, context
|
event, context
|
||||||
)
|
)
|
||||||
if duplicate is not None:
|
if duplicate is not None:
|
||||||
# Discard the new event since this membership change is a no-op.
|
# Discard the new event since this membership change is a no-op.
|
||||||
return duplicate
|
return duplicate
|
||||||
|
|
||||||
yield self.event_creation_handler.handle_new_client_event(
|
await self.event_creation_handler.handle_new_client_event(
|
||||||
requester, event, context, extra_users=[target], ratelimit=ratelimit
|
requester, event, context, extra_users=[target], ratelimit=ratelimit
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_state_ids = yield context.get_prev_state_ids()
|
prev_state_ids = await context.get_prev_state_ids()
|
||||||
|
|
||||||
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
|
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
|
||||||
|
|
||||||
@ -203,15 +202,15 @@ class RoomMemberHandler(object):
|
|||||||
# info.
|
# info.
|
||||||
newly_joined = True
|
newly_joined = True
|
||||||
if prev_member_event_id:
|
if prev_member_event_id:
|
||||||
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
prev_member_event = await self.store.get_event(prev_member_event_id)
|
||||||
newly_joined = prev_member_event.membership != Membership.JOIN
|
newly_joined = prev_member_event.membership != Membership.JOIN
|
||||||
if newly_joined:
|
if newly_joined:
|
||||||
yield self._user_joined_room(target, room_id)
|
await self._user_joined_room(target, room_id)
|
||||||
elif event.membership == Membership.LEAVE:
|
elif event.membership == Membership.LEAVE:
|
||||||
if prev_member_event_id:
|
if prev_member_event_id:
|
||||||
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
prev_member_event = await self.store.get_event(prev_member_event_id)
|
||||||
if prev_member_event.membership == Membership.JOIN:
|
if prev_member_event.membership == Membership.JOIN:
|
||||||
yield self._user_left_room(target, room_id)
|
await self._user_left_room(target, room_id)
|
||||||
|
|
||||||
return event
|
return event
|
||||||
|
|
||||||
@ -253,8 +252,7 @@ class RoomMemberHandler(object):
|
|||||||
for tag, tag_content in room_tags.items():
|
for tag, tag_content in room_tags.items():
|
||||||
yield self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content)
|
yield self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def update_membership(
|
||||||
def update_membership(
|
|
||||||
self,
|
self,
|
||||||
requester,
|
requester,
|
||||||
target,
|
target,
|
||||||
@ -269,8 +267,8 @@ class RoomMemberHandler(object):
|
|||||||
):
|
):
|
||||||
key = (room_id,)
|
key = (room_id,)
|
||||||
|
|
||||||
with (yield self.member_linearizer.queue(key)):
|
with (await self.member_linearizer.queue(key)):
|
||||||
result = yield self._update_membership(
|
result = await self._update_membership(
|
||||||
requester,
|
requester,
|
||||||
target,
|
target,
|
||||||
room_id,
|
room_id,
|
||||||
@ -285,8 +283,7 @@ class RoomMemberHandler(object):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _update_membership(
|
||||||
def _update_membership(
|
|
||||||
self,
|
self,
|
||||||
requester,
|
requester,
|
||||||
target,
|
target,
|
||||||
@ -321,7 +318,7 @@ class RoomMemberHandler(object):
|
|||||||
# if this is a join with a 3pid signature, we may need to turn a 3pid
|
# if this is a join with a 3pid signature, we may need to turn a 3pid
|
||||||
# invite into a normal invite before we can handle the join.
|
# invite into a normal invite before we can handle the join.
|
||||||
if third_party_signed is not None:
|
if third_party_signed is not None:
|
||||||
yield self.federation_handler.exchange_third_party_invite(
|
await self.federation_handler.exchange_third_party_invite(
|
||||||
third_party_signed["sender"],
|
third_party_signed["sender"],
|
||||||
target.to_string(),
|
target.to_string(),
|
||||||
room_id,
|
room_id,
|
||||||
@ -332,7 +329,7 @@ class RoomMemberHandler(object):
|
|||||||
remote_room_hosts = []
|
remote_room_hosts = []
|
||||||
|
|
||||||
if effective_membership_state not in ("leave", "ban"):
|
if effective_membership_state not in ("leave", "ban"):
|
||||||
is_blocked = yield self.store.is_room_blocked(room_id)
|
is_blocked = await self.store.is_room_blocked(room_id)
|
||||||
if is_blocked:
|
if is_blocked:
|
||||||
raise SynapseError(403, "This room has been blocked on this server")
|
raise SynapseError(403, "This room has been blocked on this server")
|
||||||
|
|
||||||
@ -351,7 +348,7 @@ class RoomMemberHandler(object):
|
|||||||
is_requester_admin = True
|
is_requester_admin = True
|
||||||
|
|
||||||
else:
|
else:
|
||||||
is_requester_admin = yield self.auth.is_server_admin(requester.user)
|
is_requester_admin = await self.auth.is_server_admin(requester.user)
|
||||||
|
|
||||||
if not is_requester_admin:
|
if not is_requester_admin:
|
||||||
if self.config.block_non_admin_invites:
|
if self.config.block_non_admin_invites:
|
||||||
@ -370,9 +367,9 @@ class RoomMemberHandler(object):
|
|||||||
if block_invite:
|
if block_invite:
|
||||||
raise SynapseError(403, "Invites have been disabled on this server")
|
raise SynapseError(403, "Invites have been disabled on this server")
|
||||||
|
|
||||||
latest_event_ids = yield self.store.get_prev_events_for_room(room_id)
|
latest_event_ids = await self.store.get_prev_events_for_room(room_id)
|
||||||
|
|
||||||
current_state_ids = yield self.state_handler.get_current_state_ids(
|
current_state_ids = await self.state_handler.get_current_state_ids(
|
||||||
room_id, latest_event_ids=latest_event_ids
|
room_id, latest_event_ids=latest_event_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -381,7 +378,7 @@ class RoomMemberHandler(object):
|
|||||||
# transitions and generic otherwise
|
# transitions and generic otherwise
|
||||||
old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
|
old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
|
||||||
if old_state_id:
|
if old_state_id:
|
||||||
old_state = yield self.store.get_event(old_state_id, allow_none=True)
|
old_state = await self.store.get_event(old_state_id, allow_none=True)
|
||||||
old_membership = old_state.content.get("membership") if old_state else None
|
old_membership = old_state.content.get("membership") if old_state else None
|
||||||
if action == "unban" and old_membership != "ban":
|
if action == "unban" and old_membership != "ban":
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
@ -413,7 +410,7 @@ class RoomMemberHandler(object):
|
|||||||
old_membership == Membership.INVITE
|
old_membership == Membership.INVITE
|
||||||
and effective_membership_state == Membership.LEAVE
|
and effective_membership_state == Membership.LEAVE
|
||||||
):
|
):
|
||||||
is_blocked = yield self._is_server_notice_room(room_id)
|
is_blocked = await self._is_server_notice_room(room_id)
|
||||||
if is_blocked:
|
if is_blocked:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
http_client.FORBIDDEN,
|
http_client.FORBIDDEN,
|
||||||
@ -424,18 +421,18 @@ class RoomMemberHandler(object):
|
|||||||
if action == "kick":
|
if action == "kick":
|
||||||
raise AuthError(403, "The target user is not in the room")
|
raise AuthError(403, "The target user is not in the room")
|
||||||
|
|
||||||
is_host_in_room = yield self._is_host_in_room(current_state_ids)
|
is_host_in_room = await self._is_host_in_room(current_state_ids)
|
||||||
|
|
||||||
if effective_membership_state == Membership.JOIN:
|
if effective_membership_state == Membership.JOIN:
|
||||||
if requester.is_guest:
|
if requester.is_guest:
|
||||||
guest_can_join = yield self._can_guest_join(current_state_ids)
|
guest_can_join = await self._can_guest_join(current_state_ids)
|
||||||
if not guest_can_join:
|
if not guest_can_join:
|
||||||
# This should be an auth check, but guests are a local concept,
|
# This should be an auth check, but guests are a local concept,
|
||||||
# so don't really fit into the general auth process.
|
# so don't really fit into the general auth process.
|
||||||
raise AuthError(403, "Guest access not allowed")
|
raise AuthError(403, "Guest access not allowed")
|
||||||
|
|
||||||
if not is_host_in_room:
|
if not is_host_in_room:
|
||||||
inviter = yield self._get_inviter(target.to_string(), room_id)
|
inviter = await self._get_inviter(target.to_string(), room_id)
|
||||||
if inviter and not self.hs.is_mine(inviter):
|
if inviter and not self.hs.is_mine(inviter):
|
||||||
remote_room_hosts.append(inviter.domain)
|
remote_room_hosts.append(inviter.domain)
|
||||||
|
|
||||||
@ -443,13 +440,13 @@ class RoomMemberHandler(object):
|
|||||||
|
|
||||||
profile = self.profile_handler
|
profile = self.profile_handler
|
||||||
if not content_specified:
|
if not content_specified:
|
||||||
content["displayname"] = yield profile.get_displayname(target)
|
content["displayname"] = await profile.get_displayname(target)
|
||||||
content["avatar_url"] = yield profile.get_avatar_url(target)
|
content["avatar_url"] = await profile.get_avatar_url(target)
|
||||||
|
|
||||||
if requester.is_guest:
|
if requester.is_guest:
|
||||||
content["kind"] = "guest"
|
content["kind"] = "guest"
|
||||||
|
|
||||||
remote_join_response = yield self._remote_join(
|
remote_join_response = await self._remote_join(
|
||||||
requester, remote_room_hosts, room_id, target, content
|
requester, remote_room_hosts, room_id, target, content
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -458,7 +455,7 @@ class RoomMemberHandler(object):
|
|||||||
elif effective_membership_state == Membership.LEAVE:
|
elif effective_membership_state == Membership.LEAVE:
|
||||||
if not is_host_in_room:
|
if not is_host_in_room:
|
||||||
# perhaps we've been invited
|
# perhaps we've been invited
|
||||||
inviter = yield self._get_inviter(target.to_string(), room_id)
|
inviter = await self._get_inviter(target.to_string(), room_id)
|
||||||
if not inviter:
|
if not inviter:
|
||||||
raise SynapseError(404, "Not a known room")
|
raise SynapseError(404, "Not a known room")
|
||||||
|
|
||||||
@ -472,12 +469,12 @@ class RoomMemberHandler(object):
|
|||||||
else:
|
else:
|
||||||
# send the rejection to the inviter's HS.
|
# send the rejection to the inviter's HS.
|
||||||
remote_room_hosts = remote_room_hosts + [inviter.domain]
|
remote_room_hosts = remote_room_hosts + [inviter.domain]
|
||||||
res = yield self._remote_reject_invite(
|
res = await self._remote_reject_invite(
|
||||||
requester, remote_room_hosts, room_id, target, content,
|
requester, remote_room_hosts, room_id, target, content,
|
||||||
)
|
)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
res = yield self._local_membership_update(
|
res = await self._local_membership_update(
|
||||||
requester=requester,
|
requester=requester,
|
||||||
target=target,
|
target=target,
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
@ -572,8 +569,7 @@ class RoomMemberHandler(object):
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def send_membership_event(self, requester, event, context, ratelimit=True):
|
||||||
def send_membership_event(self, requester, event, context, ratelimit=True):
|
|
||||||
"""
|
"""
|
||||||
Change the membership status of a user in a room.
|
Change the membership status of a user in a room.
|
||||||
|
|
||||||
@ -599,27 +595,27 @@ class RoomMemberHandler(object):
|
|||||||
else:
|
else:
|
||||||
requester = types.create_requester(target_user)
|
requester = types.create_requester(target_user)
|
||||||
|
|
||||||
prev_event = yield self.event_creation_handler.deduplicate_state_event(
|
prev_event = await self.event_creation_handler.deduplicate_state_event(
|
||||||
event, context
|
event, context
|
||||||
)
|
)
|
||||||
if prev_event is not None:
|
if prev_event is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
prev_state_ids = yield context.get_prev_state_ids()
|
prev_state_ids = await context.get_prev_state_ids()
|
||||||
if event.membership == Membership.JOIN:
|
if event.membership == Membership.JOIN:
|
||||||
if requester.is_guest:
|
if requester.is_guest:
|
||||||
guest_can_join = yield self._can_guest_join(prev_state_ids)
|
guest_can_join = await self._can_guest_join(prev_state_ids)
|
||||||
if not guest_can_join:
|
if not guest_can_join:
|
||||||
# This should be an auth check, but guests are a local concept,
|
# This should be an auth check, but guests are a local concept,
|
||||||
# so don't really fit into the general auth process.
|
# so don't really fit into the general auth process.
|
||||||
raise AuthError(403, "Guest access not allowed")
|
raise AuthError(403, "Guest access not allowed")
|
||||||
|
|
||||||
if event.membership not in (Membership.LEAVE, Membership.BAN):
|
if event.membership not in (Membership.LEAVE, Membership.BAN):
|
||||||
is_blocked = yield self.store.is_room_blocked(room_id)
|
is_blocked = await self.store.is_room_blocked(room_id)
|
||||||
if is_blocked:
|
if is_blocked:
|
||||||
raise SynapseError(403, "This room has been blocked on this server")
|
raise SynapseError(403, "This room has been blocked on this server")
|
||||||
|
|
||||||
yield self.event_creation_handler.handle_new_client_event(
|
await self.event_creation_handler.handle_new_client_event(
|
||||||
requester, event, context, extra_users=[target_user], ratelimit=ratelimit
|
requester, event, context, extra_users=[target_user], ratelimit=ratelimit
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -633,15 +629,15 @@ class RoomMemberHandler(object):
|
|||||||
# info.
|
# info.
|
||||||
newly_joined = True
|
newly_joined = True
|
||||||
if prev_member_event_id:
|
if prev_member_event_id:
|
||||||
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
prev_member_event = await self.store.get_event(prev_member_event_id)
|
||||||
newly_joined = prev_member_event.membership != Membership.JOIN
|
newly_joined = prev_member_event.membership != Membership.JOIN
|
||||||
if newly_joined:
|
if newly_joined:
|
||||||
yield self._user_joined_room(target_user, room_id)
|
await self._user_joined_room(target_user, room_id)
|
||||||
elif event.membership == Membership.LEAVE:
|
elif event.membership == Membership.LEAVE:
|
||||||
if prev_member_event_id:
|
if prev_member_event_id:
|
||||||
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
prev_member_event = await self.store.get_event(prev_member_event_id)
|
||||||
if prev_member_event.membership == Membership.JOIN:
|
if prev_member_event.membership == Membership.JOIN:
|
||||||
yield self._user_left_room(target_user, room_id)
|
await self._user_left_room(target_user, room_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _can_guest_join(self, current_state_ids):
|
def _can_guest_join(self, current_state_ids):
|
||||||
@ -699,8 +695,7 @@ class RoomMemberHandler(object):
|
|||||||
if invite:
|
if invite:
|
||||||
return UserID.from_string(invite.sender)
|
return UserID.from_string(invite.sender)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def do_3pid_invite(
|
||||||
def do_3pid_invite(
|
|
||||||
self,
|
self,
|
||||||
room_id,
|
room_id,
|
||||||
inviter,
|
inviter,
|
||||||
@ -712,7 +707,7 @@ class RoomMemberHandler(object):
|
|||||||
id_access_token=None,
|
id_access_token=None,
|
||||||
):
|
):
|
||||||
if self.config.block_non_admin_invites:
|
if self.config.block_non_admin_invites:
|
||||||
is_requester_admin = yield self.auth.is_server_admin(requester.user)
|
is_requester_admin = await self.auth.is_server_admin(requester.user)
|
||||||
if not is_requester_admin:
|
if not is_requester_admin:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
403, "Invites have been disabled on this server", Codes.FORBIDDEN
|
403, "Invites have been disabled on this server", Codes.FORBIDDEN
|
||||||
@ -720,9 +715,9 @@ class RoomMemberHandler(object):
|
|||||||
|
|
||||||
# We need to rate limit *before* we send out any 3PID invites, so we
|
# We need to rate limit *before* we send out any 3PID invites, so we
|
||||||
# can't just rely on the standard ratelimiting of events.
|
# can't just rely on the standard ratelimiting of events.
|
||||||
yield self.base_handler.ratelimit(requester)
|
await self.base_handler.ratelimit(requester)
|
||||||
|
|
||||||
can_invite = yield self.third_party_event_rules.check_threepid_can_be_invited(
|
can_invite = await self.third_party_event_rules.check_threepid_can_be_invited(
|
||||||
medium, address, room_id
|
medium, address, room_id
|
||||||
)
|
)
|
||||||
if not can_invite:
|
if not can_invite:
|
||||||
@ -737,16 +732,16 @@ class RoomMemberHandler(object):
|
|||||||
403, "Looking up third-party identifiers is denied from this server"
|
403, "Looking up third-party identifiers is denied from this server"
|
||||||
)
|
)
|
||||||
|
|
||||||
invitee = yield self.identity_handler.lookup_3pid(
|
invitee = await self.identity_handler.lookup_3pid(
|
||||||
id_server, medium, address, id_access_token
|
id_server, medium, address, id_access_token
|
||||||
)
|
)
|
||||||
|
|
||||||
if invitee:
|
if invitee:
|
||||||
yield self.update_membership(
|
await self.update_membership(
|
||||||
requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
|
requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield self._make_and_store_3pid_invite(
|
await self._make_and_store_3pid_invite(
|
||||||
requester,
|
requester,
|
||||||
id_server,
|
id_server,
|
||||||
medium,
|
medium,
|
||||||
@ -757,8 +752,7 @@ class RoomMemberHandler(object):
|
|||||||
id_access_token=id_access_token,
|
id_access_token=id_access_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _make_and_store_3pid_invite(
|
||||||
def _make_and_store_3pid_invite(
|
|
||||||
self,
|
self,
|
||||||
requester,
|
requester,
|
||||||
id_server,
|
id_server,
|
||||||
@ -769,7 +763,7 @@ class RoomMemberHandler(object):
|
|||||||
txn_id,
|
txn_id,
|
||||||
id_access_token=None,
|
id_access_token=None,
|
||||||
):
|
):
|
||||||
room_state = yield self.state_handler.get_current_state(room_id)
|
room_state = await self.state_handler.get_current_state(room_id)
|
||||||
|
|
||||||
inviter_display_name = ""
|
inviter_display_name = ""
|
||||||
inviter_avatar_url = ""
|
inviter_avatar_url = ""
|
||||||
@ -807,7 +801,7 @@ class RoomMemberHandler(object):
|
|||||||
public_keys,
|
public_keys,
|
||||||
fallback_public_key,
|
fallback_public_key,
|
||||||
display_name,
|
display_name,
|
||||||
) = yield self.identity_handler.ask_id_server_for_third_party_invite(
|
) = await self.identity_handler.ask_id_server_for_third_party_invite(
|
||||||
requester=requester,
|
requester=requester,
|
||||||
id_server=id_server,
|
id_server=id_server,
|
||||||
medium=medium,
|
medium=medium,
|
||||||
@ -823,7 +817,7 @@ class RoomMemberHandler(object):
|
|||||||
id_access_token=id_access_token,
|
id_access_token=id_access_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
await self.event_creation_handler.create_and_send_nonmember_event(
|
||||||
requester,
|
requester,
|
||||||
{
|
{
|
||||||
"type": EventTypes.ThirdPartyInvite,
|
"type": EventTypes.ThirdPartyInvite,
|
||||||
@ -917,8 +911,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
|||||||
|
|
||||||
return complexity["v1"] > max_complexity
|
return complexity["v1"] > max_complexity
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
|
||||||
def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
|
|
||||||
"""Implements RoomMemberHandler._remote_join
|
"""Implements RoomMemberHandler._remote_join
|
||||||
"""
|
"""
|
||||||
# filter ourselves out of remote_room_hosts: do_invite_join ignores it
|
# filter ourselves out of remote_room_hosts: do_invite_join ignores it
|
||||||
@ -933,7 +926,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
|||||||
|
|
||||||
if self.hs.config.limit_remote_rooms.enabled:
|
if self.hs.config.limit_remote_rooms.enabled:
|
||||||
# Fetch the room complexity
|
# Fetch the room complexity
|
||||||
too_complex = yield self._is_remote_room_too_complex(
|
too_complex = await self._is_remote_room_too_complex(
|
||||||
room_id, remote_room_hosts
|
room_id, remote_room_hosts
|
||||||
)
|
)
|
||||||
if too_complex is True:
|
if too_complex is True:
|
||||||
@ -947,12 +940,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
|||||||
# join dance for now, since we're kinda implicitly checking
|
# join dance for now, since we're kinda implicitly checking
|
||||||
# that we are allowed to join when we decide whether or not we
|
# that we are allowed to join when we decide whether or not we
|
||||||
# need to do the invite/join dance.
|
# need to do the invite/join dance.
|
||||||
yield defer.ensureDeferred(
|
await self.federation_handler.do_invite_join(
|
||||||
self.federation_handler.do_invite_join(
|
remote_room_hosts, room_id, user.to_string(), content
|
||||||
remote_room_hosts, room_id, user.to_string(), content
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
yield self._user_joined_room(user, room_id)
|
await self._user_joined_room(user, room_id)
|
||||||
|
|
||||||
# Check the room we just joined wasn't too large, if we didn't fetch the
|
# Check the room we just joined wasn't too large, if we didn't fetch the
|
||||||
# complexity of it before.
|
# complexity of it before.
|
||||||
@ -962,7 +953,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Check again, but with the local state events
|
# Check again, but with the local state events
|
||||||
too_complex = yield self._is_local_room_too_complex(room_id)
|
too_complex = await self._is_local_room_too_complex(room_id)
|
||||||
|
|
||||||
if too_complex is False:
|
if too_complex is False:
|
||||||
# We're under the limit.
|
# We're under the limit.
|
||||||
@ -970,7 +961,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
|||||||
|
|
||||||
# The room is too large. Leave.
|
# The room is too large. Leave.
|
||||||
requester = types.create_requester(user, None, False, None)
|
requester = types.create_requester(user, None, False, None)
|
||||||
yield self.update_membership(
|
await self.update_membership(
|
||||||
requester=requester, target=user, room_id=room_id, action="leave"
|
requester=requester, target=user, room_id=room_id, action="leave"
|
||||||
)
|
)
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
@ -1008,12 +999,12 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
|||||||
def _user_joined_room(self, target, room_id):
|
def _user_joined_room(self, target, room_id):
|
||||||
"""Implements RoomMemberHandler._user_joined_room
|
"""Implements RoomMemberHandler._user_joined_room
|
||||||
"""
|
"""
|
||||||
return user_joined_room(self.distributor, target, room_id)
|
return defer.succeed(user_joined_room(self.distributor, target, room_id))
|
||||||
|
|
||||||
def _user_left_room(self, target, room_id):
|
def _user_left_room(self, target, room_id):
|
||||||
"""Implements RoomMemberHandler._user_left_room
|
"""Implements RoomMemberHandler._user_left_room
|
||||||
"""
|
"""
|
||||||
return user_left_room(self.distributor, target, room_id)
|
return defer.succeed(user_left_room(self.distributor, target, room_id))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def forget(self, user, room_id):
|
def forget(self, user, room_id):
|
||||||
|
@ -27,6 +27,7 @@ import inspect
|
|||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import types
|
import types
|
||||||
|
import warnings
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
|
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
|
||||||
|
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
@ -287,6 +288,46 @@ class LoggingContext(object):
|
|||||||
return str(self.request)
|
return str(self.request)
|
||||||
return "%s@%x" % (self.name, id(self))
|
return "%s@%x" % (self.name, id(self))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def current_context(cls) -> LoggingContextOrSentinel:
|
||||||
|
"""Get the current logging context from thread local storage
|
||||||
|
|
||||||
|
This exists for backwards compatibility. ``current_context()`` should be
|
||||||
|
called directly.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LoggingContext: the current logging context
|
||||||
|
"""
|
||||||
|
warnings.warn(
|
||||||
|
"synapse.logging.context.LoggingContext.current_context() is deprecated "
|
||||||
|
"in favor of synapse.logging.context.current_context().",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
return current_context()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_current_context(
|
||||||
|
cls, context: LoggingContextOrSentinel
|
||||||
|
) -> LoggingContextOrSentinel:
|
||||||
|
"""Set the current logging context in thread local storage
|
||||||
|
|
||||||
|
This exists for backwards compatibility. ``set_current_context()`` should be
|
||||||
|
called directly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context(LoggingContext): The context to activate.
|
||||||
|
Returns:
|
||||||
|
The context that was previously active
|
||||||
|
"""
|
||||||
|
warnings.warn(
|
||||||
|
"synapse.logging.context.LoggingContext.set_current_context() is deprecated "
|
||||||
|
"in favor of synapse.logging.context.set_current_context().",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
return set_current_context(context)
|
||||||
|
|
||||||
def __enter__(self) -> "LoggingContext":
|
def __enter__(self) -> "LoggingContext":
|
||||||
"""Enters this logging context into thread local storage"""
|
"""Enters this logging context into thread local storage"""
|
||||||
old_context = set_current_context(self)
|
old_context = set_current_context(self)
|
||||||
|
@ -273,10 +273,9 @@ class Notifier(object):
|
|||||||
"room_key", room_stream_id, users=extra_users, rooms=[event.room_id]
|
"room_key", room_stream_id, users=extra_users, rooms=[event.room_id]
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _notify_app_services(self, room_stream_id):
|
||||||
def _notify_app_services(self, room_stream_id):
|
|
||||||
try:
|
try:
|
||||||
yield self.appservice_handler.notify_interested_services(room_stream_id)
|
await self.appservice_handler.notify_interested_services(room_stream_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error notifying application services of event")
|
logger.exception("Error notifying application services of event")
|
||||||
|
|
||||||
@ -475,20 +474,18 @@ class Notifier(object):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _get_room_ids(self, user, explicit_room_id):
|
||||||
def _get_room_ids(self, user, explicit_room_id):
|
joined_room_ids = await self.store.get_rooms_for_user(user.to_string())
|
||||||
joined_room_ids = yield self.store.get_rooms_for_user(user.to_string())
|
|
||||||
if explicit_room_id:
|
if explicit_room_id:
|
||||||
if explicit_room_id in joined_room_ids:
|
if explicit_room_id in joined_room_ids:
|
||||||
return [explicit_room_id], True
|
return [explicit_room_id], True
|
||||||
if (yield self._is_world_readable(explicit_room_id)):
|
if await self._is_world_readable(explicit_room_id):
|
||||||
return [explicit_room_id], False
|
return [explicit_room_id], False
|
||||||
raise AuthError(403, "Non-joined access not allowed")
|
raise AuthError(403, "Non-joined access not allowed")
|
||||||
return joined_room_ids, True
|
return joined_room_ids, True
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _is_world_readable(self, room_id):
|
||||||
def _is_world_readable(self, room_id):
|
state = await self.state_handler.get_current_state(
|
||||||
state = yield self.state_handler.get_current_state(
|
|
||||||
room_id, EventTypes.RoomHistoryVisibility, ""
|
room_id, EventTypes.RoomHistoryVisibility, ""
|
||||||
)
|
)
|
||||||
if state and "history_visibility" in state.content:
|
if state and "history_visibility" in state.content:
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
import abc
|
import abc
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
from inspect import signature
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
from six import raise_from
|
from six import raise_from
|
||||||
@ -60,6 +61,8 @@ class ReplicationEndpoint(object):
|
|||||||
must call `register` to register the path with the HTTP server.
|
must call `register` to register the path with the HTTP server.
|
||||||
|
|
||||||
Requests can be sent by calling the client returned by `make_client`.
|
Requests can be sent by calling the client returned by `make_client`.
|
||||||
|
Requests are sent to master process by default, but can be sent to other
|
||||||
|
named processes by specifying an `instance_name` keyword argument.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
NAME (str): A name for the endpoint, added to the path as well as used
|
NAME (str): A name for the endpoint, added to the path as well as used
|
||||||
@ -91,6 +94,16 @@ class ReplicationEndpoint(object):
|
|||||||
hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
|
hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# We reserve `instance_name` as a parameter to sending requests, so we
|
||||||
|
# assert here that sub classes don't try and use the name.
|
||||||
|
assert (
|
||||||
|
"instance_name" not in self.PATH_ARGS
|
||||||
|
), "`instance_name` is a reserved paramater name"
|
||||||
|
assert (
|
||||||
|
"instance_name"
|
||||||
|
not in signature(self.__class__._serialize_payload).parameters
|
||||||
|
), "`instance_name` is a reserved paramater name"
|
||||||
|
|
||||||
assert self.METHOD in ("PUT", "POST", "GET")
|
assert self.METHOD in ("PUT", "POST", "GET")
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@ -135,7 +148,11 @@ class ReplicationEndpoint(object):
|
|||||||
|
|
||||||
@trace(opname="outgoing_replication_request")
|
@trace(opname="outgoing_replication_request")
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def send_request(**kwargs):
|
def send_request(instance_name="master", **kwargs):
|
||||||
|
# Currently we only support sending requests to master process.
|
||||||
|
if instance_name != "master":
|
||||||
|
raise Exception("Unknown instance")
|
||||||
|
|
||||||
data = yield cls._serialize_payload(**kwargs)
|
data = yield cls._serialize_payload(**kwargs)
|
||||||
|
|
||||||
url_args = [
|
url_args = [
|
||||||
|
@ -50,6 +50,8 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
|
|||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
|
|
||||||
|
self._instance_name = hs.get_instance_name()
|
||||||
|
|
||||||
# We pull the streams from the replication steamer (if we try and make
|
# We pull the streams from the replication steamer (if we try and make
|
||||||
# them ourselves we end up in an import loop).
|
# them ourselves we end up in an import loop).
|
||||||
self.streams = hs.get_replication_streamer().get_streams()
|
self.streams = hs.get_replication_streamer().get_streams()
|
||||||
@ -67,7 +69,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
|
|||||||
upto_token = parse_integer(request, "upto_token", required=True)
|
upto_token = parse_integer(request, "upto_token", required=True)
|
||||||
|
|
||||||
updates, upto_token, limited = await stream.get_updates_since(
|
updates, upto_token, limited = await stream.get_updates_since(
|
||||||
from_token, upto_token
|
self._instance_name, from_token, upto_token
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
@ -49,19 +49,6 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
|
|||||||
|
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
|
||||||
def stream_positions(self) -> Dict[str, int]:
|
|
||||||
"""
|
|
||||||
Get the current positions of all the streams this store wants to subscribe to
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
map from stream name to the most recent update we have for
|
|
||||||
that stream (ie, the point we want to start replicating from)
|
|
||||||
"""
|
|
||||||
pos = {}
|
|
||||||
if self._cache_id_gen:
|
|
||||||
pos["caches"] = self._cache_id_gen.get_current_token()
|
|
||||||
return pos
|
|
||||||
|
|
||||||
def get_cache_stream_token(self):
|
def get_cache_stream_token(self):
|
||||||
if self._cache_id_gen:
|
if self._cache_id_gen:
|
||||||
return self._cache_id_gen.get_current_token()
|
return self._cache_id_gen.get_current_token()
|
||||||
|
@ -32,14 +32,6 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
|
|||||||
def get_max_account_data_stream_id(self):
|
def get_max_account_data_stream_id(self):
|
||||||
return self._account_data_id_gen.get_current_token()
|
return self._account_data_id_gen.get_current_token()
|
||||||
|
|
||||||
def stream_positions(self):
|
|
||||||
result = super(SlavedAccountDataStore, self).stream_positions()
|
|
||||||
position = self._account_data_id_gen.get_current_token()
|
|
||||||
result["user_account_data"] = position
|
|
||||||
result["room_account_data"] = position
|
|
||||||
result["tag_account_data"] = position
|
|
||||||
return result
|
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
if stream_name == "tag_account_data":
|
if stream_name == "tag_account_data":
|
||||||
self._account_data_id_gen.advance(token)
|
self._account_data_id_gen.advance(token)
|
||||||
|
@ -43,11 +43,6 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
|
|||||||
expiry_ms=30 * 60 * 1000,
|
expiry_ms=30 * 60 * 1000,
|
||||||
)
|
)
|
||||||
|
|
||||||
def stream_positions(self):
|
|
||||||
result = super(SlavedDeviceInboxStore, self).stream_positions()
|
|
||||||
result["to_device"] = self._device_inbox_id_gen.get_current_token()
|
|
||||||
return result
|
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
if stream_name == "to_device":
|
if stream_name == "to_device":
|
||||||
self._device_inbox_id_gen.advance(token)
|
self._device_inbox_id_gen.advance(token)
|
||||||
|
@ -48,16 +48,6 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
|
|||||||
"DeviceListFederationStreamChangeCache", device_list_max
|
"DeviceListFederationStreamChangeCache", device_list_max
|
||||||
)
|
)
|
||||||
|
|
||||||
def stream_positions(self):
|
|
||||||
result = super(SlavedDeviceStore, self).stream_positions()
|
|
||||||
# The user signature stream uses the same stream ID generator as the
|
|
||||||
# device list stream, so set them both to the device list ID
|
|
||||||
# generator's current token.
|
|
||||||
current_token = self._device_list_id_gen.get_current_token()
|
|
||||||
result[DeviceListsStream.NAME] = current_token
|
|
||||||
result[UserSignatureStream.NAME] = current_token
|
|
||||||
return result
|
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
if stream_name == DeviceListsStream.NAME:
|
if stream_name == DeviceListsStream.NAME:
|
||||||
self._device_list_id_gen.advance(token)
|
self._device_list_id_gen.advance(token)
|
||||||
|
@ -93,12 +93,6 @@ class SlavedEventStore(
|
|||||||
def get_room_min_stream_ordering(self):
|
def get_room_min_stream_ordering(self):
|
||||||
return self._backfill_id_gen.get_current_token()
|
return self._backfill_id_gen.get_current_token()
|
||||||
|
|
||||||
def stream_positions(self):
|
|
||||||
result = super(SlavedEventStore, self).stream_positions()
|
|
||||||
result["events"] = self._stream_id_gen.get_current_token()
|
|
||||||
result["backfill"] = -self._backfill_id_gen.get_current_token()
|
|
||||||
return result
|
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
if stream_name == "events":
|
if stream_name == "events":
|
||||||
self._stream_id_gen.advance(token)
|
self._stream_id_gen.advance(token)
|
||||||
|
@ -37,11 +37,6 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
|
|||||||
def get_group_stream_token(self):
|
def get_group_stream_token(self):
|
||||||
return self._group_updates_id_gen.get_current_token()
|
return self._group_updates_id_gen.get_current_token()
|
||||||
|
|
||||||
def stream_positions(self):
|
|
||||||
result = super(SlavedGroupServerStore, self).stream_positions()
|
|
||||||
result["groups"] = self._group_updates_id_gen.get_current_token()
|
|
||||||
return result
|
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
if stream_name == "groups":
|
if stream_name == "groups":
|
||||||
self._group_updates_id_gen.advance(token)
|
self._group_updates_id_gen.advance(token)
|
||||||
|
@ -41,15 +41,6 @@ class SlavedPresenceStore(BaseSlavedStore):
|
|||||||
def get_current_presence_token(self):
|
def get_current_presence_token(self):
|
||||||
return self._presence_id_gen.get_current_token()
|
return self._presence_id_gen.get_current_token()
|
||||||
|
|
||||||
def stream_positions(self):
|
|
||||||
result = super(SlavedPresenceStore, self).stream_positions()
|
|
||||||
|
|
||||||
if self.hs.config.use_presence:
|
|
||||||
position = self._presence_id_gen.get_current_token()
|
|
||||||
result["presence"] = position
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
if stream_name == "presence":
|
if stream_name == "presence":
|
||||||
self._presence_id_gen.advance(token)
|
self._presence_id_gen.advance(token)
|
||||||
|
@ -37,11 +37,6 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
|
|||||||
def get_max_push_rules_stream_id(self):
|
def get_max_push_rules_stream_id(self):
|
||||||
return self._push_rules_stream_id_gen.get_current_token()
|
return self._push_rules_stream_id_gen.get_current_token()
|
||||||
|
|
||||||
def stream_positions(self):
|
|
||||||
result = super(SlavedPushRuleStore, self).stream_positions()
|
|
||||||
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
|
|
||||||
return result
|
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
if stream_name == "push_rules":
|
if stream_name == "push_rules":
|
||||||
self._push_rules_stream_id_gen.advance(token)
|
self._push_rules_stream_id_gen.advance(token)
|
||||||
|
@ -28,11 +28,6 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
|
|||||||
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
|
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
|
||||||
)
|
)
|
||||||
|
|
||||||
def stream_positions(self):
|
|
||||||
result = super(SlavedPusherStore, self).stream_positions()
|
|
||||||
result["pushers"] = self._pushers_id_gen.get_current_token()
|
|
||||||
return result
|
|
||||||
|
|
||||||
def get_pushers_stream_token(self):
|
def get_pushers_stream_token(self):
|
||||||
return self._pushers_id_gen.get_current_token()
|
return self._pushers_id_gen.get_current_token()
|
||||||
|
|
||||||
|
@ -42,11 +42,6 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
|
|||||||
def get_max_receipt_stream_id(self):
|
def get_max_receipt_stream_id(self):
|
||||||
return self._receipts_id_gen.get_current_token()
|
return self._receipts_id_gen.get_current_token()
|
||||||
|
|
||||||
def stream_positions(self):
|
|
||||||
result = super(SlavedReceiptsStore, self).stream_positions()
|
|
||||||
result["receipts"] = self._receipts_id_gen.get_current_token()
|
|
||||||
return result
|
|
||||||
|
|
||||||
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
|
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
|
||||||
self.get_receipts_for_user.invalidate((user_id, receipt_type))
|
self.get_receipts_for_user.invalidate((user_id, receipt_type))
|
||||||
self._get_linearized_receipts_for_room.invalidate_many((room_id,))
|
self._get_linearized_receipts_for_room.invalidate_many((room_id,))
|
||||||
|
@ -30,11 +30,6 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
|
|||||||
def get_current_public_room_stream_id(self):
|
def get_current_public_room_stream_id(self):
|
||||||
return self._public_room_id_gen.get_current_token()
|
return self._public_room_id_gen.get_current_token()
|
||||||
|
|
||||||
def stream_positions(self):
|
|
||||||
result = super(RoomStore, self).stream_positions()
|
|
||||||
result["public_rooms"] = self._public_room_id_gen.get_current_token()
|
|
||||||
return result
|
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
if stream_name == "public_rooms":
|
if stream_name == "public_rooms":
|
||||||
self._public_room_id_gen.advance(token)
|
self._public_room_id_gen.advance(token)
|
||||||
|
@ -16,7 +16,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Dict
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from twisted.internet.protocol import ReconnectingClientFactory
|
from twisted.internet.protocol import ReconnectingClientFactory
|
||||||
|
|
||||||
@ -86,37 +86,22 @@ class ReplicationDataHandler:
|
|||||||
def __init__(self, store: BaseSlavedStore):
|
def __init__(self, store: BaseSlavedStore):
|
||||||
self.store = store
|
self.store = store
|
||||||
|
|
||||||
async def on_rdata(self, stream_name: str, token: int, rows: list):
|
async def on_rdata(
|
||||||
|
self, stream_name: str, instance_name: str, token: int, rows: list
|
||||||
|
):
|
||||||
"""Called to handle a batch of replication data with a given stream token.
|
"""Called to handle a batch of replication data with a given stream token.
|
||||||
|
|
||||||
By default this just pokes the slave store. Can be overridden in subclasses to
|
By default this just pokes the slave store. Can be overridden in subclasses to
|
||||||
handle more.
|
handle more.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
stream_name (str): name of the replication stream for this batch of rows
|
stream_name: name of the replication stream for this batch of rows
|
||||||
token (int): stream token for this batch of rows
|
instance_name: the instance that wrote the rows.
|
||||||
rows (list): a list of Stream.ROW_TYPE objects as returned by
|
token: stream token for this batch of rows
|
||||||
Stream.parse_row.
|
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
|
||||||
"""
|
"""
|
||||||
self.store.process_replication_rows(stream_name, token, rows)
|
self.store.process_replication_rows(stream_name, token, rows)
|
||||||
|
|
||||||
def get_streams_to_replicate(self) -> Dict[str, int]:
|
|
||||||
"""Called when a new connection has been established and we need to
|
|
||||||
subscribe to streams.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
map from stream name to the most recent update we have for
|
|
||||||
that stream (ie, the point we want to start replicating from)
|
|
||||||
"""
|
|
||||||
args = self.store.stream_positions()
|
|
||||||
user_account_data = args.pop("user_account_data", None)
|
|
||||||
room_account_data = args.pop("room_account_data", None)
|
|
||||||
if user_account_data:
|
|
||||||
args["account_data"] = user_account_data
|
|
||||||
elif room_account_data:
|
|
||||||
args["account_data"] = room_account_data
|
|
||||||
return args
|
|
||||||
|
|
||||||
async def on_position(self, stream_name: str, token: int):
|
async def on_position(self, stream_name: str, token: int):
|
||||||
self.store.process_replication_rows(stream_name, token, [])
|
self.store.process_replication_rows(stream_name, token, [])
|
||||||
|
|
||||||
|
@ -278,19 +278,24 @@ class ReplicationCommandHandler:
|
|||||||
# Check if this is the last of a batch of updates
|
# Check if this is the last of a batch of updates
|
||||||
rows = self._pending_batches.pop(stream_name, [])
|
rows = self._pending_batches.pop(stream_name, [])
|
||||||
rows.append(row)
|
rows.append(row)
|
||||||
await self.on_rdata(stream_name, cmd.token, rows)
|
await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
|
||||||
|
|
||||||
async def on_rdata(self, stream_name: str, token: int, rows: list):
|
async def on_rdata(
|
||||||
|
self, stream_name: str, instance_name: str, token: int, rows: list
|
||||||
|
):
|
||||||
"""Called to handle a batch of replication data with a given stream token.
|
"""Called to handle a batch of replication data with a given stream token.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
stream_name: name of the replication stream for this batch of rows
|
stream_name: name of the replication stream for this batch of rows
|
||||||
|
instance_name: the instance that wrote the rows.
|
||||||
token: stream token for this batch of rows
|
token: stream token for this batch of rows
|
||||||
rows: a list of Stream.ROW_TYPE objects as returned by
|
rows: a list of Stream.ROW_TYPE objects as returned by
|
||||||
Stream.parse_row.
|
Stream.parse_row.
|
||||||
"""
|
"""
|
||||||
logger.debug("Received rdata %s -> %s", stream_name, token)
|
logger.debug("Received rdata %s -> %s", stream_name, token)
|
||||||
await self._replication_data_handler.on_rdata(stream_name, token, rows)
|
await self._replication_data_handler.on_rdata(
|
||||||
|
stream_name, instance_name, token, rows
|
||||||
|
)
|
||||||
|
|
||||||
async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
|
async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
|
||||||
if cmd.instance_name == self._instance_name:
|
if cmd.instance_name == self._instance_name:
|
||||||
@ -314,15 +319,7 @@ class ReplicationCommandHandler:
|
|||||||
self._pending_batches.pop(cmd.stream_name, [])
|
self._pending_batches.pop(cmd.stream_name, [])
|
||||||
|
|
||||||
# Find where we previously streamed up to.
|
# Find where we previously streamed up to.
|
||||||
current_token = self._replication_data_handler.get_streams_to_replicate().get(
|
current_token = stream.current_token()
|
||||||
cmd.stream_name
|
|
||||||
)
|
|
||||||
if current_token is None:
|
|
||||||
logger.warning(
|
|
||||||
"Got POSITION for stream we're not subscribed to: %s",
|
|
||||||
cmd.stream_name,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# If the position token matches our current token then we're up to
|
# If the position token matches our current token then we're up to
|
||||||
# date and there's nothing to do. Otherwise, fetch all updates
|
# date and there's nothing to do. Otherwise, fetch all updates
|
||||||
@ -333,7 +330,9 @@ class ReplicationCommandHandler:
|
|||||||
updates,
|
updates,
|
||||||
current_token,
|
current_token,
|
||||||
missing_updates,
|
missing_updates,
|
||||||
) = await stream.get_updates_since(current_token, cmd.token)
|
) = await stream.get_updates_since(
|
||||||
|
cmd.instance_name, current_token, cmd.token
|
||||||
|
)
|
||||||
|
|
||||||
# TODO: add some tests for this
|
# TODO: add some tests for this
|
||||||
|
|
||||||
@ -342,7 +341,10 @@ class ReplicationCommandHandler:
|
|||||||
|
|
||||||
for token, rows in _batch_updates(updates):
|
for token, rows in _batch_updates(updates):
|
||||||
await self.on_rdata(
|
await self.on_rdata(
|
||||||
cmd.stream_name, token, [stream.parse_row(row) for row in rows],
|
cmd.stream_name,
|
||||||
|
cmd.instance_name,
|
||||||
|
token,
|
||||||
|
[stream.parse_row(row) for row in rows],
|
||||||
)
|
)
|
||||||
|
|
||||||
# We've now caught up to position sent to us, notify handler.
|
# We've now caught up to position sent to us, notify handler.
|
||||||
|
@ -61,6 +61,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
|||||||
outbound_redis_connection = None # type: txredisapi.RedisProtocol
|
outbound_redis_connection = None # type: txredisapi.RedisProtocol
|
||||||
|
|
||||||
def connectionMade(self):
|
def connectionMade(self):
|
||||||
|
super().connectionMade()
|
||||||
logger.info("Connected to redis instance")
|
logger.info("Connected to redis instance")
|
||||||
self.subscribe(self.stream_name)
|
self.subscribe(self.stream_name)
|
||||||
self.send_command(ReplicateCommand())
|
self.send_command(ReplicateCommand())
|
||||||
@ -119,6 +120,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
|||||||
logger.warning("Unhandled command: %r", cmd)
|
logger.warning("Unhandled command: %r", cmd)
|
||||||
|
|
||||||
def connectionLost(self, reason):
|
def connectionLost(self, reason):
|
||||||
|
super().connectionLost(reason)
|
||||||
logger.info("Lost connection to redis instance")
|
logger.info("Lost connection to redis instance")
|
||||||
self.handler.lost_connection(self)
|
self.handler.lost_connection(self)
|
||||||
|
|
||||||
@ -189,5 +191,6 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
|
|||||||
p.handler = self.handler
|
p.handler = self.handler
|
||||||
p.outbound_redis_connection = self.outbound_redis_connection
|
p.outbound_redis_connection = self.outbound_redis_connection
|
||||||
p.stream_name = self.stream_name
|
p.stream_name = self.stream_name
|
||||||
|
p.password = self.password
|
||||||
|
|
||||||
return p
|
return p
|
||||||
|
@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import Any, Awaitable, Callable, Iterable, List, Optional, Tuple
|
from typing import Any, Awaitable, Callable, List, Optional, Tuple
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
@ -53,6 +53,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
|
|||||||
#
|
#
|
||||||
# The arguments are:
|
# The arguments are:
|
||||||
#
|
#
|
||||||
|
# * instance_name: the writer of the stream
|
||||||
# * from_token: the previous stream token: the starting point for fetching the
|
# * from_token: the previous stream token: the starting point for fetching the
|
||||||
# updates
|
# updates
|
||||||
# * to_token: the new stream token: the point to get updates up to
|
# * to_token: the new stream token: the point to get updates up to
|
||||||
@ -62,7 +63,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
|
|||||||
# If there are more updates available, it should set `limited` in the result, and
|
# If there are more updates available, it should set `limited` in the result, and
|
||||||
# it will be called again to get the next batch.
|
# it will be called again to get the next batch.
|
||||||
#
|
#
|
||||||
UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]]
|
UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]]
|
||||||
|
|
||||||
|
|
||||||
class Stream(object):
|
class Stream(object):
|
||||||
@ -93,6 +94,7 @@ class Stream(object):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
local_instance_name: str,
|
||||||
current_token_function: Callable[[], Token],
|
current_token_function: Callable[[], Token],
|
||||||
update_function: UpdateFunction,
|
update_function: UpdateFunction,
|
||||||
):
|
):
|
||||||
@ -108,9 +110,11 @@ class Stream(object):
|
|||||||
stream tokens. See the UpdateFunction type definition for more info.
|
stream tokens. See the UpdateFunction type definition for more info.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
local_instance_name: The instance name of the current process
|
||||||
current_token_function: callback to get the current token, as above
|
current_token_function: callback to get the current token, as above
|
||||||
update_function: callback go get stream updates, as above
|
update_function: callback go get stream updates, as above
|
||||||
"""
|
"""
|
||||||
|
self.local_instance_name = local_instance_name
|
||||||
self.current_token = current_token_function
|
self.current_token = current_token_function
|
||||||
self.update_function = update_function
|
self.update_function = update_function
|
||||||
|
|
||||||
@ -135,14 +139,14 @@ class Stream(object):
|
|||||||
"""
|
"""
|
||||||
current_token = self.current_token()
|
current_token = self.current_token()
|
||||||
updates, current_token, limited = await self.get_updates_since(
|
updates, current_token, limited = await self.get_updates_since(
|
||||||
self.last_token, current_token
|
self.local_instance_name, self.last_token, current_token
|
||||||
)
|
)
|
||||||
self.last_token = current_token
|
self.last_token = current_token
|
||||||
|
|
||||||
return updates, current_token, limited
|
return updates, current_token, limited
|
||||||
|
|
||||||
async def get_updates_since(
|
async def get_updates_since(
|
||||||
self, from_token: Token, upto_token: Token
|
self, instance_name: str, from_token: Token, upto_token: Token
|
||||||
) -> StreamUpdateResult:
|
) -> StreamUpdateResult:
|
||||||
"""Like get_updates except allows specifying from when we should
|
"""Like get_updates except allows specifying from when we should
|
||||||
stream updates
|
stream updates
|
||||||
@ -160,19 +164,19 @@ class Stream(object):
|
|||||||
return [], upto_token, False
|
return [], upto_token, False
|
||||||
|
|
||||||
updates, upto_token, limited = await self.update_function(
|
updates, upto_token, limited = await self.update_function(
|
||||||
from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
|
instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
|
||||||
)
|
)
|
||||||
return updates, upto_token, limited
|
return updates, upto_token, limited
|
||||||
|
|
||||||
|
|
||||||
def db_query_to_update_function(
|
def db_query_to_update_function(
|
||||||
query_function: Callable[[Token, Token, int], Awaitable[Iterable[tuple]]]
|
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
|
||||||
) -> UpdateFunction:
|
) -> UpdateFunction:
|
||||||
"""Wraps a db query function which returns a list of rows to make it
|
"""Wraps a db query function which returns a list of rows to make it
|
||||||
suitable for use as an `update_function` for the Stream class
|
suitable for use as an `update_function` for the Stream class
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def update_function(from_token, upto_token, limit):
|
async def update_function(instance_name, from_token, upto_token, limit):
|
||||||
rows = await query_function(from_token, upto_token, limit)
|
rows = await query_function(from_token, upto_token, limit)
|
||||||
updates = [(row[0], row[1:]) for row in rows]
|
updates = [(row[0], row[1:]) for row in rows]
|
||||||
limited = False
|
limited = False
|
||||||
@ -193,10 +197,13 @@ def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
|
|||||||
client = ReplicationGetStreamUpdates.make_client(hs)
|
client = ReplicationGetStreamUpdates.make_client(hs)
|
||||||
|
|
||||||
async def update_function(
|
async def update_function(
|
||||||
from_token: int, upto_token: int, limit: int
|
instance_name: str, from_token: int, upto_token: int, limit: int
|
||||||
) -> StreamUpdateResult:
|
) -> StreamUpdateResult:
|
||||||
result = await client(
|
result = await client(
|
||||||
stream_name=stream_name, from_token=from_token, upto_token=upto_token,
|
instance_name=instance_name,
|
||||||
|
stream_name=stream_name,
|
||||||
|
from_token=from_token,
|
||||||
|
upto_token=upto_token,
|
||||||
)
|
)
|
||||||
return result["updates"], result["upto_token"], result["limited"]
|
return result["updates"], result["upto_token"], result["limited"]
|
||||||
|
|
||||||
@ -226,6 +233,7 @@ class BackfillStream(Stream):
|
|||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
store = hs.get_datastore()
|
store = hs.get_datastore()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
hs.get_instance_name(),
|
||||||
store.get_current_backfill_token,
|
store.get_current_backfill_token,
|
||||||
db_query_to_update_function(store.get_all_new_backfill_event_rows),
|
db_query_to_update_function(store.get_all_new_backfill_event_rows),
|
||||||
)
|
)
|
||||||
@ -261,7 +269,9 @@ class PresenceStream(Stream):
|
|||||||
# Query master process
|
# Query master process
|
||||||
update_function = make_http_update_function(hs, self.NAME)
|
update_function = make_http_update_function(hs, self.NAME)
|
||||||
|
|
||||||
super().__init__(store.get_current_presence_token, update_function)
|
super().__init__(
|
||||||
|
hs.get_instance_name(), store.get_current_presence_token, update_function
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TypingStream(Stream):
|
class TypingStream(Stream):
|
||||||
@ -284,7 +294,9 @@ class TypingStream(Stream):
|
|||||||
# Query master process
|
# Query master process
|
||||||
update_function = make_http_update_function(hs, self.NAME)
|
update_function = make_http_update_function(hs, self.NAME)
|
||||||
|
|
||||||
super().__init__(typing_handler.get_current_token, update_function)
|
super().__init__(
|
||||||
|
hs.get_instance_name(), typing_handler.get_current_token, update_function
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ReceiptsStream(Stream):
|
class ReceiptsStream(Stream):
|
||||||
@ -305,6 +317,7 @@ class ReceiptsStream(Stream):
|
|||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
store = hs.get_datastore()
|
store = hs.get_datastore()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
hs.get_instance_name(),
|
||||||
store.get_max_receipt_stream_id,
|
store.get_max_receipt_stream_id,
|
||||||
db_query_to_update_function(store.get_all_updated_receipts),
|
db_query_to_update_function(store.get_all_updated_receipts),
|
||||||
)
|
)
|
||||||
@ -322,14 +335,16 @@ class PushRulesStream(Stream):
|
|||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
super(PushRulesStream, self).__init__(
|
super(PushRulesStream, self).__init__(
|
||||||
self._current_token, self._update_function
|
hs.get_instance_name(), self._current_token, self._update_function
|
||||||
)
|
)
|
||||||
|
|
||||||
def _current_token(self) -> int:
|
def _current_token(self) -> int:
|
||||||
push_rules_token, _ = self.store.get_push_rules_stream_token()
|
push_rules_token, _ = self.store.get_push_rules_stream_token()
|
||||||
return push_rules_token
|
return push_rules_token
|
||||||
|
|
||||||
async def _update_function(self, from_token: Token, to_token: Token, limit: int):
|
async def _update_function(
|
||||||
|
self, instance_name: str, from_token: Token, to_token: Token, limit: int
|
||||||
|
):
|
||||||
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
|
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
|
||||||
|
|
||||||
limited = False
|
limited = False
|
||||||
@ -356,6 +371,7 @@ class PushersStream(Stream):
|
|||||||
store = hs.get_datastore()
|
store = hs.get_datastore()
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
hs.get_instance_name(),
|
||||||
store.get_pushers_stream_token,
|
store.get_pushers_stream_token,
|
||||||
db_query_to_update_function(store.get_all_updated_pushers_rows),
|
db_query_to_update_function(store.get_all_updated_pushers_rows),
|
||||||
)
|
)
|
||||||
@ -387,6 +403,7 @@ class CachesStream(Stream):
|
|||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
store = hs.get_datastore()
|
store = hs.get_datastore()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
hs.get_instance_name(),
|
||||||
store.get_cache_stream_token,
|
store.get_cache_stream_token,
|
||||||
db_query_to_update_function(store.get_all_updated_caches),
|
db_query_to_update_function(store.get_all_updated_caches),
|
||||||
)
|
)
|
||||||
@ -412,6 +429,7 @@ class PublicRoomsStream(Stream):
|
|||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
store = hs.get_datastore()
|
store = hs.get_datastore()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
hs.get_instance_name(),
|
||||||
store.get_current_public_room_stream_id,
|
store.get_current_public_room_stream_id,
|
||||||
db_query_to_update_function(store.get_all_new_public_rooms),
|
db_query_to_update_function(store.get_all_new_public_rooms),
|
||||||
)
|
)
|
||||||
@ -432,6 +450,7 @@ class DeviceListsStream(Stream):
|
|||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
store = hs.get_datastore()
|
store = hs.get_datastore()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
hs.get_instance_name(),
|
||||||
store.get_device_stream_token,
|
store.get_device_stream_token,
|
||||||
db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
|
db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
|
||||||
)
|
)
|
||||||
@ -449,6 +468,7 @@ class ToDeviceStream(Stream):
|
|||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
store = hs.get_datastore()
|
store = hs.get_datastore()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
hs.get_instance_name(),
|
||||||
store.get_to_device_stream_token,
|
store.get_to_device_stream_token,
|
||||||
db_query_to_update_function(store.get_all_new_device_messages),
|
db_query_to_update_function(store.get_all_new_device_messages),
|
||||||
)
|
)
|
||||||
@ -468,6 +488,7 @@ class TagAccountDataStream(Stream):
|
|||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
store = hs.get_datastore()
|
store = hs.get_datastore()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
hs.get_instance_name(),
|
||||||
store.get_max_account_data_stream_id,
|
store.get_max_account_data_stream_id,
|
||||||
db_query_to_update_function(store.get_all_updated_tags),
|
db_query_to_update_function(store.get_all_updated_tags),
|
||||||
)
|
)
|
||||||
@ -487,6 +508,7 @@ class AccountDataStream(Stream):
|
|||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
hs.get_instance_name(),
|
||||||
self.store.get_max_account_data_stream_id,
|
self.store.get_max_account_data_stream_id,
|
||||||
db_query_to_update_function(self._update_function),
|
db_query_to_update_function(self._update_function),
|
||||||
)
|
)
|
||||||
@ -517,6 +539,7 @@ class GroupServerStream(Stream):
|
|||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
store = hs.get_datastore()
|
store = hs.get_datastore()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
hs.get_instance_name(),
|
||||||
store.get_group_stream_token,
|
store.get_group_stream_token,
|
||||||
db_query_to_update_function(store.get_all_groups_changes),
|
db_query_to_update_function(store.get_all_groups_changes),
|
||||||
)
|
)
|
||||||
@ -534,6 +557,7 @@ class UserSignatureStream(Stream):
|
|||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
store = hs.get_datastore()
|
store = hs.get_datastore()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
hs.get_instance_name(),
|
||||||
store.get_device_stream_token,
|
store.get_device_stream_token,
|
||||||
db_query_to_update_function(
|
db_query_to_update_function(
|
||||||
store.get_all_user_signature_changes_for_remotes
|
store.get_all_user_signature_changes_for_remotes
|
||||||
|
@ -118,11 +118,17 @@ class EventsStream(Stream):
|
|||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self._store = hs.get_datastore()
|
self._store = hs.get_datastore()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
self._store.get_current_events_token, self._update_function,
|
hs.get_instance_name(),
|
||||||
|
self._store.get_current_events_token,
|
||||||
|
self._update_function,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _update_function(
|
async def _update_function(
|
||||||
self, from_token: Token, current_token: Token, target_row_count: int
|
self,
|
||||||
|
instance_name: str,
|
||||||
|
from_token: Token,
|
||||||
|
current_token: Token,
|
||||||
|
target_row_count: int,
|
||||||
) -> StreamUpdateResult:
|
) -> StreamUpdateResult:
|
||||||
|
|
||||||
# the events stream merges together three separate sources:
|
# the events stream merges together three separate sources:
|
||||||
|
@ -48,8 +48,8 @@ class FederationStream(Stream):
|
|||||||
current_token = lambda: 0
|
current_token = lambda: 0
|
||||||
update_function = self._stub_update_function
|
update_function = self._stub_update_function
|
||||||
|
|
||||||
super().__init__(current_token, update_function)
|
super().__init__(hs.get_instance_name(), current_token, update_function)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _stub_update_function(from_token, upto_token, limit):
|
async def _stub_update_function(instance_name, from_token, upto_token, limit):
|
||||||
return [], upto_token, False
|
return [], upto_token, False
|
||||||
|
@ -16,8 +16,6 @@ import logging
|
|||||||
|
|
||||||
from six import iteritems, string_types
|
from six import iteritems, string_types
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.api.urls import ConsentURIBuilder
|
from synapse.api.urls import ConsentURIBuilder
|
||||||
from synapse.config import ConfigError
|
from synapse.config import ConfigError
|
||||||
@ -59,8 +57,7 @@ class ConsentServerNotices(object):
|
|||||||
|
|
||||||
self._consent_uri_builder = ConsentURIBuilder(hs.config)
|
self._consent_uri_builder = ConsentURIBuilder(hs.config)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def maybe_send_server_notice_to_user(self, user_id):
|
||||||
def maybe_send_server_notice_to_user(self, user_id):
|
|
||||||
"""Check if we need to send a notice to this user, and does so if so
|
"""Check if we need to send a notice to this user, and does so if so
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -78,7 +75,7 @@ class ConsentServerNotices(object):
|
|||||||
return
|
return
|
||||||
self._users_in_progress.add(user_id)
|
self._users_in_progress.add(user_id)
|
||||||
try:
|
try:
|
||||||
u = yield self._store.get_user_by_id(user_id)
|
u = await self._store.get_user_by_id(user_id)
|
||||||
|
|
||||||
if u["is_guest"] and not self._send_to_guests:
|
if u["is_guest"] and not self._send_to_guests:
|
||||||
# don't send to guests
|
# don't send to guests
|
||||||
@ -100,8 +97,8 @@ class ConsentServerNotices(object):
|
|||||||
content = copy_with_str_subst(
|
content = copy_with_str_subst(
|
||||||
self._server_notice_content, {"consent_uri": consent_uri}
|
self._server_notice_content, {"consent_uri": consent_uri}
|
||||||
)
|
)
|
||||||
yield self._server_notices_manager.send_notice(user_id, content)
|
await self._server_notices_manager.send_notice(user_id, content)
|
||||||
yield self._store.user_set_consent_server_notice_sent(
|
await self._store.user_set_consent_server_notice_sent(
|
||||||
user_id, self._current_consent_version
|
user_id, self._current_consent_version
|
||||||
)
|
)
|
||||||
except SynapseError as e:
|
except SynapseError as e:
|
||||||
|
@ -16,8 +16,6 @@ import logging
|
|||||||
|
|
||||||
from six import iteritems
|
from six import iteritems
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.constants import (
|
from synapse.api.constants import (
|
||||||
EventTypes,
|
EventTypes,
|
||||||
LimitBlockingTypes,
|
LimitBlockingTypes,
|
||||||
@ -50,8 +48,7 @@ class ResourceLimitsServerNotices(object):
|
|||||||
|
|
||||||
self._notifier = hs.get_notifier()
|
self._notifier = hs.get_notifier()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def maybe_send_server_notice_to_user(self, user_id):
|
||||||
def maybe_send_server_notice_to_user(self, user_id):
|
|
||||||
"""Check if we need to send a notice to this user, this will be true in
|
"""Check if we need to send a notice to this user, this will be true in
|
||||||
two cases.
|
two cases.
|
||||||
1. The server has reached its limit does not reflect this
|
1. The server has reached its limit does not reflect this
|
||||||
@ -74,13 +71,13 @@ class ResourceLimitsServerNotices(object):
|
|||||||
# Don't try and send server notices unless they've been enabled
|
# Don't try and send server notices unless they've been enabled
|
||||||
return
|
return
|
||||||
|
|
||||||
timestamp = yield self._store.user_last_seen_monthly_active(user_id)
|
timestamp = await self._store.user_last_seen_monthly_active(user_id)
|
||||||
if timestamp is None:
|
if timestamp is None:
|
||||||
# This user will be blocked from receiving the notice anyway.
|
# This user will be blocked from receiving the notice anyway.
|
||||||
# In practice, not sure we can ever get here
|
# In practice, not sure we can ever get here
|
||||||
return
|
return
|
||||||
|
|
||||||
room_id = yield self._server_notices_manager.get_or_create_notice_room_for_user(
|
room_id = await self._server_notices_manager.get_or_create_notice_room_for_user(
|
||||||
user_id
|
user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -88,10 +85,10 @@ class ResourceLimitsServerNotices(object):
|
|||||||
logger.warning("Failed to get server notices room")
|
logger.warning("Failed to get server notices room")
|
||||||
return
|
return
|
||||||
|
|
||||||
yield self._check_and_set_tags(user_id, room_id)
|
await self._check_and_set_tags(user_id, room_id)
|
||||||
|
|
||||||
# Determine current state of room
|
# Determine current state of room
|
||||||
currently_blocked, ref_events = yield self._is_room_currently_blocked(room_id)
|
currently_blocked, ref_events = await self._is_room_currently_blocked(room_id)
|
||||||
|
|
||||||
limit_msg = None
|
limit_msg = None
|
||||||
limit_type = None
|
limit_type = None
|
||||||
@ -99,7 +96,7 @@ class ResourceLimitsServerNotices(object):
|
|||||||
# Normally should always pass in user_id to check_auth_blocking
|
# Normally should always pass in user_id to check_auth_blocking
|
||||||
# if you have it, but in this case are checking what would happen
|
# if you have it, but in this case are checking what would happen
|
||||||
# to other users if they were to arrive.
|
# to other users if they were to arrive.
|
||||||
yield self._auth.check_auth_blocking()
|
await self._auth.check_auth_blocking()
|
||||||
except ResourceLimitError as e:
|
except ResourceLimitError as e:
|
||||||
limit_msg = e.msg
|
limit_msg = e.msg
|
||||||
limit_type = e.limit_type
|
limit_type = e.limit_type
|
||||||
@ -112,22 +109,21 @@ class ResourceLimitsServerNotices(object):
|
|||||||
# We have hit the MAU limit, but MAU alerting is disabled:
|
# We have hit the MAU limit, but MAU alerting is disabled:
|
||||||
# reset room if necessary and return
|
# reset room if necessary and return
|
||||||
if currently_blocked:
|
if currently_blocked:
|
||||||
self._remove_limit_block_notification(user_id, ref_events)
|
await self._remove_limit_block_notification(user_id, ref_events)
|
||||||
return
|
return
|
||||||
|
|
||||||
if currently_blocked and not limit_msg:
|
if currently_blocked and not limit_msg:
|
||||||
# Room is notifying of a block, when it ought not to be.
|
# Room is notifying of a block, when it ought not to be.
|
||||||
yield self._remove_limit_block_notification(user_id, ref_events)
|
await self._remove_limit_block_notification(user_id, ref_events)
|
||||||
elif not currently_blocked and limit_msg:
|
elif not currently_blocked and limit_msg:
|
||||||
# Room is not notifying of a block, when it ought to be.
|
# Room is not notifying of a block, when it ought to be.
|
||||||
yield self._apply_limit_block_notification(
|
await self._apply_limit_block_notification(
|
||||||
user_id, limit_msg, limit_type
|
user_id, limit_msg, limit_type
|
||||||
)
|
)
|
||||||
except SynapseError as e:
|
except SynapseError as e:
|
||||||
logger.error("Error sending resource limits server notice: %s", e)
|
logger.error("Error sending resource limits server notice: %s", e)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _remove_limit_block_notification(self, user_id, ref_events):
|
||||||
def _remove_limit_block_notification(self, user_id, ref_events):
|
|
||||||
"""Utility method to remove limit block notifications from the server
|
"""Utility method to remove limit block notifications from the server
|
||||||
notices room.
|
notices room.
|
||||||
|
|
||||||
@ -137,12 +133,13 @@ class ResourceLimitsServerNotices(object):
|
|||||||
limit blocking and need to be preserved.
|
limit blocking and need to be preserved.
|
||||||
"""
|
"""
|
||||||
content = {"pinned": ref_events}
|
content = {"pinned": ref_events}
|
||||||
yield self._server_notices_manager.send_notice(
|
await self._server_notices_manager.send_notice(
|
||||||
user_id, content, EventTypes.Pinned, ""
|
user_id, content, EventTypes.Pinned, ""
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _apply_limit_block_notification(
|
||||||
def _apply_limit_block_notification(self, user_id, event_body, event_limit_type):
|
self, user_id, event_body, event_limit_type
|
||||||
|
):
|
||||||
"""Utility method to apply limit block notifications in the server
|
"""Utility method to apply limit block notifications in the server
|
||||||
notices room.
|
notices room.
|
||||||
|
|
||||||
@ -159,17 +156,16 @@ class ResourceLimitsServerNotices(object):
|
|||||||
"admin_contact": self._config.admin_contact,
|
"admin_contact": self._config.admin_contact,
|
||||||
"limit_type": event_limit_type,
|
"limit_type": event_limit_type,
|
||||||
}
|
}
|
||||||
event = yield self._server_notices_manager.send_notice(
|
event = await self._server_notices_manager.send_notice(
|
||||||
user_id, content, EventTypes.Message
|
user_id, content, EventTypes.Message
|
||||||
)
|
)
|
||||||
|
|
||||||
content = {"pinned": [event.event_id]}
|
content = {"pinned": [event.event_id]}
|
||||||
yield self._server_notices_manager.send_notice(
|
await self._server_notices_manager.send_notice(
|
||||||
user_id, content, EventTypes.Pinned, ""
|
user_id, content, EventTypes.Pinned, ""
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _check_and_set_tags(self, user_id, room_id):
|
||||||
def _check_and_set_tags(self, user_id, room_id):
|
|
||||||
"""
|
"""
|
||||||
Since server notices rooms were originally not with tags,
|
Since server notices rooms were originally not with tags,
|
||||||
important to check that tags have been set correctly
|
important to check that tags have been set correctly
|
||||||
@ -177,20 +173,19 @@ class ResourceLimitsServerNotices(object):
|
|||||||
user_id(str): the user in question
|
user_id(str): the user in question
|
||||||
room_id(str): the server notices room for that user
|
room_id(str): the server notices room for that user
|
||||||
"""
|
"""
|
||||||
tags = yield self._store.get_tags_for_room(user_id, room_id)
|
tags = await self._store.get_tags_for_room(user_id, room_id)
|
||||||
need_to_set_tag = True
|
need_to_set_tag = True
|
||||||
if tags:
|
if tags:
|
||||||
if SERVER_NOTICE_ROOM_TAG in tags:
|
if SERVER_NOTICE_ROOM_TAG in tags:
|
||||||
# tag already present, nothing to do here
|
# tag already present, nothing to do here
|
||||||
need_to_set_tag = False
|
need_to_set_tag = False
|
||||||
if need_to_set_tag:
|
if need_to_set_tag:
|
||||||
max_id = yield self._store.add_tag_to_room(
|
max_id = await self._store.add_tag_to_room(
|
||||||
user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
|
user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
|
||||||
)
|
)
|
||||||
self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
|
self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _is_room_currently_blocked(self, room_id):
|
||||||
def _is_room_currently_blocked(self, room_id):
|
|
||||||
"""
|
"""
|
||||||
Determines if the room is currently blocked
|
Determines if the room is currently blocked
|
||||||
|
|
||||||
@ -198,7 +193,7 @@ class ResourceLimitsServerNotices(object):
|
|||||||
room_id(str): The room id of the server notices room
|
room_id(str): The room id of the server notices room
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
Deferred[Tuple[bool, List]]:
|
||||||
bool: Is the room currently blocked
|
bool: Is the room currently blocked
|
||||||
list: The list of pinned events that are unrelated to limit blocking
|
list: The list of pinned events that are unrelated to limit blocking
|
||||||
This list can be used as a convenience in the case where the block
|
This list can be used as a convenience in the case where the block
|
||||||
@ -208,7 +203,7 @@ class ResourceLimitsServerNotices(object):
|
|||||||
currently_blocked = False
|
currently_blocked = False
|
||||||
pinned_state_event = None
|
pinned_state_event = None
|
||||||
try:
|
try:
|
||||||
pinned_state_event = yield self._state.get_current_state(
|
pinned_state_event = await self._state.get_current_state(
|
||||||
room_id, event_type=EventTypes.Pinned
|
room_id, event_type=EventTypes.Pinned
|
||||||
)
|
)
|
||||||
except AuthError:
|
except AuthError:
|
||||||
@ -219,7 +214,7 @@ class ResourceLimitsServerNotices(object):
|
|||||||
if pinned_state_event is not None:
|
if pinned_state_event is not None:
|
||||||
referenced_events = list(pinned_state_event.content.get("pinned", []))
|
referenced_events = list(pinned_state_event.content.get("pinned", []))
|
||||||
|
|
||||||
events = yield self._store.get_events(referenced_events)
|
events = await self._store.get_events(referenced_events)
|
||||||
for event_id, event in iteritems(events):
|
for event_id, event in iteritems(events):
|
||||||
if event.type != EventTypes.Message:
|
if event.type != EventTypes.Message:
|
||||||
continue
|
continue
|
||||||
|
@ -14,11 +14,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership, RoomCreationPreset
|
from synapse.api.constants import EventTypes, Membership, RoomCreationPreset
|
||||||
from synapse.types import UserID, create_requester
|
from synapse.types import UserID, create_requester
|
||||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -51,8 +49,7 @@ class ServerNoticesManager(object):
|
|||||||
"""
|
"""
|
||||||
return self._config.server_notices_mxid is not None
|
return self._config.server_notices_mxid is not None
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def send_notice(
|
||||||
def send_notice(
|
|
||||||
self, user_id, event_content, type=EventTypes.Message, state_key=None
|
self, user_id, event_content, type=EventTypes.Message, state_key=None
|
||||||
):
|
):
|
||||||
"""Send a notice to the given user
|
"""Send a notice to the given user
|
||||||
@ -68,8 +65,8 @@ class ServerNoticesManager(object):
|
|||||||
Returns:
|
Returns:
|
||||||
Deferred[FrozenEvent]
|
Deferred[FrozenEvent]
|
||||||
"""
|
"""
|
||||||
room_id = yield self.get_or_create_notice_room_for_user(user_id)
|
room_id = await self.get_or_create_notice_room_for_user(user_id)
|
||||||
yield self.maybe_invite_user_to_room(user_id, room_id)
|
await self.maybe_invite_user_to_room(user_id, room_id)
|
||||||
|
|
||||||
system_mxid = self._config.server_notices_mxid
|
system_mxid = self._config.server_notices_mxid
|
||||||
requester = create_requester(system_mxid)
|
requester = create_requester(system_mxid)
|
||||||
@ -86,13 +83,13 @@ class ServerNoticesManager(object):
|
|||||||
if state_key is not None:
|
if state_key is not None:
|
||||||
event_dict["state_key"] = state_key
|
event_dict["state_key"] = state_key
|
||||||
|
|
||||||
res = yield self._event_creation_handler.create_and_send_nonmember_event(
|
res = await self._event_creation_handler.create_and_send_nonmember_event(
|
||||||
requester, event_dict, ratelimit=False
|
requester, event_dict, ratelimit=False
|
||||||
)
|
)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@cachedInlineCallbacks()
|
@cached()
|
||||||
def get_or_create_notice_room_for_user(self, user_id):
|
async def get_or_create_notice_room_for_user(self, user_id):
|
||||||
"""Get the room for notices for a given user
|
"""Get the room for notices for a given user
|
||||||
|
|
||||||
If we have not yet created a notice room for this user, create it, but don't
|
If we have not yet created a notice room for this user, create it, but don't
|
||||||
@ -109,7 +106,7 @@ class ServerNoticesManager(object):
|
|||||||
|
|
||||||
assert self._is_mine_id(user_id), "Cannot send server notices to remote users"
|
assert self._is_mine_id(user_id), "Cannot send server notices to remote users"
|
||||||
|
|
||||||
rooms = yield self._store.get_rooms_for_local_user_where_membership_is(
|
rooms = await self._store.get_rooms_for_local_user_where_membership_is(
|
||||||
user_id, [Membership.INVITE, Membership.JOIN]
|
user_id, [Membership.INVITE, Membership.JOIN]
|
||||||
)
|
)
|
||||||
for room in rooms:
|
for room in rooms:
|
||||||
@ -118,7 +115,7 @@ class ServerNoticesManager(object):
|
|||||||
# be joined. This is kinda deliberate, in that if somebody somehow
|
# be joined. This is kinda deliberate, in that if somebody somehow
|
||||||
# manages to invite the system user to a room, that doesn't make it
|
# manages to invite the system user to a room, that doesn't make it
|
||||||
# the server notices room.
|
# the server notices room.
|
||||||
user_ids = yield self._store.get_users_in_room(room.room_id)
|
user_ids = await self._store.get_users_in_room(room.room_id)
|
||||||
if self.server_notices_mxid in user_ids:
|
if self.server_notices_mxid in user_ids:
|
||||||
# we found a room which our user shares with the system notice
|
# we found a room which our user shares with the system notice
|
||||||
# user
|
# user
|
||||||
@ -146,7 +143,7 @@ class ServerNoticesManager(object):
|
|||||||
}
|
}
|
||||||
|
|
||||||
requester = create_requester(self.server_notices_mxid)
|
requester = create_requester(self.server_notices_mxid)
|
||||||
info = yield self._room_creation_handler.create_room(
|
info = await self._room_creation_handler.create_room(
|
||||||
requester,
|
requester,
|
||||||
config={
|
config={
|
||||||
"preset": RoomCreationPreset.PRIVATE_CHAT,
|
"preset": RoomCreationPreset.PRIVATE_CHAT,
|
||||||
@ -158,7 +155,7 @@ class ServerNoticesManager(object):
|
|||||||
)
|
)
|
||||||
room_id = info["room_id"]
|
room_id = info["room_id"]
|
||||||
|
|
||||||
max_id = yield self._store.add_tag_to_room(
|
max_id = await self._store.add_tag_to_room(
|
||||||
user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
|
user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
|
||||||
)
|
)
|
||||||
self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
|
self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
|
||||||
@ -166,8 +163,7 @@ class ServerNoticesManager(object):
|
|||||||
logger.info("Created server notices room %s for %s", room_id, user_id)
|
logger.info("Created server notices room %s for %s", room_id, user_id)
|
||||||
return room_id
|
return room_id
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def maybe_invite_user_to_room(self, user_id: str, room_id: str):
|
||||||
def maybe_invite_user_to_room(self, user_id: str, room_id: str):
|
|
||||||
"""Invite the given user to the given server room, unless the user has already
|
"""Invite the given user to the given server room, unless the user has already
|
||||||
joined or been invited to it.
|
joined or been invited to it.
|
||||||
|
|
||||||
@ -179,14 +175,14 @@ class ServerNoticesManager(object):
|
|||||||
|
|
||||||
# Check whether the user has already joined or been invited to this room. If
|
# Check whether the user has already joined or been invited to this room. If
|
||||||
# that's the case, there is no need to re-invite them.
|
# that's the case, there is no need to re-invite them.
|
||||||
joined_rooms = yield self._store.get_rooms_for_local_user_where_membership_is(
|
joined_rooms = await self._store.get_rooms_for_local_user_where_membership_is(
|
||||||
user_id, [Membership.INVITE, Membership.JOIN]
|
user_id, [Membership.INVITE, Membership.JOIN]
|
||||||
)
|
)
|
||||||
for room in joined_rooms:
|
for room in joined_rooms:
|
||||||
if room.room_id == room_id:
|
if room.room_id == room_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
yield self._room_member_handler.update_membership(
|
await self._room_member_handler.update_membership(
|
||||||
requester=requester,
|
requester=requester,
|
||||||
target=UserID.from_string(user_id),
|
target=UserID.from_string(user_id),
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
|
@ -12,8 +12,6 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.server_notices.consent_server_notices import ConsentServerNotices
|
from synapse.server_notices.consent_server_notices import ConsentServerNotices
|
||||||
from synapse.server_notices.resource_limits_server_notices import (
|
from synapse.server_notices.resource_limits_server_notices import (
|
||||||
ResourceLimitsServerNotices,
|
ResourceLimitsServerNotices,
|
||||||
@ -36,18 +34,16 @@ class ServerNoticesSender(object):
|
|||||||
ResourceLimitsServerNotices(hs),
|
ResourceLimitsServerNotices(hs),
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_user_syncing(self, user_id):
|
||||||
def on_user_syncing(self, user_id):
|
|
||||||
"""Called when the user performs a sync operation.
|
"""Called when the user performs a sync operation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): mxid of user who synced
|
user_id (str): mxid of user who synced
|
||||||
"""
|
"""
|
||||||
for sn in self._server_notices:
|
for sn in self._server_notices:
|
||||||
yield sn.maybe_send_server_notice_to_user(user_id)
|
await sn.maybe_send_server_notice_to_user(user_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_user_ip(self, user_id):
|
||||||
def on_user_ip(self, user_id):
|
|
||||||
"""Called on the master when a worker process saw a client request.
|
"""Called on the master when a worker process saw a client request.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -57,4 +53,4 @@ class ServerNoticesSender(object):
|
|||||||
# we check for notices to send to the user in on_user_ip as well as
|
# we check for notices to send to the user in on_user_ip as well as
|
||||||
# in on_user_syncing
|
# in on_user_syncing
|
||||||
for sn in self._server_notices:
|
for sn in self._server_notices:
|
||||||
yield sn.maybe_send_server_notice_to_user(user_id)
|
await sn.maybe_send_server_notice_to_user(user_id)
|
||||||
|
@ -273,8 +273,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||||||
desc="delete_account_validity_for_user",
|
desc="delete_account_validity_for_user",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def is_server_admin(self, user):
|
||||||
def is_server_admin(self, user):
|
|
||||||
"""Determines if a user is an admin of this homeserver.
|
"""Determines if a user is an admin of this homeserver.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -283,7 +282,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||||||
Returns (bool):
|
Returns (bool):
|
||||||
true iff the user is a server admin, false otherwise.
|
true iff the user is a server admin, false otherwise.
|
||||||
"""
|
"""
|
||||||
res = yield self.db.simple_select_one_onecol(
|
res = await self.db.simple_select_one_onecol(
|
||||||
table="users",
|
table="users",
|
||||||
keyvalues={"name": user.to_string()},
|
keyvalues={"name": user.to_string()},
|
||||||
retcol="admin",
|
retcol="admin",
|
||||||
|
@ -35,9 +35,13 @@ DELETE FROM background_updates WHERE update_name IN (
|
|||||||
'populate_stats_cleanup'
|
'populate_stats_cleanup'
|
||||||
);
|
);
|
||||||
|
|
||||||
|
-- this relies on current_state_events.membership having been populated, so add
|
||||||
|
-- a dependency on current_state_events_membership.
|
||||||
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
|
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
|
||||||
('populate_stats_process_rooms', '{}', '');
|
('populate_stats_process_rooms', '{}', 'current_state_events_membership');
|
||||||
|
|
||||||
|
-- this also relies on current_state_events.membership having been populated, but
|
||||||
|
-- we get that as a side-effect of depending on populate_stats_process_rooms.
|
||||||
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
|
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
|
||||||
('populate_stats_process_users', '{}', 'populate_stats_process_rooms');
|
('populate_stats_process_users', '{}', 'populate_stats_process_rooms');
|
||||||
|
|
||||||
|
100
tests/events/test_snapshot.py
Normal file
100
tests/events/test_snapshot.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from synapse.events.snapshot import EventContext
|
||||||
|
from synapse.rest import admin
|
||||||
|
from synapse.rest.client.v1 import login, room
|
||||||
|
|
||||||
|
from tests import unittest
|
||||||
|
from tests.test_utils.event_injection import create_event
|
||||||
|
|
||||||
|
|
||||||
|
class TestEventContext(unittest.HomeserverTestCase):
|
||||||
|
servlets = [
|
||||||
|
admin.register_servlets,
|
||||||
|
login.register_servlets,
|
||||||
|
room.register_servlets,
|
||||||
|
]
|
||||||
|
|
||||||
|
def prepare(self, reactor, clock, hs):
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.storage = hs.get_storage()
|
||||||
|
|
||||||
|
self.user_id = self.register_user("u1", "pass")
|
||||||
|
self.user_tok = self.login("u1", "pass")
|
||||||
|
self.room_id = self.helper.create_room_as(tok=self.user_tok)
|
||||||
|
|
||||||
|
def test_serialize_deserialize_msg(self):
|
||||||
|
"""Test that an EventContext for a message event is the same after
|
||||||
|
serialize/deserialize.
|
||||||
|
"""
|
||||||
|
|
||||||
|
event, context = create_event(
|
||||||
|
self.hs, room_id=self.room_id, type="m.test", sender=self.user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._check_serialize_deserialize(event, context)
|
||||||
|
|
||||||
|
def test_serialize_deserialize_state_no_prev(self):
|
||||||
|
"""Test that an EventContext for a state event (with not previous entry)
|
||||||
|
is the same after serialize/deserialize.
|
||||||
|
"""
|
||||||
|
event, context = create_event(
|
||||||
|
self.hs,
|
||||||
|
room_id=self.room_id,
|
||||||
|
type="m.test",
|
||||||
|
sender=self.user_id,
|
||||||
|
state_key="",
|
||||||
|
)
|
||||||
|
|
||||||
|
self._check_serialize_deserialize(event, context)
|
||||||
|
|
||||||
|
def test_serialize_deserialize_state_prev(self):
|
||||||
|
"""Test that an EventContext for a state event (which replaces a
|
||||||
|
previous entry) is the same after serialize/deserialize.
|
||||||
|
"""
|
||||||
|
event, context = create_event(
|
||||||
|
self.hs,
|
||||||
|
room_id=self.room_id,
|
||||||
|
type="m.room.member",
|
||||||
|
sender=self.user_id,
|
||||||
|
state_key=self.user_id,
|
||||||
|
content={"membership": "leave"},
|
||||||
|
)
|
||||||
|
|
||||||
|
self._check_serialize_deserialize(event, context)
|
||||||
|
|
||||||
|
def _check_serialize_deserialize(self, event, context):
|
||||||
|
serialized = self.get_success(context.serialize(event, self.store))
|
||||||
|
|
||||||
|
d_context = EventContext.deserialize(self.storage, serialized)
|
||||||
|
|
||||||
|
self.assertEqual(context.state_group, d_context.state_group)
|
||||||
|
self.assertEqual(context.rejected, d_context.rejected)
|
||||||
|
self.assertEqual(
|
||||||
|
context.state_group_before_event, d_context.state_group_before_event
|
||||||
|
)
|
||||||
|
self.assertEqual(context.prev_group, d_context.prev_group)
|
||||||
|
self.assertEqual(context.delta_ids, d_context.delta_ids)
|
||||||
|
self.assertEqual(context.app_service, d_context.app_service)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
self.get_success(context.get_current_state_ids()),
|
||||||
|
self.get_success(d_context.get_current_state_ids()),
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
self.get_success(context.get_prev_state_ids()),
|
||||||
|
self.get_success(d_context.get_prev_state_ids()),
|
||||||
|
)
|
@ -82,18 +82,26 @@ class ProfileTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_set_my_name(self):
|
def test_set_my_name(self):
|
||||||
yield self.handler.set_displayname(
|
yield defer.ensureDeferred(
|
||||||
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
|
self.handler.set_displayname(
|
||||||
|
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
(yield self.store.get_profile_displayname(self.frank.localpart)),
|
(
|
||||||
|
yield defer.ensureDeferred(
|
||||||
|
self.store.get_profile_displayname(self.frank.localpart)
|
||||||
|
)
|
||||||
|
),
|
||||||
"Frank Jr.",
|
"Frank Jr.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set displayname again
|
# Set displayname again
|
||||||
yield self.handler.set_displayname(
|
yield defer.ensureDeferred(
|
||||||
self.frank, synapse.types.create_requester(self.frank), "Frank"
|
self.handler.set_displayname(
|
||||||
|
self.frank, synapse.types.create_requester(self.frank), "Frank"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
@ -112,16 +120,20 @@ class ProfileTestCase(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setting displayname a second time is forbidden
|
# Setting displayname a second time is forbidden
|
||||||
d = self.handler.set_displayname(
|
d = defer.ensureDeferred(
|
||||||
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
|
self.handler.set_displayname(
|
||||||
|
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.assertFailure(d, SynapseError)
|
yield self.assertFailure(d, SynapseError)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_set_my_name_noauth(self):
|
def test_set_my_name_noauth(self):
|
||||||
d = self.handler.set_displayname(
|
d = defer.ensureDeferred(
|
||||||
self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
|
self.handler.set_displayname(
|
||||||
|
self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.assertFailure(d, AuthError)
|
yield self.assertFailure(d, AuthError)
|
||||||
@ -165,10 +177,12 @@ class ProfileTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_set_my_avatar(self):
|
def test_set_my_avatar(self):
|
||||||
yield self.handler.set_avatar_url(
|
yield defer.ensureDeferred(
|
||||||
self.frank,
|
self.handler.set_avatar_url(
|
||||||
synapse.types.create_requester(self.frank),
|
self.frank,
|
||||||
"http://my.server/pic.gif",
|
synapse.types.create_requester(self.frank),
|
||||||
|
"http://my.server/pic.gif",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
@ -177,10 +191,12 @@ class ProfileTestCase(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Set avatar again
|
# Set avatar again
|
||||||
yield self.handler.set_avatar_url(
|
yield defer.ensureDeferred(
|
||||||
self.frank,
|
self.handler.set_avatar_url(
|
||||||
synapse.types.create_requester(self.frank),
|
self.frank,
|
||||||
"http://my.server/me.png",
|
synapse.types.create_requester(self.frank),
|
||||||
|
"http://my.server/me.png",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
@ -203,10 +219,12 @@ class ProfileTestCase(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Set avatar a second time is forbidden
|
# Set avatar a second time is forbidden
|
||||||
d = self.handler.set_avatar_url(
|
d = defer.ensureDeferred(
|
||||||
self.frank,
|
self.handler.set_avatar_url(
|
||||||
synapse.types.create_requester(self.frank),
|
self.frank,
|
||||||
"http://my.server/pic.gif",
|
synapse.types.create_requester(self.frank),
|
||||||
|
"http://my.server/pic.gif",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.assertFailure(d, SynapseError)
|
yield self.assertFailure(d, SynapseError)
|
||||||
|
@ -175,7 +175,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||||||
room_alias_str = "#room:test"
|
room_alias_str = "#room:test"
|
||||||
self.hs.config.auto_join_rooms = [room_alias_str]
|
self.hs.config.auto_join_rooms = [room_alias_str]
|
||||||
|
|
||||||
self.store.is_real_user = Mock(return_value=False)
|
self.store.is_real_user = Mock(return_value=defer.succeed(False))
|
||||||
user_id = self.get_success(self.handler.register_user(localpart="support"))
|
user_id = self.get_success(self.handler.register_user(localpart="support"))
|
||||||
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
|
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
|
||||||
self.assertEqual(len(rooms), 0)
|
self.assertEqual(len(rooms), 0)
|
||||||
@ -187,8 +187,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||||||
room_alias_str = "#room:test"
|
room_alias_str = "#room:test"
|
||||||
self.hs.config.auto_join_rooms = [room_alias_str]
|
self.hs.config.auto_join_rooms = [room_alias_str]
|
||||||
|
|
||||||
self.store.count_real_users = Mock(return_value=1)
|
self.store.count_real_users = Mock(return_value=defer.succeed(1))
|
||||||
self.store.is_real_user = Mock(return_value=True)
|
self.store.is_real_user = Mock(return_value=defer.succeed(True))
|
||||||
user_id = self.get_success(self.handler.register_user(localpart="real"))
|
user_id = self.get_success(self.handler.register_user(localpart="real"))
|
||||||
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
|
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
|
||||||
directory_handler = self.hs.get_handlers().directory_handler
|
directory_handler = self.hs.get_handlers().directory_handler
|
||||||
@ -202,8 +202,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||||||
room_alias_str = "#room:test"
|
room_alias_str = "#room:test"
|
||||||
self.hs.config.auto_join_rooms = [room_alias_str]
|
self.hs.config.auto_join_rooms = [room_alias_str]
|
||||||
|
|
||||||
self.store.count_real_users = Mock(return_value=2)
|
self.store.count_real_users = Mock(return_value=defer.succeed(2))
|
||||||
self.store.is_real_user = Mock(return_value=True)
|
self.store.is_real_user = Mock(return_value=defer.succeed(True))
|
||||||
user_id = self.get_success(self.handler.register_user(localpart="real"))
|
user_id = self.get_success(self.handler.register_user(localpart="real"))
|
||||||
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
|
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
|
||||||
self.assertEqual(len(rooms), 0)
|
self.assertEqual(len(rooms), 0)
|
||||||
@ -256,8 +256,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||||||
self.handler.register_user(localpart=invalid_user_id), SynapseError
|
self.handler.register_user(localpart=invalid_user_id), SynapseError
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_or_create_user(
|
||||||
def get_or_create_user(self, requester, localpart, displayname, password_hash=None):
|
self, requester, localpart, displayname, password_hash=None
|
||||||
|
):
|
||||||
"""Creates a new user if the user does not exist,
|
"""Creates a new user if the user does not exist,
|
||||||
else revokes all previous access tokens and generates a new one.
|
else revokes all previous access tokens and generates a new one.
|
||||||
|
|
||||||
@ -272,11 +273,11 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||||||
"""
|
"""
|
||||||
if localpart is None:
|
if localpart is None:
|
||||||
raise SynapseError(400, "Request must include user id")
|
raise SynapseError(400, "Request must include user id")
|
||||||
yield self.hs.get_auth().check_auth_blocking()
|
await self.hs.get_auth().check_auth_blocking()
|
||||||
need_register = True
|
need_register = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield self.handler.check_username(localpart)
|
await self.handler.check_username(localpart)
|
||||||
except SynapseError as e:
|
except SynapseError as e:
|
||||||
if e.errcode == Codes.USER_IN_USE:
|
if e.errcode == Codes.USER_IN_USE:
|
||||||
need_register = False
|
need_register = False
|
||||||
@ -288,23 +289,21 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||||||
token = self.macaroon_generator.generate_access_token(user_id)
|
token = self.macaroon_generator.generate_access_token(user_id)
|
||||||
|
|
||||||
if need_register:
|
if need_register:
|
||||||
yield self.handler.register_with_store(
|
await self.handler.register_with_store(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
password_hash=password_hash,
|
password_hash=password_hash,
|
||||||
create_profile_with_displayname=user.localpart,
|
create_profile_with_displayname=user.localpart,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield defer.ensureDeferred(
|
await self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
|
||||||
self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
yield self.store.add_access_token_to_user(
|
await self.store.add_access_token_to_user(
|
||||||
user_id=user_id, token=token, device_id=None, valid_until_ms=None
|
user_id=user_id, token=token, device_id=None, valid_until_ms=None
|
||||||
)
|
)
|
||||||
|
|
||||||
if displayname is not None:
|
if displayname is not None:
|
||||||
# logger.info("setting user display name: %s -> %s", user_id, displayname)
|
# logger.info("setting user display name: %s -> %s", user_id, displayname)
|
||||||
yield self.hs.get_profile_handler().set_displayname(
|
await self.hs.get_profile_handler().set_displayname(
|
||||||
user, requester, displayname, by_admin=True
|
user, requester, displayname, by_admin=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
@ -22,13 +22,15 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
|
|||||||
from twisted.internet.task import LoopingCall
|
from twisted.internet.task import LoopingCall
|
||||||
from twisted.web.http import HTTPChannel
|
from twisted.web.http import HTTPChannel
|
||||||
|
|
||||||
from synapse.app.generic_worker import GenericWorkerServer
|
from synapse.app.generic_worker import (
|
||||||
|
GenericWorkerReplicationHandler,
|
||||||
|
GenericWorkerServer,
|
||||||
|
)
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
|
||||||
from synapse.replication.tcp.client import ReplicationDataHandler
|
|
||||||
from synapse.replication.tcp.handler import ReplicationCommandHandler
|
from synapse.replication.tcp.handler import ReplicationCommandHandler
|
||||||
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
|
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
|
||||||
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
||||||
|
from synapse.server import HomeServer
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
@ -77,7 +79,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||||||
self._server_transport = None
|
self._server_transport = None
|
||||||
|
|
||||||
def _build_replication_data_handler(self):
|
def _build_replication_data_handler(self):
|
||||||
return TestReplicationDataHandler(self.worker_hs.get_datastore())
|
return TestReplicationDataHandler(self.worker_hs)
|
||||||
|
|
||||||
def reconnect(self):
|
def reconnect(self):
|
||||||
if self._client_transport:
|
if self._client_transport:
|
||||||
@ -172,32 +174,20 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||||||
self.assertEqual(request.method, b"GET")
|
self.assertEqual(request.method, b"GET")
|
||||||
|
|
||||||
|
|
||||||
class TestReplicationDataHandler(ReplicationDataHandler):
|
class TestReplicationDataHandler(GenericWorkerReplicationHandler):
|
||||||
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
|
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
|
||||||
|
|
||||||
def __init__(self, store: BaseSlavedStore):
|
def __init__(self, hs: HomeServer):
|
||||||
super().__init__(store)
|
super().__init__(hs)
|
||||||
|
|
||||||
# streams to subscribe to: map from stream id to position
|
|
||||||
self.stream_positions = {} # type: Dict[str, int]
|
|
||||||
|
|
||||||
# list of received (stream_name, token, row) tuples
|
# list of received (stream_name, token, row) tuples
|
||||||
self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]
|
self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]
|
||||||
|
|
||||||
def get_streams_to_replicate(self):
|
async def on_rdata(self, stream_name, instance_name, token, rows):
|
||||||
return self.stream_positions
|
await super().on_rdata(stream_name, instance_name, token, rows)
|
||||||
|
|
||||||
async def on_rdata(self, stream_name, token, rows):
|
|
||||||
await super().on_rdata(stream_name, token, rows)
|
|
||||||
for r in rows:
|
for r in rows:
|
||||||
self.received_rdata_rows.append((stream_name, token, r))
|
self.received_rdata_rows.append((stream_name, token, r))
|
||||||
|
|
||||||
if (
|
|
||||||
stream_name in self.stream_positions
|
|
||||||
and token > self.stream_positions[stream_name]
|
|
||||||
):
|
|
||||||
self.stream_positions[stream_name] = token
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s()
|
@attr.s()
|
||||||
class OneShotRequestFactory:
|
class OneShotRequestFactory:
|
||||||
|
@ -43,7 +43,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
|||||||
self.user_tok = self.login("u1", "pass")
|
self.user_tok = self.login("u1", "pass")
|
||||||
|
|
||||||
self.reconnect()
|
self.reconnect()
|
||||||
self.test_handler.stream_positions["events"] = 0
|
|
||||||
|
|
||||||
self.room_id = self.helper.create_room_as(tok=self.user_tok)
|
self.room_id = self.helper.create_room_as(tok=self.user_tok)
|
||||||
self.test_handler.received_rdata_rows.clear()
|
self.test_handler.received_rdata_rows.clear()
|
||||||
@ -80,8 +79,12 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
|||||||
self.reconnect()
|
self.reconnect()
|
||||||
self.replicate()
|
self.replicate()
|
||||||
|
|
||||||
# we should have received all the expected rows in the right order
|
# we should have received all the expected rows in the right order (as
|
||||||
received_rows = self.test_handler.received_rdata_rows
|
# well as various cache invalidation updates which we ignore)
|
||||||
|
received_rows = [
|
||||||
|
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
|
||||||
|
]
|
||||||
|
|
||||||
for event in events:
|
for event in events:
|
||||||
stream_name, token, row = received_rows.pop(0)
|
stream_name, token, row = received_rows.pop(0)
|
||||||
self.assertEqual("events", stream_name)
|
self.assertEqual("events", stream_name)
|
||||||
@ -184,7 +187,8 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
|||||||
self.reconnect()
|
self.reconnect()
|
||||||
self.replicate()
|
self.replicate()
|
||||||
|
|
||||||
# now we should have received all the expected rows in the right order.
|
# we should have received all the expected rows in the right order (as
|
||||||
|
# well as various cache invalidation updates which we ignore)
|
||||||
#
|
#
|
||||||
# we expect:
|
# we expect:
|
||||||
#
|
#
|
||||||
@ -193,7 +197,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
|||||||
# of the states that got reverted.
|
# of the states that got reverted.
|
||||||
# - two rows for state2
|
# - two rows for state2
|
||||||
|
|
||||||
received_rows = self.test_handler.received_rdata_rows
|
received_rows = [
|
||||||
|
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
|
||||||
|
]
|
||||||
|
|
||||||
# first check the first two rows, which should be state1
|
# first check the first two rows, which should be state1
|
||||||
|
|
||||||
@ -334,9 +340,11 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
|||||||
self.reconnect()
|
self.reconnect()
|
||||||
self.replicate()
|
self.replicate()
|
||||||
|
|
||||||
# we should have received all the expected rows in the right order
|
# we should have received all the expected rows in the right order (as
|
||||||
|
# well as various cache invalidation updates which we ignore)
|
||||||
received_rows = self.test_handler.received_rdata_rows
|
received_rows = [
|
||||||
|
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
|
||||||
|
]
|
||||||
self.assertGreaterEqual(len(received_rows), len(events))
|
self.assertGreaterEqual(len(received_rows), len(events))
|
||||||
for i in range(NUM_USERS):
|
for i in range(NUM_USERS):
|
||||||
# for each user, we expect the PL event row, followed by state rows for
|
# for each user, we expect the PL event row, followed by state rows for
|
||||||
|
@ -31,9 +31,6 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
|
|||||||
def test_receipt(self):
|
def test_receipt(self):
|
||||||
self.reconnect()
|
self.reconnect()
|
||||||
|
|
||||||
# make the client subscribe to the receipts stream
|
|
||||||
self.test_handler.stream_positions.update({"receipts": 0})
|
|
||||||
|
|
||||||
# tell the master to send a new receipt
|
# tell the master to send a new receipt
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.hs.get_datastore().insert_receipt(
|
self.hs.get_datastore().insert_receipt(
|
||||||
@ -44,7 +41,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
|
|||||||
|
|
||||||
# there should be one RDATA command
|
# there should be one RDATA command
|
||||||
self.test_handler.on_rdata.assert_called_once()
|
self.test_handler.on_rdata.assert_called_once()
|
||||||
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
||||||
self.assertEqual(stream_name, "receipts")
|
self.assertEqual(stream_name, "receipts")
|
||||||
self.assertEqual(1, len(rdata_rows))
|
self.assertEqual(1, len(rdata_rows))
|
||||||
row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
|
row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
|
||||||
@ -74,7 +71,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
|
|||||||
|
|
||||||
# We should now have caught up and get the missing data
|
# We should now have caught up and get the missing data
|
||||||
self.test_handler.on_rdata.assert_called_once()
|
self.test_handler.on_rdata.assert_called_once()
|
||||||
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
||||||
self.assertEqual(stream_name, "receipts")
|
self.assertEqual(stream_name, "receipts")
|
||||||
self.assertEqual(token, 3)
|
self.assertEqual(token, 3)
|
||||||
self.assertEqual(1, len(rdata_rows))
|
self.assertEqual(1, len(rdata_rows))
|
||||||
|
@ -38,9 +38,6 @@ class TypingStreamTestCase(BaseStreamTestCase):
|
|||||||
|
|
||||||
self.reconnect()
|
self.reconnect()
|
||||||
|
|
||||||
# make the client subscribe to the typing stream
|
|
||||||
self.test_handler.stream_positions.update({"typing": 0})
|
|
||||||
|
|
||||||
typing._push_update(member=RoomMember(room_id, USER_ID), typing=True)
|
typing._push_update(member=RoomMember(room_id, USER_ID), typing=True)
|
||||||
|
|
||||||
self.reactor.advance(0)
|
self.reactor.advance(0)
|
||||||
@ -50,7 +47,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
|
|||||||
self.assert_request_is_get_repl_stream_updates(request, "typing")
|
self.assert_request_is_get_repl_stream_updates(request, "typing")
|
||||||
|
|
||||||
self.test_handler.on_rdata.assert_called_once()
|
self.test_handler.on_rdata.assert_called_once()
|
||||||
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
||||||
self.assertEqual(stream_name, "typing")
|
self.assertEqual(stream_name, "typing")
|
||||||
self.assertEqual(1, len(rdata_rows))
|
self.assertEqual(1, len(rdata_rows))
|
||||||
row = rdata_rows[0] # type: TypingStream.TypingStreamRow
|
row = rdata_rows[0] # type: TypingStream.TypingStreamRow
|
||||||
@ -77,7 +74,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
|
|||||||
self.assertEqual(int(request.args[b"from_token"][0]), token)
|
self.assertEqual(int(request.args[b"from_token"][0]), token)
|
||||||
|
|
||||||
self.test_handler.on_rdata.assert_called_once()
|
self.test_handler.on_rdata.assert_called_once()
|
||||||
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
||||||
self.assertEqual(stream_name, "typing")
|
self.assertEqual(stream_name, "typing")
|
||||||
self.assertEqual(1, len(rdata_rows))
|
self.assertEqual(1, len(rdata_rows))
|
||||||
row = rdata_rows[0]
|
row = rdata_rows[0]
|
||||||
|
@ -55,26 +55,19 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
|||||||
self._rlsn._store.user_last_seen_monthly_active = Mock(
|
self._rlsn._store.user_last_seen_monthly_active = Mock(
|
||||||
return_value=defer.succeed(1000)
|
return_value=defer.succeed(1000)
|
||||||
)
|
)
|
||||||
self._send_notice = self._rlsn._server_notices_manager.send_notice
|
self._rlsn._server_notices_manager.send_notice = Mock(
|
||||||
self._rlsn._server_notices_manager.send_notice = Mock()
|
return_value=defer.succeed(Mock())
|
||||||
self._rlsn._state.get_current_state = Mock(return_value=defer.succeed(None))
|
)
|
||||||
self._rlsn._store.get_events = Mock(return_value=defer.succeed({}))
|
|
||||||
|
|
||||||
self._send_notice = self._rlsn._server_notices_manager.send_notice
|
self._send_notice = self._rlsn._server_notices_manager.send_notice
|
||||||
|
|
||||||
self.hs.config.limit_usage_by_mau = True
|
self.hs.config.limit_usage_by_mau = True
|
||||||
self.user_id = "@user_id:test"
|
self.user_id = "@user_id:test"
|
||||||
|
|
||||||
# self.server_notices_mxid = "@server:test"
|
|
||||||
# self.server_notices_mxid_display_name = None
|
|
||||||
# self.server_notices_mxid_avatar_url = None
|
|
||||||
# self.server_notices_room_name = "Server Notices"
|
|
||||||
|
|
||||||
self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock(
|
self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock(
|
||||||
returnValue=""
|
return_value=defer.succeed("!something:localhost")
|
||||||
)
|
)
|
||||||
self._rlsn._store.add_tag_to_room = Mock()
|
self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
|
||||||
self._rlsn._store.get_tags_for_room = Mock(return_value={})
|
self._rlsn._store.get_tags_for_room = Mock(return_value=defer.succeed({}))
|
||||||
self.hs.config.admin_contact = "mailto:user@test.com"
|
self.hs.config.admin_contact = "mailto:user@test.com"
|
||||||
|
|
||||||
def test_maybe_send_server_notice_to_user_flag_off(self):
|
def test_maybe_send_server_notice_to_user_flag_off(self):
|
||||||
@ -95,14 +88,13 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
|||||||
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
|
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
|
||||||
"""Test when user has blocked notice, but should have it removed"""
|
"""Test when user has blocked notice, but should have it removed"""
|
||||||
|
|
||||||
self._rlsn._auth.check_auth_blocking = Mock()
|
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
|
||||||
mock_event = Mock(
|
mock_event = Mock(
|
||||||
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
|
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
|
||||||
)
|
)
|
||||||
self._rlsn._store.get_events = Mock(
|
self._rlsn._store.get_events = Mock(
|
||||||
return_value=defer.succeed({"123": mock_event})
|
return_value=defer.succeed({"123": mock_event})
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
||||||
# Would be better to check the content, but once == remove blocking event
|
# Would be better to check the content, but once == remove blocking event
|
||||||
self._send_notice.assert_called_once()
|
self._send_notice.assert_called_once()
|
||||||
@ -112,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
|||||||
Test when user has blocked notice, but notice ought to be there (NOOP)
|
Test when user has blocked notice, but notice ought to be there (NOOP)
|
||||||
"""
|
"""
|
||||||
self._rlsn._auth.check_auth_blocking = Mock(
|
self._rlsn._auth.check_auth_blocking = Mock(
|
||||||
side_effect=ResourceLimitError(403, "foo")
|
return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_event = Mock(
|
mock_event = Mock(
|
||||||
@ -121,6 +113,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
|||||||
self._rlsn._store.get_events = Mock(
|
self._rlsn._store.get_events = Mock(
|
||||||
return_value=defer.succeed({"123": mock_event})
|
return_value=defer.succeed({"123": mock_event})
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
||||||
|
|
||||||
self._send_notice.assert_not_called()
|
self._send_notice.assert_not_called()
|
||||||
@ -129,9 +122,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
|||||||
"""
|
"""
|
||||||
Test when user does not have blocked notice, but should have one
|
Test when user does not have blocked notice, but should have one
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._rlsn._auth.check_auth_blocking = Mock(
|
self._rlsn._auth.check_auth_blocking = Mock(
|
||||||
side_effect=ResourceLimitError(403, "foo")
|
return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
|
||||||
)
|
)
|
||||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
||||||
|
|
||||||
@ -142,7 +134,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
|||||||
"""
|
"""
|
||||||
Test when user does not have blocked notice, nor should they (NOOP)
|
Test when user does not have blocked notice, nor should they (NOOP)
|
||||||
"""
|
"""
|
||||||
self._rlsn._auth.check_auth_blocking = Mock()
|
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
|
||||||
|
|
||||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
||||||
|
|
||||||
@ -153,7 +145,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
|||||||
Test when user is not part of the MAU cohort - this should not ever
|
Test when user is not part of the MAU cohort - this should not ever
|
||||||
happen - but ...
|
happen - but ...
|
||||||
"""
|
"""
|
||||||
self._rlsn._auth.check_auth_blocking = Mock()
|
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
|
||||||
self._rlsn._store.user_last_seen_monthly_active = Mock(
|
self._rlsn._store.user_last_seen_monthly_active = Mock(
|
||||||
return_value=defer.succeed(None)
|
return_value=defer.succeed(None)
|
||||||
)
|
)
|
||||||
@ -167,24 +159,28 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
|||||||
an alert message is not sent into the room
|
an alert message is not sent into the room
|
||||||
"""
|
"""
|
||||||
self.hs.config.mau_limit_alerting = False
|
self.hs.config.mau_limit_alerting = False
|
||||||
|
|
||||||
self._rlsn._auth.check_auth_blocking = Mock(
|
self._rlsn._auth.check_auth_blocking = Mock(
|
||||||
|
return_value=defer.succeed(None),
|
||||||
side_effect=ResourceLimitError(
|
side_effect=ResourceLimitError(
|
||||||
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
|
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
||||||
|
|
||||||
self.assertTrue(self._send_notice.call_count == 0)
|
self.assertEqual(self._send_notice.call_count, 0)
|
||||||
|
|
||||||
def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self):
|
def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self):
|
||||||
"""
|
"""
|
||||||
Test that when a server is disabled, that MAU limit alerting is ignored.
|
Test that when a server is disabled, that MAU limit alerting is ignored.
|
||||||
"""
|
"""
|
||||||
self.hs.config.mau_limit_alerting = False
|
self.hs.config.mau_limit_alerting = False
|
||||||
|
|
||||||
self._rlsn._auth.check_auth_blocking = Mock(
|
self._rlsn._auth.check_auth_blocking = Mock(
|
||||||
|
return_value=defer.succeed(None),
|
||||||
side_effect=ResourceLimitError(
|
side_effect=ResourceLimitError(
|
||||||
403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
|
403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
||||||
|
|
||||||
@ -198,10 +194,12 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
|||||||
"""
|
"""
|
||||||
self.hs.config.mau_limit_alerting = False
|
self.hs.config.mau_limit_alerting = False
|
||||||
self._rlsn._auth.check_auth_blocking = Mock(
|
self._rlsn._auth.check_auth_blocking = Mock(
|
||||||
|
return_value=defer.succeed(None),
|
||||||
side_effect=ResourceLimitError(
|
side_effect=ResourceLimitError(
|
||||||
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
|
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock(
|
self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock(
|
||||||
return_value=defer.succeed((True, []))
|
return_value=defer.succeed((True, []))
|
||||||
)
|
)
|
||||||
@ -256,7 +254,9 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
|
|||||||
def test_server_notice_only_sent_once(self):
|
def test_server_notice_only_sent_once(self):
|
||||||
self.store.get_monthly_active_count = Mock(return_value=1000)
|
self.store.get_monthly_active_count = Mock(return_value=1000)
|
||||||
|
|
||||||
self.store.user_last_seen_monthly_active = Mock(return_value=1000)
|
self.store.user_last_seen_monthly_active = Mock(
|
||||||
|
return_value=defer.succeed(1000)
|
||||||
|
)
|
||||||
|
|
||||||
# Call the function multiple times to ensure we only send the notice once
|
# Call the function multiple times to ensure we only send the notice once
|
||||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
||||||
|
@ -27,8 +27,10 @@ class MessageAcceptTests(unittest.TestCase):
|
|||||||
user_id = UserID("us", "test")
|
user_id = UserID("us", "test")
|
||||||
our_user = Requester(user_id, None, False, None, None)
|
our_user = Requester(user_id, None, False, None, None)
|
||||||
room_creator = self.homeserver.get_room_creation_handler()
|
room_creator = self.homeserver.get_room_creation_handler()
|
||||||
room = room_creator.create_room(
|
room = ensureDeferred(
|
||||||
our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
|
room_creator.create_room(
|
||||||
|
our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.reactor.advance(0.1)
|
self.reactor.advance(0.1)
|
||||||
self.room_id = self.successResultOf(room)["room_id"]
|
self.room_id = self.successResultOf(room)["room_id"]
|
||||||
|
@ -14,12 +14,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import synapse.server
|
import synapse.server
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.types import Collection
|
from synapse.types import Collection
|
||||||
|
|
||||||
from tests.test_utils import get_awaitable_result
|
from tests.test_utils import get_awaitable_result
|
||||||
@ -75,6 +76,23 @@ def inject_event(
|
|||||||
"""
|
"""
|
||||||
test_reactor = hs.get_reactor()
|
test_reactor = hs.get_reactor()
|
||||||
|
|
||||||
|
event, context = create_event(hs, room_version, prev_event_ids, **kwargs)
|
||||||
|
|
||||||
|
d = hs.get_storage().persistence.persist_event(event, context)
|
||||||
|
test_reactor.advance(0)
|
||||||
|
get_awaitable_result(d)
|
||||||
|
|
||||||
|
return event
|
||||||
|
|
||||||
|
|
||||||
|
def create_event(
|
||||||
|
hs: synapse.server.HomeServer,
|
||||||
|
room_version: Optional[str] = None,
|
||||||
|
prev_event_ids: Optional[Collection[str]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> Tuple[EventBase, EventContext]:
|
||||||
|
test_reactor = hs.get_reactor()
|
||||||
|
|
||||||
if room_version is None:
|
if room_version is None:
|
||||||
d = hs.get_datastore().get_room_version_id(kwargs["room_id"])
|
d = hs.get_datastore().get_room_version_id(kwargs["room_id"])
|
||||||
test_reactor.advance(0)
|
test_reactor.advance(0)
|
||||||
@ -89,8 +107,4 @@ def inject_event(
|
|||||||
test_reactor.advance(0)
|
test_reactor.advance(0)
|
||||||
event, context = get_awaitable_result(d)
|
event, context = get_awaitable_result(d)
|
||||||
|
|
||||||
d = hs.get_storage().persistence.persist_event(event, context)
|
return event, context
|
||||||
test_reactor.advance(0)
|
|
||||||
get_awaitable_result(d)
|
|
||||||
|
|
||||||
return event
|
|
||||||
|
Loading…
Reference in New Issue
Block a user