Merge remote-tracking branch 'upstream/release-v1.27.0'

This commit is contained in:
Tulir Asokan 2021-02-02 16:06:33 +02:00
commit 7f7fb9b566
146 changed files with 4893 additions and 1484 deletions

View file

@ -14,6 +14,7 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING
import twisted
import twisted.internet.error
@ -22,6 +23,9 @@ from twisted.web.resource import Resource
from synapse.app import check_bind_error
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
ACME_REGISTER_FAIL_ERROR = """
@ -35,12 +39,12 @@ solutions, please read https://github.com/matrix-org/synapse/blob/master/docs/AC
class AcmeHandler:
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.reactor = hs.get_reactor()
self._acme_domain = hs.config.acme_domain
async def start_listening(self):
async def start_listening(self) -> None:
from synapse.handlers import acme_issuing_service
# Configure logging for txacme, if you need to debug
@ -85,7 +89,7 @@ class AcmeHandler:
logger.error(ACME_REGISTER_FAIL_ERROR)
raise
async def provision_certificate(self):
async def provision_certificate(self) -> None:
logger.warning("Reprovisioning %s", self._acme_domain)
@ -110,5 +114,3 @@ class AcmeHandler:
except Exception:
logger.exception("Failed saving!")
raise
return True

View file

@ -22,8 +22,10 @@ only need (and may only have available) if we are doing ACME, so is designed to
imported conditionally.
"""
import logging
from typing import Dict, Iterable, List
import attr
import pem
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from josepy import JWKRSA
@ -36,20 +38,27 @@ from txacme.util import generate_private_key
from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.interfaces import IReactorTCP
from twisted.python.filepath import FilePath
from twisted.python.url import URL
from twisted.web.resource import IResource
logger = logging.getLogger(__name__)
def create_issuing_service(reactor, acme_url, account_key_file, well_known_resource):
def create_issuing_service(
reactor: IReactorTCP,
acme_url: str,
account_key_file: str,
well_known_resource: IResource,
) -> AcmeIssuingService:
"""Create an ACME issuing service, and attach it to a web Resource
Args:
reactor: twisted reactor
acme_url (str): URL to use to request certificates
account_key_file (str): where to store the account key
well_known_resource (twisted.web.IResource): web resource for .well-known.
acme_url: URL to use to request certificates
account_key_file: where to store the account key
well_known_resource: web resource for .well-known.
we will attach a child resource for "acme-challenge".
Returns:
@ -83,18 +92,20 @@ class ErsatzStore:
A store that only stores in memory.
"""
certs = attr.ib(default=attr.Factory(dict))
certs = attr.ib(type=Dict[bytes, List[bytes]], default=attr.Factory(dict))
def store(self, server_name, pem_objects):
def store(
self, server_name: bytes, pem_objects: Iterable[pem.AbstractPEMObject]
) -> defer.Deferred:
self.certs[server_name] = [o.as_bytes() for o in pem_objects]
return defer.succeed(None)
def load_or_create_client_key(key_file):
def load_or_create_client_key(key_file: str) -> JWKRSA:
"""Load the ACME account key from a file, creating it if it does not exist.
Args:
key_file (str): name of the file to use as the account key
key_file: name of the file to use as the account key
"""
# this is based on txacme.endpoint.load_or_create_client_key, but doesn't
# hardcode the 'client.key' filename

View file

@ -61,6 +61,7 @@ from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi
from synapse.storage.roommember import ProfileInfo
from synapse.types import JsonDict, Requester, UserID
from synapse.util import stringutils as stringutils
from synapse.util.async_helpers import maybe_awaitable
@ -567,16 +568,6 @@ class AuthHandler(BaseHandler):
session.session_id, login_type, result
)
except LoginError as e:
if login_type == LoginType.EMAIL_IDENTITY:
# riot used to have a bug where it would request a new
# validation token (thus sending a new email) each time it
# got a 401 with a 'flows' field.
# (https://github.com/vector-im/vector-web/issues/2447).
#
# Grandfather in the old behaviour for now to avoid
# breaking old riot deployments.
raise
# this step failed. Merge the error dict into the response
# so that the client can have another go.
errordict = e.error_dict()
@ -1387,7 +1378,9 @@ class AuthHandler(BaseHandler):
)
return self._sso_auth_confirm_template.render(
description=session.description, redirect_url=redirect_url,
description=session.description,
redirect_url=redirect_url,
idp=sso_auth_provider,
)
async def complete_sso_login(
@ -1396,6 +1389,7 @@ class AuthHandler(BaseHandler):
request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
new_user: bool = False,
):
"""Having figured out a mxid for this user, complete the HTTP request
@ -1406,6 +1400,8 @@ class AuthHandler(BaseHandler):
process.
extra_attributes: Extra attributes which will be passed to the client
during successful login. Must be JSON serializable.
new_user: True if we should use wording appropriate to a user who has just
registered.
"""
# If the account has been deactivated, do not proceed with the login
# flow.
@ -1414,8 +1410,17 @@ class AuthHandler(BaseHandler):
respond_with_html(request, 403, self._sso_account_deactivated_template)
return
profile = await self.store.get_profileinfo(
UserID.from_string(registered_user_id).localpart
)
self._complete_sso_login(
registered_user_id, request, client_redirect_url, extra_attributes
registered_user_id,
request,
client_redirect_url,
extra_attributes,
new_user=new_user,
user_profile_data=profile,
)
def _complete_sso_login(
@ -1424,12 +1429,18 @@ class AuthHandler(BaseHandler):
request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
new_user: bool = False,
user_profile_data: Optional[ProfileInfo] = None,
):
"""
The synchronous portion of complete_sso_login.
This exists purely for backwards compatibility of synapse.module_api.ModuleApi.
"""
if user_profile_data is None:
user_profile_data = ProfileInfo(None, None)
# Store any extra attributes which will be passed in the login response.
# Note that this is per-user so it may overwrite a previous value, this
# is considered OK since the newest SSO attributes should be most valid.
@ -1467,6 +1478,9 @@ class AuthHandler(BaseHandler):
display_url=redirect_url_no_params,
redirect_url=redirect_url,
server_name=self._server_name,
new_user=new_user,
user_id=registered_user_id,
user_profile=user_profile_data,
)
respond_with_html(request, 200, html)

View file

@ -80,9 +80,10 @@ class CasHandler:
# user-facing name of this auth provider
self.idp_name = "CAS"
# we do not currently support icons for CAS auth, but this is required by
# we do not currently support brands/icons for CAS auth, but this is required by
# the SsoIdentityProvider protocol type.
self.idp_icon = None
self.idp_brand = None
self._sso_handler = hs.get_sso_handler()
@ -99,11 +100,7 @@ class CasHandler:
Returns:
The URL to use as a "service" parameter.
"""
return "%s%s?%s" % (
self._cas_service_url,
"/_matrix/client/r0/login/cas/ticket",
urllib.parse.urlencode(args),
)
return "%s?%s" % (self._cas_service_url, urllib.parse.urlencode(args),)
async def _validate_ticket(
self, ticket: str, service_args: Dict[str, str]

View file

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api import errors
from synapse.api.constants import EventTypes
@ -62,7 +62,7 @@ class DeviceWorkerHandler(BaseHandler):
self._auth_handler = hs.get_auth_handler()
@trace
async def get_devices_by_user(self, user_id: str) -> List[Dict[str, Any]]:
async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
"""
Retrieve the given user's devices
@ -85,7 +85,7 @@ class DeviceWorkerHandler(BaseHandler):
return devices
@trace
async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
async def get_device(self, user_id: str, device_id: str) -> JsonDict:
""" Retrieve the given device
Args:
@ -598,7 +598,7 @@ class DeviceHandler(DeviceWorkerHandler):
def _update_device_from_client_ips(
device: Dict[str, Any], client_ips: Dict[Tuple[str, str], Dict[str, Any]]
device: JsonDict, client_ips: Dict[Tuple[str, str], JsonDict]
) -> None:
ip = client_ips.get((device["user_id"], device["device_id"]), {})
device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
@ -946,8 +946,8 @@ class DeviceListUpdater:
async def process_cross_signing_key_update(
self,
user_id: str,
master_key: Optional[Dict[str, Any]],
self_signing_key: Optional[Dict[str, Any]],
master_key: Optional[JsonDict],
self_signing_key: Optional[JsonDict],
) -> List[str]:
"""Process the given new master and self-signing key for the given remote user.

View file

@ -16,7 +16,7 @@
# limitations under the License.
import logging
from typing import Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
import attr
from canonicaljson import encode_canonical_json
@ -31,6 +31,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.types import (
JsonDict,
UserID,
get_domain_from_id,
get_verify_key_from_cross_signing_key,
@ -40,11 +41,14 @@ from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
class E2eKeysHandler:
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.federation = hs.get_federation_client()
self.device_handler = hs.get_device_handler()
@ -78,7 +82,9 @@ class E2eKeysHandler:
)
@trace
async def query_devices(self, query_body, timeout, from_user_id):
async def query_devices(
self, query_body: JsonDict, timeout: int, from_user_id: str
) -> JsonDict:
""" Handle a device key query from a client
{
@ -98,12 +104,14 @@ class E2eKeysHandler:
}
Args:
from_user_id (str): the user making the query. This is used when
from_user_id: the user making the query. This is used when
adding cross-signing signatures to limit what signatures users
can see.
"""
device_keys_query = query_body.get("device_keys", {})
device_keys_query = query_body.get(
"device_keys", {}
) # type: Dict[str, Iterable[str]]
# separate users by domain.
# make a map from domain to user_id to device_ids
@ -121,7 +129,8 @@ class E2eKeysHandler:
set_tag("remote_key_query", remote_queries)
# First get local devices.
failures = {}
# A map of destination -> failure response.
failures = {} # type: Dict[str, JsonDict]
results = {}
if local_query:
local_result = await self.query_local_devices(local_query)
@ -135,9 +144,10 @@ class E2eKeysHandler:
)
# Now attempt to get any remote devices from our local cache.
remote_queries_not_in_cache = {}
# A map of destination -> user ID -> device IDs.
remote_queries_not_in_cache = {} # type: Dict[str, Dict[str, Iterable[str]]]
if remote_queries:
query_list = []
query_list = [] # type: List[Tuple[str, Optional[str]]]
for user_id, device_ids in remote_queries.items():
if device_ids:
query_list.extend((user_id, device_id) for device_id in device_ids)
@ -284,15 +294,15 @@ class E2eKeysHandler:
return ret
async def get_cross_signing_keys_from_cache(
self, query, from_user_id
self, query: Iterable[str], from_user_id: Optional[str]
) -> Dict[str, Dict[str, dict]]:
"""Get cross-signing keys for users from the database
Args:
query (Iterable[string]) an iterable of user IDs. A dict whose keys
query: an iterable of user IDs. A dict whose keys
are user IDs satisfies this, so the query format used for
query_devices can be used here.
from_user_id (str): the user making the query. This is used when
from_user_id: the user making the query. This is used when
adding cross-signing signatures to limit what signatures users
can see.
@ -315,14 +325,12 @@ class E2eKeysHandler:
if "self_signing" in user_info:
self_signing_keys[user_id] = user_info["self_signing"]
if (
from_user_id in keys
and keys[from_user_id] is not None
and "user_signing" in keys[from_user_id]
):
# users can see other users' master and self-signing keys, but can
# only see their own user-signing keys
user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"]
# users can see other users' master and self-signing keys, but can
# only see their own user-signing keys
if from_user_id:
from_user_key = keys.get(from_user_id)
if from_user_key and "user_signing" in from_user_key:
user_signing_keys[from_user_id] = from_user_key["user_signing"]
return {
"master_keys": master_keys,
@ -344,9 +352,9 @@ class E2eKeysHandler:
A map from user_id -> device_id -> device details
"""
set_tag("local_query", query)
local_query = []
local_query = [] # type: List[Tuple[str, Optional[str]]]
result_dict = {}
result_dict = {} # type: Dict[str, Dict[str, dict]]
for user_id, device_ids in query.items():
# we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
@ -380,10 +388,14 @@ class E2eKeysHandler:
log_kv(results)
return result_dict
async def on_federation_query_client_keys(self, query_body):
async def on_federation_query_client_keys(
self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
) -> JsonDict:
""" Handle a device key query from a federated server
"""
device_keys_query = query_body.get("device_keys", {})
device_keys_query = query_body.get(
"device_keys", {}
) # type: Dict[str, Optional[List[str]]]
res = await self.query_local_devices(device_keys_query)
ret = {"device_keys": res}
@ -397,31 +409,34 @@ class E2eKeysHandler:
return ret
@trace
async def claim_one_time_keys(self, query, timeout):
local_query = []
remote_queries = {}
async def claim_one_time_keys(
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
) -> JsonDict:
local_query = [] # type: List[Tuple[str, str, str]]
remote_queries = {} # type: Dict[str, Dict[str, Dict[str, str]]]
for user_id, device_keys in query.get("one_time_keys", {}).items():
for user_id, one_time_keys in query.get("one_time_keys", {}).items():
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
for device_id, algorithm in device_keys.items():
for device_id, algorithm in one_time_keys.items():
local_query.append((user_id, device_id, algorithm))
else:
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = device_keys
remote_queries.setdefault(domain, {})[user_id] = one_time_keys
set_tag("local_key_query", local_query)
set_tag("remote_key_query", remote_queries)
results = await self.store.claim_e2e_one_time_keys(local_query)
json_result = {}
failures = {}
# A map of user ID -> device ID -> key ID -> key.
json_result = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
failures = {} # type: Dict[str, JsonDict]
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items():
for key_id, json_str in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
key_id: json_decoder.decode(json_bytes)
key_id: json_decoder.decode(json_str)
}
@trace
@ -468,7 +483,9 @@ class E2eKeysHandler:
return {"one_time_keys": json_result, "failures": failures}
@tag_args
async def upload_keys_for_user(self, user_id, device_id, keys):
async def upload_keys_for_user(
self, user_id: str, device_id: str, keys: JsonDict
) -> JsonDict:
time_now = self.clock.time_msec()
@ -543,8 +560,8 @@ class E2eKeysHandler:
return {"one_time_key_counts": result}
async def _upload_one_time_keys_for_user(
self, user_id, device_id, time_now, one_time_keys
):
self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
) -> None:
logger.info(
"Adding one_time_keys %r for device %r for user %r at %d",
one_time_keys.keys(),
@ -585,12 +602,14 @@ class E2eKeysHandler:
log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
async def upload_signing_keys_for_user(self, user_id, keys):
async def upload_signing_keys_for_user(
self, user_id: str, keys: JsonDict
) -> JsonDict:
"""Upload signing keys for cross-signing
Args:
user_id (string): the user uploading the keys
keys (dict[string, dict]): the signing keys
user_id: the user uploading the keys
keys: the signing keys
"""
# if a master key is uploaded, then check it. Otherwise, load the
@ -667,16 +686,17 @@ class E2eKeysHandler:
return {}
async def upload_signatures_for_device_keys(self, user_id, signatures):
async def upload_signatures_for_device_keys(
self, user_id: str, signatures: JsonDict
) -> JsonDict:
"""Upload device signatures for cross-signing
Args:
user_id (string): the user uploading the signatures
signatures (dict[string, dict[string, dict]]): map of users to
devices to signed keys. This is the submission from the user; an
exception will be raised if it is malformed.
user_id: the user uploading the signatures
signatures: map of users to devices to signed keys. This is the submission
from the user; an exception will be raised if it is malformed.
Returns:
dict: response to be sent back to the client. The response will have
The response to be sent back to the client. The response will have
a "failures" key, which will be a dict mapping users to devices
to errors for the signatures that failed.
Raises:
@ -719,7 +739,9 @@ class E2eKeysHandler:
return {"failures": failures}
async def _process_self_signatures(self, user_id, signatures):
async def _process_self_signatures(
self, user_id: str, signatures: JsonDict
) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]:
"""Process uploaded signatures of the user's own keys.
Signatures of the user's own keys from this API come in two forms:
@ -731,15 +753,14 @@ class E2eKeysHandler:
signatures (dict[string, dict]): map of devices to signed keys
Returns:
(list[SignatureListItem], dict[string, dict[string, dict]]):
a list of signatures to store, and a map of users to devices to failure
reasons
A tuple of a list of signatures to store, and a map of users to
devices to failure reasons
Raises:
SynapseError: if the input is malformed
"""
signature_list = []
failures = {}
signature_list = [] # type: List[SignatureListItem]
failures = {} # type: Dict[str, Dict[str, JsonDict]]
if not signatures:
return signature_list, failures
@ -834,19 +855,24 @@ class E2eKeysHandler:
return signature_list, failures
def _check_master_key_signature(
self, user_id, master_key_id, signed_master_key, stored_master_key, devices
):
self,
user_id: str,
master_key_id: str,
signed_master_key: JsonDict,
stored_master_key: JsonDict,
devices: Dict[str, Dict[str, JsonDict]],
) -> List["SignatureListItem"]:
"""Check signatures of a user's master key made by their devices.
Args:
user_id (string): the user whose master key is being checked
master_key_id (string): the ID of the user's master key
signed_master_key (dict): the user's signed master key that was uploaded
stored_master_key (dict): our previously-stored copy of the user's master key
devices (iterable(dict)): the user's devices
user_id: the user whose master key is being checked
master_key_id: the ID of the user's master key
signed_master_key: the user's signed master key that was uploaded
stored_master_key: our previously-stored copy of the user's master key
devices: the user's devices
Returns:
list[SignatureListItem]: a list of signatures to store
A list of signatures to store
Raises:
SynapseError: if a signature is invalid
@ -877,25 +903,26 @@ class E2eKeysHandler:
return master_key_signature_list
async def _process_other_signatures(self, user_id, signatures):
async def _process_other_signatures(
self, user_id: str, signatures: Dict[str, dict]
) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]:
"""Process uploaded signatures of other users' keys. These will be the
target user's master keys, signed by the uploading user's user-signing
key.
Args:
user_id (string): the user uploading the keys
signatures (dict[string, dict]): map of users to devices to signed keys
user_id: the user uploading the keys
signatures: map of users to devices to signed keys
Returns:
(list[SignatureListItem], dict[string, dict[string, dict]]):
a list of signatures to store, and a map of users to devices to failure
A list of signatures to store, and a map of users to devices to failure
reasons
Raises:
SynapseError: if the input is malformed
"""
signature_list = []
failures = {}
signature_list = [] # type: List[SignatureListItem]
failures = {} # type: Dict[str, Dict[str, JsonDict]]
if not signatures:
return signature_list, failures
@ -983,7 +1010,7 @@ class E2eKeysHandler:
async def _get_e2e_cross_signing_verify_key(
self, user_id: str, key_type: str, from_user_id: str = None
):
) -> Tuple[JsonDict, str, VerifyKey]:
"""Fetch locally or remotely query for a cross-signing public key.
First, attempt to fetch the cross-signing public key from storage.
@ -997,8 +1024,7 @@ class E2eKeysHandler:
This affects what signatures are fetched.
Returns:
dict, str, VerifyKey: the raw key data, the key ID, and the
signedjson verify key
The raw key data, the key ID, and the signedjson verify key
Raises:
NotFoundError: if the key is not found
@ -1135,16 +1161,18 @@ class E2eKeysHandler:
return desired_key, desired_key_id, desired_verify_key
def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
def _check_cross_signing_key(
key: JsonDict, user_id: str, key_type: str, signing_key: Optional[VerifyKey] = None
) -> None:
"""Check a cross-signing key uploaded by a user. Performs some basic sanity
checking, and ensures that it is signed, if a signature is required.
Args:
key (dict): the key data to verify
user_id (str): the user whose key is being checked
key_type (str): the type of key that the key should be
signing_key (VerifyKey): (optional) the signing key that the key should
be signed with. If omitted, signatures will not be checked.
key: the key data to verify
user_id: the user whose key is being checked
key_type: the type of key that the key should be
signing_key: the signing key that the key should be signed with. If
omitted, signatures will not be checked.
"""
if (
key.get("user_id") != user_id
@ -1162,16 +1190,21 @@ def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
)
def _check_device_signature(user_id, verify_key, signed_device, stored_device):
def _check_device_signature(
user_id: str,
verify_key: VerifyKey,
signed_device: JsonDict,
stored_device: JsonDict,
) -> None:
"""Check that a signature on a device or cross-signing key is correct and
matches the copy of the device/key that we have stored. Throws an
exception if an error is detected.
Args:
user_id (str): the user ID whose signature is being checked
verify_key (VerifyKey): the key to verify the device with
signed_device (dict): the uploaded signed device data
stored_device (dict): our previously stored copy of the device
user_id: the user ID whose signature is being checked
verify_key: the key to verify the device with
signed_device: the uploaded signed device data
stored_device: our previously stored copy of the device
Raises:
SynapseError: if the signature was invalid or the sent device is not the
@ -1201,7 +1234,7 @@ def _check_device_signature(user_id, verify_key, signed_device, stored_device):
raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE)
def _exception_to_failure(e):
def _exception_to_failure(e: Exception) -> JsonDict:
if isinstance(e, SynapseError):
return {"status": e.code, "errcode": e.errcode, "message": str(e)}
@ -1218,7 +1251,7 @@ def _exception_to_failure(e):
return {"status": 503, "message": str(e)}
def _one_time_keys_match(old_key_json, new_key):
def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool:
old_key = json_decoder.decode(old_key_json)
# if either is a string rather than an object, they must match exactly
@ -1239,16 +1272,16 @@ class SignatureListItem:
"""An item in the signature list as used by upload_signatures_for_device_keys.
"""
signing_key_id = attr.ib()
target_user_id = attr.ib()
target_device_id = attr.ib()
signature = attr.ib()
signing_key_id = attr.ib(type=str)
target_user_id = attr.ib(type=str)
target_device_id = attr.ib(type=str)
signature = attr.ib(type=JsonDict)
class SigningKeyEduUpdater:
"""Handles incoming signing key updates from federation and updates the DB"""
def __init__(self, hs, e2e_keys_handler):
def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
self.store = hs.get_datastore()
self.federation = hs.get_federation_client()
self.clock = hs.get_clock()
@ -1257,7 +1290,7 @@ class SigningKeyEduUpdater:
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
# user_id -> list of updates waiting to be handled.
self._pending_updates = {}
self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
# Recently seen stream ids. We don't bother keeping these in the DB,
# but they're useful to have them about to reduce the number of spurious
@ -1270,13 +1303,15 @@ class SigningKeyEduUpdater:
iterable=True,
)
async def incoming_signing_key_update(self, origin, edu_content):
async def incoming_signing_key_update(
self, origin: str, edu_content: JsonDict
) -> None:
"""Called on incoming signing key update from federation. Responsible for
parsing the EDU and adding to pending updates list.
Args:
origin (string): the server that sent the EDU
edu_content (dict): the contents of the EDU
origin: the server that sent the EDU
edu_content: the contents of the EDU
"""
user_id = edu_content.pop("user_id")
@ -1299,11 +1334,11 @@ class SigningKeyEduUpdater:
await self._handle_signing_key_updates(user_id)
async def _handle_signing_key_updates(self, user_id):
async def _handle_signing_key_updates(self, user_id: str) -> None:
"""Actually handle pending updates.
Args:
user_id (string): the user whose updates we are processing
user_id: the user whose updates we are processing
"""
device_handler = self.e2e_keys_handler.device_handler
@ -1315,7 +1350,7 @@ class SigningKeyEduUpdater:
# This can happen since we batch updates
return
device_ids = []
device_ids = [] # type: List[str]
logger.info("pending updates: %r", pending_updates)

