mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-12-15 17:58:45 -05:00
Merge branch 'develop' of github.com:matrix-org/synapse into erikj/initial_sync_asnyc
This commit is contained in:
commit
6828b47c45
156 changed files with 1623 additions and 846 deletions
|
|
@ -13,8 +13,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
|
||||
class AccountDataEventSource(object):
|
||||
def __init__(self, hs):
|
||||
|
|
@ -23,15 +21,14 @@ class AccountDataEventSource(object):
|
|||
def get_current_key(self, direction="f"):
|
||||
return self.store.get_max_account_data_stream_id()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_new_events(self, user, from_key, **kwargs):
|
||||
async def get_new_events(self, user, from_key, **kwargs):
|
||||
user_id = user.to_string()
|
||||
last_stream_id = from_key
|
||||
|
||||
current_stream_id = yield self.store.get_max_account_data_stream_id()
|
||||
current_stream_id = self.store.get_max_account_data_stream_id()
|
||||
|
||||
results = []
|
||||
tags = yield self.store.get_updated_tags(user_id, last_stream_id)
|
||||
tags = await self.store.get_updated_tags(user_id, last_stream_id)
|
||||
|
||||
for room_id, room_tags in tags.items():
|
||||
results.append(
|
||||
|
|
@ -41,7 +38,7 @@ class AccountDataEventSource(object):
|
|||
(
|
||||
account_data,
|
||||
room_account_data,
|
||||
) = yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
|
||||
) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id)
|
||||
|
||||
for account_data_type, content in account_data.items():
|
||||
results.append({"type": account_data_type, "content": content})
|
||||
|
|
@ -54,6 +51,5 @@ class AccountDataEventSource(object):
|
|||
|
||||
return results, current_stream_id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_pagination_rows(self, user, config, key):
|
||||
async def get_pagination_rows(self, user, config, key):
|
||||
return [], config.to_id
|
||||
|
|
|
|||
|
|
@ -18,8 +18,7 @@ import email.utils
|
|||
import logging
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
|
||||
from twisted.internet import defer
|
||||
from typing import List
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
|
|
@ -78,42 +77,39 @@ class AccountValidityHandler(object):
|
|||
# run as a background process to make sure that the database transactions
|
||||
# have a logcontext to report to
|
||||
return run_as_background_process(
|
||||
"send_renewals", self.send_renewal_emails
|
||||
"send_renewals", self._send_renewal_emails
|
||||
)
|
||||
|
||||
self.clock.looping_call(send_emails, 30 * 60 * 1000)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_renewal_emails(self):
|
||||
async def _send_renewal_emails(self):
|
||||
"""Gets the list of users whose account is expiring in the amount of time
|
||||
configured in the ``renew_at`` parameter from the ``account_validity``
|
||||
configuration, and sends renewal emails to all of these users as long as they
|
||||
have an email 3PID attached to their account.
|
||||
"""
|
||||
expiring_users = yield self.store.get_users_expiring_soon()
|
||||
expiring_users = await self.store.get_users_expiring_soon()
|
||||
|
||||
if expiring_users:
|
||||
for user in expiring_users:
|
||||
yield self._send_renewal_email(
|
||||
await self._send_renewal_email(
|
||||
user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_renewal_email_to_user(self, user_id):
|
||||
expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
|
||||
yield self._send_renewal_email(user_id, expiration_ts)
|
||||
async def send_renewal_email_to_user(self, user_id: str):
|
||||
expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
|
||||
await self._send_renewal_email(user_id, expiration_ts)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _send_renewal_email(self, user_id, expiration_ts):
|
||||
async def _send_renewal_email(self, user_id: str, expiration_ts: int):
|
||||
"""Sends out a renewal email to every email address attached to the given user
|
||||
with a unique link allowing them to renew their account.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the user to send email(s) to.
|
||||
expiration_ts (int): Timestamp in milliseconds for the expiration date of
|
||||
user_id: ID of the user to send email(s) to.
|
||||
expiration_ts: Timestamp in milliseconds for the expiration date of
|
||||
this user's account (used in the email templates).
|
||||
"""
|
||||
addresses = yield self._get_email_addresses_for_user(user_id)
|
||||
addresses = await self._get_email_addresses_for_user(user_id)
|
||||
|
||||
# Stop right here if the user doesn't have at least one email address.
|
||||
# In this case, they will have to ask their server admin to renew their
|
||||
|
|
@ -125,7 +121,7 @@ class AccountValidityHandler(object):
|
|||
return
|
||||
|
||||
try:
|
||||
user_display_name = yield self.store.get_profile_displayname(
|
||||
user_display_name = await self.store.get_profile_displayname(
|
||||
UserID.from_string(user_id).localpart
|
||||
)
|
||||
if user_display_name is None:
|
||||
|
|
@ -133,7 +129,7 @@ class AccountValidityHandler(object):
|
|||
except StoreError:
|
||||
user_display_name = user_id
|
||||
|
||||
renewal_token = yield self._get_renewal_token(user_id)
|
||||
renewal_token = await self._get_renewal_token(user_id)
|
||||
url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % (
|
||||
self.hs.config.public_baseurl,
|
||||
renewal_token,
|
||||
|
|
@ -165,7 +161,7 @@ class AccountValidityHandler(object):
|
|||
|
||||
logger.info("Sending renewal email to %s", address)
|
||||
|
||||
yield make_deferred_yieldable(
|
||||
await make_deferred_yieldable(
|
||||
self.sendmail(
|
||||
self.hs.config.email_smtp_host,
|
||||
self._raw_from,
|
||||
|
|
@ -180,19 +176,18 @@ class AccountValidityHandler(object):
|
|||
)
|
||||
)
|
||||
|
||||
yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
|
||||
await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_email_addresses_for_user(self, user_id):
|
||||
async def _get_email_addresses_for_user(self, user_id: str) -> List[str]:
|
||||
"""Retrieve the list of email addresses attached to a user's account.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the user to lookup email addresses for.
|
||||
user_id: ID of the user to lookup email addresses for.
|
||||
|
||||
Returns:
|
||||
defer.Deferred[list[str]]: Email addresses for this account.
|
||||
Email addresses for this account.
|
||||
"""
|
||||
threepids = yield self.store.user_get_threepids(user_id)
|
||||
threepids = await self.store.user_get_threepids(user_id)
|
||||
|
||||
addresses = []
|
||||
for threepid in threepids:
|
||||
|
|
@ -201,16 +196,15 @@ class AccountValidityHandler(object):
|
|||
|
||||
return addresses
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_renewal_token(self, user_id):
|
||||
async def _get_renewal_token(self, user_id: str) -> str:
|
||||
"""Generates a 32-byte long random string that will be inserted into the
|
||||
user's renewal email's unique link, then saves it into the database.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the user to generate a string for.
|
||||
user_id: ID of the user to generate a string for.
|
||||
|
||||
Returns:
|
||||
defer.Deferred[str]: The generated string.
|
||||
The generated string.
|
||||
|
||||
Raises:
|
||||
StoreError(500): Couldn't generate a unique string after 5 attempts.
|
||||
|
|
@ -219,52 +213,52 @@ class AccountValidityHandler(object):
|
|||
while attempts < 5:
|
||||
try:
|
||||
renewal_token = stringutils.random_string(32)
|
||||
yield self.store.set_renewal_token_for_user(user_id, renewal_token)
|
||||
await self.store.set_renewal_token_for_user(user_id, renewal_token)
|
||||
return renewal_token
|
||||
except StoreError:
|
||||
attempts += 1
|
||||
raise StoreError(500, "Couldn't generate a unique string as refresh string.")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def renew_account(self, renewal_token):
|
||||
async def renew_account(self, renewal_token: str) -> bool:
|
||||
"""Renews the account attached to a given renewal token by pushing back the
|
||||
expiration date by the current validity period in the server's configuration.
|
||||
|
||||
Args:
|
||||
renewal_token (str): Token sent with the renewal request.
|
||||
renewal_token: Token sent with the renewal request.
|
||||
Returns:
|
||||
bool: Whether the provided token is valid.
|
||||
Whether the provided token is valid.
|
||||
"""
|
||||
try:
|
||||
user_id = yield self.store.get_user_from_renewal_token(renewal_token)
|
||||
user_id = await self.store.get_user_from_renewal_token(renewal_token)
|
||||
except StoreError:
|
||||
defer.returnValue(False)
|
||||
return False
|
||||
|
||||
logger.debug("Renewing an account for user %s", user_id)
|
||||
yield self.renew_account_for_user(user_id)
|
||||
await self.renew_account_for_user(user_id)
|
||||
|
||||
defer.returnValue(True)
|
||||
return True
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False):
|
||||
async def renew_account_for_user(
|
||||
self, user_id: str, expiration_ts: int = None, email_sent: bool = False
|
||||
) -> int:
|
||||
"""Renews the account attached to a given user by pushing back the
|
||||
expiration date by the current validity period in the server's
|
||||
configuration.
|
||||
|
||||
Args:
|
||||
renewal_token (str): Token sent with the renewal request.
|
||||
expiration_ts (int): New expiration date. Defaults to now + validity period.
|
||||
email_sent (bool): Whether an email has been sent for this validity period.
|
||||
renewal_token: Token sent with the renewal request.
|
||||
expiration_ts: New expiration date. Defaults to now + validity period.
|
||||
email_sen: Whether an email has been sent for this validity period.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
defer.Deferred[int]: New expiration date for this account, as a timestamp
|
||||
in milliseconds since epoch.
|
||||
New expiration date for this account, as a timestamp in
|
||||
milliseconds since epoch.
|
||||
"""
|
||||
if expiration_ts is None:
|
||||
expiration_ts = self.clock.time_msec() + self._account_validity.period
|
||||
|
||||
yield self.store.set_account_validity_for_user(
|
||||
await self.store.set_account_validity_for_user(
|
||||
user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -264,7 +264,6 @@ class E2eKeysHandler(object):
|
|||
|
||||
return ret
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_cross_signing_keys_from_cache(self, query, from_user_id):
|
||||
"""Get cross-signing keys for users from the database
|
||||
|
||||
|
|
@ -284,35 +283,14 @@ class E2eKeysHandler(object):
|
|||
self_signing_keys = {}
|
||||
user_signing_keys = {}
|
||||
|
||||
for user_id in query:
|
||||
# XXX: consider changing the store functions to allow querying
|
||||
# multiple users simultaneously.
|
||||
key = yield self.store.get_e2e_cross_signing_key(
|
||||
user_id, "master", from_user_id
|
||||
)
|
||||
if key:
|
||||
master_keys[user_id] = key
|
||||
|
||||
key = yield self.store.get_e2e_cross_signing_key(
|
||||
user_id, "self_signing", from_user_id
|
||||
)
|
||||
if key:
|
||||
self_signing_keys[user_id] = key
|
||||
|
||||
# users can see other users' master and self-signing keys, but can
|
||||
# only see their own user-signing keys
|
||||
if from_user_id == user_id:
|
||||
key = yield self.store.get_e2e_cross_signing_key(
|
||||
user_id, "user_signing", from_user_id
|
||||
)
|
||||
if key:
|
||||
user_signing_keys[user_id] = key
|
||||
|
||||
return {
|
||||
"master_keys": master_keys,
|
||||
"self_signing_keys": self_signing_keys,
|
||||
"user_signing_keys": user_signing_keys,
|
||||
}
|
||||
# Currently a stub, implementation coming in https://github.com/matrix-org/synapse/pull/6486
|
||||
return defer.succeed(
|
||||
{
|
||||
"master_keys": master_keys,
|
||||
"self_signing_keys": self_signing_keys,
|
||||
"user_signing_keys": user_signing_keys,
|
||||
}
|
||||
)
|
||||
|
||||
@trace
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@
|
|||
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Dict, Iterable, Optional, Sequence, Tuple
|
||||
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
|
||||
|
||||
import six
|
||||
from six import iteritems, itervalues
|
||||
|
|
@ -63,8 +63,9 @@ from synapse.replication.http.federation import (
|
|||
)
|
||||
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
|
||||
from synapse.state import StateResolutionStore, resolve_events_with_store
|
||||
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
|
||||
from synapse.types import UserID, get_domain_from_id
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util import batch_iter, unwrapFirstError
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.distributor import user_joined_room
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
|
@ -164,8 +165,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False):
|
||||
async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
|
||||
""" Process a PDU received via a federation /send/ transaction, or
|
||||
via backfill of missing prev_events
|
||||
|
||||
|
|
@ -175,17 +175,15 @@ class FederationHandler(BaseHandler):
|
|||
pdu (FrozenEvent): received PDU
|
||||
sent_to_us_directly (bool): True if this event was pushed to us; False if
|
||||
we pulled it as the result of a missing prev_event.
|
||||
|
||||
Returns (Deferred): completes with None
|
||||
"""
|
||||
|
||||
room_id = pdu.room_id
|
||||
event_id = pdu.event_id
|
||||
|
||||
logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu)
|
||||
logger.info("handling received PDU: %s", pdu)
|
||||
|
||||
# We reprocess pdus when we have seen them only as outliers
|
||||
existing = yield self.store.get_event(
|
||||
existing = await self.store.get_event(
|
||||
event_id, allow_none=True, allow_rejected=True
|
||||
)
|
||||
|
||||
|
|
@ -229,7 +227,7 @@ class FederationHandler(BaseHandler):
|
|||
#
|
||||
# Note that if we were never in the room then we would have already
|
||||
# dropped the event, since we wouldn't know the room version.
|
||||
is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name)
|
||||
is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
|
||||
if not is_in_room:
|
||||
logger.info(
|
||||
"[%s %s] Ignoring PDU from %s as we're not in the room",
|
||||
|
|
@ -245,12 +243,12 @@ class FederationHandler(BaseHandler):
|
|||
# Get missing pdus if necessary.
|
||||
if not pdu.internal_metadata.is_outlier():
|
||||
# We only backfill backwards to the min depth.
|
||||
min_depth = yield self.get_min_depth_for_context(pdu.room_id)
|
||||
min_depth = await self.get_min_depth_for_context(pdu.room_id)
|
||||
|
||||
logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
|
||||
|
||||
prevs = set(pdu.prev_event_ids())
|
||||
seen = yield self.store.have_seen_events(prevs)
|
||||
seen = await self.store.have_seen_events(prevs)
|
||||
|
||||
if min_depth and pdu.depth < min_depth:
|
||||
# This is so that we don't notify the user about this
|
||||
|
|
@ -270,7 +268,7 @@ class FederationHandler(BaseHandler):
|
|||
len(missing_prevs),
|
||||
shortstr(missing_prevs),
|
||||
)
|
||||
with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
|
||||
with (await self._room_pdu_linearizer.queue(pdu.room_id)):
|
||||
logger.info(
|
||||
"[%s %s] Acquired room lock to fetch %d missing prev_events",
|
||||
room_id,
|
||||
|
|
@ -278,13 +276,19 @@ class FederationHandler(BaseHandler):
|
|||
len(missing_prevs),
|
||||
)
|
||||
|
||||
yield self._get_missing_events_for_pdu(
|
||||
origin, pdu, prevs, min_depth
|
||||
)
|
||||
try:
|
||||
await self._get_missing_events_for_pdu(
|
||||
origin, pdu, prevs, min_depth
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
"Error fetching missing prev_events for %s: %s"
|
||||
% (event_id, e)
|
||||
)
|
||||
|
||||
# Update the set of things we've seen after trying to
|
||||
# fetch the missing stuff
|
||||
seen = yield self.store.have_seen_events(prevs)
|
||||
seen = await self.store.have_seen_events(prevs)
|
||||
|
||||
if not prevs - seen:
|
||||
logger.info(
|
||||
|
|
@ -292,14 +296,6 @@ class FederationHandler(BaseHandler):
|
|||
room_id,
|
||||
event_id,
|
||||
)
|
||||
elif missing_prevs:
|
||||
logger.info(
|
||||
"[%s %s] Not recursively fetching %d missing prev_events: %s",
|
||||
room_id,
|
||||
event_id,
|
||||
len(missing_prevs),
|
||||
shortstr(missing_prevs),
|
||||
)
|
||||
|
||||
if prevs - seen:
|
||||
# We've still not been able to get all of the prev_events for this event.
|
||||
|
|
@ -344,13 +340,19 @@ class FederationHandler(BaseHandler):
|
|||
affected=pdu.event_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Event %s is missing prev_events: calculating state for a "
|
||||
"backwards extremity",
|
||||
event_id,
|
||||
)
|
||||
|
||||
# Calculate the state after each of the previous events, and
|
||||
# resolve them to find the correct state at the current event.
|
||||
auth_chains = set()
|
||||
event_map = {event_id: pdu}
|
||||
try:
|
||||
# Get the state of the events we know about
|
||||
ours = yield self.state_store.get_state_groups_ids(room_id, seen)
|
||||
ours = await self.state_store.get_state_groups_ids(room_id, seen)
|
||||
|
||||
# state_maps is a list of mappings from (type, state_key) to event_id
|
||||
state_maps = list(
|
||||
|
|
@ -364,13 +366,10 @@ class FederationHandler(BaseHandler):
|
|||
# know about
|
||||
for p in prevs - seen:
|
||||
logger.info(
|
||||
"[%s %s] Requesting state at missing prev_event %s",
|
||||
room_id,
|
||||
event_id,
|
||||
p,
|
||||
"Requesting state at missing prev_event %s", event_id,
|
||||
)
|
||||
|
||||
room_version = yield self.store.get_room_version(room_id)
|
||||
room_version = await self.store.get_room_version(room_id)
|
||||
|
||||
with nested_logging_context(p):
|
||||
# note that if any of the missing prevs share missing state or
|
||||
|
|
@ -379,24 +378,10 @@ class FederationHandler(BaseHandler):
|
|||
(
|
||||
remote_state,
|
||||
got_auth_chain,
|
||||
) = yield self.federation_client.get_state_for_room(
|
||||
origin, room_id, p
|
||||
) = await self._get_state_for_room(
|
||||
origin, room_id, p, include_event_in_state=True
|
||||
)
|
||||
|
||||
# we want the state *after* p; get_state_for_room returns the
|
||||
# state *before* p.
|
||||
remote_event = yield self.federation_client.get_pdu(
|
||||
[origin], p, room_version, outlier=True
|
||||
)
|
||||
|
||||
if remote_event is None:
|
||||
raise Exception(
|
||||
"Unable to get missing prev_event %s" % (p,)
|
||||
)
|
||||
|
||||
if remote_event.is_state():
|
||||
remote_state.append(remote_event)
|
||||
|
||||
# XXX hrm I'm not convinced that duplicate events will compare
|
||||
# for equality, so I'm not sure this does what the author
|
||||
# hoped.
|
||||
|
|
@ -410,7 +395,7 @@ class FederationHandler(BaseHandler):
|
|||
for x in remote_state:
|
||||
event_map[x.event_id] = x
|
||||
|
||||
state_map = yield resolve_events_with_store(
|
||||
state_map = await resolve_events_with_store(
|
||||
room_version,
|
||||
state_maps,
|
||||
event_map,
|
||||
|
|
@ -422,10 +407,10 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
# First though we need to fetch all the events that are in
|
||||
# state_map, so we can build up the state below.
|
||||
evs = yield self.store.get_events(
|
||||
evs = await self.store.get_events(
|
||||
list(state_map.values()),
|
||||
get_prev_content=False,
|
||||
check_redacted=False,
|
||||
redact_behaviour=EventRedactBehaviour.AS_IS,
|
||||
)
|
||||
event_map.update(evs)
|
||||
|
||||
|
|
@ -446,12 +431,11 @@ class FederationHandler(BaseHandler):
|
|||
affected=event_id,
|
||||
)
|
||||
|
||||
yield self._process_received_pdu(
|
||||
await self._process_received_pdu(
|
||||
origin, pdu, state=state, auth_chain=auth_chain
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
|
||||
async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
|
||||
"""
|
||||
Args:
|
||||
origin (str): Origin of the pdu. Will be called to get the missing events
|
||||
|
|
@ -463,12 +447,12 @@ class FederationHandler(BaseHandler):
|
|||
room_id = pdu.room_id
|
||||
event_id = pdu.event_id
|
||||
|
||||
seen = yield self.store.have_seen_events(prevs)
|
||||
seen = await self.store.have_seen_events(prevs)
|
||||
|
||||
if not prevs - seen:
|
||||
return
|
||||
|
||||
latest = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||
latest = await self.store.get_latest_event_ids_in_room(room_id)
|
||||
|
||||
# We add the prev events that we have seen to the latest
|
||||
# list to ensure the remote server doesn't give them to us
|
||||
|
|
@ -532,7 +516,7 @@ class FederationHandler(BaseHandler):
|
|||
# All that said: Let's try increasing the timout to 60s and see what happens.
|
||||
|
||||
try:
|
||||
missing_events = yield self.federation_client.get_missing_events(
|
||||
missing_events = await self.federation_client.get_missing_events(
|
||||
origin,
|
||||
room_id,
|
||||
earliest_events_ids=list(latest),
|
||||
|
|
@ -571,7 +555,7 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
with nested_logging_context(ev.event_id):
|
||||
try:
|
||||
yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
|
||||
await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
|
||||
except FederationError as e:
|
||||
if e.code == 403:
|
||||
logger.warning(
|
||||
|
|
@ -583,8 +567,116 @@ class FederationHandler(BaseHandler):
|
|||
else:
|
||||
raise
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _process_received_pdu(self, origin, event, state, auth_chain):
|
||||
async def _get_state_for_room(
|
||||
self,
|
||||
destination: str,
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
include_event_in_state: bool = False,
|
||||
) -> Tuple[List[EventBase], List[EventBase]]:
|
||||
"""Requests all of the room state at a given event from a remote homeserver.
|
||||
|
||||
Args:
|
||||
destination: The remote homeserver to query for the state.
|
||||
room_id: The id of the room we're interested in.
|
||||
event_id: The id of the event we want the state at.
|
||||
include_event_in_state: if true, the event itself will be included in the
|
||||
returned state event list.
|
||||
|
||||
Returns:
|
||||
A list of events in the state, possibly including the event itself, and
|
||||
a list of events in the auth chain for the given event.
|
||||
"""
|
||||
(
|
||||
state_event_ids,
|
||||
auth_event_ids,
|
||||
) = await self.federation_client.get_room_state_ids(
|
||||
destination, room_id, event_id=event_id
|
||||
)
|
||||
|
||||
desired_events = set(state_event_ids + auth_event_ids)
|
||||
|
||||
if include_event_in_state:
|
||||
desired_events.add(event_id)
|
||||
|
||||
event_map = await self._get_events_from_store_or_dest(
|
||||
destination, room_id, desired_events
|
||||
)
|
||||
|
||||
failed_to_fetch = desired_events - event_map.keys()
|
||||
if failed_to_fetch:
|
||||
logger.warning(
|
||||
"Failed to fetch missing state/auth events for %s %s",
|
||||
event_id,
|
||||
failed_to_fetch,
|
||||
)
|
||||
|
||||
remote_state = [
|
||||
event_map[e_id] for e_id in state_event_ids if e_id in event_map
|
||||
]
|
||||
|
||||
if include_event_in_state:
|
||||
remote_event = event_map.get(event_id)
|
||||
if not remote_event:
|
||||
raise Exception("Unable to get missing prev_event %s" % (event_id,))
|
||||
if remote_event.is_state():
|
||||
remote_state.append(remote_event)
|
||||
|
||||
auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
|
||||
auth_chain.sort(key=lambda e: e.depth)
|
||||
|
||||
return remote_state, auth_chain
|
||||
|
||||
async def _get_events_from_store_or_dest(
|
||||
self, destination: str, room_id: str, event_ids: Iterable[str]
|
||||
) -> Dict[str, EventBase]:
|
||||
"""Fetch events from a remote destination, checking if we already have them.
|
||||
|
||||
Args:
|
||||
destination
|
||||
room_id
|
||||
event_ids
|
||||
|
||||
Returns:
|
||||
map from event_id to event
|
||||
"""
|
||||
fetched_events = await self.store.get_events(event_ids, allow_rejected=True)
|
||||
|
||||
missing_events = set(event_ids) - fetched_events.keys()
|
||||
|
||||
if not missing_events:
|
||||
return fetched_events
|
||||
|
||||
logger.debug(
|
||||
"Fetching unknown state/auth events %s for room %s",
|
||||
missing_events,
|
||||
event_ids,
|
||||
)
|
||||
|
||||
room_version = await self.store.get_room_version(room_id)
|
||||
|
||||
# XXX 20 requests at once? really?
|
||||
for batch in batch_iter(missing_events, 20):
|
||||
deferreds = [
|
||||
run_in_background(
|
||||
self.federation_client.get_pdu,
|
||||
destinations=[destination],
|
||||
event_id=e_id,
|
||||
room_version=room_version,
|
||||
)
|
||||
for e_id in batch
|
||||
]
|
||||
|
||||
res = await make_deferred_yieldable(
|
||||
defer.DeferredList(deferreds, consumeErrors=True)
|
||||
)
|
||||
for success, result in res:
|
||||
if success and result:
|
||||
fetched_events[result.event_id] = result
|
||||
|
||||
return fetched_events
|
||||
|
||||
async def _process_received_pdu(self, origin, event, state, auth_chain):
|
||||
""" Called when we have a new pdu. We need to do auth checks and put it
|
||||
through the StateHandler.
|
||||
"""
|
||||
|
|
@ -599,7 +691,7 @@ class FederationHandler(BaseHandler):
|
|||
if auth_chain:
|
||||
event_ids |= {e.event_id for e in auth_chain}
|
||||
|
||||
seen_ids = yield self.store.have_seen_events(event_ids)
|
||||
seen_ids = await self.store.have_seen_events(event_ids)
|
||||
|
||||
if state and auth_chain is not None:
|
||||
# If we have any state or auth_chain given to us by the replication
|
||||
|
|
@ -626,18 +718,18 @@ class FederationHandler(BaseHandler):
|
|||
event_id,
|
||||
[e.event.event_id for e in event_infos],
|
||||
)
|
||||
yield self._handle_new_events(origin, event_infos)
|
||||
await self._handle_new_events(origin, event_infos)
|
||||
|
||||
try:
|
||||
context = yield self._handle_new_event(origin, event, state=state)
|
||||
context = await self._handle_new_event(origin, event, state=state)
|
||||
except AuthError as e:
|
||||
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
|
||||
|
||||
room = yield self.store.get_room(room_id)
|
||||
room = await self.store.get_room(room_id)
|
||||
|
||||
if not room:
|
||||
try:
|
||||
yield self.store.store_room(
|
||||
await self.store.store_room(
|
||||
room_id=room_id, room_creator_user_id="", is_public=False
|
||||
)
|
||||
except StoreError:
|
||||
|
|
@ -650,11 +742,11 @@ class FederationHandler(BaseHandler):
|
|||
# changing their profile info.
|
||||
newly_joined = True
|
||||
|
||||
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
||||
prev_state_ids = await context.get_prev_state_ids(self.store)
|
||||
|
||||
prev_state_id = prev_state_ids.get((event.type, event.state_key))
|
||||
if prev_state_id:
|
||||
prev_state = yield self.store.get_event(
|
||||
prev_state = await self.store.get_event(
|
||||
prev_state_id, allow_none=True
|
||||
)
|
||||
if prev_state and prev_state.membership == Membership.JOIN:
|
||||
|
|
@ -662,11 +754,10 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
if newly_joined:
|
||||
user = UserID.from_string(event.state_key)
|
||||
yield self.user_joined_room(user, room_id)
|
||||
await self.user_joined_room(user, room_id)
|
||||
|
||||
@log_function
|
||||
@defer.inlineCallbacks
|
||||
def backfill(self, dest, room_id, limit, extremities):
|
||||
async def backfill(self, dest, room_id, limit, extremities):
|
||||
""" Trigger a backfill request to `dest` for the given `room_id`
|
||||
|
||||
This will attempt to get more events from the remote. If the other side
|
||||
|
|
@ -683,9 +774,9 @@ class FederationHandler(BaseHandler):
|
|||
if dest == self.server_name:
|
||||
raise SynapseError(400, "Can't backfill from self.")
|
||||
|
||||
room_version = yield self.store.get_room_version(room_id)
|
||||
room_version = await self.store.get_room_version(room_id)
|
||||
|
||||
events = yield self.federation_client.backfill(
|
||||
events = await self.federation_client.backfill(
|
||||
dest, room_id, limit=limit, extremities=extremities
|
||||
)
|
||||
|
||||
|
|
@ -700,7 +791,7 @@ class FederationHandler(BaseHandler):
|
|||
# self._sanity_check_event(ev)
|
||||
|
||||
# Don't bother processing events we already have.
|
||||
seen_events = yield self.store.have_events_in_timeline(
|
||||
seen_events = await self.store.have_events_in_timeline(
|
||||
set(e.event_id for e in events)
|
||||
)
|
||||
|
||||
|
|
@ -723,7 +814,7 @@ class FederationHandler(BaseHandler):
|
|||
state_events = {}
|
||||
events_to_state = {}
|
||||
for e_id in edges:
|
||||
state, auth = yield self.federation_client.get_state_for_room(
|
||||
state, auth = await self._get_state_for_room(
|
||||
destination=dest, room_id=room_id, event_id=e_id
|
||||
)
|
||||
auth_events.update({a.event_id: a for a in auth})
|
||||
|
|
@ -748,7 +839,7 @@ class FederationHandler(BaseHandler):
|
|||
# We repeatedly do this until we stop finding new auth events.
|
||||
while missing_auth - failed_to_fetch:
|
||||
logger.info("Missing auth for backfill: %r", missing_auth)
|
||||
ret_events = yield self.store.get_events(missing_auth - failed_to_fetch)
|
||||
ret_events = await self.store.get_events(missing_auth - failed_to_fetch)
|
||||
auth_events.update(ret_events)
|
||||
|
||||
required_auth.update(
|
||||
|
|
@ -762,7 +853,7 @@ class FederationHandler(BaseHandler):
|
|||
missing_auth - failed_to_fetch,
|
||||
)
|
||||
|
||||
results = yield make_deferred_yieldable(
|
||||
results = await make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
[
|
||||
run_in_background(
|
||||
|
|
@ -789,7 +880,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
failed_to_fetch = missing_auth - set(auth_events)
|
||||
|
||||
seen_events = yield self.store.have_seen_events(
|
||||
seen_events = await self.store.have_seen_events(
|
||||
set(auth_events.keys()) | set(state_events.keys())
|
||||
)
|
||||
|
||||
|
|
@ -851,7 +942,7 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
)
|
||||
|
||||
yield self._handle_new_events(dest, ev_infos, backfilled=True)
|
||||
await self._handle_new_events(dest, ev_infos, backfilled=True)
|
||||
|
||||
# Step 2: Persist the rest of the events in the chunk one by one
|
||||
events.sort(key=lambda e: e.depth)
|
||||
|
|
@ -867,16 +958,15 @@ class FederationHandler(BaseHandler):
|
|||
# We store these one at a time since each event depends on the
|
||||
# previous to work out the state.
|
||||
# TODO: We can probably do something more clever here.
|
||||
yield self._handle_new_event(dest, event, backfilled=True)
|
||||
await self._handle_new_event(dest, event, backfilled=True)
|
||||
|
||||
return events
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def maybe_backfill(self, room_id, current_depth):
|
||||
async def maybe_backfill(self, room_id, current_depth):
|
||||
"""Checks the database to see if we should backfill before paginating,
|
||||
and if so do.
|
||||
"""
|
||||
extremities = yield self.store.get_oldest_events_with_depth_in_room(room_id)
|
||||
extremities = await self.store.get_oldest_events_with_depth_in_room(room_id)
|
||||
|
||||
if not extremities:
|
||||
logger.debug("Not backfilling as no extremeties found.")
|
||||
|
|
@ -908,15 +998,17 @@ class FederationHandler(BaseHandler):
|
|||
# state *before* the event, ignoring the special casing certain event
|
||||
# types have.
|
||||
|
||||
forward_events = yield self.store.get_successor_events(list(extremities))
|
||||
forward_events = await self.store.get_successor_events(list(extremities))
|
||||
|
||||
extremities_events = yield self.store.get_events(
|
||||
forward_events, check_redacted=False, get_prev_content=False
|
||||
extremities_events = await self.store.get_events(
|
||||
forward_events,
|
||||
redact_behaviour=EventRedactBehaviour.AS_IS,
|
||||
get_prev_content=False,
|
||||
)
|
||||
|
||||
# We set `check_history_visibility_only` as we might otherwise get false
|
||||
# positives from users having been erased.
|
||||
filtered_extremities = yield filter_events_for_server(
|
||||
filtered_extremities = await filter_events_for_server(
|
||||
self.storage,
|
||||
self.server_name,
|
||||
list(extremities_events.values()),
|
||||
|
|
@ -946,7 +1038,7 @@ class FederationHandler(BaseHandler):
|
|||
# First we try hosts that are already in the room
|
||||
# TODO: HEURISTIC ALERT.
|
||||
|
||||
curr_state = yield self.state_handler.get_current_state(room_id)
|
||||
curr_state = await self.state_handler.get_current_state(room_id)
|
||||
|
||||
def get_domains_from_state(state):
|
||||
"""Get joined domains from state
|
||||
|
|
@ -985,12 +1077,11 @@ class FederationHandler(BaseHandler):
|
|||
domain for domain, depth in curr_domains if domain != self.server_name
|
||||
]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def try_backfill(domains):
|
||||
async def try_backfill(domains):
|
||||
# TODO: Should we try multiple of these at a time?
|
||||
for dom in domains:
|
||||
try:
|
||||
yield self.backfill(
|
||||
await self.backfill(
|
||||
dom, room_id, limit=100, extremities=extremities
|
||||
)
|
||||
# If this succeeded then we probably already have the
|
||||
|
|
@ -1021,7 +1112,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
return False
|
||||
|
||||
success = yield try_backfill(likely_domains)
|
||||
success = await try_backfill(likely_domains)
|
||||
if success:
|
||||
return True
|
||||
|
||||
|
|
@ -1035,7 +1126,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
logger.debug("calling resolve_state_groups in _maybe_backfill")
|
||||
resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events)
|
||||
states = yield make_deferred_yieldable(
|
||||
states = await make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
[resolve(room_id, [e]) for e in event_ids], consumeErrors=True
|
||||
)
|
||||
|
|
@ -1045,7 +1136,7 @@ class FederationHandler(BaseHandler):
|
|||
# event_ids.
|
||||
states = dict(zip(event_ids, [s.state for s in states]))
|
||||
|
||||
state_map = yield self.store.get_events(
|
||||
state_map = await self.store.get_events(
|
||||
[e_id for ids in itervalues(states) for e_id in itervalues(ids)],
|
||||
get_prev_content=False,
|
||||
)
|
||||
|
|
@ -1061,7 +1152,7 @@ class FederationHandler(BaseHandler):
|
|||
for e_id, _ in sorted_extremeties_tuple:
|
||||
likely_domains = get_domains_from_state(states[e_id])
|
||||
|
||||
success = yield try_backfill(
|
||||
success = await try_backfill(
|
||||
[dom for dom, _ in likely_domains if dom not in tried_domains]
|
||||
)
|
||||
if success:
|
||||
|
|
@ -1210,7 +1301,7 @@ class FederationHandler(BaseHandler):
|
|||
# Check whether this room is the result of an upgrade of a room we already know
|
||||
# about. If so, migrate over user information
|
||||
predecessor = yield self.store.get_room_predecessor(room_id)
|
||||
if not predecessor:
|
||||
if not predecessor or not isinstance(predecessor.get("room_id"), str):
|
||||
return
|
||||
old_room_id = predecessor["room_id"]
|
||||
logger.debug(
|
||||
|
|
@ -1238,8 +1329,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
return True
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _handle_queued_pdus(self, room_queue):
|
||||
async def _handle_queued_pdus(self, room_queue):
|
||||
"""Process PDUs which got queued up while we were busy send_joining.
|
||||
|
||||
Args:
|
||||
|
|
@ -1255,7 +1345,7 @@ class FederationHandler(BaseHandler):
|
|||
p.room_id,
|
||||
)
|
||||
with nested_logging_context(p.event_id):
|
||||
yield self.on_receive_pdu(origin, p, sent_to_us_directly=True)
|
||||
await self.on_receive_pdu(origin, p, sent_to_us_directly=True)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Error handling queued PDU %s from %s: %s", p.event_id, origin, e
|
||||
|
|
@ -1453,7 +1543,7 @@ class FederationHandler(BaseHandler):
|
|||
@defer.inlineCallbacks
|
||||
def do_remotely_reject_invite(self, target_hosts, room_id, user_id, content):
|
||||
origin, event, event_format_version = yield self._make_and_verify_event(
|
||||
target_hosts, room_id, user_id, "leave", content=content,
|
||||
target_hosts, room_id, user_id, "leave", content=content
|
||||
)
|
||||
# Mark as outlier as we don't have any state for this event; we're not
|
||||
# even in the room.
|
||||
|
|
@ -2814,7 +2904,7 @@ class FederationHandler(BaseHandler):
|
|||
room_id=room_id, user_id=user.to_string(), change="joined"
|
||||
)
|
||||
else:
|
||||
return user_joined_room(self.distributor, user, room_id)
|
||||
return defer.succeed(user_joined_room(self.distributor, user, room_id))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_complexity(self, remote_room_hosts, room_id):
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ from synapse.events.validator import EventValidator
|
|||
from synapse.logging.context import run_in_background
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
|
||||
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import RoomAlias, UserID, create_requester
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
|
|
@ -875,7 +876,7 @@ class EventCreationHandler(object):
|
|||
if event.type == EventTypes.Redaction:
|
||||
original_event = yield self.store.get_event(
|
||||
event.redacts,
|
||||
check_redacted=False,
|
||||
redact_behaviour=EventRedactBehaviour.AS_IS,
|
||||
get_prev_content=False,
|
||||
allow_rejected=False,
|
||||
allow_none=True,
|
||||
|
|
@ -952,7 +953,7 @@ class EventCreationHandler(object):
|
|||
if event.type == EventTypes.Redaction:
|
||||
original_event = yield self.store.get_event(
|
||||
event.redacts,
|
||||
check_redacted=False,
|
||||
redact_behaviour=EventRedactBehaviour.AS_IS,
|
||||
get_prev_content=False,
|
||||
allow_rejected=False,
|
||||
allow_none=True,
|
||||
|
|
|
|||
|
|
@ -280,8 +280,7 @@ class PaginationHandler(object):
|
|||
|
||||
await self.storage.purge_events.purge_room(room_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_messages(
|
||||
async def get_messages(
|
||||
self,
|
||||
requester,
|
||||
room_id=None,
|
||||
|
|
@ -307,7 +306,7 @@ class PaginationHandler(object):
|
|||
room_token = pagin_config.from_token.room_key
|
||||
else:
|
||||
pagin_config.from_token = (
|
||||
yield self.hs.get_event_sources().get_current_token_for_pagination()
|
||||
await self.hs.get_event_sources().get_current_token_for_pagination()
|
||||
)
|
||||
room_token = pagin_config.from_token.room_key
|
||||
|
||||
|
|
@ -319,11 +318,11 @@ class PaginationHandler(object):
|
|||
|
||||
source_config = pagin_config.get_source_config("room")
|
||||
|
||||
with (yield self.pagination_lock.read(room_id)):
|
||||
with (await self.pagination_lock.read(room_id)):
|
||||
(
|
||||
membership,
|
||||
member_event_id,
|
||||
) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
|
||||
) = await self.auth.check_in_room_or_world_readable(room_id, user_id)
|
||||
|
||||
if source_config.direction == "b":
|
||||
# if we're going backwards, we might need to backfill. This
|
||||
|
|
@ -331,7 +330,7 @@ class PaginationHandler(object):
|
|||
if room_token.topological:
|
||||
max_topo = room_token.topological
|
||||
else:
|
||||
max_topo = yield self.store.get_max_topological_token(
|
||||
max_topo = await self.store.get_max_topological_token(
|
||||
room_id, room_token.stream
|
||||
)
|
||||
|
||||
|
|
@ -339,18 +338,18 @@ class PaginationHandler(object):
|
|||
# If they have left the room then clamp the token to be before
|
||||
# they left the room, to save the effort of loading from the
|
||||
# database.
|
||||
leave_token = yield self.store.get_topological_token_for_event(
|
||||
leave_token = await self.store.get_topological_token_for_event(
|
||||
member_event_id
|
||||
)
|
||||
leave_token = RoomStreamToken.parse(leave_token)
|
||||
if leave_token.topological < max_topo:
|
||||
source_config.from_key = str(leave_token)
|
||||
|
||||
yield self.hs.get_handlers().federation_handler.maybe_backfill(
|
||||
await self.hs.get_handlers().federation_handler.maybe_backfill(
|
||||
room_id, max_topo
|
||||
)
|
||||
|
||||
events, next_key = yield self.store.paginate_room_events(
|
||||
events, next_key = await self.store.paginate_room_events(
|
||||
room_id=room_id,
|
||||
from_key=source_config.from_key,
|
||||
to_key=source_config.to_key,
|
||||
|
|
@ -365,7 +364,7 @@ class PaginationHandler(object):
|
|||
if event_filter:
|
||||
events = event_filter.filter(events)
|
||||
|
||||
events = yield filter_events_for_client(
|
||||
events = await filter_events_for_client(
|
||||
self.storage, user_id, events, is_peeking=(member_event_id is None)
|
||||
)
|
||||
|
||||
|
|
@ -385,19 +384,19 @@ class PaginationHandler(object):
|
|||
(EventTypes.Member, event.sender) for event in events
|
||||
)
|
||||
|
||||
state_ids = yield self.state_store.get_state_ids_for_event(
|
||||
state_ids = await self.state_store.get_state_ids_for_event(
|
||||
events[0].event_id, state_filter=state_filter
|
||||
)
|
||||
|
||||
if state_ids:
|
||||
state = yield self.store.get_events(list(state_ids.values()))
|
||||
state = await self.store.get_events(list(state_ids.values()))
|
||||
state = state.values()
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
chunk = {
|
||||
"chunk": (
|
||||
yield self._event_serializer.serialize_events(
|
||||
await self._event_serializer.serialize_events(
|
||||
events, time_now, as_client_event=as_client_event
|
||||
)
|
||||
),
|
||||
|
|
@ -406,7 +405,7 @@ class PaginationHandler(object):
|
|||
}
|
||||
|
||||
if state:
|
||||
chunk["state"] = yield self._event_serializer.serialize_events(
|
||||
chunk["state"] = await self._event_serializer.serialize_events(
|
||||
state, time_now, as_client_event=as_client_event
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -13,20 +13,36 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import re
|
||||
from typing import Tuple
|
||||
|
||||
import attr
|
||||
import saml2
|
||||
import saml2.response
|
||||
from saml2.client import Saml2Client
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.config import ConfigError
|
||||
from synapse.http.servlet import parse_string
|
||||
from synapse.rest.client.v1.login import SSOAuthHandler
|
||||
from synapse.types import UserID, map_username_to_mxid_localpart
|
||||
from synapse.types import (
|
||||
UserID,
|
||||
map_username_to_mxid_localpart,
|
||||
mxid_localpart_allowed_characters,
|
||||
)
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@attr.s
|
||||
class Saml2SessionData:
|
||||
"""Data we track about SAML2 sessions"""
|
||||
|
||||
# time the session was created, in milliseconds
|
||||
creation_time = attr.ib()
|
||||
|
||||
|
||||
class SamlHandler:
|
||||
def __init__(self, hs):
|
||||
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
|
||||
|
|
@ -37,11 +53,14 @@ class SamlHandler:
|
|||
self._datastore = hs.get_datastore()
|
||||
self._hostname = hs.hostname
|
||||
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
|
||||
self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute
|
||||
self._grandfathered_mxid_source_attribute = (
|
||||
hs.config.saml2_grandfathered_mxid_source_attribute
|
||||
)
|
||||
self._mxid_mapper = hs.config.saml2_mxid_mapper
|
||||
|
||||
# plugin to do custom mapping from saml response to mxid
|
||||
self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
|
||||
hs.config.saml2_user_mapping_provider_config
|
||||
)
|
||||
|
||||
# identifier for the external_ids table
|
||||
self._auth_provider_id = "saml"
|
||||
|
|
@ -118,22 +137,10 @@ class SamlHandler:
|
|||
remote_user_id = saml2_auth.ava["uid"][0]
|
||||
except KeyError:
|
||||
logger.warning("SAML2 response lacks a 'uid' attestation")
|
||||
raise SynapseError(400, "uid not in SAML2 response")
|
||||
|
||||
try:
|
||||
mxid_source = saml2_auth.ava[self._mxid_source_attribute][0]
|
||||
except KeyError:
|
||||
logger.warning(
|
||||
"SAML2 response lacks a '%s' attestation", self._mxid_source_attribute
|
||||
)
|
||||
raise SynapseError(
|
||||
400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
|
||||
)
|
||||
raise SynapseError(400, "'uid' not in SAML2 response")
|
||||
|
||||
self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
|
||||
|
||||
displayName = saml2_auth.ava.get("displayName", [None])[0]
|
||||
|
||||
with (await self._mapping_lock.queue(self._auth_provider_id)):
|
||||
# first of all, check if we already have a mapping for this user
|
||||
logger.info(
|
||||
|
|
@ -173,22 +180,46 @@ class SamlHandler:
|
|||
)
|
||||
return registered_user_id
|
||||
|
||||
# figure out a new mxid for this user
|
||||
base_mxid_localpart = self._mxid_mapper(mxid_source)
|
||||
# Map saml response to user attributes using the configured mapping provider
|
||||
for i in range(1000):
|
||||
attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
|
||||
saml2_auth, i
|
||||
)
|
||||
|
||||
suffix = 0
|
||||
while True:
|
||||
localpart = base_mxid_localpart + (str(suffix) if suffix else "")
|
||||
logger.debug(
|
||||
"Retrieved SAML attributes from user mapping provider: %s "
|
||||
"(attempt %d)",
|
||||
attribute_dict,
|
||||
i,
|
||||
)
|
||||
|
||||
localpart = attribute_dict.get("mxid_localpart")
|
||||
if not localpart:
|
||||
logger.error(
|
||||
"SAML mapping provider plugin did not return a "
|
||||
"mxid_localpart object"
|
||||
)
|
||||
raise SynapseError(500, "Error parsing SAML2 response")
|
||||
|
||||
displayname = attribute_dict.get("displayname")
|
||||
|
||||
# Check if this mxid already exists
|
||||
if not await self._datastore.get_users_by_id_case_insensitive(
|
||||
UserID(localpart, self._hostname).to_string()
|
||||
):
|
||||
# This mxid is free
|
||||
break
|
||||
suffix += 1
|
||||
logger.info("Allocating mxid for new user with localpart %s", localpart)
|
||||
else:
|
||||
# Unable to generate a username in 1000 iterations
|
||||
# Break and return error to the user
|
||||
raise SynapseError(
|
||||
500, "Unable to generate a Matrix ID from the SAML response"
|
||||
)
|
||||
|
||||
registered_user_id = await self._registration_handler.register_user(
|
||||
localpart=localpart, default_display_name=displayName
|
||||
localpart=localpart, default_display_name=displayname
|
||||
)
|
||||
|
||||
await self._datastore.record_user_external_id(
|
||||
self._auth_provider_id, remote_user_id, registered_user_id
|
||||
)
|
||||
|
|
@ -205,9 +236,120 @@ class SamlHandler:
|
|||
del self._outstanding_requests_dict[reqid]
|
||||
|
||||
|
||||
@attr.s
|
||||
class Saml2SessionData:
|
||||
"""Data we track about SAML2 sessions"""
|
||||
DOT_REPLACE_PATTERN = re.compile(
|
||||
("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
|
||||
)
|
||||
|
||||
# time the session was created, in milliseconds
|
||||
creation_time = attr.ib()
|
||||
|
||||
def dot_replace_for_mxid(username: str) -> str:
|
||||
username = username.lower()
|
||||
username = DOT_REPLACE_PATTERN.sub(".", username)
|
||||
|
||||
# regular mxids aren't allowed to start with an underscore either
|
||||
username = re.sub("^_", "", username)
|
||||
return username
|
||||
|
||||
|
||||
MXID_MAPPER_MAP = {
|
||||
"hexencode": map_username_to_mxid_localpart,
|
||||
"dotreplace": dot_replace_for_mxid,
|
||||
}
|
||||
|
||||
|
||||
@attr.s
|
||||
class SamlConfig(object):
|
||||
mxid_source_attribute = attr.ib()
|
||||
mxid_mapper = attr.ib()
|
||||
|
||||
|
||||
class DefaultSamlMappingProvider(object):
|
||||
__version__ = "0.0.1"
|
||||
|
||||
def __init__(self, parsed_config: SamlConfig):
|
||||
"""The default SAML user mapping provider
|
||||
|
||||
Args:
|
||||
parsed_config: Module configuration
|
||||
"""
|
||||
self._mxid_source_attribute = parsed_config.mxid_source_attribute
|
||||
self._mxid_mapper = parsed_config.mxid_mapper
|
||||
|
||||
def saml_response_to_user_attributes(
|
||||
self, saml_response: saml2.response.AuthnResponse, failures: int = 0,
|
||||
) -> dict:
|
||||
"""Maps some text from a SAML response to attributes of a new user
|
||||
|
||||
Args:
|
||||
saml_response: A SAML auth response object
|
||||
|
||||
failures: How many times a call to this function with this
|
||||
saml_response has resulted in a failure
|
||||
|
||||
Returns:
|
||||
dict: A dict containing new user attributes. Possible keys:
|
||||
* mxid_localpart (str): Required. The localpart of the user's mxid
|
||||
* displayname (str): The displayname of the user
|
||||
"""
|
||||
try:
|
||||
mxid_source = saml_response.ava[self._mxid_source_attribute][0]
|
||||
except KeyError:
|
||||
logger.warning(
|
||||
"SAML2 response lacks a '%s' attestation", self._mxid_source_attribute,
|
||||
)
|
||||
raise SynapseError(
|
||||
400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
|
||||
)
|
||||
|
||||
# Use the configured mapper for this mxid_source
|
||||
base_mxid_localpart = self._mxid_mapper(mxid_source)
|
||||
|
||||
# Append suffix integer if last call to this function failed to produce
|
||||
# a usable mxid
|
||||
localpart = base_mxid_localpart + (str(failures) if failures else "")
|
||||
|
||||
# Retrieve the display name from the saml response
|
||||
# If displayname is None, the mxid_localpart will be used instead
|
||||
displayname = saml_response.ava.get("displayName", [None])[0]
|
||||
|
||||
return {
|
||||
"mxid_localpart": localpart,
|
||||
"displayname": displayname,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def parse_config(config: dict) -> SamlConfig:
|
||||
"""Parse the dict provided by the homeserver's config
|
||||
Args:
|
||||
config: A dictionary containing configuration options for this provider
|
||||
Returns:
|
||||
SamlConfig: A custom config object for this module
|
||||
"""
|
||||
# Parse config options and use defaults where necessary
|
||||
mxid_source_attribute = config.get("mxid_source_attribute", "uid")
|
||||
mapping_type = config.get("mxid_mapping", "hexencode")
|
||||
|
||||
# Retrieve the associating mapping function
|
||||
try:
|
||||
mxid_mapper = MXID_MAPPER_MAP[mapping_type]
|
||||
except KeyError:
|
||||
raise ConfigError(
|
||||
"saml2_config.user_mapping_provider.config: '%s' is not a valid "
|
||||
"mxid_mapping value" % (mapping_type,)
|
||||
)
|
||||
|
||||
return SamlConfig(mxid_source_attribute, mxid_mapper)
|
||||
|
||||
@staticmethod
|
||||
def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]:
|
||||
"""Returns the required attributes of a SAML
|
||||
|
||||
Args:
|
||||
config: A SamlConfig object containing configuration params for this provider
|
||||
|
||||
Returns:
|
||||
tuple[set,set]: The first set equates to the saml auth response
|
||||
attributes that are required for the module to function, whereas the
|
||||
second set consists of those attributes which can be used if
|
||||
available, but are not necessary
|
||||
"""
|
||||
return {"uid", config.mxid_source_attribute}, {"displayName"}
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from unpaddedbase64 import decode_base64, encode_base64
|
|||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.errors import NotFoundError, SynapseError
|
||||
from synapse.api.filtering import Filter
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
|
@ -37,6 +37,7 @@ class SearchHandler(BaseHandler):
|
|||
self._event_serializer = hs.get_event_client_serializer()
|
||||
self.storage = hs.get_storage()
|
||||
self.state_store = self.storage.state
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_old_rooms_from_upgraded_room(self, room_id):
|
||||
|
|
@ -53,23 +54,38 @@ class SearchHandler(BaseHandler):
|
|||
room_id (str): id of the room to search through.
|
||||
|
||||
Returns:
|
||||
Deferred[iterable[unicode]]: predecessor room ids
|
||||
Deferred[iterable[str]]: predecessor room ids
|
||||
"""
|
||||
|
||||
historical_room_ids = []
|
||||
|
||||
while True:
|
||||
predecessor = yield self.store.get_room_predecessor(room_id)
|
||||
# The initial room must have been known for us to get this far
|
||||
predecessor = yield self.store.get_room_predecessor(room_id)
|
||||
|
||||
# If no predecessor, assume we've hit a dead end
|
||||
while True:
|
||||
if not predecessor:
|
||||
# We have reached the end of the chain of predecessors
|
||||
break
|
||||
|
||||
# Add predecessor's room ID
|
||||
historical_room_ids.append(predecessor["room_id"])
|
||||
if not isinstance(predecessor.get("room_id"), str):
|
||||
# This predecessor object is malformed. Exit here
|
||||
break
|
||||
|
||||
# Scan through the old room for further predecessors
|
||||
room_id = predecessor["room_id"]
|
||||
predecessor_room_id = predecessor["room_id"]
|
||||
|
||||
# Don't add it to the list until we have checked that we are in the room
|
||||
try:
|
||||
next_predecessor_room = yield self.store.get_room_predecessor(
|
||||
predecessor_room_id
|
||||
)
|
||||
except NotFoundError:
|
||||
# The predecessor is not a known room, so we are done here
|
||||
break
|
||||
|
||||
historical_room_ids.append(predecessor_room_id)
|
||||
|
||||
# And repeat
|
||||
predecessor = next_predecessor_room
|
||||
|
||||
return historical_room_ids
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue