Add ability to wait for replication streams (#7542)

The idea here is that if an instance persists an event via the replication HTTP API it can return before we receive that event over replication, which can lead to races where code assumes that persisting an event immediately updates various caches (e.g. current state of the room).

Most of Synapse doesn't hit such races, so we don't do the waiting automagically, instead we do so where necessary to avoid unnecessary delays. We may decide to change our minds here if it turns out there are a lot of subtle races going on.

People probably want to look at this commit by commit.
This commit is contained in:
Erik Johnston 2020-05-22 14:21:54 +01:00 committed by GitHub
parent 06a02bc1ce
commit 1531b214fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 304 additions and 112 deletions

1
changelog.d/7542.misc Normal file
View File

@ -0,0 +1 @@
Add ability to wait for replication streams.

View File

@ -126,6 +126,7 @@ class FederationHandler(BaseHandler):
self._server_notices_mxid = hs.config.server_notices_mxid self._server_notices_mxid = hs.config.server_notices_mxid
self.config = hs.config self.config = hs.config
self.http_client = hs.get_simple_http_client() self.http_client = hs.get_simple_http_client()
self._replication = hs.get_replication_data_handler()
self._send_events_to_master = ReplicationFederationSendEventsRestServlet.make_client( self._send_events_to_master = ReplicationFederationSendEventsRestServlet.make_client(
hs hs
@ -1221,7 +1222,7 @@ class FederationHandler(BaseHandler):
async def do_invite_join( async def do_invite_join(
self, target_hosts: Iterable[str], room_id: str, joinee: str, content: JsonDict self, target_hosts: Iterable[str], room_id: str, joinee: str, content: JsonDict
) -> None: ) -> Tuple[str, int]:
""" Attempts to join the `joinee` to the room `room_id` via the """ Attempts to join the `joinee` to the room `room_id` via the
servers contained in `target_hosts`. servers contained in `target_hosts`.
@ -1304,15 +1305,23 @@ class FederationHandler(BaseHandler):
room_id=room_id, room_version=room_version_obj, room_id=room_id, room_version=room_version_obj,
) )
await self._persist_auth_tree( max_stream_id = await self._persist_auth_tree(
origin, auth_chain, state, event, room_version_obj origin, auth_chain, state, event, room_version_obj
) )
# We wait here until this instance has seen the events come down
# replication (if we're using replication) as the below uses caches.
#
# TODO: Currently the events stream is written to from master
await self._replication.wait_for_stream_position(
"master", "events", max_stream_id
)
# Check whether this room is the result of an upgrade of a room we already know # Check whether this room is the result of an upgrade of a room we already know
# about. If so, migrate over user information # about. If so, migrate over user information
predecessor = await self.store.get_room_predecessor(room_id) predecessor = await self.store.get_room_predecessor(room_id)
if not predecessor or not isinstance(predecessor.get("room_id"), str): if not predecessor or not isinstance(predecessor.get("room_id"), str):
return return event.event_id, max_stream_id
old_room_id = predecessor["room_id"] old_room_id = predecessor["room_id"]
logger.debug( logger.debug(
"Found predecessor for %s during remote join: %s", room_id, old_room_id "Found predecessor for %s during remote join: %s", room_id, old_room_id
@ -1325,6 +1334,7 @@ class FederationHandler(BaseHandler):
) )
logger.debug("Finished joining %s to %s", joinee, room_id) logger.debug("Finished joining %s to %s", joinee, room_id)
return event.event_id, max_stream_id
finally: finally:
room_queue = self.room_queues[room_id] room_queue = self.room_queues[room_id]
del self.room_queues[room_id] del self.room_queues[room_id]
@ -1554,7 +1564,7 @@ class FederationHandler(BaseHandler):
async def do_remotely_reject_invite( async def do_remotely_reject_invite(
self, target_hosts: Iterable[str], room_id: str, user_id: str, content: JsonDict self, target_hosts: Iterable[str], room_id: str, user_id: str, content: JsonDict
) -> EventBase: ) -> Tuple[EventBase, int]:
origin, event, room_version = await self._make_and_verify_event( origin, event, room_version = await self._make_and_verify_event(
target_hosts, room_id, user_id, "leave", content=content target_hosts, room_id, user_id, "leave", content=content
) )
@ -1574,9 +1584,9 @@ class FederationHandler(BaseHandler):
await self.federation_client.send_leave(target_hosts, event) await self.federation_client.send_leave(target_hosts, event)
context = await self.state_handler.compute_event_context(event) context = await self.state_handler.compute_event_context(event)
await self.persist_events_and_notify([(event, context)]) stream_id = await self.persist_events_and_notify([(event, context)])
return event return event, stream_id
async def _make_and_verify_event( async def _make_and_verify_event(
self, self,
@ -1888,7 +1898,7 @@ class FederationHandler(BaseHandler):
state: List[EventBase], state: List[EventBase],
event: EventBase, event: EventBase,
room_version: RoomVersion, room_version: RoomVersion,
) -> None: ) -> int:
"""Checks the auth chain is valid (and passes auth checks) for the """Checks the auth chain is valid (and passes auth checks) for the
state and event. Then persists the auth chain and state atomically. state and event. Then persists the auth chain and state atomically.
Persists the event separately. Notifies about the persisted events Persists the event separately. Notifies about the persisted events
@ -1982,7 +1992,7 @@ class FederationHandler(BaseHandler):
event, old_state=state event, old_state=state
) )
await self.persist_events_and_notify([(event, new_event_context)]) return await self.persist_events_and_notify([(event, new_event_context)])
async def _prep_event( async def _prep_event(
self, self,
@ -2835,7 +2845,7 @@ class FederationHandler(BaseHandler):
self, self,
event_and_contexts: Sequence[Tuple[EventBase, EventContext]], event_and_contexts: Sequence[Tuple[EventBase, EventContext]],
backfilled: bool = False, backfilled: bool = False,
) -> None: ) -> int:
"""Persists events and tells the notifier/pushers about them, if """Persists events and tells the notifier/pushers about them, if
necessary. necessary.
@ -2845,11 +2855,12 @@ class FederationHandler(BaseHandler):
backfilling or not backfilling or not
""" """
if self.config.worker_app: if self.config.worker_app:
await self._send_events_to_master( result = await self._send_events_to_master(
store=self.store, store=self.store,
event_and_contexts=event_and_contexts, event_and_contexts=event_and_contexts,
backfilled=backfilled, backfilled=backfilled,
) )
return result["max_stream_id"]
else: else:
max_stream_id = await self.storage.persistence.persist_events( max_stream_id = await self.storage.persistence.persist_events(
event_and_contexts, backfilled=backfilled event_and_contexts, backfilled=backfilled
@ -2864,6 +2875,8 @@ class FederationHandler(BaseHandler):
for event, _ in event_and_contexts: for event, _ in event_and_contexts:
await self._notify_persisted_event(event, max_stream_id) await self._notify_persisted_event(event, max_stream_id)
return max_stream_id
async def _notify_persisted_event( async def _notify_persisted_event(
self, event: EventBase, max_stream_id: int self, event: EventBase, max_stream_id: int
) -> None: ) -> None:

View File

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Optional from typing import Optional, Tuple
from six import iteritems, itervalues, string_types from six import iteritems, itervalues, string_types
@ -42,6 +42,7 @@ from synapse.api.errors import (
) )
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.api.urls import ConsentURIBuilder from synapse.api.urls import ConsentURIBuilder
from synapse.events import EventBase
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.logging.context import run_in_background from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
@ -630,7 +631,9 @@ class EventCreationHandler(object):
msg = self._block_events_without_consent_error % {"consent_uri": consent_uri} msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri) raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
async def send_nonmember_event(self, requester, event, context, ratelimit=True): async def send_nonmember_event(
self, requester, event, context, ratelimit=True
) -> int:
""" """
Persists and notifies local clients and federation of an event. Persists and notifies local clients and federation of an event.
@ -639,6 +642,9 @@ class EventCreationHandler(object):
context (Context) the context of the event. context (Context) the context of the event.
ratelimit (bool): Whether to rate limit this send. ratelimit (bool): Whether to rate limit this send.
is_guest (bool): Whether the sender is a guest. is_guest (bool): Whether the sender is a guest.
Return:
The stream_id of the persisted event.
""" """
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
raise SynapseError( raise SynapseError(
@ -659,7 +665,7 @@ class EventCreationHandler(object):
) )
return prev_state return prev_state
await self.handle_new_client_event( return await self.handle_new_client_event(
requester=requester, event=event, context=context, ratelimit=ratelimit requester=requester, event=event, context=context, ratelimit=ratelimit
) )
@ -688,7 +694,7 @@ class EventCreationHandler(object):
async def create_and_send_nonmember_event( async def create_and_send_nonmember_event(
self, requester, event_dict, ratelimit=True, txn_id=None self, requester, event_dict, ratelimit=True, txn_id=None
): ) -> Tuple[EventBase, int]:
""" """
Creates an event, then sends it. Creates an event, then sends it.
@ -711,10 +717,10 @@ class EventCreationHandler(object):
spam_error = "Spam is not permitted here" spam_error = "Spam is not permitted here"
raise SynapseError(403, spam_error, Codes.FORBIDDEN) raise SynapseError(403, spam_error, Codes.FORBIDDEN)
await self.send_nonmember_event( stream_id = await self.send_nonmember_event(
requester, event, context, ratelimit=ratelimit requester, event, context, ratelimit=ratelimit
) )
return event return event, stream_id
@measure_func("create_new_client_event") @measure_func("create_new_client_event")
@defer.inlineCallbacks @defer.inlineCallbacks
@ -774,7 +780,7 @@ class EventCreationHandler(object):
@measure_func("handle_new_client_event") @measure_func("handle_new_client_event")
async def handle_new_client_event( async def handle_new_client_event(
self, requester, event, context, ratelimit=True, extra_users=[] self, requester, event, context, ratelimit=True, extra_users=[]
): ) -> int:
"""Processes a new event. This includes checking auth, persisting it, """Processes a new event. This includes checking auth, persisting it,
notifying users, sending to remote servers, etc. notifying users, sending to remote servers, etc.
@ -787,6 +793,9 @@ class EventCreationHandler(object):
context (EventContext) context (EventContext)
ratelimit (bool) ratelimit (bool)
extra_users (list(UserID)): Any extra users to notify about event extra_users (list(UserID)): Any extra users to notify about event
Return:
The stream_id of the persisted event.
""" """
if event.is_state() and (event.type, event.state_key) == ( if event.is_state() and (event.type, event.state_key) == (
@ -827,7 +836,7 @@ class EventCreationHandler(object):
try: try:
# If we're a worker we need to hit out to the master. # If we're a worker we need to hit out to the master.
if self.config.worker_app: if self.config.worker_app:
await self.send_event_to_master( result = await self.send_event_to_master(
event_id=event.event_id, event_id=event.event_id,
store=self.store, store=self.store,
requester=requester, requester=requester,
@ -836,14 +845,17 @@ class EventCreationHandler(object):
ratelimit=ratelimit, ratelimit=ratelimit,
extra_users=extra_users, extra_users=extra_users,
) )
stream_id = result["stream_id"]
event.internal_metadata.stream_ordering = stream_id
success = True success = True
return return stream_id
await self.persist_and_notify_client_event( stream_id = await self.persist_and_notify_client_event(
requester, event, context, ratelimit=ratelimit, extra_users=extra_users requester, event, context, ratelimit=ratelimit, extra_users=extra_users
) )
success = True success = True
return stream_id
finally: finally:
if not success: if not success:
# Ensure that we actually remove the entries in the push actions # Ensure that we actually remove the entries in the push actions
@ -886,7 +898,7 @@ class EventCreationHandler(object):
async def persist_and_notify_client_event( async def persist_and_notify_client_event(
self, requester, event, context, ratelimit=True, extra_users=[] self, requester, event, context, ratelimit=True, extra_users=[]
): ) -> int:
"""Called when we have fully built the event, have already """Called when we have fully built the event, have already
calculated the push actions for the event, and checked auth. calculated the push actions for the event, and checked auth.
@ -1076,6 +1088,8 @@ class EventCreationHandler(object):
# matters as sometimes presence code can take a while. # matters as sometimes presence code can take a while.
run_in_background(self._bump_active_time, requester.user) run_in_background(self._bump_active_time, requester.user)
return event_stream_id
async def _bump_active_time(self, user): async def _bump_active_time(self, user):
try: try:
presence = self.hs.get_presence_handler() presence = self.hs.get_presence_handler()

View File

@ -22,6 +22,7 @@ import logging
import math import math
import string import string
from collections import OrderedDict from collections import OrderedDict
from typing import Tuple
from six import iteritems, string_types from six import iteritems, string_types
@ -518,7 +519,7 @@ class RoomCreationHandler(BaseHandler):
async def create_room( async def create_room(
self, requester, config, ratelimit=True, creator_join_profile=None self, requester, config, ratelimit=True, creator_join_profile=None
): ) -> Tuple[dict, int]:
""" Creates a new room. """ Creates a new room.
Args: Args:
@ -535,9 +536,9 @@ class RoomCreationHandler(BaseHandler):
`avatar_url` and/or `displayname`. `avatar_url` and/or `displayname`.
Returns: Returns:
Deferred[dict]: First, a dict containing the keys `room_id` and, if an alias
a dict containing the keys `room_id` and, if an alias was was, requested, `room_alias`. Secondly, the stream_id of the
requested, `room_alias`. last persisted event.
Raises: Raises:
SynapseError if the room ID couldn't be stored, or something went SynapseError if the room ID couldn't be stored, or something went
horribly wrong. horribly wrong.
@ -669,7 +670,7 @@ class RoomCreationHandler(BaseHandler):
# override any attempt to set room versions via the creation_content # override any attempt to set room versions via the creation_content
creation_content["room_version"] = room_version.identifier creation_content["room_version"] = room_version.identifier
await self._send_events_for_new_room( last_stream_id = await self._send_events_for_new_room(
requester, requester,
room_id, room_id,
preset_config=preset_config, preset_config=preset_config,
@ -683,7 +684,10 @@ class RoomCreationHandler(BaseHandler):
if "name" in config: if "name" in config:
name = config["name"] name = config["name"]
await self.event_creation_handler.create_and_send_nonmember_event( (
_,
last_stream_id,
) = await self.event_creation_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.Name, "type": EventTypes.Name,
@ -697,7 +701,10 @@ class RoomCreationHandler(BaseHandler):
if "topic" in config: if "topic" in config:
topic = config["topic"] topic = config["topic"]
await self.event_creation_handler.create_and_send_nonmember_event( (
_,
last_stream_id,
) = await self.event_creation_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.Topic, "type": EventTypes.Topic,
@ -715,7 +722,7 @@ class RoomCreationHandler(BaseHandler):
if is_direct: if is_direct:
content["is_direct"] = is_direct content["is_direct"] = is_direct
await self.room_member_handler.update_membership( _, last_stream_id = await self.room_member_handler.update_membership(
requester, requester,
UserID.from_string(invitee), UserID.from_string(invitee),
room_id, room_id,
@ -729,7 +736,7 @@ class RoomCreationHandler(BaseHandler):
id_access_token = invite_3pid.get("id_access_token") # optional id_access_token = invite_3pid.get("id_access_token") # optional
address = invite_3pid["address"] address = invite_3pid["address"]
medium = invite_3pid["medium"] medium = invite_3pid["medium"]
await self.hs.get_room_member_handler().do_3pid_invite( last_stream_id = await self.hs.get_room_member_handler().do_3pid_invite(
room_id, room_id,
requester.user, requester.user,
medium, medium,
@ -745,7 +752,7 @@ class RoomCreationHandler(BaseHandler):
if room_alias: if room_alias:
result["room_alias"] = room_alias.to_string() result["room_alias"] = room_alias.to_string()
return result return result, last_stream_id
async def _send_events_for_new_room( async def _send_events_for_new_room(
self, self,
@ -758,7 +765,13 @@ class RoomCreationHandler(BaseHandler):
room_alias=None, room_alias=None,
power_level_content_override=None, # Doesn't apply when initial state has power level state event content power_level_content_override=None, # Doesn't apply when initial state has power level state event content
creator_join_profile=None, creator_join_profile=None,
): ) -> int:
"""Sends the initial events into a new room.
Returns:
The stream_id of the last event persisted.
"""
def create(etype, content, **kwargs): def create(etype, content, **kwargs):
e = {"type": etype, "content": content} e = {"type": etype, "content": content}
@ -767,12 +780,16 @@ class RoomCreationHandler(BaseHandler):
return e return e
async def send(etype, content, **kwargs): async def send(etype, content, **kwargs) -> int:
event = create(etype, content, **kwargs) event = create(etype, content, **kwargs)
logger.debug("Sending %s in new room", etype) logger.debug("Sending %s in new room", etype)
await self.event_creation_handler.create_and_send_nonmember_event( (
_,
last_stream_id,
) = await self.event_creation_handler.create_and_send_nonmember_event(
creator, event, ratelimit=False creator, event, ratelimit=False
) )
return last_stream_id
config = RoomCreationHandler.PRESETS_DICT[preset_config] config = RoomCreationHandler.PRESETS_DICT[preset_config]
@ -797,7 +814,9 @@ class RoomCreationHandler(BaseHandler):
# of the first events that get sent into a room. # of the first events that get sent into a room.
pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None)
if pl_content is not None: if pl_content is not None:
await send(etype=EventTypes.PowerLevels, content=pl_content) last_sent_stream_id = await send(
etype=EventTypes.PowerLevels, content=pl_content
)
else: else:
power_level_content = { power_level_content = {
"users": {creator_id: 100}, "users": {creator_id: 100},
@ -830,33 +849,39 @@ class RoomCreationHandler(BaseHandler):
if power_level_content_override: if power_level_content_override:
power_level_content.update(power_level_content_override) power_level_content.update(power_level_content_override)
await send(etype=EventTypes.PowerLevels, content=power_level_content) last_sent_stream_id = await send(
etype=EventTypes.PowerLevels, content=power_level_content
)
if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
await send( last_sent_stream_id = await send(
etype=EventTypes.CanonicalAlias, etype=EventTypes.CanonicalAlias,
content={"alias": room_alias.to_string()}, content={"alias": room_alias.to_string()},
) )
if (EventTypes.JoinRules, "") not in initial_state: if (EventTypes.JoinRules, "") not in initial_state:
await send( last_sent_stream_id = await send(
etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]} etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]}
) )
if (EventTypes.RoomHistoryVisibility, "") not in initial_state: if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
await send( last_sent_stream_id = await send(
etype=EventTypes.RoomHistoryVisibility, etype=EventTypes.RoomHistoryVisibility,
content={"history_visibility": config["history_visibility"]}, content={"history_visibility": config["history_visibility"]},
) )
if config["guest_can_join"]: if config["guest_can_join"]:
if (EventTypes.GuestAccess, "") not in initial_state: if (EventTypes.GuestAccess, "") not in initial_state:
await send( last_sent_stream_id = await send(
etype=EventTypes.GuestAccess, content={"guest_access": "can_join"} etype=EventTypes.GuestAccess, content={"guest_access": "can_join"}
) )
for (etype, state_key), content in initial_state.items(): for (etype, state_key), content in initial_state.items():
await send(etype=etype, state_key=state_key, content=content) last_sent_stream_id = await send(
etype=etype, state_key=state_key, content=content
)
return last_sent_stream_id
async def _generate_room_id( async def _generate_room_id(
self, creator_id: str, is_public: str, room_version: RoomVersion, self, creator_id: str, is_public: str, room_version: RoomVersion,

View File

@ -17,7 +17,7 @@
import abc import abc
import logging import logging
from typing import Dict, Iterable, List, Optional, Tuple, Union from typing import Dict, Iterable, List, Optional, Tuple
from six.moves import http_client from six.moves import http_client
@ -84,7 +84,7 @@ class RoomMemberHandler(object):
room_id: str, room_id: str,
user: UserID, user: UserID,
content: dict, content: dict,
) -> Optional[dict]: ) -> Tuple[str, int]:
"""Try and join a room that this server is not in """Try and join a room that this server is not in
Args: Args:
@ -104,7 +104,7 @@ class RoomMemberHandler(object):
room_id: str, room_id: str,
target: UserID, target: UserID,
content: dict, content: dict,
) -> dict: ) -> Tuple[Optional[str], int]:
"""Attempt to reject an invite for a room this server is not in. If we """Attempt to reject an invite for a room this server is not in. If we
fail to do so we locally mark the invite as rejected. fail to do so we locally mark the invite as rejected.
@ -154,7 +154,7 @@ class RoomMemberHandler(object):
ratelimit: bool = True, ratelimit: bool = True,
content: Optional[dict] = None, content: Optional[dict] = None,
require_consent: bool = True, require_consent: bool = True,
) -> EventBase: ) -> Tuple[str, int]:
user_id = target.to_string() user_id = target.to_string()
if content is None: if content is None:
@ -187,9 +187,10 @@ class RoomMemberHandler(object):
) )
if duplicate is not None: if duplicate is not None:
# Discard the new event since this membership change is a no-op. # Discard the new event since this membership change is a no-op.
return duplicate _, stream_id = await self.store.get_event_ordering(duplicate.event_id)
return duplicate.event_id, stream_id
await self.event_creation_handler.handle_new_client_event( stream_id = await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target], ratelimit=ratelimit requester, event, context, extra_users=[target], ratelimit=ratelimit
) )
@ -213,7 +214,7 @@ class RoomMemberHandler(object):
if prev_member_event.membership == Membership.JOIN: if prev_member_event.membership == Membership.JOIN:
await self._user_left_room(target, room_id) await self._user_left_room(target, room_id)
return event return event.event_id, stream_id
async def copy_room_tags_and_direct_to_room( async def copy_room_tags_and_direct_to_room(
self, old_room_id, new_room_id, user_id self, old_room_id, new_room_id, user_id
@ -263,7 +264,7 @@ class RoomMemberHandler(object):
ratelimit: bool = True, ratelimit: bool = True,
content: Optional[dict] = None, content: Optional[dict] = None,
require_consent: bool = True, require_consent: bool = True,
) -> Union[EventBase, Optional[dict]]: ) -> Tuple[Optional[str], int]:
key = (room_id,) key = (room_id,)
with (await self.member_linearizer.queue(key)): with (await self.member_linearizer.queue(key)):
@ -294,7 +295,7 @@ class RoomMemberHandler(object):
ratelimit: bool = True, ratelimit: bool = True,
content: Optional[dict] = None, content: Optional[dict] = None,
require_consent: bool = True, require_consent: bool = True,
) -> Union[EventBase, Optional[dict]]: ) -> Tuple[Optional[str], int]:
content_specified = bool(content) content_specified = bool(content)
if content is None: if content is None:
content = {} content = {}
@ -398,7 +399,13 @@ class RoomMemberHandler(object):
same_membership = old_membership == effective_membership_state same_membership = old_membership == effective_membership_state
same_sender = requester.user.to_string() == old_state.sender same_sender = requester.user.to_string() == old_state.sender
if same_sender and same_membership and same_content: if same_sender and same_membership and same_content:
return old_state _, stream_id = await self.store.get_event_ordering(
old_state.event_id
)
return (
old_state.event_id,
stream_id,
)
if old_membership in ["ban", "leave"] and action == "kick": if old_membership in ["ban", "leave"] and action == "kick":
raise AuthError(403, "The target user is not in the room") raise AuthError(403, "The target user is not in the room")
@ -705,7 +712,7 @@ class RoomMemberHandler(object):
requester: Requester, requester: Requester,
txn_id: Optional[str], txn_id: Optional[str],
id_access_token: Optional[str] = None, id_access_token: Optional[str] = None,
) -> None: ) -> int:
if self.config.block_non_admin_invites: if self.config.block_non_admin_invites:
is_requester_admin = await self.auth.is_server_admin(requester.user) is_requester_admin = await self.auth.is_server_admin(requester.user)
if not is_requester_admin: if not is_requester_admin:
@ -737,11 +744,11 @@ class RoomMemberHandler(object):
) )
if invitee: if invitee:
await self.update_membership( _, stream_id = await self.update_membership(
requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
) )
else: else:
await self._make_and_store_3pid_invite( stream_id = await self._make_and_store_3pid_invite(
requester, requester,
id_server, id_server,
medium, medium,
@ -752,6 +759,8 @@ class RoomMemberHandler(object):
id_access_token=id_access_token, id_access_token=id_access_token,
) )
return stream_id
async def _make_and_store_3pid_invite( async def _make_and_store_3pid_invite(
self, self,
requester: Requester, requester: Requester,
@ -762,7 +771,7 @@ class RoomMemberHandler(object):
user: UserID, user: UserID,
txn_id: Optional[str], txn_id: Optional[str],
id_access_token: Optional[str] = None, id_access_token: Optional[str] = None,
) -> None: ) -> int:
room_state = await self.state_handler.get_current_state(room_id) room_state = await self.state_handler.get_current_state(room_id)
inviter_display_name = "" inviter_display_name = ""
@ -817,7 +826,10 @@ class RoomMemberHandler(object):
id_access_token=id_access_token, id_access_token=id_access_token,
) )
await self.event_creation_handler.create_and_send_nonmember_event( (
event,
stream_id,
) = await self.event_creation_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.ThirdPartyInvite, "type": EventTypes.ThirdPartyInvite,
@ -835,6 +847,7 @@ class RoomMemberHandler(object):
ratelimit=False, ratelimit=False,
txn_id=txn_id, txn_id=txn_id,
) )
return stream_id
async def _is_host_in_room( async def _is_host_in_room(
self, current_state_ids: Dict[Tuple[str, str], str] self, current_state_ids: Dict[Tuple[str, str], str]
@ -916,7 +929,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
room_id: str, room_id: str,
user: UserID, user: UserID,
content: dict, content: dict,
) -> None: ) -> Tuple[str, int]:
"""Implements RoomMemberHandler._remote_join """Implements RoomMemberHandler._remote_join
""" """
# filter ourselves out of remote_room_hosts: do_invite_join ignores it # filter ourselves out of remote_room_hosts: do_invite_join ignores it
@ -945,7 +958,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# join dance for now, since we're kinda implicitly checking # join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we # that we are allowed to join when we decide whether or not we
# need to do the invite/join dance. # need to do the invite/join dance.
await self.federation_handler.do_invite_join( event_id, stream_id = await self.federation_handler.do_invite_join(
remote_room_hosts, room_id, user.to_string(), content remote_room_hosts, room_id, user.to_string(), content
) )
await self._user_joined_room(user, room_id) await self._user_joined_room(user, room_id)
@ -955,14 +968,14 @@ class RoomMemberMasterHandler(RoomMemberHandler):
if self.hs.config.limit_remote_rooms.enabled: if self.hs.config.limit_remote_rooms.enabled:
if too_complex is False: if too_complex is False:
# We checked, and we're under the limit. # We checked, and we're under the limit.
return return event_id, stream_id
# Check again, but with the local state events # Check again, but with the local state events
too_complex = await self._is_local_room_too_complex(room_id) too_complex = await self._is_local_room_too_complex(room_id)
if too_complex is False: if too_complex is False:
# We're under the limit. # We're under the limit.
return return event_id, stream_id
# The room is too large. Leave. # The room is too large. Leave.
requester = types.create_requester(user, None, False, None) requester = types.create_requester(user, None, False, None)
@ -975,6 +988,8 @@ class RoomMemberMasterHandler(RoomMemberHandler):
errcode=Codes.RESOURCE_LIMIT_EXCEEDED, errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
) )
return event_id, stream_id
async def _remote_reject_invite( async def _remote_reject_invite(
self, self,
requester: Requester, requester: Requester,
@ -982,15 +997,15 @@ class RoomMemberMasterHandler(RoomMemberHandler):
room_id: str, room_id: str,
target: UserID, target: UserID,
content: dict, content: dict,
) -> dict: ) -> Tuple[Optional[str], int]:
"""Implements RoomMemberHandler._remote_reject_invite """Implements RoomMemberHandler._remote_reject_invite
""" """
fed_handler = self.federation_handler fed_handler = self.federation_handler
try: try:
ret = await fed_handler.do_remotely_reject_invite( event, stream_id = await fed_handler.do_remotely_reject_invite(
remote_room_hosts, room_id, target.to_string(), content=content, remote_room_hosts, room_id, target.to_string(), content=content,
) )
return ret return event.event_id, stream_id
except Exception as e: except Exception as e:
# if we were unable to reject the exception, just mark # if we were unable to reject the exception, just mark
# it as rejected on our end and plough ahead. # it as rejected on our end and plough ahead.
@ -1000,8 +1015,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# #
logger.warning("Failed to reject invite: %s", e) logger.warning("Failed to reject invite: %s", e)
await self.store.locally_reject_invite(target.to_string(), room_id) stream_id = await self.store.locally_reject_invite(
return {} target.to_string(), room_id
)
return None, stream_id
async def _user_joined_room(self, target: UserID, room_id: str) -> None: async def _user_joined_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_joined_room """Implements RoomMemberHandler._user_joined_room

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import List, Optional from typing import List, Optional, Tuple
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.handlers.room_member import RoomMemberHandler from synapse.handlers.room_member import RoomMemberHandler
@ -43,7 +43,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
room_id: str, room_id: str,
user: UserID, user: UserID,
content: dict, content: dict,
) -> Optional[dict]: ) -> Tuple[str, int]:
"""Implements RoomMemberHandler._remote_join """Implements RoomMemberHandler._remote_join
""" """
if len(remote_room_hosts) == 0: if len(remote_room_hosts) == 0:
@ -59,7 +59,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
await self._user_joined_room(user, room_id) await self._user_joined_room(user, room_id)
return ret return ret["event_id"], ret["stream_id"]
async def _remote_reject_invite( async def _remote_reject_invite(
self, self,
@ -68,16 +68,17 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
room_id: str, room_id: str,
target: UserID, target: UserID,
content: dict, content: dict,
) -> dict: ) -> Tuple[Optional[str], int]:
"""Implements RoomMemberHandler._remote_reject_invite """Implements RoomMemberHandler._remote_reject_invite
""" """
return await self._remote_reject_client( ret = await self._remote_reject_client(
requester=requester, requester=requester,
remote_room_hosts=remote_room_hosts, remote_room_hosts=remote_room_hosts,
room_id=room_id, room_id=room_id,
user_id=target.to_string(), user_id=target.to_string(),
content=content, content=content,
) )
return ret["event_id"], ret["stream_id"]
async def _user_joined_room(self, target: UserID, room_id: str) -> None: async def _user_joined_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_joined_room """Implements RoomMemberHandler._user_joined_room

View File

@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
"""Handles events newly received from federation, including persisting and """Handles events newly received from federation, including persisting and
notifying. notifying. Returns the maximum stream ID of the persisted events.
The API looks like: The API looks like:
@ -46,6 +46,13 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
"context": { .. serialized event context .. }, "context": { .. serialized event context .. },
}], }],
"backfilled": false "backfilled": false
}
200 OK
{
"max_stream_id": 32443,
}
""" """
NAME = "fed_send_events" NAME = "fed_send_events"
@ -115,11 +122,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
logger.info("Got %d events from federation", len(event_and_contexts)) logger.info("Got %d events from federation", len(event_and_contexts))
await self.federation_handler.persist_events_and_notify( max_stream_id = await self.federation_handler.persist_events_and_notify(
event_and_contexts, backfilled event_and_contexts, backfilled
) )
return 200, {} return 200, {"max_stream_id": max_stream_id}
class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):

View File

@ -76,11 +76,11 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
logger.info("remote_join: %s into room: %s", user_id, room_id) logger.info("remote_join: %s into room: %s", user_id, room_id)
await self.federation_handler.do_invite_join( event_id, stream_id = await self.federation_handler.do_invite_join(
remote_room_hosts, room_id, user_id, event_content remote_room_hosts, room_id, user_id, event_content
) )
return 200, {} return 200, {"event_id": event_id, "stream_id": stream_id}
class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
@ -136,10 +136,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id) logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id)
try: try:
event = await self.federation_handler.do_remotely_reject_invite( event, stream_id = await self.federation_handler.do_remotely_reject_invite(
remote_room_hosts, room_id, user_id, event_content, remote_room_hosts, room_id, user_id, event_content,
) )
ret = event.get_pdu_json() event_id = event.event_id
except Exception as e: except Exception as e:
# if we were unable to reject the exception, just mark # if we were unable to reject the exception, just mark
# it as rejected on our end and plough ahead. # it as rejected on our end and plough ahead.
@ -149,10 +149,10 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
# #
logger.warning("Failed to reject invite: %s", e) logger.warning("Failed to reject invite: %s", e)
await self.store.locally_reject_invite(user_id, room_id) stream_id = await self.store.locally_reject_invite(user_id, room_id)
ret = {} event_id = None
return 200, ret return 200, {"event_id": event_id, "stream_id": stream_id}
class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):

View File

@ -119,11 +119,11 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
"Got event to send with ID: %s into room: %s", event.event_id, event.room_id "Got event to send with ID: %s into room: %s", event.event_id, event.room_id
) )
await self.event_creation_handler.persist_and_notify_client_event( stream_id = await self.event_creation_handler.persist_and_notify_client_event(
requester, event, context, ratelimit=ratelimit, extra_users=extra_users requester, event, context, ratelimit=ratelimit, extra_users=extra_users
) )
return 200, {} return 200, {"stream_id": stream_id}
def register_servlets(hs, http_server): def register_servlets(hs, http_server):

View File

@ -51,10 +51,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
super().__init__(hs) super().__init__(hs)
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
self.streams = hs.get_replication_streams()
# We pull the streams from the replication handler (if we try and make
# them ourselves we end up in an import loop).
self.streams = hs.get_tcp_replication().get_streams()
@staticmethod @staticmethod
def _serialize_payload(stream_name, from_token, upto_token): def _serialize_payload(stream_name, from_token, upto_token):

View File

@ -14,19 +14,23 @@
# limitations under the License. # limitations under the License.
"""A replication client for use by synapse workers. """A replication client for use by synapse workers.
""" """
import heapq
import logging import logging
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Dict, List, Tuple
from twisted.internet.defer import Deferred
from twisted.internet.protocol import ReconnectingClientFactory from twisted.internet.protocol import ReconnectingClientFactory
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.streams.events import ( from synapse.replication.tcp.streams.events import (
EventsStream, EventsStream,
EventsStreamEventRow, EventsStreamEventRow,
EventsStreamRow, EventsStreamRow,
) )
from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -35,6 +39,10 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# How long we allow callers to wait for replication updates before timing out.
_WAIT_FOR_REPLICATION_TIMEOUT_SECONDS = 30
class DirectTcpReplicationClientFactory(ReconnectingClientFactory): class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
"""Factory for building connections to the master. Will reconnect if the """Factory for building connections to the master. Will reconnect if the
connection is lost. connection is lost.
@ -92,6 +100,16 @@ class ReplicationDataHandler:
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.pusher_pool = hs.get_pusherpool() self.pusher_pool = hs.get_pusherpool()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self._reactor = hs.get_reactor()
self._clock = hs.get_clock()
self._streams = hs.get_replication_streams()
self._instance_name = hs.get_instance_name()
# Map from stream to list of deferreds waiting for the stream to
# arrive at a particular position. The lists are sorted by stream position.
self._streams_to_waiters = (
{}
) # type: Dict[str, List[Tuple[int, Deferred[None]]]]
async def on_rdata( async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list self, stream_name: str, instance_name: str, token: int, rows: list
@ -131,8 +149,76 @@ class ReplicationDataHandler:
await self.pusher_pool.on_new_notifications(token, token) await self.pusher_pool.on_new_notifications(token, token)
# Notify any waiting deferreds. The list is ordered by position so we
# just iterate through the list until we reach a position that is
# greater than the received row position.
waiting_list = self._streams_to_waiters.get(stream_name, [])
# Index of first item with a position after the current token, i.e we
# have called all deferreds before this index. If not overwritten by
# loop below means either a) no items in list so no-op or b) all items
# in list were called and so the list should be cleared. Setting it to
# `len(list)` works for both cases.
index_of_first_deferred_not_called = len(waiting_list)
for idx, (position, deferred) in enumerate(waiting_list):
if position <= token:
try:
with PreserveLoggingContext():
deferred.callback(None)
except Exception:
# The deferred has been cancelled or timed out.
pass
else:
# The list is sorted by position so we don't need to continue
# checking any futher entries in the list.
index_of_first_deferred_not_called = idx
break
# Drop all entries in the waiting list that were called in the above
# loop. (This maintains the order so no need to resort)
waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
async def on_position(self, stream_name: str, instance_name: str, token: int): async def on_position(self, stream_name: str, instance_name: str, token: int):
self.store.process_replication_rows(stream_name, instance_name, token, []) self.store.process_replication_rows(stream_name, instance_name, token, [])
def on_remote_server_up(self, server: str): def on_remote_server_up(self, server: str):
"""Called when get a new REMOTE_SERVER_UP command.""" """Called when get a new REMOTE_SERVER_UP command."""
async def wait_for_stream_position(
self, instance_name: str, stream_name: str, position: int
):
"""Wait until this instance has received updates up to and including
the given stream position.
"""
if instance_name == self._instance_name:
# We don't get told about updates written by this process, and
# anyway in that case we don't need to wait.
return
current_position = self._streams[stream_name].current_token(self._instance_name)
if position <= current_position:
# We're already past the position
return
# Create a new deferred that times out after N seconds, as we don't want
# to wedge here forever.
deferred = Deferred()
deferred = timeout_deferred(
deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor
)
waiting_list = self._streams_to_waiters.setdefault(stream_name, [])
# We insert into the list using heapq as it is more efficient than
# pushing then resorting each time.
heapq.heappush(waiting_list, (position, deferred))
# We measure here to get in flight counts and average waiting time.
with Measure(self._clock, "repl.wait_for_stream_position"):
logger.info("Waiting for repl stream %r to reach %s", stream_name, position)
await make_deferred_yieldable(deferred)
logger.info(
"Finished waiting for repl stream %r to reach %s", stream_name, position
)

View File

@ -59,6 +59,7 @@ class ShutdownRoomRestServlet(RestServlet):
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler() self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._replication = hs.get_replication_data_handler()
async def on_POST(self, request, room_id): async def on_POST(self, request, room_id):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
@ -73,7 +74,7 @@ class ShutdownRoomRestServlet(RestServlet):
message = content.get("message", self.DEFAULT_MESSAGE) message = content.get("message", self.DEFAULT_MESSAGE)
room_name = content.get("room_name", "Content Violation Notification") room_name = content.get("room_name", "Content Violation Notification")
info = await self._room_creation_handler.create_room( info, stream_id = await self._room_creation_handler.create_room(
room_creator_requester, room_creator_requester,
config={ config={
"preset": "public_chat", "preset": "public_chat",
@ -94,6 +95,13 @@ class ShutdownRoomRestServlet(RestServlet):
# desirable in case the first attempt at blocking the room failed below. # desirable in case the first attempt at blocking the room failed below.
await self.store.block_room(room_id, requester_user_id) await self.store.block_room(room_id, requester_user_id)
# We now wait for the create room to come back in via replication so
# that we can assume that all the joins/invites have propogated before
# we try and auto join below.
#
# TODO: Currently the events stream is written to from master
await self._replication.wait_for_stream_position("master", "events", stream_id)
users = await self.state.get_current_users_in_room(room_id) users = await self.state.get_current_users_in_room(room_id)
kicked_users = [] kicked_users = []
failed_to_kick_users = [] failed_to_kick_users = []

View File

@ -93,7 +93,7 @@ class RoomCreateRestServlet(TransactionRestServlet):
async def on_POST(self, request): async def on_POST(self, request):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
info = await self._room_creation_handler.create_room( info, _ = await self._room_creation_handler.create_room(
requester, self.get_room_config(request) requester, self.get_room_config(request)
) )
@ -202,7 +202,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
if event_type == EventTypes.Member: if event_type == EventTypes.Member:
membership = content.get("membership", None) membership = content.get("membership", None)
event = await self.room_member_handler.update_membership( event_id, _ = await self.room_member_handler.update_membership(
requester, requester,
target=UserID.from_string(state_key), target=UserID.from_string(state_key),
room_id=room_id, room_id=room_id,
@ -210,14 +210,18 @@ class RoomStateEventRestServlet(TransactionRestServlet):
content=content, content=content,
) )
else: else:
event = await self.event_creation_handler.create_and_send_nonmember_event( (
event,
_,
) = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, txn_id=txn_id requester, event_dict, txn_id=txn_id
) )
event_id = event.event_id
ret = {} # type: dict ret = {} # type: dict
if event: if event_id:
set_tag("event_id", event.event_id) set_tag("event_id", event_id)
ret = {"event_id": event.event_id} ret = {"event_id": event_id}
return 200, ret return 200, ret
@ -247,7 +251,7 @@ class RoomSendEventRestServlet(TransactionRestServlet):
if b"ts" in request.args and requester.app_service: if b"ts" in request.args and requester.app_service:
event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
event = await self.event_creation_handler.create_and_send_nonmember_event( event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, txn_id=txn_id requester, event_dict, txn_id=txn_id
) )
@ -781,7 +785,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
event = await self.event_creation_handler.create_and_send_nonmember_event( event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.Redaction, "type": EventTypes.Redaction,

View File

@ -111,7 +111,7 @@ class RelationSendServlet(RestServlet):
"sender": requester.user.to_string(), "sender": requester.user.to_string(),
} }
event = await self.event_creation_handler.create_and_send_nonmember_event( event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict=event_dict, txn_id=txn_id requester, event_dict=event_dict, txn_id=txn_id
) )

View File

@ -90,6 +90,7 @@ from synapse.push.pusherpool import PusherPool
from synapse.replication.tcp.client import ReplicationDataHandler from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.resource import ReplicationStreamer from synapse.replication.tcp.resource import ReplicationStreamer
from synapse.replication.tcp.streams import STREAMS_MAP
from synapse.rest.media.v1.media_repository import ( from synapse.rest.media.v1.media_repository import (
MediaRepository, MediaRepository,
MediaRepositoryResource, MediaRepositoryResource,
@ -210,6 +211,7 @@ class HomeServer(object):
"storage", "storage",
"replication_streamer", "replication_streamer",
"replication_data_handler", "replication_data_handler",
"replication_streams",
] ]
REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
@ -583,6 +585,9 @@ class HomeServer(object):
def build_replication_data_handler(self): def build_replication_data_handler(self):
return ReplicationDataHandler(self) return ReplicationDataHandler(self)
def build_replication_streams(self):
return {stream.NAME: stream(self) for stream in STREAMS_MAP.values()}
def remove_pusher(self, app_id, push_key, user_id): def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)