View file

@ -15,6 +15,7 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, List, Optional
from synapse.api.errors import (
Codes,
@ -24,8 +25,12 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.logging.opentracing import log_kv, trace
from synapse.types import JsonDict
from synapse.util.async_helpers import Linearizer
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@ -37,7 +42,7 @@ class E2eRoomKeysHandler:
The actual payload of the encrypted keys is completely opaque to the handler.
"""
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
# Used to lock whenever a client is uploading key data. This prevents collisions
@ -48,21 +53,27 @@ class E2eRoomKeysHandler:
self._upload_linearizer = Linearizer("upload_room_keys_lock")
@trace
async def get_room_keys(self, user_id, version, room_id=None, session_id=None):
async def get_room_keys(
self,
user_id: str,
version: str,
room_id: Optional[str] = None,
session_id: Optional[str] = None,
) -> List[JsonDict]:
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
room, or a given session.
See EndToEndRoomKeyStore.get_e2e_room_keys for full details.
Args:
user_id(str): the user whose keys we're getting
version(str): the version ID of the backup we're getting keys from
room_id(string): room ID to get keys for, for None to get keys for all rooms
session_id(string): session ID to get keys for, for None to get keys for all
user_id: the user whose keys we're getting
version: the version ID of the backup we're getting keys from
room_id: room ID to get keys for, for None to get keys for all rooms
session_id: session ID to get keys for, for None to get keys for all
sessions
Raises:
NotFoundError: if the backup version does not exist
Returns:
A deferred list of dicts giving the session_data and message metadata for
A list of dicts giving the session_data and message metadata for
these room keys.
"""
@ -86,17 +97,23 @@ class E2eRoomKeysHandler:
return results
@trace
async def delete_room_keys(self, user_id, version, room_id=None, session_id=None):
async def delete_room_keys(
self,
user_id: str,
version: str,
room_id: Optional[str] = None,
session_id: Optional[str] = None,
) -> JsonDict:
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
room or a given session.
See EndToEndRoomKeyStore.delete_e2e_room_keys for full details.
Args:
user_id(str): the user whose backup we're deleting
version(str): the version ID of the backup we're deleting
room_id(string): room ID to delete keys for, for None to delete keys for all
user_id: the user whose backup we're deleting
version: the version ID of the backup we're deleting
room_id: room ID to delete keys for, for None to delete keys for all
rooms
session_id(string): session ID to delete keys for, for None to delete keys
session_id: session ID to delete keys for, for None to delete keys
for all sessions
Raises:
NotFoundError: if the backup version does not exist
@ -128,15 +145,17 @@ class E2eRoomKeysHandler:
return {"etag": str(version_etag), "count": count}
@trace
async def upload_room_keys(self, user_id, version, room_keys):
async def upload_room_keys(
self, user_id: str, version: str, room_keys: JsonDict
) -> JsonDict:
"""Bulk upload a list of room keys into a given backup version, asserting
that the given version is the current backup version. room_keys are merged
into the current backup as described in RoomKeysServlet.on_PUT().
Args:
user_id(str): the user whose backup we're setting
version(str): the version ID of the backup we're updating
room_keys(dict): a nested dict describing the room_keys we're setting:
user_id: the user whose backup we're setting
version: the version ID of the backup we're updating
room_keys: a nested dict describing the room_keys we're setting:
{
"rooms": {
@ -254,14 +273,16 @@ class E2eRoomKeysHandler:
return {"etag": str(version_etag), "count": count}
@staticmethod
def _should_replace_room_key(current_room_key, room_key):
def _should_replace_room_key(
current_room_key: Optional[JsonDict], room_key: JsonDict
) -> bool:
"""
Determine whether to replace a given current_room_key (if any)
with a newly uploaded room_key backup
Args:
current_room_key (dict): Optional, the current room_key dict if any
room_key (dict): The new room_key dict which may or may not be fit to
current_room_key: Optional, the current room_key dict if any
room_key : The new room_key dict which may or may not be fit to
replace the current_room_key
Returns:
@ -286,14 +307,14 @@ class E2eRoomKeysHandler:
return True
@trace
async def create_version(self, user_id, version_info):
async def create_version(self, user_id: str, version_info: JsonDict) -> str:
"""Create a new backup version. This automatically becomes the new
backup version for the user's keys; previous backups will no longer be
writeable to.
Args:
user_id(str): the user whose backup version we're creating
version_info(dict): metadata about the new version being created
user_id: the user whose backup version we're creating
version_info: metadata about the new version being created
{
"algorithm": "m.megolm_backup.v1",
@ -301,7 +322,7 @@ class E2eRoomKeysHandler:
}
Returns:
A deferred of a string that gives the new version number.
The new version number.
"""
# TODO: Validate the JSON to make sure it has the right keys.
@ -313,17 +334,19 @@ class E2eRoomKeysHandler:
)
return new_version
async def get_version_info(self, user_id, version=None):
async def get_version_info(
self, user_id: str, version: Optional[str] = None
) -> JsonDict:
"""Get the info about a given version of the user's backup
Args:
user_id(str): the user whose current backup version we're querying
version(str): Optional; if None gives the most recent version
user_id: the user whose current backup version we're querying
version: Optional; if None gives the most recent version
otherwise a historical one.
Raises:
NotFoundError: if the requested backup version doesn't exist
Returns:
A deferred of a info dict that gives the info about the new version.
A info dict that gives the info about the new version.
{
"version": "1234",
@ -346,7 +369,7 @@ class E2eRoomKeysHandler:
return res
@trace
async def delete_version(self, user_id, version=None):
async def delete_version(self, user_id: str, version: Optional[str] = None) -> None:
"""Deletes a given version of the user's e2e_room_keys backup
Args:
@ -366,17 +389,19 @@ class E2eRoomKeysHandler:
raise
@trace
async def update_version(self, user_id, version, version_info):
async def update_version(
self, user_id: str, version: str, version_info: JsonDict
) -> JsonDict:
"""Update the info about a given version of the user's backup
Args:
user_id(str): the user whose current backup version we're updating
version(str): the backup version we're updating
version_info(dict): the new information about the backup
user_id: the user whose current backup version we're updating
version: the backup version we're updating
version_info: the new information about the backup
Raises:
NotFoundError: if the requested backup version doesn't exist
Returns:
A deferred of an empty dict.
An empty dict.
"""
if "version" not in version_info:
version_info["version"] = version

View file

@ -1617,6 +1617,10 @@ class FederationHandler(BaseHandler):
if event.state_key == self._server_notices_mxid:
raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user")
# We retrieve the room member handler here as to not cause a cyclic dependency
member_handler = self.hs.get_room_member_handler()
member_handler.ratelimit_invite(event.room_id, event.state_key)
# keep a record of the room version, if we don't yet know it.
# (this may get overwritten if we later get a different room version in a
# join dance).
@ -2093,6 +2097,11 @@ class FederationHandler(BaseHandler):
if event.type == EventTypes.GuestAccess and not context.rejected:
await self.maybe_kick_guest_users(event)
# If we are going to send this event over federation we precaclculate
# the joined hosts.
if event.internal_metadata.get_send_on_behalf_of():
await self.event_creation_handler.cache_joined_hosts_for_event(event)
return context
async def _check_for_soft_fail(

View file

@ -15,9 +15,13 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, Iterable, List, Set
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.types import GroupID, get_domain_from_id
from synapse.types import GroupID, JsonDict, get_domain_from_id
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@ -56,7 +60,7 @@ def _create_rerouter(func_name):
class GroupsLocalWorkerHandler:
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.room_list_handler = hs.get_room_list_handler()
@ -84,7 +88,9 @@ class GroupsLocalWorkerHandler:
get_group_role = _create_rerouter("get_group_role")
get_group_roles = _create_rerouter("get_group_roles")
async def get_group_summary(self, group_id, requester_user_id):
async def get_group_summary(
self, group_id: str, requester_user_id: str
) -> JsonDict:
"""Get the group summary for a group.
If the group is remote we check that the users have valid attestations.
@ -137,14 +143,15 @@ class GroupsLocalWorkerHandler:
return res
async def get_users_in_group(self, group_id, requester_user_id):
async def get_users_in_group(
self, group_id: str, requester_user_id: str
) -> JsonDict:
"""Get users in a group
"""
if self.is_mine_id(group_id):
res = await self.groups_server_handler.get_users_in_group(
return await self.groups_server_handler.get_users_in_group(
group_id, requester_user_id
)
return res
group_server_name = get_domain_from_id(group_id)
@ -178,11 +185,11 @@ class GroupsLocalWorkerHandler:
return res
async def get_joined_groups(self, user_id):
async def get_joined_groups(self, user_id: str) -> JsonDict:
group_ids = await self.store.get_joined_groups(user_id)
return {"groups": group_ids}
async def get_publicised_groups_for_user(self, user_id):
async def get_publicised_groups_for_user(self, user_id: str) -> JsonDict:
if self.hs.is_mine_id(user_id):
result = await self.store.get_publicised_groups_for_user(user_id)
@ -206,8 +213,10 @@ class GroupsLocalWorkerHandler:
# TODO: Verify attestations
return {"groups": result}
async def bulk_get_publicised_groups(self, user_ids, proxy=True):
destinations = {}
async def bulk_get_publicised_groups(
self, user_ids: Iterable[str], proxy: bool = True
) -> JsonDict:
destinations = {} # type: Dict[str, Set[str]]
local_users = set()
for user_id in user_ids:
@ -220,7 +229,7 @@ class GroupsLocalWorkerHandler:
raise SynapseError(400, "Some user_ids are not local")
results = {}
failed_results = []
failed_results = [] # type: List[str]
for destination, dest_user_ids in destinations.items():
try:
r = await self.transport_client.bulk_get_publicised_groups(
@ -242,7 +251,7 @@ class GroupsLocalWorkerHandler:
class GroupsLocalHandler(GroupsLocalWorkerHandler):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
# Ensure attestations get renewed
@ -271,7 +280,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
set_group_join_policy = _create_rerouter("set_group_join_policy")
async def create_group(self, group_id, user_id, content):
async def create_group(
self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
"""Create a group
"""
@ -284,27 +295,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
local_attestation = None
remote_attestation = None
else:
local_attestation = self.attestations.create_attestation(group_id, user_id)
content["attestation"] = local_attestation
content["user_profile"] = await self.profile_handler.get_profile(user_id)
try:
res = await self.transport_client.create_group(
get_domain_from_id(group_id), group_id, user_id, content
)
except HttpResponseException as e:
raise e.to_synapse_error()
except RequestSendFailed:
raise SynapseError(502, "Failed to contact group server")
remote_attestation = res["attestation"]
await self.attestations.verify_attestation(
remote_attestation,
group_id=group_id,
user_id=user_id,
server_name=get_domain_from_id(group_id),
)
raise SynapseError(400, "Unable to create remote groups")
is_publicised = content.get("publicise", False)
token = await self.store.register_user_group_membership(
@ -320,7 +311,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return res
async def join_group(self, group_id, user_id, content):
async def join_group(
self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
"""Request to join a group
"""
if self.is_mine_id(group_id):
@ -365,7 +358,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {}
async def accept_invite(self, group_id, user_id, content):
async def accept_invite(
self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
"""Accept an invite to a group
"""
if self.is_mine_id(group_id):
@ -410,7 +405,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {}
async def invite(self, group_id, user_id, requester_user_id, config):
async def invite(
self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict
) -> JsonDict:
"""Invite a user to a group
"""
content = {"requester_user_id": requester_user_id, "config": config}
@ -434,7 +431,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return res
async def on_invite(self, group_id, user_id, content):
async def on_invite(
self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
"""One of our users were invited to a group
"""
# TODO: Support auto join and rejection
@ -465,8 +464,8 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {"state": "invite", "user_profile": user_profile}
async def remove_user_from_group(
self, group_id, user_id, requester_user_id, content
):
self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
) -> JsonDict:
"""Remove a user from a group
"""
if user_id == requester_user_id:
@ -499,7 +498,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return res
async def user_removed_from_group(self, group_id, user_id, content):
async def user_removed_from_group(
self, group_id: str, user_id: str, content: JsonDict
) -> None:
"""One of our users was removed/kicked from a group
"""
# TODO: Check if user in group

View file

@ -27,9 +27,11 @@ from synapse.api.errors import (
HttpResponseException,
SynapseError,
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.http import RequestTimedOutError
from synapse.http.client import SimpleHttpClient
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, Requester
from synapse.util import json_decoder
from synapse.util.hash import sha256_and_url_safe_base64
@ -57,6 +59,32 @@ class IdentityHandler(BaseHandler):
self._web_client_location = hs.config.invite_client_location
# Ratelimiters for `/requestToken` endpoints.
self._3pid_validation_ratelimiter_ip = Ratelimiter(
clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
)
self._3pid_validation_ratelimiter_address = Ratelimiter(
clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
)
def ratelimit_request_token_requests(
self, request: SynapseRequest, medium: str, address: str,
):
"""Used to ratelimit requests to `/requestToken` by IP and address.
Args:
request: The associated request
medium: The type of threepid, e.g. "msisdn" or "email"
address: The actual threepid ID, e.g. the phone number or email address
"""
self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP()))
self._3pid_validation_ratelimiter_address.ratelimit((medium, address))
async def threepid_from_creds(
self, id_server: str, creds: Dict[str, str]
) -> Optional[JsonDict]:

View file

@ -174,7 +174,7 @@ class MessageHandler:
raise NotFoundError("Can't find event for token %s" % (at_token,))
visible_events = await filter_events_for_client(
self.storage, user_id, last_events, filter_send_to_client=False
self.storage, user_id, last_events, filter_send_to_client=False,
)
event = last_events[0]
@ -434,6 +434,8 @@ class EventCreationHandler:
self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
self._external_cache = hs.get_external_cache()
async def create_event(
self,
requester: Requester,
@ -941,6 +943,8 @@ class EventCreationHandler:
await self.action_generator.handle_push_actions_for_event(event, context)
await self.cache_joined_hosts_for_event(event)
try:
# If we're a worker we need to hit out to the master.
writer_instance = self._events_shard_config.get_instance(event.room_id)
@ -980,6 +984,44 @@ class EventCreationHandler:
await self.store.remove_push_actions_from_staging(event.event_id)
raise
async def cache_joined_hosts_for_event(self, event: EventBase) -> None:
"""Precalculate the joined hosts at the event, when using Redis, so that
external federation senders don't have to recalculate it themselves.
"""
if not self._external_cache.is_enabled():
return
# We actually store two mappings, event ID -> prev state group,
# state group -> joined hosts, which is much more space efficient
# than event ID -> joined hosts.
#
# Note: We have to cache event ID -> prev state group, as we don't
# store that in the DB.
#
# Note: We always set the state group -> joined hosts cache, even if
# we already set it, so that the expiry time is reset.
state_entry = await self.state.resolve_state_groups_for_events(
event.room_id, event_ids=event.prev_event_ids()
)
if state_entry.state_group:
joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry)
await self._external_cache.set(
"event_to_prev_state_group",
event.event_id,
state_entry.state_group,
expiry_ms=60 * 60 * 1000,
)
await self._external_cache.set(
"get_joined_hosts",
str(state_entry.state_group),
list(joined_hosts),
expiry_ms=60 * 60 * 1000,
)
async def _validate_canonical_alias(
self, directory_handler, room_alias_str: str, expected_room_id: str
) -> None:

View file

@ -102,7 +102,7 @@ class OidcHandler:
) from e
async def handle_oidc_callback(self, request: SynapseRequest) -> None:
"""Handle an incoming request to /_synapse/oidc/callback
"""Handle an incoming request to /_synapse/client/oidc/callback
Since we might want to display OIDC-related errors in a user-friendly
way, we don't raise SynapseError from here. Instead, we call
@ -274,6 +274,9 @@ class OidcProvider:
# MXC URI for icon for this auth provider
self.idp_icon = provider.idp_icon
# optional brand identifier for this auth provider
self.idp_brand = provider.idp_brand
self._sso_handler = hs.get_sso_handler()
self._sso_handler.register_identity_provider(self)
@ -640,7 +643,7 @@ class OidcProvider:
- ``client_id``: the client ID set in ``oidc_config.client_id``
- ``response_type``: ``code``
- ``redirect_uri``: the callback URL ; ``{base url}/_synapse/oidc/callback``
- ``redirect_uri``: the callback URL ; ``{base url}/_synapse/client/oidc/callback``
- ``scope``: the list of scopes set in ``oidc_config.scopes``
- ``state``: a random string
- ``nonce``: a random string
@ -681,7 +684,7 @@ class OidcProvider:
request.addCookie(
SESSION_COOKIE_NAME,
cookie,
path="/_synapse/oidc",
path="/_synapse/client/oidc",
max_age="3600",
httpOnly=True,
sameSite="lax",
@ -702,7 +705,7 @@ class OidcProvider:
async def handle_oidc_callback(
self, request: SynapseRequest, session_data: "OidcSessionData", code: str
) -> None:
"""Handle an incoming request to /_synapse/oidc/callback
"""Handle an incoming request to /_synapse/client/oidc/callback
By this time we have already validated the session on the synapse side, and
now need to do the provider-specific operations. This includes:
@ -1056,7 +1059,8 @@ class OidcSessionData:
UserAttributeDict = TypedDict(
"UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]}
"UserAttributeDict",
{"localpart": Optional[str], "display_name": Optional[str], "emails": List[str]},
)
C = TypeVar("C")
@ -1135,11 +1139,12 @@ def jinja_finalize(thing):
env = Environment(finalize=jinja_finalize)
@attr.s
@attr.s(slots=True, frozen=True)
class JinjaOidcMappingConfig:
subject_claim = attr.ib(type=str)
localpart_template = attr.ib(type=Optional[Template])
display_name_template = attr.ib(type=Optional[Template])
email_template = attr.ib(type=Optional[Template])
extra_attributes = attr.ib(type=Dict[str, Template])
@ -1156,23 +1161,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
def parse_config(config: dict) -> JinjaOidcMappingConfig:
subject_claim = config.get("subject_claim", "sub")
localpart_template = None # type: Optional[Template]
if "localpart_template" in config:
def parse_template_config(option_name: str) -> Optional[Template]:
if option_name not in config:
return None
try:
localpart_template = env.from_string(config["localpart_template"])
return env.from_string(config[option_name])
except Exception as e:
raise ConfigError(
"invalid jinja template", path=["localpart_template"]
) from e
raise ConfigError("invalid jinja template", path=[option_name]) from e
display_name_template = None # type: Optional[Template]
if "display_name_template" in config:
try:
display_name_template = env.from_string(config["display_name_template"])
except Exception as e:
raise ConfigError(
"invalid jinja template", path=["display_name_template"]
) from e
localpart_template = parse_template_config("localpart_template")
display_name_template = parse_template_config("display_name_template")
email_template = parse_template_config("email_template")
extra_attributes = {} # type Dict[str, Template]
if "extra_attributes" in config:
@ -1192,6 +1191,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
subject_claim=subject_claim,
localpart_template=localpart_template,
display_name_template=display_name_template,
email_template=email_template,
extra_attributes=extra_attributes,
)
@ -1213,16 +1213,23 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
# a usable mxid.
localpart += str(failures) if failures else ""
display_name = None # type: Optional[str]
if self._config.display_name_template is not None:
display_name = self._config.display_name_template.render(
user=userinfo
).strip()
def render_template_field(template: Optional[Template]) -> Optional[str]:
if template is None:
return None
return template.render(user=userinfo).strip()
if display_name == "":
display_name = None
display_name = render_template_field(self._config.display_name_template)
if display_name == "":
display_name = None
return UserAttributeDict(localpart=localpart, display_name=display_name)
emails = [] # type: List[str]
email = render_template_field(self._config.email_template)
if email:
emails.append(email)
return UserAttributeDict(
localpart=localpart, display_name=display_name, emails=emails
)
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
extras = {} # type: Dict[str, str]

View file

@ -14,8 +14,9 @@
# limitations under the License.
"""Contains functions for registering clients."""
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from synapse import types
from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
@ -157,7 +158,7 @@ class RegistrationHandler(BaseHandler):
user_type: Optional[str] = None,
default_display_name: Optional[str] = None,
address: Optional[str] = None,
bind_emails: List[str] = [],
bind_emails: Iterable[str] = [],
by_admin: bool = False,
user_agent_ips: Optional[List[Tuple[str, str]]] = None,
) -> str:
@ -700,6 +701,8 @@ class RegistrationHandler(BaseHandler):
access_token: The access token of the newly logged in device, or
None if `inhibit_login` enabled.
"""
# TODO: 3pid registration can actually happen on the workers. Consider
# refactoring it.
if self.hs.config.worker_app:
await self._post_registration_client(
user_id=user_id, auth_result=auth_result, access_token=access_token

View file

@ -126,6 +126,10 @@ class RoomCreationHandler(BaseHandler):
self.third_party_event_rules = hs.get_third_party_event_rules()
self._invite_burst_count = (
hs.config.ratelimiting.rc_invites_per_room.burst_count
)
async def upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion
) -> str:
@ -662,6 +666,9 @@ class RoomCreationHandler(BaseHandler):
invite_3pid_list = []
invite_list = []
if len(invite_list) + len(invite_3pid_list) > self._invite_burst_count:
raise SynapseError(400, "Cannot invite so many users at once")
await self.event_creation_handler.assert_accepted_privacy_policy(requester)
power_level_content_override = config.get("power_level_content_override")

View file

@ -85,6 +85,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
)
self._invites_per_room_limiter = Ratelimiter(
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second,
burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count,
)
self._invites_per_user_limiter = Ratelimiter(
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second,
burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
)
# This is only used to get at ratelimit function, and
# maybe_kick_guest_users. It's fine there are multiple of these as
# it doesn't store state.
@ -144,6 +155,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
"""
raise NotImplementedError()
def ratelimit_invite(self, room_id: str, invitee_user_id: str):
"""Ratelimit invites by room and by target user.
"""
self._invites_per_room_limiter.ratelimit(room_id)
self._invites_per_user_limiter.ratelimit(invitee_user_id)
async def _local_membership_update(
self,
requester: Requester,
@ -387,8 +404,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
raise SynapseError(403, "This room has been blocked on this server")
if effective_membership_state == Membership.INVITE:
target_id = target.to_string()
if ratelimit:
self.ratelimit_invite(room_id, target_id)
# block any attempts to invite the server notices mxid
if target.to_string() == self._server_notices_mxid:
if target_id == self._server_notices_mxid:
raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user")
block_invite = False
@ -412,7 +433,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
block_invite = True
if not await self.spam_checker.user_may_invite(
requester.user.to_string(), target.to_string(), room_id
requester.user.to_string(), target_id, room_id
):
logger.info("Blocking invite due to spam checker")
block_invite = True

View file

@ -78,9 +78,10 @@ class SamlHandler(BaseHandler):
# user-facing name of this auth provider
self.idp_name = "SAML"
# we do not currently support icons for SAML auth, but this is required by
# we do not currently support icons/brands for SAML auth, but this is required by
# the SsoIdentityProvider protocol type.
self.idp_icon = None
self.idp_brand = None
# a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
@ -132,7 +133,7 @@ class SamlHandler(BaseHandler):
raise Exception("prepare_for_authenticate didn't return a Location header")
async def handle_saml_response(self, request: SynapseRequest) -> None:
"""Handle an incoming request to /_matrix/saml2/authn_response
"""Handle an incoming request to /_synapse/client/saml2/authn_response
Args:
request: the incoming request from the browser. We'll

View file

@ -15,23 +15,28 @@
import itertools
import logging
from typing import Iterable
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional
from unpaddedbase64 import decode_base64, encode_base64
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.filtering import Filter
from synapse.events import EventBase
from synapse.storage.state import StateFilter
from synapse.types import JsonDict, UserID
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
class SearchHandler(BaseHandler):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
@ -87,13 +92,15 @@ class SearchHandler(BaseHandler):
return historical_room_ids
async def search(self, user, content, batch=None):
async def search(
self, user: UserID, content: JsonDict, batch: Optional[str] = None
) -> JsonDict:
"""Performs a full text search for a user.
Args:
user (UserID)
content (dict): Search parameters
batch (str): The next_batch parameter. Used for pagination.
user
content: Search parameters
batch: The next_batch parameter. Used for pagination.
Returns:
dict to be returned to the client with results of search
@ -186,7 +193,7 @@ class SearchHandler(BaseHandler):
# If doing a subset of all rooms seearch, check if any of the rooms
# are from an upgraded room, and search their contents as well
if search_filter.rooms:
historical_room_ids = []
historical_room_ids = [] # type: List[str]
for room_id in search_filter.rooms:
# Add any previous rooms to the search if they exist
ids = await self.get_old_rooms_from_upgraded_room(room_id)
@ -209,8 +216,10 @@ class SearchHandler(BaseHandler):
rank_map = {} # event_id -> rank of event
allowed_events = []
room_groups = {} # Holds result of grouping by room, if applicable
sender_group = {} # Holds result of grouping by sender, if applicable
# Holds result of grouping by room, if applicable
room_groups = {} # type: Dict[str, JsonDict]
# Holds result of grouping by sender, if applicable
sender_group = {} # type: Dict[str, JsonDict]
# Holds the next_batch for the entire result set if one of those exists
global_next_batch = None
@ -254,7 +263,7 @@ class SearchHandler(BaseHandler):
s["results"].append(e.event_id)
elif order_by == "recent":
room_events = []
room_events = [] # type: List[EventBase]
i = 0
pagination_token = batch_token
@ -418,13 +427,10 @@ class SearchHandler(BaseHandler):
state_results = {}
if include_state:
rooms = {e.room_id for e in allowed_events}
for room_id in rooms:
for room_id in {e.room_id for e in allowed_events}:
state = await self.state_handler.get_current_state(room_id)
state_results[room_id] = list(state.values())
state_results.values()
# We're now about to serialize the events. We should not make any
# blocking calls after this. Otherwise the 'age' will be wrong
@ -448,9 +454,9 @@ class SearchHandler(BaseHandler):
if state_results:
s = {}
for room_id, state in state_results.items():
for room_id, state_events in state_results.items():
s[room_id] = await self._event_serializer.serialize_events(
state, time_now
state_events, time_now
)
rooms_cat_res["state"] = s