View File

@ -1,3 +1,5 @@
from typing import Dict
import twisted.internet import twisted.internet
import synapse.api.auth import synapse.api.auth
@ -28,6 +30,7 @@ import synapse.server_notices.server_notices_sender
import synapse.state import synapse.state
import synapse.storage import synapse.storage
from synapse.events.builder import EventBuilderFactory from synapse.events.builder import EventBuilderFactory
from synapse.replication.tcp.streams import Stream
class HomeServer(object): class HomeServer(object):
@property @property
@ -136,3 +139,5 @@ class HomeServer(object):
pass pass
def get_pusherpool(self) -> synapse.push.pusherpool.PusherPool: def get_pusherpool(self) -> synapse.push.pusherpool.PusherPool:
pass pass
def get_replication_streams(self) -> Dict[str, Stream]:
pass

View File

@ -83,10 +83,10 @@ class ServerNoticesManager(object):
if state_key is not None: if state_key is not None:
event_dict["state_key"] = state_key event_dict["state_key"] = state_key
res = await self._event_creation_handler.create_and_send_nonmember_event( event, _ = await self._event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, ratelimit=False requester, event_dict, ratelimit=False
) )
return res return event
@cached() @cached()
async def get_or_create_notice_room_for_user(self, user_id): async def get_or_create_notice_room_for_user(self, user_id):
@ -143,7 +143,7 @@ class ServerNoticesManager(object):
} }
requester = create_requester(self.server_notices_mxid) requester = create_requester(self.server_notices_mxid)
info = await self._room_creation_handler.create_room( info, _ = await self._room_creation_handler.create_room(
requester, requester,
config={ config={
"preset": RoomCreationPreset.PRIVATE_CHAT, "preset": RoomCreationPreset.PRIVATE_CHAT,

View File

@ -1289,12 +1289,12 @@ class EventsWorkerStore(SQLBaseStore):
async def is_event_after(self, event_id1, event_id2): async def is_event_after(self, event_id1, event_id2):
"""Returns True if event_id1 is after event_id2 in the stream """Returns True if event_id1 is after event_id2 in the stream
""" """
to_1, so_1 = await self._get_event_ordering(event_id1) to_1, so_1 = await self.get_event_ordering(event_id1)
to_2, so_2 = await self._get_event_ordering(event_id2) to_2, so_2 = await self.get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2) return (to_1, so_1) > (to_2, so_2)
@cachedInlineCallbacks(max_entries=5000) @cachedInlineCallbacks(max_entries=5000)
def _get_event_ordering(self, event_id): def get_event_ordering(self, event_id):
res = yield self.db.simple_select_one( res = yield self.db.simple_select_one(
table="events", table="events",
retcols=["topological_ordering", "stream_ordering"], retcols=["topological_ordering", "stream_ordering"],

View File

@ -1069,6 +1069,8 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
with self._stream_id_gen.get_next() as stream_ordering: with self._stream_id_gen.get_next() as stream_ordering:
yield self.db.runInteraction("locally_reject_invite", f, stream_ordering) yield self.db.runInteraction("locally_reject_invite", f, stream_ordering)
return stream_ordering
def forget(self, user_id, room_id): def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id.""" """Indicate that user_id wishes to discard history for room_id."""

View File

@ -79,7 +79,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join # Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999})) fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(return_value=defer.succeed(1)) handler.federation_handler.do_invite_join = Mock(
return_value=defer.succeed(("", 1))
)
d = handler._remote_join( d = handler._remote_join(
None, None,
@ -115,7 +117,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join # Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=defer.succeed(None)) fed_transport.client.get_json = Mock(return_value=defer.succeed(None))
handler.federation_handler.do_invite_join = Mock(return_value=defer.succeed(1)) handler.federation_handler.do_invite_join = Mock(
return_value=defer.succeed(("", 1))
)
# Artificially raise the complexity # Artificially raise the complexity
self.hs.get_datastore().get_current_state_event_counts = lambda x: defer.succeed( self.hs.get_datastore().get_current_state_event_counts = lambda x: defer.succeed(

View File

@ -86,7 +86,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
reactor.pump((1000,)) reactor.pump((1000,))
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(
notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring notifier=Mock(),
http_client=mock_federation_client,
keyring=mock_keyring,
replication_streams={},
) )
hs.datastores = datastores hs.datastores = datastores

View File

@ -39,7 +39,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
# Create a test user and room # Create a test user and room
self.user = UserID("alice", "test") self.user = UserID("alice", "test")
self.requester = Requester(self.user, None, False, None, None) self.requester = Requester(self.user, None, False, None, None)
info = self.get_success(self.room_creator.create_room(self.requester, {})) info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"] self.room_id = info["room_id"]
def run_background_update(self): def run_background_update(self):
@ -261,7 +261,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.user = UserID.from_string(self.register_user("user1", "password")) self.user = UserID.from_string(self.register_user("user1", "password"))
self.token1 = self.login("user1", "password") self.token1 = self.login("user1", "password")
self.requester = Requester(self.user, None, False, None, None) self.requester = Requester(self.user, None, False, None, None)
info = self.get_success(self.room_creator.create_room(self.requester, {})) info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"] self.room_id = info["room_id"]
self.event_creator = homeserver.get_event_creation_handler() self.event_creator = homeserver.get_event_creation_handler()
homeserver.config.user_consent_version = self.CONSENT_VERSION homeserver.config.user_consent_version = self.CONSENT_VERSION

View File

@ -33,7 +33,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
events = [(3, 2), (6, 2), (4, 6)] events = [(3, 2), (6, 2), (4, 6)]
for event_count, extrems in events: for event_count, extrems in events:
info = self.get_success(room_creator.create_room(requester, {})) info, _ = self.get_success(room_creator.create_room(requester, {}))
room_id = info["room_id"] room_id = info["room_id"]
last_event = None last_event = None

View File

@ -28,13 +28,13 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
user_id = UserID("us", "test") user_id = UserID("us", "test")
our_user = Requester(user_id, None, False, None, None) our_user = Requester(user_id, None, False, None, None)
room_creator = self.homeserver.get_room_creation_handler() room_creator = self.homeserver.get_room_creation_handler()
room = ensureDeferred( room_deferred = ensureDeferred(
room_creator.create_room( room_creator.create_room(
our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
) )
) )
self.reactor.advance(0.1) self.reactor.advance(0.1)
self.room_id = self.successResultOf(room)["room_id"] self.room_id = self.successResultOf(room_deferred)[0]["room_id"]
self.store = self.homeserver.get_datastore() self.store = self.homeserver.get_datastore()