View file

@ -13,24 +13,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional
from typing import TYPE_CHECKING, Optional
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.types import Requester
from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
class SetPasswordHandler(BaseHandler):
"""Handler which deals with changing user account passwords"""
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
self._password_policy_handler = hs.get_password_policy_handler()
async def set_password(
self,
@ -38,7 +40,7 @@ class SetPasswordHandler(BaseHandler):
password_hash: str,
logout_devices: bool,
requester: Optional[Requester] = None,
):
) -> None:
if not self.hs.config.password_localdb_enabled:
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)

View file

@ -14,21 +14,31 @@
# limitations under the License.
import abc
import logging
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Mapping, Optional
from typing import (
TYPE_CHECKING,
Awaitable,
Callable,
Dict,
Iterable,
Mapping,
Optional,
Set,
)
from urllib.parse import urlencode
import attr
from typing_extensions import NoReturn, Protocol
from twisted.web.http import Request
from twisted.web.iweb import IRequest
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, RedirectException, SynapseError
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent
from synapse.http.server import respond_with_html
from synapse.http.server import respond_with_html, respond_with_redirect
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
from synapse.types import Collection, JsonDict, UserID, contains_invalid_mxid_characters
from synapse.util.async_helpers import Linearizer
from synapse.util.stringutils import random_string
@ -80,6 +90,11 @@ class SsoIdentityProvider(Protocol):
"""Optional MXC URI for user-facing icon"""
return None
@property
def idp_brand(self) -> Optional[str]:
"""Optional branding identifier"""
return None
@abc.abstractmethod
async def handle_redirect_request(
self,
@ -109,7 +124,7 @@ class UserAttributes:
# enter one.
localpart = attr.ib(type=Optional[str])
display_name = attr.ib(type=Optional[str], default=None)
emails = attr.ib(type=List[str], default=attr.Factory(list))
emails = attr.ib(type=Collection[str], default=attr.Factory(list))
@attr.s(slots=True)
@ -124,7 +139,7 @@ class UsernameMappingSession:
# attributes returned by the ID mapper
display_name = attr.ib(type=Optional[str])
emails = attr.ib(type=List[str])
emails = attr.ib(type=Collection[str])
# An optional dictionary of extra attributes to be provided to the client in the
# login response.
@ -136,6 +151,12 @@ class UsernameMappingSession:
# expiry time for the session, in milliseconds
expiry_time_ms = attr.ib(type=int)
# choices made by the user
chosen_localpart = attr.ib(type=Optional[str], default=None)
use_display_name = attr.ib(type=bool, default=True)
emails_to_use = attr.ib(type=Collection[str], default=())
terms_accepted_version = attr.ib(type=Optional[str], default=None)
# the HTTP cookie used to track the mapping session id
USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session"
@ -170,6 +191,8 @@ class SsoHandler:
# map from idp_id to SsoIdentityProvider
self._identity_providers = {} # type: Dict[str, SsoIdentityProvider]
self._consent_at_registration = hs.config.consent.user_consent_at_registration
def register_identity_provider(self, p: SsoIdentityProvider):
p_id = p.idp_id
assert p_id not in self._identity_providers
@ -235,7 +258,10 @@ class SsoHandler:
respond_with_html(request, code, html)
async def handle_redirect_request(
self, request: SynapseRequest, client_redirect_url: bytes,
self,
request: SynapseRequest,
client_redirect_url: bytes,
idp_id: Optional[str],
) -> str:
"""Handle a request to /login/sso/redirect
@ -243,6 +269,7 @@ class SsoHandler:
request: incoming HTTP request
client_redirect_url: the URL that we should redirect the
client to after login.
idp_id: optional identity provider chosen by the client
Returns:
the URI to redirect to
@ -252,10 +279,19 @@ class SsoHandler:
400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED
)
# if the client chose an IdP, use that
idp = None # type: Optional[SsoIdentityProvider]
if idp_id:
idp = self._identity_providers.get(idp_id)
if not idp:
raise NotFoundError("Unknown identity provider")
# if we only have one auth provider, redirect to it directly
if len(self._identity_providers) == 1:
ap = next(iter(self._identity_providers.values()))
return await ap.handle_redirect_request(request, client_redirect_url)
elif len(self._identity_providers) == 1:
idp = next(iter(self._identity_providers.values()))
if idp:
return await idp.handle_redirect_request(request, client_redirect_url)
# otherwise, redirect to the IDP picker
return "/_synapse/client/pick_idp?" + urlencode(
@ -369,6 +405,8 @@ class SsoHandler:
to an additional page. (e.g. to prompt for more information)
"""
new_user = False
# grab a lock while we try to find a mapping for this user. This seems...
# optimistic, especially for implementations that end up redirecting to
# interstitial pages.
@ -409,9 +447,14 @@ class SsoHandler:
get_request_user_agent(request),
request.getClientIP(),
)
new_user = True
await self._auth_handler.complete_sso_login(
user_id, request, client_redirect_url, extra_login_attributes
user_id,
request,
client_redirect_url,
extra_login_attributes,
new_user=new_user,
)
async def _call_attribute_mapper(
@ -501,7 +544,7 @@ class SsoHandler:
logger.info("Recorded registration session id %s", session_id)
# Set the cookie and redirect to the username picker
e = RedirectException(b"/_synapse/client/pick_username")
e = RedirectException(b"/_synapse/client/pick_username/account_details")
e.cookies.append(
b"%s=%s; path=/"
% (USERNAME_MAPPING_SESSION_COOKIE_NAME, session_id.encode("ascii"))
@ -629,6 +672,25 @@ class SsoHandler:
)
respond_with_html(request, 200, html)
def get_mapping_session(self, session_id: str) -> UsernameMappingSession:
"""Look up the given username mapping session
If it is not found, raises a SynapseError with an http code of 400
Args:
session_id: session to look up
Returns:
active mapping session
Raises:
SynapseError if the session is not found/has expired
"""
self._expire_old_sessions()
session = self._username_mapping_sessions.get(session_id)
if session:
return session
logger.info("Couldn't find session id %s", session_id)
raise SynapseError(400, "unknown session")
async def check_username_availability(
self, localpart: str, session_id: str,
) -> bool:
@ -645,12 +707,7 @@ class SsoHandler:
# make sure that there is a valid mapping session, to stop people dictionary-
# scanning for accounts
self._expire_old_sessions()
session = self._username_mapping_sessions.get(session_id)
if not session:
logger.info("Couldn't find session id %s", session_id)
raise SynapseError(400, "unknown session")
self.get_mapping_session(session_id)
logger.info(
"[session %s] Checking for availability of username %s",
@ -667,7 +724,12 @@ class SsoHandler:
return not user_infos
async def handle_submit_username_request(
self, request: SynapseRequest, localpart: str, session_id: str
self,
request: SynapseRequest,
session_id: str,
localpart: str,
use_display_name: bool,
emails_to_use: Iterable[str],
) -> None:
"""Handle a request to the username-picker 'submit' endpoint
@ -677,21 +739,90 @@ class SsoHandler:
request: HTTP request
localpart: localpart requested by the user
session_id: ID of the username mapping session, extracted from a cookie
use_display_name: whether the user wants to use the suggested display name
emails_to_use: emails that the user would like to use
"""
self._expire_old_sessions()
session = self._username_mapping_sessions.get(session_id)
if not session:
logger.info("Couldn't find session id %s", session_id)
raise SynapseError(400, "unknown session")
session = self.get_mapping_session(session_id)
logger.info("[session %s] Registering localpart %s", session_id, localpart)
# update the session with the user's choices
session.chosen_localpart = localpart
session.use_display_name = use_display_name
emails_from_idp = set(session.emails)
filtered_emails = set() # type: Set[str]
# we iterate through the list rather than just building a set conjunction, so
# that we can log attempts to use unknown addresses
for email in emails_to_use:
if email in emails_from_idp:
filtered_emails.add(email)
else:
logger.warning(
"[session %s] ignoring user request to use unknown email address %r",
session_id,
email,
)
session.emails_to_use = filtered_emails
# we may now need to collect consent from the user, in which case, redirect
# to the consent-extraction-unit
if self._consent_at_registration:
redirect_url = b"/_synapse/client/new_user_consent"
# otherwise, redirect to the completion page
else:
redirect_url = b"/_synapse/client/sso_register"
respond_with_redirect(request, redirect_url)
async def handle_terms_accepted(
self, request: Request, session_id: str, terms_version: str
):
"""Handle a request to the new-user 'consent' endpoint
Will serve an HTTP response to the request.
Args:
request: HTTP request
session_id: ID of the username mapping session, extracted from a cookie
terms_version: the version of the terms which the user viewed and consented
to
"""
logger.info(
"[session %s] User consented to terms version %s",
session_id,
terms_version,
)
session = self.get_mapping_session(session_id)
session.terms_accepted_version = terms_version
# we're done; now we can register the user
respond_with_redirect(request, b"/_synapse/client/sso_register")
async def register_sso_user(self, request: Request, session_id: str) -> None:
"""Called once we have all the info we need to register a new user.
Does so and serves an HTTP response
Args:
request: HTTP request
session_id: ID of the username mapping session, extracted from a cookie
"""
session = self.get_mapping_session(session_id)
logger.info(
"[session %s] Registering localpart %s",
session_id,
session.chosen_localpart,
)
attributes = UserAttributes(
localpart=localpart,
display_name=session.display_name,
emails=session.emails,
localpart=session.chosen_localpart, emails=session.emails_to_use,
)
if session.use_display_name:
attributes.display_name = session.display_name
# the following will raise a 400 error if the username has been taken in the
# meantime.
user_id = await self._register_mapped_user(
@ -702,7 +833,12 @@ class SsoHandler:
request.getClientIP(),
)
logger.info("[session %s] Registered userid %s", session_id, user_id)
logger.info(
"[session %s] Registered userid %s with attributes %s",
session_id,
user_id,
attributes,
)
# delete the mapping session and the cookie
del self._username_mapping_sessions[session_id]
@ -715,11 +851,21 @@ class SsoHandler:
path=b"/",
)
auth_result = {}
if session.terms_accepted_version:
# TODO: make this less awful.
auth_result[LoginType.TERMS] = True
await self._registration_handler.post_registration_actions(
user_id, auth_result, access_token=None
)
await self._auth_handler.complete_sso_login(
user_id,
request,
session.client_redirect_url,
session.extra_login_attributes,
new_user=True,
)
def _expire_old_sessions(self):
@ -733,3 +879,14 @@ class SsoHandler:
for session_id in to_expire:
logger.info("Expiring mapping session %s", session_id)
del self._username_mapping_sessions[session_id]
def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
"""Extract the session ID from the cookie
Raises a SynapseError if the cookie isn't found
"""
session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME)
if not session_id:
raise SynapseError(code=400, msg="missing session_id")
return session_id.decode("ascii", errors="replace")

View file

@ -14,15 +14,25 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
class StateDeltasHandler:
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
async def _get_key_change(self, prev_event_id, event_id, key_name, public_value):
async def _get_key_change(
self,
prev_event_id: Optional[str],
event_id: Optional[str],
key_name: str,
public_value: str,
) -> Optional[bool]:
"""Given two events check if the `key_name` field in content changed
from not matching `public_value` to doing so.

View file

@ -12,13 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from collections import Counter
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple
from typing_extensions import Counter as CounterType
from synapse.api.constants import EventTypes, Membership
from synapse.metrics import event_processing_positions
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@ -31,7 +37,7 @@ class StatsHandler:
Heavily derived from UserDirectoryHandler
"""
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
@ -44,7 +50,7 @@ class StatsHandler:
self.stats_enabled = hs.config.stats_enabled
# The current position in the current_state_delta stream
self.pos = None
self.pos = None # type: Optional[int]
# Guard to ensure we only process deltas one at a time
self._is_processing = False
@ -56,7 +62,7 @@ class StatsHandler:
# we start populating stats
self.clock.call_later(0, self.notify_new_event)
def notify_new_event(self):
def notify_new_event(self) -> None:
"""Called when there may be more deltas to process
"""
if not self.stats_enabled or self._is_processing:
@ -72,7 +78,7 @@ class StatsHandler:
run_as_background_process("stats.notify_new_event", process)
async def _unsafe_process(self):
async def _unsafe_process(self) -> None:
# If self.pos is None then means we haven't fetched it from DB
if self.pos is None:
self.pos = await self.store.get_stats_positions()
@ -110,10 +116,10 @@ class StatsHandler:
)
for room_id, fields in room_count.items():
room_deltas.setdefault(room_id, {}).update(fields)
room_deltas.setdefault(room_id, Counter()).update(fields)
for user_id, fields in user_count.items():
user_deltas.setdefault(user_id, {}).update(fields)
user_deltas.setdefault(user_id, Counter()).update(fields)
logger.debug("room_deltas: %s", room_deltas)
logger.debug("user_deltas: %s", user_deltas)
@ -131,19 +137,20 @@ class StatsHandler:
self.pos = max_pos
async def _handle_deltas(self, deltas):
async def _handle_deltas(
self, deltas: Iterable[JsonDict]
) -> Tuple[Dict[str, CounterType[str]], Dict[str, CounterType[str]]]:
"""Called with the state deltas to process
Returns:
tuple[dict[str, Counter], dict[str, counter]]
Two dicts: the room deltas and the user deltas,
mapping from room/user ID to changes in the various fields.
"""
room_to_stats_deltas = {}
user_to_stats_deltas = {}
room_to_stats_deltas = {} # type: Dict[str, CounterType[str]]
user_to_stats_deltas = {} # type: Dict[str, CounterType[str]]
room_to_state_updates = {}
room_to_state_updates = {} # type: Dict[str, Dict[str, Any]]
for delta in deltas:
typ = delta["type"]
@ -173,7 +180,7 @@ class StatsHandler:
)
continue
event_content = {}
event_content = {} # type: JsonDict
sender = None
if event_id is not None:
@ -257,13 +264,13 @@ class StatsHandler:
)
if has_changed_joinedness:
delta = +1 if membership == Membership.JOIN else -1
membership_delta = +1 if membership == Membership.JOIN else -1
user_to_stats_deltas.setdefault(user_id, Counter())[
"joined_rooms"
] += delta
] += membership_delta
room_stats_delta["local_users_in_room"] += delta
room_stats_delta["local_users_in_room"] += membership_delta
elif typ == EventTypes.Create:
room_state["is_federatable"] = (

View file

@ -15,13 +15,13 @@
import logging
import random
from collections import namedtuple
from typing import TYPE_CHECKING, List, Set, Tuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api.errors import AuthError, ShadowBanError, SynapseError
from synapse.appservice import ApplicationService
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.streams import TypingStream
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
@ -65,17 +65,17 @@ class FollowerTypingHandler:
)
# map room IDs to serial numbers
self._room_serials = {}
self._room_serials = {} # type: Dict[str, int]
# map room IDs to sets of users currently typing
self._room_typing = {}
self._room_typing = {} # type: Dict[str, Set[str]]
self._member_last_federation_poke = {}
self._member_last_federation_poke = {} # type: Dict[RoomMember, int]
self.wheel_timer = WheelTimer(bucket_size=5000)
self._latest_room_serial = 0
self.clock.looping_call(self._handle_timeouts, 5000)
def _reset(self):
def _reset(self) -> None:
"""Reset the typing handler's data caches.
"""
# map room IDs to serial numbers
@ -86,7 +86,7 @@ class FollowerTypingHandler:
self._member_last_federation_poke = {}
self.wheel_timer = WheelTimer(bucket_size=5000)
def _handle_timeouts(self):
def _handle_timeouts(self) -> None:
logger.debug("Checking for typing timeouts")
now = self.clock.time_msec()
@ -96,7 +96,7 @@ class FollowerTypingHandler:
for member in members:
self._handle_timeout_for_member(now, member)
def _handle_timeout_for_member(self, now: int, member: RoomMember):
def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
if not self.is_typing(member):
# Nothing to do if they're no longer typing
return
@ -114,10 +114,10 @@ class FollowerTypingHandler:
# each person typing.
self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
def is_typing(self, member):
def is_typing(self, member: RoomMember) -> bool:
return member.user_id in self._room_typing.get(member.room_id, [])
async def _push_remote(self, member, typing):
async def _push_remote(self, member: RoomMember, typing: bool) -> None:
if not self.federation:
return
@ -148,7 +148,7 @@ class FollowerTypingHandler:
def process_replication_rows(
self, token: int, rows: List[TypingStream.TypingStreamRow]
):
) -> None:
"""Should be called whenever we receive updates for typing stream.
"""
@ -178,7 +178,7 @@ class FollowerTypingHandler:
async def _send_changes_in_typing_to_remotes(
self, room_id: str, prev_typing: Set[str], now_typing: Set[str]
):
) -> None:
"""Process a change in typing of a room from replication, sending EDUs
for any local users.
"""
@ -194,12 +194,12 @@ class FollowerTypingHandler:
if self.is_mine_id(user_id):
await self._push_remote(RoomMember(room_id, user_id), False)
def get_current_token(self):
def get_current_token(self) -> int:
return self._latest_room_serial
class TypingWriterHandler(FollowerTypingHandler):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
assert hs.config.worker.writers.typing == hs.get_instance_name()
@ -213,14 +213,15 @@ class TypingWriterHandler(FollowerTypingHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room)
self._member_typing_until = {} # clock time we expect to stop
# clock time we expect to stop
self._member_typing_until = {} # type: Dict[RoomMember, int]
# caches which room_ids changed at which serials
self._typing_stream_change_cache = StreamChangeCache(
"TypingStreamChangeCache", self._latest_room_serial
)
def _handle_timeout_for_member(self, now: int, member: RoomMember):
def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
super()._handle_timeout_for_member(now, member)
if not self.is_typing(member):
@ -233,7 +234,9 @@ class TypingWriterHandler(FollowerTypingHandler):
self._stopped_typing(member)
return
async def started_typing(self, target_user, requester, room_id, timeout):
async def started_typing(
self, target_user: UserID, requester: Requester, room_id: str, timeout: int
) -> None:
target_user_id = target_user.to_string()
auth_user_id = requester.user.to_string()
@ -263,11 +266,13 @@ class TypingWriterHandler(FollowerTypingHandler):
if was_present:
# No point sending another notification
return None
return
self._push_update(member=member, typing=True)
async def stopped_typing(self, target_user, requester, room_id):
async def stopped_typing(
self, target_user: UserID, requester: Requester, room_id: str
) -> None:
target_user_id = target_user.to_string()
auth_user_id = requester.user.to_string()
@ -290,23 +295,23 @@ class TypingWriterHandler(FollowerTypingHandler):
self._stopped_typing(member)
def user_left_room(self, user, room_id):
def user_left_room(self, user: UserID, room_id: str) -> None:
user_id = user.to_string()
if self.is_mine_id(user_id):
member = RoomMember(room_id=room_id, user_id=user_id)
self._stopped_typing(member)
def _stopped_typing(self, member):
def _stopped_typing(self, member: RoomMember) -> None:
if member.user_id not in self._room_typing.get(member.room_id, set()):
# No point
return None
return
self._member_typing_until.pop(member, None)
self._member_last_federation_poke.pop(member, None)
self._push_update(member=member, typing=False)
def _push_update(self, member, typing):
def _push_update(self, member: RoomMember, typing: bool) -> None:
if self.hs.is_mine_id(member.user_id):
# Only send updates for changes to our own users.
run_as_background_process(
@ -315,7 +320,7 @@ class TypingWriterHandler(FollowerTypingHandler):
self._push_update_local(member=member, typing=typing)
async def _recv_edu(self, origin, content):
async def _recv_edu(self, origin: str, content: JsonDict) -> None:
room_id = content["room_id"]
user_id = content["user_id"]
@ -340,7 +345,7 @@ class TypingWriterHandler(FollowerTypingHandler):
self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT)
self._push_update_local(member=member, typing=content["typing"])
def _push_update_local(self, member, typing):
def _push_update_local(self, member: RoomMember, typing: bool) -> None:
room_set = self._room_typing.setdefault(member.room_id, set())
if typing:
room_set.add(member.user_id)
@ -386,7 +391,7 @@ class TypingWriterHandler(FollowerTypingHandler):
changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
last_id
)
) # type: Optional[Iterable[str]]
if changed_rooms is None:
changed_rooms = self._room_serials
@ -412,13 +417,13 @@ class TypingWriterHandler(FollowerTypingHandler):
def process_replication_rows(
self, token: int, rows: List[TypingStream.TypingStreamRow]
):
) -> None:
# The writing process should never get updates from replication.
raise Exception("Typing writer instance got typing info over replication")
class TypingNotificationEventSource:
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.clock = hs.get_clock()
# We can't call get_typing_handler here because there's a cycle:
@ -427,7 +432,7 @@ class TypingNotificationEventSource:
#
self.get_typing_handler = hs.get_typing_handler
def _make_event_for(self, room_id):
def _make_event_for(self, room_id: str) -> JsonDict:
typing = self.get_typing_handler()._room_typing[room_id]
return {
"type": "m.typing",
@ -462,7 +467,9 @@ class TypingNotificationEventSource:
return (events, handler._latest_room_serial)
async def get_new_events(self, from_key, room_ids, **kwargs):
async def get_new_events(
self, from_key: int, room_ids: Iterable[str], **kwargs
) -> Tuple[List[JsonDict], int]:
with Measure(self.clock, "typing.get_new_events"):
from_key = int(from_key)
handler = self.get_typing_handler()
@ -478,5 +485,5 @@ class TypingNotificationEventSource:
return (events, handler._latest_room_serial)
def get_current_key(self):
def get_current_key(self) -> int:
return self.get_typing_handler()._latest_room_serial

View file

@ -145,10 +145,6 @@ class UserDirectoryHandler(StateDeltasHandler):
if self.pos is None:
self.pos = await self.store.get_user_directory_stream_pos()
# If still None then the initial background update hasn't happened yet
if self.pos is None:
return None
# Loop round handling deltas until we're up to date
while True:
with Measure(self.clock, "user_dir_delta"):
@ -233,6 +229,11 @@ class UserDirectoryHandler(StateDeltasHandler):
if change: # The user joined
event = await self.store.get_event(event_id, allow_none=True)
# It isn't expected for this event to not exist, but we
# don't want the entire background process to break.
if event is None:
continue
profile = ProfileInfo(
avatar_url=event.content.get("avatar_url"),
display_name=event.content.get("displayname"),