mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-08-18 03:20:17 -04:00
Merge remote-tracking branch 'upstream/release-v1.27.0'
This commit is contained in:
commit
7f7fb9b566
146 changed files with 4893 additions and 1484 deletions
|
@ -14,6 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import twisted
|
||||
import twisted.internet.error
|
||||
|
@ -22,6 +23,9 @@ from twisted.web.resource import Resource
|
|||
|
||||
from synapse.app import check_bind_error
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ACME_REGISTER_FAIL_ERROR = """
|
||||
|
@ -35,12 +39,12 @@ solutions, please read https://github.com/matrix-org/synapse/blob/master/docs/AC
|
|||
|
||||
|
||||
class AcmeHandler:
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.reactor = hs.get_reactor()
|
||||
self._acme_domain = hs.config.acme_domain
|
||||
|
||||
async def start_listening(self):
|
||||
async def start_listening(self) -> None:
|
||||
from synapse.handlers import acme_issuing_service
|
||||
|
||||
# Configure logging for txacme, if you need to debug
|
||||
|
@ -85,7 +89,7 @@ class AcmeHandler:
|
|||
logger.error(ACME_REGISTER_FAIL_ERROR)
|
||||
raise
|
||||
|
||||
async def provision_certificate(self):
|
||||
async def provision_certificate(self) -> None:
|
||||
|
||||
logger.warning("Reprovisioning %s", self._acme_domain)
|
||||
|
||||
|
@ -110,5 +114,3 @@ class AcmeHandler:
|
|||
except Exception:
|
||||
logger.exception("Failed saving!")
|
||||
raise
|
||||
|
||||
return True
|
||||
|
|
|
@ -22,8 +22,10 @@ only need (and may only have available) if we are doing ACME, so is designed to
|
|||
imported conditionally.
|
||||
"""
|
||||
import logging
|
||||
from typing import Dict, Iterable, List
|
||||
|
||||
import attr
|
||||
import pem
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from josepy import JWKRSA
|
||||
|
@ -36,20 +38,27 @@ from txacme.util import generate_private_key
|
|||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.interfaces import IReactorTCP
|
||||
from twisted.python.filepath import FilePath
|
||||
from twisted.python.url import URL
|
||||
from twisted.web.resource import IResource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_issuing_service(reactor, acme_url, account_key_file, well_known_resource):
|
||||
def create_issuing_service(
|
||||
reactor: IReactorTCP,
|
||||
acme_url: str,
|
||||
account_key_file: str,
|
||||
well_known_resource: IResource,
|
||||
) -> AcmeIssuingService:
|
||||
"""Create an ACME issuing service, and attach it to a web Resource
|
||||
|
||||
Args:
|
||||
reactor: twisted reactor
|
||||
acme_url (str): URL to use to request certificates
|
||||
account_key_file (str): where to store the account key
|
||||
well_known_resource (twisted.web.IResource): web resource for .well-known.
|
||||
acme_url: URL to use to request certificates
|
||||
account_key_file: where to store the account key
|
||||
well_known_resource: web resource for .well-known.
|
||||
we will attach a child resource for "acme-challenge".
|
||||
|
||||
Returns:
|
||||
|
@ -83,18 +92,20 @@ class ErsatzStore:
|
|||
A store that only stores in memory.
|
||||
"""
|
||||
|
||||
certs = attr.ib(default=attr.Factory(dict))
|
||||
certs = attr.ib(type=Dict[bytes, List[bytes]], default=attr.Factory(dict))
|
||||
|
||||
def store(self, server_name, pem_objects):
|
||||
def store(
|
||||
self, server_name: bytes, pem_objects: Iterable[pem.AbstractPEMObject]
|
||||
) -> defer.Deferred:
|
||||
self.certs[server_name] = [o.as_bytes() for o in pem_objects]
|
||||
return defer.succeed(None)
|
||||
|
||||
|
||||
def load_or_create_client_key(key_file):
|
||||
def load_or_create_client_key(key_file: str) -> JWKRSA:
|
||||
"""Load the ACME account key from a file, creating it if it does not exist.
|
||||
|
||||
Args:
|
||||
key_file (str): name of the file to use as the account key
|
||||
key_file: name of the file to use as the account key
|
||||
"""
|
||||
# this is based on txacme.endpoint.load_or_create_client_key, but doesn't
|
||||
# hardcode the 'client.key' filename
|
||||
|
|
|
@ -61,6 +61,7 @@ from synapse.http.site import SynapseRequest
|
|||
from synapse.logging.context import defer_to_thread
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.storage.roommember import ProfileInfo
|
||||
from synapse.types import JsonDict, Requester, UserID
|
||||
from synapse.util import stringutils as stringutils
|
||||
from synapse.util.async_helpers import maybe_awaitable
|
||||
|
@ -567,16 +568,6 @@ class AuthHandler(BaseHandler):
|
|||
session.session_id, login_type, result
|
||||
)
|
||||
except LoginError as e:
|
||||
if login_type == LoginType.EMAIL_IDENTITY:
|
||||
# riot used to have a bug where it would request a new
|
||||
# validation token (thus sending a new email) each time it
|
||||
# got a 401 with a 'flows' field.
|
||||
# (https://github.com/vector-im/vector-web/issues/2447).
|
||||
#
|
||||
# Grandfather in the old behaviour for now to avoid
|
||||
# breaking old riot deployments.
|
||||
raise
|
||||
|
||||
# this step failed. Merge the error dict into the response
|
||||
# so that the client can have another go.
|
||||
errordict = e.error_dict()
|
||||
|
@ -1387,7 +1378,9 @@ class AuthHandler(BaseHandler):
|
|||
)
|
||||
|
||||
return self._sso_auth_confirm_template.render(
|
||||
description=session.description, redirect_url=redirect_url,
|
||||
description=session.description,
|
||||
redirect_url=redirect_url,
|
||||
idp=sso_auth_provider,
|
||||
)
|
||||
|
||||
async def complete_sso_login(
|
||||
|
@ -1396,6 +1389,7 @@ class AuthHandler(BaseHandler):
|
|||
request: Request,
|
||||
client_redirect_url: str,
|
||||
extra_attributes: Optional[JsonDict] = None,
|
||||
new_user: bool = False,
|
||||
):
|
||||
"""Having figured out a mxid for this user, complete the HTTP request
|
||||
|
||||
|
@ -1406,6 +1400,8 @@ class AuthHandler(BaseHandler):
|
|||
process.
|
||||
extra_attributes: Extra attributes which will be passed to the client
|
||||
during successful login. Must be JSON serializable.
|
||||
new_user: True if we should use wording appropriate to a user who has just
|
||||
registered.
|
||||
"""
|
||||
# If the account has been deactivated, do not proceed with the login
|
||||
# flow.
|
||||
|
@ -1414,8 +1410,17 @@ class AuthHandler(BaseHandler):
|
|||
respond_with_html(request, 403, self._sso_account_deactivated_template)
|
||||
return
|
||||
|
||||
profile = await self.store.get_profileinfo(
|
||||
UserID.from_string(registered_user_id).localpart
|
||||
)
|
||||
|
||||
self._complete_sso_login(
|
||||
registered_user_id, request, client_redirect_url, extra_attributes
|
||||
registered_user_id,
|
||||
request,
|
||||
client_redirect_url,
|
||||
extra_attributes,
|
||||
new_user=new_user,
|
||||
user_profile_data=profile,
|
||||
)
|
||||
|
||||
def _complete_sso_login(
|
||||
|
@ -1424,12 +1429,18 @@ class AuthHandler(BaseHandler):
|
|||
request: Request,
|
||||
client_redirect_url: str,
|
||||
extra_attributes: Optional[JsonDict] = None,
|
||||
new_user: bool = False,
|
||||
user_profile_data: Optional[ProfileInfo] = None,
|
||||
):
|
||||
"""
|
||||
The synchronous portion of complete_sso_login.
|
||||
|
||||
This exists purely for backwards compatibility of synapse.module_api.ModuleApi.
|
||||
"""
|
||||
|
||||
if user_profile_data is None:
|
||||
user_profile_data = ProfileInfo(None, None)
|
||||
|
||||
# Store any extra attributes which will be passed in the login response.
|
||||
# Note that this is per-user so it may overwrite a previous value, this
|
||||
# is considered OK since the newest SSO attributes should be most valid.
|
||||
|
@ -1467,6 +1478,9 @@ class AuthHandler(BaseHandler):
|
|||
display_url=redirect_url_no_params,
|
||||
redirect_url=redirect_url,
|
||||
server_name=self._server_name,
|
||||
new_user=new_user,
|
||||
user_id=registered_user_id,
|
||||
user_profile=user_profile_data,
|
||||
)
|
||||
respond_with_html(request, 200, html)
|
||||
|
||||
|
|
|
@ -80,9 +80,10 @@ class CasHandler:
|
|||
# user-facing name of this auth provider
|
||||
self.idp_name = "CAS"
|
||||
|
||||
# we do not currently support icons for CAS auth, but this is required by
|
||||
# we do not currently support brands/icons for CAS auth, but this is required by
|
||||
# the SsoIdentityProvider protocol type.
|
||||
self.idp_icon = None
|
||||
self.idp_brand = None
|
||||
|
||||
self._sso_handler = hs.get_sso_handler()
|
||||
|
||||
|
@ -99,11 +100,7 @@ class CasHandler:
|
|||
Returns:
|
||||
The URL to use as a "service" parameter.
|
||||
"""
|
||||
return "%s%s?%s" % (
|
||||
self._cas_service_url,
|
||||
"/_matrix/client/r0/login/cas/ticket",
|
||||
urllib.parse.urlencode(args),
|
||||
)
|
||||
return "%s?%s" % (self._cas_service_url, urllib.parse.urlencode(args),)
|
||||
|
||||
async def _validate_ticket(
|
||||
self, ticket: str, service_args: Dict[str, str]
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
from synapse.api import errors
|
||||
from synapse.api.constants import EventTypes
|
||||
|
@ -62,7 +62,7 @@ class DeviceWorkerHandler(BaseHandler):
|
|||
self._auth_handler = hs.get_auth_handler()
|
||||
|
||||
@trace
|
||||
async def get_devices_by_user(self, user_id: str) -> List[Dict[str, Any]]:
|
||||
async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
|
||||
"""
|
||||
Retrieve the given user's devices
|
||||
|
||||
|
@ -85,7 +85,7 @@ class DeviceWorkerHandler(BaseHandler):
|
|||
return devices
|
||||
|
||||
@trace
|
||||
async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
|
||||
async def get_device(self, user_id: str, device_id: str) -> JsonDict:
|
||||
""" Retrieve the given device
|
||||
|
||||
Args:
|
||||
|
@ -598,7 +598,7 @@ class DeviceHandler(DeviceWorkerHandler):
|
|||
|
||||
|
||||
def _update_device_from_client_ips(
|
||||
device: Dict[str, Any], client_ips: Dict[Tuple[str, str], Dict[str, Any]]
|
||||
device: JsonDict, client_ips: Dict[Tuple[str, str], JsonDict]
|
||||
) -> None:
|
||||
ip = client_ips.get((device["user_id"], device["device_id"]), {})
|
||||
device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
|
||||
|
@ -946,8 +946,8 @@ class DeviceListUpdater:
|
|||
async def process_cross_signing_key_update(
|
||||
self,
|
||||
user_id: str,
|
||||
master_key: Optional[Dict[str, Any]],
|
||||
self_signing_key: Optional[Dict[str, Any]],
|
||||
master_key: Optional[JsonDict],
|
||||
self_signing_key: Optional[JsonDict],
|
||||
) -> List[str]:
|
||||
"""Process the given new master and self-signing key for the given remote user.
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
@ -31,6 +31,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
|
|||
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
|
||||
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
UserID,
|
||||
get_domain_from_id,
|
||||
get_verify_key_from_cross_signing_key,
|
||||
|
@ -40,11 +41,14 @@ from synapse.util.async_helpers import Linearizer
|
|||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class E2eKeysHandler:
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
self.federation = hs.get_federation_client()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
@ -78,7 +82,9 @@ class E2eKeysHandler:
|
|||
)
|
||||
|
||||
@trace
|
||||
async def query_devices(self, query_body, timeout, from_user_id):
|
||||
async def query_devices(
|
||||
self, query_body: JsonDict, timeout: int, from_user_id: str
|
||||
) -> JsonDict:
|
||||
""" Handle a device key query from a client
|
||||
|
||||
{
|
||||
|
@ -98,12 +104,14 @@ class E2eKeysHandler:
|
|||
}
|
||||
|
||||
Args:
|
||||
from_user_id (str): the user making the query. This is used when
|
||||
from_user_id: the user making the query. This is used when
|
||||
adding cross-signing signatures to limit what signatures users
|
||||
can see.
|
||||
"""
|
||||
|
||||
device_keys_query = query_body.get("device_keys", {})
|
||||
device_keys_query = query_body.get(
|
||||
"device_keys", {}
|
||||
) # type: Dict[str, Iterable[str]]
|
||||
|
||||
# separate users by domain.
|
||||
# make a map from domain to user_id to device_ids
|
||||
|
@ -121,7 +129,8 @@ class E2eKeysHandler:
|
|||
set_tag("remote_key_query", remote_queries)
|
||||
|
||||
# First get local devices.
|
||||
failures = {}
|
||||
# A map of destination -> failure response.
|
||||
failures = {} # type: Dict[str, JsonDict]
|
||||
results = {}
|
||||
if local_query:
|
||||
local_result = await self.query_local_devices(local_query)
|
||||
|
@ -135,9 +144,10 @@ class E2eKeysHandler:
|
|||
)
|
||||
|
||||
# Now attempt to get any remote devices from our local cache.
|
||||
remote_queries_not_in_cache = {}
|
||||
# A map of destination -> user ID -> device IDs.
|
||||
remote_queries_not_in_cache = {} # type: Dict[str, Dict[str, Iterable[str]]]
|
||||
if remote_queries:
|
||||
query_list = []
|
||||
query_list = [] # type: List[Tuple[str, Optional[str]]]
|
||||
for user_id, device_ids in remote_queries.items():
|
||||
if device_ids:
|
||||
query_list.extend((user_id, device_id) for device_id in device_ids)
|
||||
|
@ -284,15 +294,15 @@ class E2eKeysHandler:
|
|||
return ret
|
||||
|
||||
async def get_cross_signing_keys_from_cache(
|
||||
self, query, from_user_id
|
||||
self, query: Iterable[str], from_user_id: Optional[str]
|
||||
) -> Dict[str, Dict[str, dict]]:
|
||||
"""Get cross-signing keys for users from the database
|
||||
|
||||
Args:
|
||||
query (Iterable[string]) an iterable of user IDs. A dict whose keys
|
||||
query: an iterable of user IDs. A dict whose keys
|
||||
are user IDs satisfies this, so the query format used for
|
||||
query_devices can be used here.
|
||||
from_user_id (str): the user making the query. This is used when
|
||||
from_user_id: the user making the query. This is used when
|
||||
adding cross-signing signatures to limit what signatures users
|
||||
can see.
|
||||
|
||||
|
@ -315,14 +325,12 @@ class E2eKeysHandler:
|
|||
if "self_signing" in user_info:
|
||||
self_signing_keys[user_id] = user_info["self_signing"]
|
||||
|
||||
if (
|
||||
from_user_id in keys
|
||||
and keys[from_user_id] is not None
|
||||
and "user_signing" in keys[from_user_id]
|
||||
):
|
||||
# users can see other users' master and self-signing keys, but can
|
||||
# only see their own user-signing keys
|
||||
user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"]
|
||||
# users can see other users' master and self-signing keys, but can
|
||||
# only see their own user-signing keys
|
||||
if from_user_id:
|
||||
from_user_key = keys.get(from_user_id)
|
||||
if from_user_key and "user_signing" in from_user_key:
|
||||
user_signing_keys[from_user_id] = from_user_key["user_signing"]
|
||||
|
||||
return {
|
||||
"master_keys": master_keys,
|
||||
|
@ -344,9 +352,9 @@ class E2eKeysHandler:
|
|||
A map from user_id -> device_id -> device details
|
||||
"""
|
||||
set_tag("local_query", query)
|
||||
local_query = []
|
||||
local_query = [] # type: List[Tuple[str, Optional[str]]]
|
||||
|
||||
result_dict = {}
|
||||
result_dict = {} # type: Dict[str, Dict[str, dict]]
|
||||
for user_id, device_ids in query.items():
|
||||
# we use UserID.from_string to catch invalid user ids
|
||||
if not self.is_mine(UserID.from_string(user_id)):
|
||||
|
@ -380,10 +388,14 @@ class E2eKeysHandler:
|
|||
log_kv(results)
|
||||
return result_dict
|
||||
|
||||
async def on_federation_query_client_keys(self, query_body):
|
||||
async def on_federation_query_client_keys(
|
||||
self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
|
||||
) -> JsonDict:
|
||||
""" Handle a device key query from a federated server
|
||||
"""
|
||||
device_keys_query = query_body.get("device_keys", {})
|
||||
device_keys_query = query_body.get(
|
||||
"device_keys", {}
|
||||
) # type: Dict[str, Optional[List[str]]]
|
||||
res = await self.query_local_devices(device_keys_query)
|
||||
ret = {"device_keys": res}
|
||||
|
||||
|
@ -397,31 +409,34 @@ class E2eKeysHandler:
|
|||
return ret
|
||||
|
||||
@trace
|
||||
async def claim_one_time_keys(self, query, timeout):
|
||||
local_query = []
|
||||
remote_queries = {}
|
||||
async def claim_one_time_keys(
|
||||
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
|
||||
) -> JsonDict:
|
||||
local_query = [] # type: List[Tuple[str, str, str]]
|
||||
remote_queries = {} # type: Dict[str, Dict[str, Dict[str, str]]]
|
||||
|
||||
for user_id, device_keys in query.get("one_time_keys", {}).items():
|
||||
for user_id, one_time_keys in query.get("one_time_keys", {}).items():
|
||||
# we use UserID.from_string to catch invalid user ids
|
||||
if self.is_mine(UserID.from_string(user_id)):
|
||||
for device_id, algorithm in device_keys.items():
|
||||
for device_id, algorithm in one_time_keys.items():
|
||||
local_query.append((user_id, device_id, algorithm))
|
||||
else:
|
||||
domain = get_domain_from_id(user_id)
|
||||
remote_queries.setdefault(domain, {})[user_id] = device_keys
|
||||
remote_queries.setdefault(domain, {})[user_id] = one_time_keys
|
||||
|
||||
set_tag("local_key_query", local_query)
|
||||
set_tag("remote_key_query", remote_queries)
|
||||
|
||||
results = await self.store.claim_e2e_one_time_keys(local_query)
|
||||
|
||||
json_result = {}
|
||||
failures = {}
|
||||
# A map of user ID -> device ID -> key ID -> key.
|
||||
json_result = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
|
||||
failures = {} # type: Dict[str, JsonDict]
|
||||
for user_id, device_keys in results.items():
|
||||
for device_id, keys in device_keys.items():
|
||||
for key_id, json_bytes in keys.items():
|
||||
for key_id, json_str in keys.items():
|
||||
json_result.setdefault(user_id, {})[device_id] = {
|
||||
key_id: json_decoder.decode(json_bytes)
|
||||
key_id: json_decoder.decode(json_str)
|
||||
}
|
||||
|
||||
@trace
|
||||
|
@ -468,7 +483,9 @@ class E2eKeysHandler:
|
|||
return {"one_time_keys": json_result, "failures": failures}
|
||||
|
||||
@tag_args
|
||||
async def upload_keys_for_user(self, user_id, device_id, keys):
|
||||
async def upload_keys_for_user(
|
||||
self, user_id: str, device_id: str, keys: JsonDict
|
||||
) -> JsonDict:
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
|
@ -543,8 +560,8 @@ class E2eKeysHandler:
|
|||
return {"one_time_key_counts": result}
|
||||
|
||||
async def _upload_one_time_keys_for_user(
|
||||
self, user_id, device_id, time_now, one_time_keys
|
||||
):
|
||||
self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict
|
||||
) -> None:
|
||||
logger.info(
|
||||
"Adding one_time_keys %r for device %r for user %r at %d",
|
||||
one_time_keys.keys(),
|
||||
|
@ -585,12 +602,14 @@ class E2eKeysHandler:
|
|||
log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
|
||||
await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
|
||||
|
||||
async def upload_signing_keys_for_user(self, user_id, keys):
|
||||
async def upload_signing_keys_for_user(
|
||||
self, user_id: str, keys: JsonDict
|
||||
) -> JsonDict:
|
||||
"""Upload signing keys for cross-signing
|
||||
|
||||
Args:
|
||||
user_id (string): the user uploading the keys
|
||||
keys (dict[string, dict]): the signing keys
|
||||
user_id: the user uploading the keys
|
||||
keys: the signing keys
|
||||
"""
|
||||
|
||||
# if a master key is uploaded, then check it. Otherwise, load the
|
||||
|
@ -667,16 +686,17 @@ class E2eKeysHandler:
|
|||
|
||||
return {}
|
||||
|
||||
async def upload_signatures_for_device_keys(self, user_id, signatures):
|
||||
async def upload_signatures_for_device_keys(
|
||||
self, user_id: str, signatures: JsonDict
|
||||
) -> JsonDict:
|
||||
"""Upload device signatures for cross-signing
|
||||
|
||||
Args:
|
||||
user_id (string): the user uploading the signatures
|
||||
signatures (dict[string, dict[string, dict]]): map of users to
|
||||
devices to signed keys. This is the submission from the user; an
|
||||
exception will be raised if it is malformed.
|
||||
user_id: the user uploading the signatures
|
||||
signatures: map of users to devices to signed keys. This is the submission
|
||||
from the user; an exception will be raised if it is malformed.
|
||||
Returns:
|
||||
dict: response to be sent back to the client. The response will have
|
||||
The response to be sent back to the client. The response will have
|
||||
a "failures" key, which will be a dict mapping users to devices
|
||||
to errors for the signatures that failed.
|
||||
Raises:
|
||||
|
@ -719,7 +739,9 @@ class E2eKeysHandler:
|
|||
|
||||
return {"failures": failures}
|
||||
|
||||
async def _process_self_signatures(self, user_id, signatures):
|
||||
async def _process_self_signatures(
|
||||
self, user_id: str, signatures: JsonDict
|
||||
) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]:
|
||||
"""Process uploaded signatures of the user's own keys.
|
||||
|
||||
Signatures of the user's own keys from this API come in two forms:
|
||||
|
@ -731,15 +753,14 @@ class E2eKeysHandler:
|
|||
signatures (dict[string, dict]): map of devices to signed keys
|
||||
|
||||
Returns:
|
||||
(list[SignatureListItem], dict[string, dict[string, dict]]):
|
||||
a list of signatures to store, and a map of users to devices to failure
|
||||
reasons
|
||||
A tuple of a list of signatures to store, and a map of users to
|
||||
devices to failure reasons
|
||||
|
||||
Raises:
|
||||
SynapseError: if the input is malformed
|
||||
"""
|
||||
signature_list = []
|
||||
failures = {}
|
||||
signature_list = [] # type: List[SignatureListItem]
|
||||
failures = {} # type: Dict[str, Dict[str, JsonDict]]
|
||||
if not signatures:
|
||||
return signature_list, failures
|
||||
|
||||
|
@ -834,19 +855,24 @@ class E2eKeysHandler:
|
|||
return signature_list, failures
|
||||
|
||||
def _check_master_key_signature(
|
||||
self, user_id, master_key_id, signed_master_key, stored_master_key, devices
|
||||
):
|
||||
self,
|
||||
user_id: str,
|
||||
master_key_id: str,
|
||||
signed_master_key: JsonDict,
|
||||
stored_master_key: JsonDict,
|
||||
devices: Dict[str, Dict[str, JsonDict]],
|
||||
) -> List["SignatureListItem"]:
|
||||
"""Check signatures of a user's master key made by their devices.
|
||||
|
||||
Args:
|
||||
user_id (string): the user whose master key is being checked
|
||||
master_key_id (string): the ID of the user's master key
|
||||
signed_master_key (dict): the user's signed master key that was uploaded
|
||||
stored_master_key (dict): our previously-stored copy of the user's master key
|
||||
devices (iterable(dict)): the user's devices
|
||||
user_id: the user whose master key is being checked
|
||||
master_key_id: the ID of the user's master key
|
||||
signed_master_key: the user's signed master key that was uploaded
|
||||
stored_master_key: our previously-stored copy of the user's master key
|
||||
devices: the user's devices
|
||||
|
||||
Returns:
|
||||
list[SignatureListItem]: a list of signatures to store
|
||||
A list of signatures to store
|
||||
|
||||
Raises:
|
||||
SynapseError: if a signature is invalid
|
||||
|
@ -877,25 +903,26 @@ class E2eKeysHandler:
|
|||
|
||||
return master_key_signature_list
|
||||
|
||||
async def _process_other_signatures(self, user_id, signatures):
|
||||
async def _process_other_signatures(
|
||||
self, user_id: str, signatures: Dict[str, dict]
|
||||
) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]:
|
||||
"""Process uploaded signatures of other users' keys. These will be the
|
||||
target user's master keys, signed by the uploading user's user-signing
|
||||
key.
|
||||
|
||||
Args:
|
||||
user_id (string): the user uploading the keys
|
||||
signatures (dict[string, dict]): map of users to devices to signed keys
|
||||
user_id: the user uploading the keys
|
||||
signatures: map of users to devices to signed keys
|
||||
|
||||
Returns:
|
||||
(list[SignatureListItem], dict[string, dict[string, dict]]):
|
||||
a list of signatures to store, and a map of users to devices to failure
|
||||
A list of signatures to store, and a map of users to devices to failure
|
||||
reasons
|
||||
|
||||
Raises:
|
||||
SynapseError: if the input is malformed
|
||||
"""
|
||||
signature_list = []
|
||||
failures = {}
|
||||
signature_list = [] # type: List[SignatureListItem]
|
||||
failures = {} # type: Dict[str, Dict[str, JsonDict]]
|
||||
if not signatures:
|
||||
return signature_list, failures
|
||||
|
||||
|
@ -983,7 +1010,7 @@ class E2eKeysHandler:
|
|||
|
||||
async def _get_e2e_cross_signing_verify_key(
|
||||
self, user_id: str, key_type: str, from_user_id: str = None
|
||||
):
|
||||
) -> Tuple[JsonDict, str, VerifyKey]:
|
||||
"""Fetch locally or remotely query for a cross-signing public key.
|
||||
|
||||
First, attempt to fetch the cross-signing public key from storage.
|
||||
|
@ -997,8 +1024,7 @@ class E2eKeysHandler:
|
|||
This affects what signatures are fetched.
|
||||
|
||||
Returns:
|
||||
dict, str, VerifyKey: the raw key data, the key ID, and the
|
||||
signedjson verify key
|
||||
The raw key data, the key ID, and the signedjson verify key
|
||||
|
||||
Raises:
|
||||
NotFoundError: if the key is not found
|
||||
|
@ -1135,16 +1161,18 @@ class E2eKeysHandler:
|
|||
return desired_key, desired_key_id, desired_verify_key
|
||||
|
||||
|
||||
def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
|
||||
def _check_cross_signing_key(
|
||||
key: JsonDict, user_id: str, key_type: str, signing_key: Optional[VerifyKey] = None
|
||||
) -> None:
|
||||
"""Check a cross-signing key uploaded by a user. Performs some basic sanity
|
||||
checking, and ensures that it is signed, if a signature is required.
|
||||
|
||||
Args:
|
||||
key (dict): the key data to verify
|
||||
user_id (str): the user whose key is being checked
|
||||
key_type (str): the type of key that the key should be
|
||||
signing_key (VerifyKey): (optional) the signing key that the key should
|
||||
be signed with. If omitted, signatures will not be checked.
|
||||
key: the key data to verify
|
||||
user_id: the user whose key is being checked
|
||||
key_type: the type of key that the key should be
|
||||
signing_key: the signing key that the key should be signed with. If
|
||||
omitted, signatures will not be checked.
|
||||
"""
|
||||
if (
|
||||
key.get("user_id") != user_id
|
||||
|
@ -1162,16 +1190,21 @@ def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
|
|||
)
|
||||
|
||||
|
||||
def _check_device_signature(user_id, verify_key, signed_device, stored_device):
|
||||
def _check_device_signature(
|
||||
user_id: str,
|
||||
verify_key: VerifyKey,
|
||||
signed_device: JsonDict,
|
||||
stored_device: JsonDict,
|
||||
) -> None:
|
||||
"""Check that a signature on a device or cross-signing key is correct and
|
||||
matches the copy of the device/key that we have stored. Throws an
|
||||
exception if an error is detected.
|
||||
|
||||
Args:
|
||||
user_id (str): the user ID whose signature is being checked
|
||||
verify_key (VerifyKey): the key to verify the device with
|
||||
signed_device (dict): the uploaded signed device data
|
||||
stored_device (dict): our previously stored copy of the device
|
||||
user_id: the user ID whose signature is being checked
|
||||
verify_key: the key to verify the device with
|
||||
signed_device: the uploaded signed device data
|
||||
stored_device: our previously stored copy of the device
|
||||
|
||||
Raises:
|
||||
SynapseError: if the signature was invalid or the sent device is not the
|
||||
|
@ -1201,7 +1234,7 @@ def _check_device_signature(user_id, verify_key, signed_device, stored_device):
|
|||
raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE)
|
||||
|
||||
|
||||
def _exception_to_failure(e):
|
||||
def _exception_to_failure(e: Exception) -> JsonDict:
|
||||
if isinstance(e, SynapseError):
|
||||
return {"status": e.code, "errcode": e.errcode, "message": str(e)}
|
||||
|
||||
|
@ -1218,7 +1251,7 @@ def _exception_to_failure(e):
|
|||
return {"status": 503, "message": str(e)}
|
||||
|
||||
|
||||
def _one_time_keys_match(old_key_json, new_key):
|
||||
def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool:
|
||||
old_key = json_decoder.decode(old_key_json)
|
||||
|
||||
# if either is a string rather than an object, they must match exactly
|
||||
|
@ -1239,16 +1272,16 @@ class SignatureListItem:
|
|||
"""An item in the signature list as used by upload_signatures_for_device_keys.
|
||||
"""
|
||||
|
||||
signing_key_id = attr.ib()
|
||||
target_user_id = attr.ib()
|
||||
target_device_id = attr.ib()
|
||||
signature = attr.ib()
|
||||
signing_key_id = attr.ib(type=str)
|
||||
target_user_id = attr.ib(type=str)
|
||||
target_device_id = attr.ib(type=str)
|
||||
signature = attr.ib(type=JsonDict)
|
||||
|
||||
|
||||
class SigningKeyEduUpdater:
|
||||
"""Handles incoming signing key updates from federation and updates the DB"""
|
||||
|
||||
def __init__(self, hs, e2e_keys_handler):
|
||||
def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
|
||||
self.store = hs.get_datastore()
|
||||
self.federation = hs.get_federation_client()
|
||||
self.clock = hs.get_clock()
|
||||
|
@ -1257,7 +1290,7 @@ class SigningKeyEduUpdater:
|
|||
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
|
||||
|
||||
# user_id -> list of updates waiting to be handled.
|
||||
self._pending_updates = {}
|
||||
self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]]
|
||||
|
||||
# Recently seen stream ids. We don't bother keeping these in the DB,
|
||||
# but they're useful to have them about to reduce the number of spurious
|
||||
|
@ -1270,13 +1303,15 @@ class SigningKeyEduUpdater:
|
|||
iterable=True,
|
||||
)
|
||||
|
||||
async def incoming_signing_key_update(self, origin, edu_content):
|
||||
async def incoming_signing_key_update(
|
||||
self, origin: str, edu_content: JsonDict
|
||||
) -> None:
|
||||
"""Called on incoming signing key update from federation. Responsible for
|
||||
parsing the EDU and adding to pending updates list.
|
||||
|
||||
Args:
|
||||
origin (string): the server that sent the EDU
|
||||
edu_content (dict): the contents of the EDU
|
||||
origin: the server that sent the EDU
|
||||
edu_content: the contents of the EDU
|
||||
"""
|
||||
|
||||
user_id = edu_content.pop("user_id")
|
||||
|
@ -1299,11 +1334,11 @@ class SigningKeyEduUpdater:
|
|||
|
||||
await self._handle_signing_key_updates(user_id)
|
||||
|
||||
async def _handle_signing_key_updates(self, user_id):
|
||||
async def _handle_signing_key_updates(self, user_id: str) -> None:
|
||||
"""Actually handle pending updates.
|
||||
|
||||
Args:
|
||||
user_id (string): the user whose updates we are processing
|
||||
user_id: the user whose updates we are processing
|
||||
"""
|
||||
|
||||
device_handler = self.e2e_keys_handler.device_handler
|
||||
|
@ -1315,7 +1350,7 @@ class SigningKeyEduUpdater:
|
|||
# This can happen since we batch updates
|
||||
return
|
||||
|
||||
device_ids = []
|
||||
device_ids = [] # type: List[str]
|
||||
|
||||
logger.info("pending updates: %r", pending_updates)
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from synapse.api.errors import (
|
||||
Codes,
|
||||
|
@ -24,8 +25,12 @@ from synapse.api.errors import (
|
|||
SynapseError,
|
||||
)
|
||||
from synapse.logging.opentracing import log_kv, trace
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -37,7 +42,7 @@ class E2eRoomKeysHandler:
|
|||
The actual payload of the encrypted keys is completely opaque to the handler.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
# Used to lock whenever a client is uploading key data. This prevents collisions
|
||||
|
@ -48,21 +53,27 @@ class E2eRoomKeysHandler:
|
|||
self._upload_linearizer = Linearizer("upload_room_keys_lock")
|
||||
|
||||
@trace
|
||||
async def get_room_keys(self, user_id, version, room_id=None, session_id=None):
|
||||
async def get_room_keys(
|
||||
self,
|
||||
user_id: str,
|
||||
version: str,
|
||||
room_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> List[JsonDict]:
|
||||
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
|
||||
room, or a given session.
|
||||
See EndToEndRoomKeyStore.get_e2e_room_keys for full details.
|
||||
|
||||
Args:
|
||||
user_id(str): the user whose keys we're getting
|
||||
version(str): the version ID of the backup we're getting keys from
|
||||
room_id(string): room ID to get keys for, for None to get keys for all rooms
|
||||
session_id(string): session ID to get keys for, for None to get keys for all
|
||||
user_id: the user whose keys we're getting
|
||||
version: the version ID of the backup we're getting keys from
|
||||
room_id: room ID to get keys for, for None to get keys for all rooms
|
||||
session_id: session ID to get keys for, for None to get keys for all
|
||||
sessions
|
||||
Raises:
|
||||
NotFoundError: if the backup version does not exist
|
||||
Returns:
|
||||
A deferred list of dicts giving the session_data and message metadata for
|
||||
A list of dicts giving the session_data and message metadata for
|
||||
these room keys.
|
||||
"""
|
||||
|
||||
|
@ -86,17 +97,23 @@ class E2eRoomKeysHandler:
|
|||
return results
|
||||
|
||||
@trace
|
||||
async def delete_room_keys(self, user_id, version, room_id=None, session_id=None):
|
||||
async def delete_room_keys(
|
||||
self,
|
||||
user_id: str,
|
||||
version: str,
|
||||
room_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> JsonDict:
|
||||
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
|
||||
room or a given session.
|
||||
See EndToEndRoomKeyStore.delete_e2e_room_keys for full details.
|
||||
|
||||
Args:
|
||||
user_id(str): the user whose backup we're deleting
|
||||
version(str): the version ID of the backup we're deleting
|
||||
room_id(string): room ID to delete keys for, for None to delete keys for all
|
||||
user_id: the user whose backup we're deleting
|
||||
version: the version ID of the backup we're deleting
|
||||
room_id: room ID to delete keys for, for None to delete keys for all
|
||||
rooms
|
||||
session_id(string): session ID to delete keys for, for None to delete keys
|
||||
session_id: session ID to delete keys for, for None to delete keys
|
||||
for all sessions
|
||||
Raises:
|
||||
NotFoundError: if the backup version does not exist
|
||||
|
@ -128,15 +145,17 @@ class E2eRoomKeysHandler:
|
|||
return {"etag": str(version_etag), "count": count}
|
||||
|
||||
@trace
|
||||
async def upload_room_keys(self, user_id, version, room_keys):
|
||||
async def upload_room_keys(
|
||||
self, user_id: str, version: str, room_keys: JsonDict
|
||||
) -> JsonDict:
|
||||
"""Bulk upload a list of room keys into a given backup version, asserting
|
||||
that the given version is the current backup version. room_keys are merged
|
||||
into the current backup as described in RoomKeysServlet.on_PUT().
|
||||
|
||||
Args:
|
||||
user_id(str): the user whose backup we're setting
|
||||
version(str): the version ID of the backup we're updating
|
||||
room_keys(dict): a nested dict describing the room_keys we're setting:
|
||||
user_id: the user whose backup we're setting
|
||||
version: the version ID of the backup we're updating
|
||||
room_keys: a nested dict describing the room_keys we're setting:
|
||||
|
||||
{
|
||||
"rooms": {
|
||||
|
@ -254,14 +273,16 @@ class E2eRoomKeysHandler:
|
|||
return {"etag": str(version_etag), "count": count}
|
||||
|
||||
@staticmethod
|
||||
def _should_replace_room_key(current_room_key, room_key):
|
||||
def _should_replace_room_key(
|
||||
current_room_key: Optional[JsonDict], room_key: JsonDict
|
||||
) -> bool:
|
||||
"""
|
||||
Determine whether to replace a given current_room_key (if any)
|
||||
with a newly uploaded room_key backup
|
||||
|
||||
Args:
|
||||
current_room_key (dict): Optional, the current room_key dict if any
|
||||
room_key (dict): The new room_key dict which may or may not be fit to
|
||||
current_room_key: Optional, the current room_key dict if any
|
||||
room_key : The new room_key dict which may or may not be fit to
|
||||
replace the current_room_key
|
||||
|
||||
Returns:
|
||||
|
@ -286,14 +307,14 @@ class E2eRoomKeysHandler:
|
|||
return True
|
||||
|
||||
@trace
|
||||
async def create_version(self, user_id, version_info):
|
||||
async def create_version(self, user_id: str, version_info: JsonDict) -> str:
|
||||
"""Create a new backup version. This automatically becomes the new
|
||||
backup version for the user's keys; previous backups will no longer be
|
||||
writeable to.
|
||||
|
||||
Args:
|
||||
user_id(str): the user whose backup version we're creating
|
||||
version_info(dict): metadata about the new version being created
|
||||
user_id: the user whose backup version we're creating
|
||||
version_info: metadata about the new version being created
|
||||
|
||||
{
|
||||
"algorithm": "m.megolm_backup.v1",
|
||||
|
@ -301,7 +322,7 @@ class E2eRoomKeysHandler:
|
|||
}
|
||||
|
||||
Returns:
|
||||
A deferred of a string that gives the new version number.
|
||||
The new version number.
|
||||
"""
|
||||
|
||||
# TODO: Validate the JSON to make sure it has the right keys.
|
||||
|
@ -313,17 +334,19 @@ class E2eRoomKeysHandler:
|
|||
)
|
||||
return new_version
|
||||
|
||||
async def get_version_info(self, user_id, version=None):
|
||||
async def get_version_info(
|
||||
self, user_id: str, version: Optional[str] = None
|
||||
) -> JsonDict:
|
||||
"""Get the info about a given version of the user's backup
|
||||
|
||||
Args:
|
||||
user_id(str): the user whose current backup version we're querying
|
||||
version(str): Optional; if None gives the most recent version
|
||||
user_id: the user whose current backup version we're querying
|
||||
version: Optional; if None gives the most recent version
|
||||
otherwise a historical one.
|
||||
Raises:
|
||||
NotFoundError: if the requested backup version doesn't exist
|
||||
Returns:
|
||||
A deferred of a info dict that gives the info about the new version.
|
||||
A info dict that gives the info about the new version.
|
||||
|
||||
{
|
||||
"version": "1234",
|
||||
|
@ -346,7 +369,7 @@ class E2eRoomKeysHandler:
|
|||
return res
|
||||
|
||||
@trace
|
||||
async def delete_version(self, user_id, version=None):
|
||||
async def delete_version(self, user_id: str, version: Optional[str] = None) -> None:
|
||||
"""Deletes a given version of the user's e2e_room_keys backup
|
||||
|
||||
Args:
|
||||
|
@ -366,17 +389,19 @@ class E2eRoomKeysHandler:
|
|||
raise
|
||||
|
||||
@trace
|
||||
async def update_version(self, user_id, version, version_info):
|
||||
async def update_version(
|
||||
self, user_id: str, version: str, version_info: JsonDict
|
||||
) -> JsonDict:
|
||||
"""Update the info about a given version of the user's backup
|
||||
|
||||
Args:
|
||||
user_id(str): the user whose current backup version we're updating
|
||||
version(str): the backup version we're updating
|
||||
version_info(dict): the new information about the backup
|
||||
user_id: the user whose current backup version we're updating
|
||||
version: the backup version we're updating
|
||||
version_info: the new information about the backup
|
||||
Raises:
|
||||
NotFoundError: if the requested backup version doesn't exist
|
||||
Returns:
|
||||
A deferred of an empty dict.
|
||||
An empty dict.
|
||||
"""
|
||||
if "version" not in version_info:
|
||||
version_info["version"] = version
|
||||
|
|
|
@ -1617,6 +1617,10 @@ class FederationHandler(BaseHandler):
|
|||
if event.state_key == self._server_notices_mxid:
|
||||
raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user")
|
||||
|
||||
# We retrieve the room member handler here as to not cause a cyclic dependency
|
||||
member_handler = self.hs.get_room_member_handler()
|
||||
member_handler.ratelimit_invite(event.room_id, event.state_key)
|
||||
|
||||
# keep a record of the room version, if we don't yet know it.
|
||||
# (this may get overwritten if we later get a different room version in a
|
||||
# join dance).
|
||||
|
@ -2093,6 +2097,11 @@ class FederationHandler(BaseHandler):
|
|||
if event.type == EventTypes.GuestAccess and not context.rejected:
|
||||
await self.maybe_kick_guest_users(event)
|
||||
|
||||
# If we are going to send this event over federation we precaclculate
|
||||
# the joined hosts.
|
||||
if event.internal_metadata.get_send_on_behalf_of():
|
||||
await self.event_creation_handler.cache_joined_hosts_for_event(event)
|
||||
|
||||
return context
|
||||
|
||||
async def _check_for_soft_fail(
|
||||
|
|
|
@ -15,9 +15,13 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Set
|
||||
|
||||
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
|
||||
from synapse.types import GroupID, get_domain_from_id
|
||||
from synapse.types import GroupID, JsonDict, get_domain_from_id
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -56,7 +60,7 @@ def _create_rerouter(func_name):
|
|||
|
||||
|
||||
class GroupsLocalWorkerHandler:
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.room_list_handler = hs.get_room_list_handler()
|
||||
|
@ -84,7 +88,9 @@ class GroupsLocalWorkerHandler:
|
|||
get_group_role = _create_rerouter("get_group_role")
|
||||
get_group_roles = _create_rerouter("get_group_roles")
|
||||
|
||||
async def get_group_summary(self, group_id, requester_user_id):
|
||||
async def get_group_summary(
|
||||
self, group_id: str, requester_user_id: str
|
||||
) -> JsonDict:
|
||||
"""Get the group summary for a group.
|
||||
|
||||
If the group is remote we check that the users have valid attestations.
|
||||
|
@ -137,14 +143,15 @@ class GroupsLocalWorkerHandler:
|
|||
|
||||
return res
|
||||
|
||||
async def get_users_in_group(self, group_id, requester_user_id):
|
||||
async def get_users_in_group(
|
||||
self, group_id: str, requester_user_id: str
|
||||
) -> JsonDict:
|
||||
"""Get users in a group
|
||||
"""
|
||||
if self.is_mine_id(group_id):
|
||||
res = await self.groups_server_handler.get_users_in_group(
|
||||
return await self.groups_server_handler.get_users_in_group(
|
||||
group_id, requester_user_id
|
||||
)
|
||||
return res
|
||||
|
||||
group_server_name = get_domain_from_id(group_id)
|
||||
|
||||
|
@ -178,11 +185,11 @@ class GroupsLocalWorkerHandler:
|
|||
|
||||
return res
|
||||
|
||||
async def get_joined_groups(self, user_id):
|
||||
async def get_joined_groups(self, user_id: str) -> JsonDict:
|
||||
group_ids = await self.store.get_joined_groups(user_id)
|
||||
return {"groups": group_ids}
|
||||
|
||||
async def get_publicised_groups_for_user(self, user_id):
|
||||
async def get_publicised_groups_for_user(self, user_id: str) -> JsonDict:
|
||||
if self.hs.is_mine_id(user_id):
|
||||
result = await self.store.get_publicised_groups_for_user(user_id)
|
||||
|
||||
|
@ -206,8 +213,10 @@ class GroupsLocalWorkerHandler:
|
|||
# TODO: Verify attestations
|
||||
return {"groups": result}
|
||||
|
||||
async def bulk_get_publicised_groups(self, user_ids, proxy=True):
|
||||
destinations = {}
|
||||
async def bulk_get_publicised_groups(
|
||||
self, user_ids: Iterable[str], proxy: bool = True
|
||||
) -> JsonDict:
|
||||
destinations = {} # type: Dict[str, Set[str]]
|
||||
local_users = set()
|
||||
|
||||
for user_id in user_ids:
|
||||
|
@ -220,7 +229,7 @@ class GroupsLocalWorkerHandler:
|
|||
raise SynapseError(400, "Some user_ids are not local")
|
||||
|
||||
results = {}
|
||||
failed_results = []
|
||||
failed_results = [] # type: List[str]
|
||||
for destination, dest_user_ids in destinations.items():
|
||||
try:
|
||||
r = await self.transport_client.bulk_get_publicised_groups(
|
||||
|
@ -242,7 +251,7 @@ class GroupsLocalWorkerHandler:
|
|||
|
||||
|
||||
class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
|
||||
# Ensure attestations get renewed
|
||||
|
@ -271,7 +280,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||
|
||||
set_group_join_policy = _create_rerouter("set_group_join_policy")
|
||||
|
||||
async def create_group(self, group_id, user_id, content):
|
||||
async def create_group(
|
||||
self, group_id: str, user_id: str, content: JsonDict
|
||||
) -> JsonDict:
|
||||
"""Create a group
|
||||
"""
|
||||
|
||||
|
@ -284,27 +295,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||
local_attestation = None
|
||||
remote_attestation = None
|
||||
else:
|
||||
local_attestation = self.attestations.create_attestation(group_id, user_id)
|
||||
content["attestation"] = local_attestation
|
||||
|
||||
content["user_profile"] = await self.profile_handler.get_profile(user_id)
|
||||
|
||||
try:
|
||||
res = await self.transport_client.create_group(
|
||||
get_domain_from_id(group_id), group_id, user_id, content
|
||||
)
|
||||
except HttpResponseException as e:
|
||||
raise e.to_synapse_error()
|
||||
except RequestSendFailed:
|
||||
raise SynapseError(502, "Failed to contact group server")
|
||||
|
||||
remote_attestation = res["attestation"]
|
||||
await self.attestations.verify_attestation(
|
||||
remote_attestation,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
server_name=get_domain_from_id(group_id),
|
||||
)
|
||||
raise SynapseError(400, "Unable to create remote groups")
|
||||
|
||||
is_publicised = content.get("publicise", False)
|
||||
token = await self.store.register_user_group_membership(
|
||||
|
@ -320,7 +311,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||
|
||||
return res
|
||||
|
||||
async def join_group(self, group_id, user_id, content):
|
||||
async def join_group(
|
||||
self, group_id: str, user_id: str, content: JsonDict
|
||||
) -> JsonDict:
|
||||
"""Request to join a group
|
||||
"""
|
||||
if self.is_mine_id(group_id):
|
||||
|
@ -365,7 +358,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||
|
||||
return {}
|
||||
|
||||
async def accept_invite(self, group_id, user_id, content):
|
||||
async def accept_invite(
|
||||
self, group_id: str, user_id: str, content: JsonDict
|
||||
) -> JsonDict:
|
||||
"""Accept an invite to a group
|
||||
"""
|
||||
if self.is_mine_id(group_id):
|
||||
|
@ -410,7 +405,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||
|
||||
return {}
|
||||
|
||||
async def invite(self, group_id, user_id, requester_user_id, config):
|
||||
async def invite(
|
||||
self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict
|
||||
) -> JsonDict:
|
||||
"""Invite a user to a group
|
||||
"""
|
||||
content = {"requester_user_id": requester_user_id, "config": config}
|
||||
|
@ -434,7 +431,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||
|
||||
return res
|
||||
|
||||
async def on_invite(self, group_id, user_id, content):
|
||||
async def on_invite(
|
||||
self, group_id: str, user_id: str, content: JsonDict
|
||||
) -> JsonDict:
|
||||
"""One of our users were invited to a group
|
||||
"""
|
||||
# TODO: Support auto join and rejection
|
||||
|
@ -465,8 +464,8 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||
return {"state": "invite", "user_profile": user_profile}
|
||||
|
||||
async def remove_user_from_group(
|
||||
self, group_id, user_id, requester_user_id, content
|
||||
):
|
||||
self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
|
||||
) -> JsonDict:
|
||||
"""Remove a user from a group
|
||||
"""
|
||||
if user_id == requester_user_id:
|
||||
|
@ -499,7 +498,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||
|
||||
return res
|
||||
|
||||
async def user_removed_from_group(self, group_id, user_id, content):
|
||||
async def user_removed_from_group(
|
||||
self, group_id: str, user_id: str, content: JsonDict
|
||||
) -> None:
|
||||
"""One of our users was removed/kicked from a group
|
||||
"""
|
||||
# TODO: Check if user in group
|
||||
|
|
|
@ -27,9 +27,11 @@ from synapse.api.errors import (
|
|||
HttpResponseException,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.config.emailconfig import ThreepidBehaviour
|
||||
from synapse.http import RequestTimedOutError
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.types import JsonDict, Requester
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.hash import sha256_and_url_safe_base64
|
||||
|
@ -57,6 +59,32 @@ class IdentityHandler(BaseHandler):
|
|||
|
||||
self._web_client_location = hs.config.invite_client_location
|
||||
|
||||
# Ratelimiters for `/requestToken` endpoints.
|
||||
self._3pid_validation_ratelimiter_ip = Ratelimiter(
|
||||
clock=hs.get_clock(),
|
||||
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
|
||||
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
|
||||
)
|
||||
self._3pid_validation_ratelimiter_address = Ratelimiter(
|
||||
clock=hs.get_clock(),
|
||||
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
|
||||
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
|
||||
)
|
||||
|
||||
def ratelimit_request_token_requests(
|
||||
self, request: SynapseRequest, medium: str, address: str,
|
||||
):
|
||||
"""Used to ratelimit requests to `/requestToken` by IP and address.
|
||||
|
||||
Args:
|
||||
request: The associated request
|
||||
medium: The type of threepid, e.g. "msisdn" or "email"
|
||||
address: The actual threepid ID, e.g. the phone number or email address
|
||||
"""
|
||||
|
||||
self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP()))
|
||||
self._3pid_validation_ratelimiter_address.ratelimit((medium, address))
|
||||
|
||||
async def threepid_from_creds(
|
||||
self, id_server: str, creds: Dict[str, str]
|
||||
) -> Optional[JsonDict]:
|
||||
|
|
|
@ -174,7 +174,7 @@ class MessageHandler:
|
|||
raise NotFoundError("Can't find event for token %s" % (at_token,))
|
||||
|
||||
visible_events = await filter_events_for_client(
|
||||
self.storage, user_id, last_events, filter_send_to_client=False
|
||||
self.storage, user_id, last_events, filter_send_to_client=False,
|
||||
)
|
||||
|
||||
event = last_events[0]
|
||||
|
@ -434,6 +434,8 @@ class EventCreationHandler:
|
|||
|
||||
self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
|
||||
|
||||
self._external_cache = hs.get_external_cache()
|
||||
|
||||
async def create_event(
|
||||
self,
|
||||
requester: Requester,
|
||||
|
@ -941,6 +943,8 @@ class EventCreationHandler:
|
|||
|
||||
await self.action_generator.handle_push_actions_for_event(event, context)
|
||||
|
||||
await self.cache_joined_hosts_for_event(event)
|
||||
|
||||
try:
|
||||
# If we're a worker we need to hit out to the master.
|
||||
writer_instance = self._events_shard_config.get_instance(event.room_id)
|
||||
|
@ -980,6 +984,44 @@ class EventCreationHandler:
|
|||
await self.store.remove_push_actions_from_staging(event.event_id)
|
||||
raise
|
||||
|
||||
async def cache_joined_hosts_for_event(self, event: EventBase) -> None:
|
||||
"""Precalculate the joined hosts at the event, when using Redis, so that
|
||||
external federation senders don't have to recalculate it themselves.
|
||||
"""
|
||||
|
||||
if not self._external_cache.is_enabled():
|
||||
return
|
||||
|
||||
# We actually store two mappings, event ID -> prev state group,
|
||||
# state group -> joined hosts, which is much more space efficient
|
||||
# than event ID -> joined hosts.
|
||||
#
|
||||
# Note: We have to cache event ID -> prev state group, as we don't
|
||||
# store that in the DB.
|
||||
#
|
||||
# Note: We always set the state group -> joined hosts cache, even if
|
||||
# we already set it, so that the expiry time is reset.
|
||||
|
||||
state_entry = await self.state.resolve_state_groups_for_events(
|
||||
event.room_id, event_ids=event.prev_event_ids()
|
||||
)
|
||||
|
||||
if state_entry.state_group:
|
||||
joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry)
|
||||
|
||||
await self._external_cache.set(
|
||||
"event_to_prev_state_group",
|
||||
event.event_id,
|
||||
state_entry.state_group,
|
||||
expiry_ms=60 * 60 * 1000,
|
||||
)
|
||||
await self._external_cache.set(
|
||||
"get_joined_hosts",
|
||||
str(state_entry.state_group),
|
||||
list(joined_hosts),
|
||||
expiry_ms=60 * 60 * 1000,
|
||||
)
|
||||
|
||||
async def _validate_canonical_alias(
|
||||
self, directory_handler, room_alias_str: str, expected_room_id: str
|
||||
) -> None:
|
||||
|
|
|
@ -102,7 +102,7 @@ class OidcHandler:
|
|||
) from e
|
||||
|
||||
async def handle_oidc_callback(self, request: SynapseRequest) -> None:
|
||||
"""Handle an incoming request to /_synapse/oidc/callback
|
||||
"""Handle an incoming request to /_synapse/client/oidc/callback
|
||||
|
||||
Since we might want to display OIDC-related errors in a user-friendly
|
||||
way, we don't raise SynapseError from here. Instead, we call
|
||||
|
@ -274,6 +274,9 @@ class OidcProvider:
|
|||
# MXC URI for icon for this auth provider
|
||||
self.idp_icon = provider.idp_icon
|
||||
|
||||
# optional brand identifier for this auth provider
|
||||
self.idp_brand = provider.idp_brand
|
||||
|
||||
self._sso_handler = hs.get_sso_handler()
|
||||
|
||||
self._sso_handler.register_identity_provider(self)
|
||||
|
@ -640,7 +643,7 @@ class OidcProvider:
|
|||
|
||||
- ``client_id``: the client ID set in ``oidc_config.client_id``
|
||||
- ``response_type``: ``code``
|
||||
- ``redirect_uri``: the callback URL ; ``{base url}/_synapse/oidc/callback``
|
||||
- ``redirect_uri``: the callback URL ; ``{base url}/_synapse/client/oidc/callback``
|
||||
- ``scope``: the list of scopes set in ``oidc_config.scopes``
|
||||
- ``state``: a random string
|
||||
- ``nonce``: a random string
|
||||
|
@ -681,7 +684,7 @@ class OidcProvider:
|
|||
request.addCookie(
|
||||
SESSION_COOKIE_NAME,
|
||||
cookie,
|
||||
path="/_synapse/oidc",
|
||||
path="/_synapse/client/oidc",
|
||||
max_age="3600",
|
||||
httpOnly=True,
|
||||
sameSite="lax",
|
||||
|
@ -702,7 +705,7 @@ class OidcProvider:
|
|||
async def handle_oidc_callback(
|
||||
self, request: SynapseRequest, session_data: "OidcSessionData", code: str
|
||||
) -> None:
|
||||
"""Handle an incoming request to /_synapse/oidc/callback
|
||||
"""Handle an incoming request to /_synapse/client/oidc/callback
|
||||
|
||||
By this time we have already validated the session on the synapse side, and
|
||||
now need to do the provider-specific operations. This includes:
|
||||
|
@ -1056,7 +1059,8 @@ class OidcSessionData:
|
|||
|
||||
|
||||
UserAttributeDict = TypedDict(
|
||||
"UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]}
|
||||
"UserAttributeDict",
|
||||
{"localpart": Optional[str], "display_name": Optional[str], "emails": List[str]},
|
||||
)
|
||||
C = TypeVar("C")
|
||||
|
||||
|
@ -1135,11 +1139,12 @@ def jinja_finalize(thing):
|
|||
env = Environment(finalize=jinja_finalize)
|
||||
|
||||
|
||||
@attr.s
|
||||
@attr.s(slots=True, frozen=True)
|
||||
class JinjaOidcMappingConfig:
|
||||
subject_claim = attr.ib(type=str)
|
||||
localpart_template = attr.ib(type=Optional[Template])
|
||||
display_name_template = attr.ib(type=Optional[Template])
|
||||
email_template = attr.ib(type=Optional[Template])
|
||||
extra_attributes = attr.ib(type=Dict[str, Template])
|
||||
|
||||
|
||||
|
@ -1156,23 +1161,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
|
|||
def parse_config(config: dict) -> JinjaOidcMappingConfig:
|
||||
subject_claim = config.get("subject_claim", "sub")
|
||||
|
||||
localpart_template = None # type: Optional[Template]
|
||||
if "localpart_template" in config:
|
||||
def parse_template_config(option_name: str) -> Optional[Template]:
|
||||
if option_name not in config:
|
||||
return None
|
||||
try:
|
||||
localpart_template = env.from_string(config["localpart_template"])
|
||||
return env.from_string(config[option_name])
|
||||
except Exception as e:
|
||||
raise ConfigError(
|
||||
"invalid jinja template", path=["localpart_template"]
|
||||
) from e
|
||||
raise ConfigError("invalid jinja template", path=[option_name]) from e
|
||||
|
||||
display_name_template = None # type: Optional[Template]
|
||||
if "display_name_template" in config:
|
||||
try:
|
||||
display_name_template = env.from_string(config["display_name_template"])
|
||||
except Exception as e:
|
||||
raise ConfigError(
|
||||
"invalid jinja template", path=["display_name_template"]
|
||||
) from e
|
||||
localpart_template = parse_template_config("localpart_template")
|
||||
display_name_template = parse_template_config("display_name_template")
|
||||
email_template = parse_template_config("email_template")
|
||||
|
||||
extra_attributes = {} # type Dict[str, Template]
|
||||
if "extra_attributes" in config:
|
||||
|
@ -1192,6 +1191,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
|
|||
subject_claim=subject_claim,
|
||||
localpart_template=localpart_template,
|
||||
display_name_template=display_name_template,
|
||||
email_template=email_template,
|
||||
extra_attributes=extra_attributes,
|
||||
)
|
||||
|
||||
|
@ -1213,16 +1213,23 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
|
|||
# a usable mxid.
|
||||
localpart += str(failures) if failures else ""
|
||||
|
||||
display_name = None # type: Optional[str]
|
||||
if self._config.display_name_template is not None:
|
||||
display_name = self._config.display_name_template.render(
|
||||
user=userinfo
|
||||
).strip()
|
||||
def render_template_field(template: Optional[Template]) -> Optional[str]:
|
||||
if template is None:
|
||||
return None
|
||||
return template.render(user=userinfo).strip()
|
||||
|
||||
if display_name == "":
|
||||
display_name = None
|
||||
display_name = render_template_field(self._config.display_name_template)
|
||||
if display_name == "":
|
||||
display_name = None
|
||||
|
||||
return UserAttributeDict(localpart=localpart, display_name=display_name)
|
||||
emails = [] # type: List[str]
|
||||
email = render_template_field(self._config.email_template)
|
||||
if email:
|
||||
emails.append(email)
|
||||
|
||||
return UserAttributeDict(
|
||||
localpart=localpart, display_name=display_name, emails=emails
|
||||
)
|
||||
|
||||
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
|
||||
extras = {} # type: Dict[str, str]
|
||||
|
|
|
@ -14,8 +14,9 @@
|
|||
# limitations under the License.
|
||||
|
||||
"""Contains functions for registering clients."""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
|
||||
|
||||
from synapse import types
|
||||
from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
|
||||
|
@ -157,7 +158,7 @@ class RegistrationHandler(BaseHandler):
|
|||
user_type: Optional[str] = None,
|
||||
default_display_name: Optional[str] = None,
|
||||
address: Optional[str] = None,
|
||||
bind_emails: List[str] = [],
|
||||
bind_emails: Iterable[str] = [],
|
||||
by_admin: bool = False,
|
||||
user_agent_ips: Optional[List[Tuple[str, str]]] = None,
|
||||
) -> str:
|
||||
|
@ -700,6 +701,8 @@ class RegistrationHandler(BaseHandler):
|
|||
access_token: The access token of the newly logged in device, or
|
||||
None if `inhibit_login` enabled.
|
||||
"""
|
||||
# TODO: 3pid registration can actually happen on the workers. Consider
|
||||
# refactoring it.
|
||||
if self.hs.config.worker_app:
|
||||
await self._post_registration_client(
|
||||
user_id=user_id, auth_result=auth_result, access_token=access_token
|
||||
|
|
|
@ -126,6 +126,10 @@ class RoomCreationHandler(BaseHandler):
|
|||
|
||||
self.third_party_event_rules = hs.get_third_party_event_rules()
|
||||
|
||||
self._invite_burst_count = (
|
||||
hs.config.ratelimiting.rc_invites_per_room.burst_count
|
||||
)
|
||||
|
||||
async def upgrade_room(
|
||||
self, requester: Requester, old_room_id: str, new_version: RoomVersion
|
||||
) -> str:
|
||||
|
@ -662,6 +666,9 @@ class RoomCreationHandler(BaseHandler):
|
|||
invite_3pid_list = []
|
||||
invite_list = []
|
||||
|
||||
if len(invite_list) + len(invite_3pid_list) > self._invite_burst_count:
|
||||
raise SynapseError(400, "Cannot invite so many users at once")
|
||||
|
||||
await self.event_creation_handler.assert_accepted_privacy_policy(requester)
|
||||
|
||||
power_level_content_override = config.get("power_level_content_override")
|
||||
|
|
|
@ -85,6 +85,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
|
||||
)
|
||||
|
||||
self._invites_per_room_limiter = Ratelimiter(
|
||||
clock=self.clock,
|
||||
rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second,
|
||||
burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count,
|
||||
)
|
||||
self._invites_per_user_limiter = Ratelimiter(
|
||||
clock=self.clock,
|
||||
rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second,
|
||||
burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
|
||||
)
|
||||
|
||||
# This is only used to get at ratelimit function, and
|
||||
# maybe_kick_guest_users. It's fine there are multiple of these as
|
||||
# it doesn't store state.
|
||||
|
@ -144,6 +155,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def ratelimit_invite(self, room_id: str, invitee_user_id: str):
|
||||
"""Ratelimit invites by room and by target user.
|
||||
"""
|
||||
self._invites_per_room_limiter.ratelimit(room_id)
|
||||
self._invites_per_user_limiter.ratelimit(invitee_user_id)
|
||||
|
||||
async def _local_membership_update(
|
||||
self,
|
||||
requester: Requester,
|
||||
|
@ -387,8 +404,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
raise SynapseError(403, "This room has been blocked on this server")
|
||||
|
||||
if effective_membership_state == Membership.INVITE:
|
||||
target_id = target.to_string()
|
||||
if ratelimit:
|
||||
self.ratelimit_invite(room_id, target_id)
|
||||
|
||||
# block any attempts to invite the server notices mxid
|
||||
if target.to_string() == self._server_notices_mxid:
|
||||
if target_id == self._server_notices_mxid:
|
||||
raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user")
|
||||
|
||||
block_invite = False
|
||||
|
@ -412,7 +433,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
block_invite = True
|
||||
|
||||
if not await self.spam_checker.user_may_invite(
|
||||
requester.user.to_string(), target.to_string(), room_id
|
||||
requester.user.to_string(), target_id, room_id
|
||||
):
|
||||
logger.info("Blocking invite due to spam checker")
|
||||
block_invite = True
|
||||
|
|
|
@ -78,9 +78,10 @@ class SamlHandler(BaseHandler):
|
|||
# user-facing name of this auth provider
|
||||
self.idp_name = "SAML"
|
||||
|
||||
# we do not currently support icons for SAML auth, but this is required by
|
||||
# we do not currently support icons/brands for SAML auth, but this is required by
|
||||
# the SsoIdentityProvider protocol type.
|
||||
self.idp_icon = None
|
||||
self.idp_brand = None
|
||||
|
||||
# a map from saml session id to Saml2SessionData object
|
||||
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
|
||||
|
@ -132,7 +133,7 @@ class SamlHandler(BaseHandler):
|
|||
raise Exception("prepare_for_authenticate didn't return a Location header")
|
||||
|
||||
async def handle_saml_response(self, request: SynapseRequest) -> None:
|
||||
"""Handle an incoming request to /_matrix/saml2/authn_response
|
||||
"""Handle an incoming request to /_synapse/client/saml2/authn_response
|
||||
|
||||
Args:
|
||||
request: the incoming request from the browser. We'll
|
||||
|
|
|
@ -15,23 +15,28 @@
|
|||
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Iterable
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional
|
||||
|
||||
from unpaddedbase64 import decode_base64, encode_base64
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import NotFoundError, SynapseError
|
||||
from synapse.api.filtering import Filter
|
||||
from synapse.events import EventBase
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SearchHandler(BaseHandler):
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
self.storage = hs.get_storage()
|
||||
|
@ -87,13 +92,15 @@ class SearchHandler(BaseHandler):
|
|||
|
||||
return historical_room_ids
|
||||
|
||||
async def search(self, user, content, batch=None):
|
||||
async def search(
|
||||
self, user: UserID, content: JsonDict, batch: Optional[str] = None
|
||||
) -> JsonDict:
|
||||
"""Performs a full text search for a user.
|
||||
|
||||
Args:
|
||||
user (UserID)
|
||||
content (dict): Search parameters
|
||||
batch (str): The next_batch parameter. Used for pagination.
|
||||
user
|
||||
content: Search parameters
|
||||
batch: The next_batch parameter. Used for pagination.
|
||||
|
||||
Returns:
|
||||
dict to be returned to the client with results of search
|
||||
|
@ -186,7 +193,7 @@ class SearchHandler(BaseHandler):
|
|||
# If doing a subset of all rooms seearch, check if any of the rooms
|
||||
# are from an upgraded room, and search their contents as well
|
||||
if search_filter.rooms:
|
||||
historical_room_ids = []
|
||||
historical_room_ids = [] # type: List[str]
|
||||
for room_id in search_filter.rooms:
|
||||
# Add any previous rooms to the search if they exist
|
||||
ids = await self.get_old_rooms_from_upgraded_room(room_id)
|
||||
|
@ -209,8 +216,10 @@ class SearchHandler(BaseHandler):
|
|||
|
||||
rank_map = {} # event_id -> rank of event
|
||||
allowed_events = []
|
||||
room_groups = {} # Holds result of grouping by room, if applicable
|
||||
sender_group = {} # Holds result of grouping by sender, if applicable
|
||||
# Holds result of grouping by room, if applicable
|
||||
room_groups = {} # type: Dict[str, JsonDict]
|
||||
# Holds result of grouping by sender, if applicable
|
||||
sender_group = {} # type: Dict[str, JsonDict]
|
||||
|
||||
# Holds the next_batch for the entire result set if one of those exists
|
||||
global_next_batch = None
|
||||
|
@ -254,7 +263,7 @@ class SearchHandler(BaseHandler):
|
|||
s["results"].append(e.event_id)
|
||||
|
||||
elif order_by == "recent":
|
||||
room_events = []
|
||||
room_events = [] # type: List[EventBase]
|
||||
i = 0
|
||||
|
||||
pagination_token = batch_token
|
||||
|
@ -418,13 +427,10 @@ class SearchHandler(BaseHandler):
|
|||
|
||||
state_results = {}
|
||||
if include_state:
|
||||
rooms = {e.room_id for e in allowed_events}
|
||||
for room_id in rooms:
|
||||
for room_id in {e.room_id for e in allowed_events}:
|
||||
state = await self.state_handler.get_current_state(room_id)
|
||||
state_results[room_id] = list(state.values())
|
||||
|
||||
state_results.values()
|
||||
|
||||
# We're now about to serialize the events. We should not make any
|
||||
# blocking calls after this. Otherwise the 'age' will be wrong
|
||||
|
||||
|
@ -448,9 +454,9 @@ class SearchHandler(BaseHandler):
|
|||
|
||||
if state_results:
|
||||
s = {}
|
||||
for room_id, state in state_results.items():
|
||||
for room_id, state_events in state_results.items():
|
||||
s[room_id] = await self._event_serializer.serialize_events(
|
||||
state, time_now
|
||||
state_events, time_now
|
||||
)
|
||||
|
||||
rooms_cat_res["state"] = s
|
||||
|
|
|
@ -13,24 +13,26 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from synapse.api.errors import Codes, StoreError, SynapseError
|
||||
from synapse.types import Requester
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SetPasswordHandler(BaseHandler):
|
||||
"""Handler which deals with changing user account passwords"""
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._device_handler = hs.get_device_handler()
|
||||
self._password_policy_handler = hs.get_password_policy_handler()
|
||||
|
||||
async def set_password(
|
||||
self,
|
||||
|
@ -38,7 +40,7 @@ class SetPasswordHandler(BaseHandler):
|
|||
password_hash: str,
|
||||
logout_devices: bool,
|
||||
requester: Optional[Requester] = None,
|
||||
):
|
||||
) -> None:
|
||||
if not self.hs.config.password_localdb_enabled:
|
||||
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
|
||||
|
||||
|
|
|
@ -14,21 +14,31 @@
|
|||
# limitations under the License.
|
||||
import abc
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Mapping, Optional
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Mapping,
|
||||
Optional,
|
||||
Set,
|
||||
)
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import attr
|
||||
from typing_extensions import NoReturn, Protocol
|
||||
|
||||
from twisted.web.http import Request
|
||||
from twisted.web.iweb import IRequest
|
||||
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import Codes, RedirectException, SynapseError
|
||||
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
|
||||
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
|
||||
from synapse.http import get_request_user_agent
|
||||
from synapse.http.server import respond_with_html
|
||||
from synapse.http.server import respond_with_html, respond_with_redirect
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
|
||||
from synapse.types import Collection, JsonDict, UserID, contains_invalid_mxid_characters
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
|
@ -80,6 +90,11 @@ class SsoIdentityProvider(Protocol):
|
|||
"""Optional MXC URI for user-facing icon"""
|
||||
return None
|
||||
|
||||
@property
|
||||
def idp_brand(self) -> Optional[str]:
|
||||
"""Optional branding identifier"""
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
async def handle_redirect_request(
|
||||
self,
|
||||
|
@ -109,7 +124,7 @@ class UserAttributes:
|
|||
# enter one.
|
||||
localpart = attr.ib(type=Optional[str])
|
||||
display_name = attr.ib(type=Optional[str], default=None)
|
||||
emails = attr.ib(type=List[str], default=attr.Factory(list))
|
||||
emails = attr.ib(type=Collection[str], default=attr.Factory(list))
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
|
@ -124,7 +139,7 @@ class UsernameMappingSession:
|
|||
|
||||
# attributes returned by the ID mapper
|
||||
display_name = attr.ib(type=Optional[str])
|
||||
emails = attr.ib(type=List[str])
|
||||
emails = attr.ib(type=Collection[str])
|
||||
|
||||
# An optional dictionary of extra attributes to be provided to the client in the
|
||||
# login response.
|
||||
|
@ -136,6 +151,12 @@ class UsernameMappingSession:
|
|||
# expiry time for the session, in milliseconds
|
||||
expiry_time_ms = attr.ib(type=int)
|
||||
|
||||
# choices made by the user
|
||||
chosen_localpart = attr.ib(type=Optional[str], default=None)
|
||||
use_display_name = attr.ib(type=bool, default=True)
|
||||
emails_to_use = attr.ib(type=Collection[str], default=())
|
||||
terms_accepted_version = attr.ib(type=Optional[str], default=None)
|
||||
|
||||
|
||||
# the HTTP cookie used to track the mapping session id
|
||||
USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session"
|
||||
|
@ -170,6 +191,8 @@ class SsoHandler:
|
|||
# map from idp_id to SsoIdentityProvider
|
||||
self._identity_providers = {} # type: Dict[str, SsoIdentityProvider]
|
||||
|
||||
self._consent_at_registration = hs.config.consent.user_consent_at_registration
|
||||
|
||||
def register_identity_provider(self, p: SsoIdentityProvider):
|
||||
p_id = p.idp_id
|
||||
assert p_id not in self._identity_providers
|
||||
|
@ -235,7 +258,10 @@ class SsoHandler:
|
|||
respond_with_html(request, code, html)
|
||||
|
||||
async def handle_redirect_request(
|
||||
self, request: SynapseRequest, client_redirect_url: bytes,
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
client_redirect_url: bytes,
|
||||
idp_id: Optional[str],
|
||||
) -> str:
|
||||
"""Handle a request to /login/sso/redirect
|
||||
|
||||
|
@ -243,6 +269,7 @@ class SsoHandler:
|
|||
request: incoming HTTP request
|
||||
client_redirect_url: the URL that we should redirect the
|
||||
client to after login.
|
||||
idp_id: optional identity provider chosen by the client
|
||||
|
||||
Returns:
|
||||
the URI to redirect to
|
||||
|
@ -252,10 +279,19 @@ class SsoHandler:
|
|||
400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED
|
||||
)
|
||||
|
||||
# if the client chose an IdP, use that
|
||||
idp = None # type: Optional[SsoIdentityProvider]
|
||||
if idp_id:
|
||||
idp = self._identity_providers.get(idp_id)
|
||||
if not idp:
|
||||
raise NotFoundError("Unknown identity provider")
|
||||
|
||||
# if we only have one auth provider, redirect to it directly
|
||||
if len(self._identity_providers) == 1:
|
||||
ap = next(iter(self._identity_providers.values()))
|
||||
return await ap.handle_redirect_request(request, client_redirect_url)
|
||||
elif len(self._identity_providers) == 1:
|
||||
idp = next(iter(self._identity_providers.values()))
|
||||
|
||||
if idp:
|
||||
return await idp.handle_redirect_request(request, client_redirect_url)
|
||||
|
||||
# otherwise, redirect to the IDP picker
|
||||
return "/_synapse/client/pick_idp?" + urlencode(
|
||||
|
@ -369,6 +405,8 @@ class SsoHandler:
|
|||
to an additional page. (e.g. to prompt for more information)
|
||||
|
||||
"""
|
||||
new_user = False
|
||||
|
||||
# grab a lock while we try to find a mapping for this user. This seems...
|
||||
# optimistic, especially for implementations that end up redirecting to
|
||||
# interstitial pages.
|
||||
|
@ -409,9 +447,14 @@ class SsoHandler:
|
|||
get_request_user_agent(request),
|
||||
request.getClientIP(),
|
||||
)
|
||||
new_user = True
|
||||
|
||||
await self._auth_handler.complete_sso_login(
|
||||
user_id, request, client_redirect_url, extra_login_attributes
|
||||
user_id,
|
||||
request,
|
||||
client_redirect_url,
|
||||
extra_login_attributes,
|
||||
new_user=new_user,
|
||||
)
|
||||
|
||||
async def _call_attribute_mapper(
|
||||
|
@ -501,7 +544,7 @@ class SsoHandler:
|
|||
logger.info("Recorded registration session id %s", session_id)
|
||||
|
||||
# Set the cookie and redirect to the username picker
|
||||
e = RedirectException(b"/_synapse/client/pick_username")
|
||||
e = RedirectException(b"/_synapse/client/pick_username/account_details")
|
||||
e.cookies.append(
|
||||
b"%s=%s; path=/"
|
||||
% (USERNAME_MAPPING_SESSION_COOKIE_NAME, session_id.encode("ascii"))
|
||||
|
@ -629,6 +672,25 @@ class SsoHandler:
|
|||
)
|
||||
respond_with_html(request, 200, html)
|
||||
|
||||
def get_mapping_session(self, session_id: str) -> UsernameMappingSession:
|
||||
"""Look up the given username mapping session
|
||||
|
||||
If it is not found, raises a SynapseError with an http code of 400
|
||||
|
||||
Args:
|
||||
session_id: session to look up
|
||||
Returns:
|
||||
active mapping session
|
||||
Raises:
|
||||
SynapseError if the session is not found/has expired
|
||||
"""
|
||||
self._expire_old_sessions()
|
||||
session = self._username_mapping_sessions.get(session_id)
|
||||
if session:
|
||||
return session
|
||||
logger.info("Couldn't find session id %s", session_id)
|
||||
raise SynapseError(400, "unknown session")
|
||||
|
||||
async def check_username_availability(
|
||||
self, localpart: str, session_id: str,
|
||||
) -> bool:
|
||||
|
@ -645,12 +707,7 @@ class SsoHandler:
|
|||
|
||||
# make sure that there is a valid mapping session, to stop people dictionary-
|
||||
# scanning for accounts
|
||||
|
||||
self._expire_old_sessions()
|
||||
session = self._username_mapping_sessions.get(session_id)
|
||||
if not session:
|
||||
logger.info("Couldn't find session id %s", session_id)
|
||||
raise SynapseError(400, "unknown session")
|
||||
self.get_mapping_session(session_id)
|
||||
|
||||
logger.info(
|
||||
"[session %s] Checking for availability of username %s",
|
||||
|
@ -667,7 +724,12 @@ class SsoHandler:
|
|||
return not user_infos
|
||||
|
||||
async def handle_submit_username_request(
|
||||
self, request: SynapseRequest, localpart: str, session_id: str
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
session_id: str,
|
||||
localpart: str,
|
||||
use_display_name: bool,
|
||||
emails_to_use: Iterable[str],
|
||||
) -> None:
|
||||
"""Handle a request to the username-picker 'submit' endpoint
|
||||
|
||||
|
@ -677,21 +739,90 @@ class SsoHandler:
|
|||
request: HTTP request
|
||||
localpart: localpart requested by the user
|
||||
session_id: ID of the username mapping session, extracted from a cookie
|
||||
use_display_name: whether the user wants to use the suggested display name
|
||||
emails_to_use: emails that the user would like to use
|
||||
"""
|
||||
self._expire_old_sessions()
|
||||
session = self._username_mapping_sessions.get(session_id)
|
||||
if not session:
|
||||
logger.info("Couldn't find session id %s", session_id)
|
||||
raise SynapseError(400, "unknown session")
|
||||
session = self.get_mapping_session(session_id)
|
||||
|
||||
logger.info("[session %s] Registering localpart %s", session_id, localpart)
|
||||
# update the session with the user's choices
|
||||
session.chosen_localpart = localpart
|
||||
session.use_display_name = use_display_name
|
||||
|
||||
emails_from_idp = set(session.emails)
|
||||
filtered_emails = set() # type: Set[str]
|
||||
|
||||
# we iterate through the list rather than just building a set conjunction, so
|
||||
# that we can log attempts to use unknown addresses
|
||||
for email in emails_to_use:
|
||||
if email in emails_from_idp:
|
||||
filtered_emails.add(email)
|
||||
else:
|
||||
logger.warning(
|
||||
"[session %s] ignoring user request to use unknown email address %r",
|
||||
session_id,
|
||||
email,
|
||||
)
|
||||
session.emails_to_use = filtered_emails
|
||||
|
||||
# we may now need to collect consent from the user, in which case, redirect
|
||||
# to the consent-extraction-unit
|
||||
if self._consent_at_registration:
|
||||
redirect_url = b"/_synapse/client/new_user_consent"
|
||||
|
||||
# otherwise, redirect to the completion page
|
||||
else:
|
||||
redirect_url = b"/_synapse/client/sso_register"
|
||||
|
||||
respond_with_redirect(request, redirect_url)
|
||||
|
||||
async def handle_terms_accepted(
|
||||
self, request: Request, session_id: str, terms_version: str
|
||||
):
|
||||
"""Handle a request to the new-user 'consent' endpoint
|
||||
|
||||
Will serve an HTTP response to the request.
|
||||
|
||||
Args:
|
||||
request: HTTP request
|
||||
session_id: ID of the username mapping session, extracted from a cookie
|
||||
terms_version: the version of the terms which the user viewed and consented
|
||||
to
|
||||
"""
|
||||
logger.info(
|
||||
"[session %s] User consented to terms version %s",
|
||||
session_id,
|
||||
terms_version,
|
||||
)
|
||||
session = self.get_mapping_session(session_id)
|
||||
session.terms_accepted_version = terms_version
|
||||
|
||||
# we're done; now we can register the user
|
||||
respond_with_redirect(request, b"/_synapse/client/sso_register")
|
||||
|
||||
async def register_sso_user(self, request: Request, session_id: str) -> None:
|
||||
"""Called once we have all the info we need to register a new user.
|
||||
|
||||
Does so and serves an HTTP response
|
||||
|
||||
Args:
|
||||
request: HTTP request
|
||||
session_id: ID of the username mapping session, extracted from a cookie
|
||||
"""
|
||||
session = self.get_mapping_session(session_id)
|
||||
|
||||
logger.info(
|
||||
"[session %s] Registering localpart %s",
|
||||
session_id,
|
||||
session.chosen_localpart,
|
||||
)
|
||||
|
||||
attributes = UserAttributes(
|
||||
localpart=localpart,
|
||||
display_name=session.display_name,
|
||||
emails=session.emails,
|
||||
localpart=session.chosen_localpart, emails=session.emails_to_use,
|
||||
)
|
||||
|
||||
if session.use_display_name:
|
||||
attributes.display_name = session.display_name
|
||||
|
||||
# the following will raise a 400 error if the username has been taken in the
|
||||
# meantime.
|
||||
user_id = await self._register_mapped_user(
|
||||
|
@ -702,7 +833,12 @@ class SsoHandler:
|
|||
request.getClientIP(),
|
||||
)
|
||||
|
||||
logger.info("[session %s] Registered userid %s", session_id, user_id)
|
||||
logger.info(
|
||||
"[session %s] Registered userid %s with attributes %s",
|
||||
session_id,
|
||||
user_id,
|
||||
attributes,
|
||||
)
|
||||
|
||||
# delete the mapping session and the cookie
|
||||
del self._username_mapping_sessions[session_id]
|
||||
|
@ -715,11 +851,21 @@ class SsoHandler:
|
|||
path=b"/",
|
||||
)
|
||||
|
||||
auth_result = {}
|
||||
if session.terms_accepted_version:
|
||||
# TODO: make this less awful.
|
||||
auth_result[LoginType.TERMS] = True
|
||||
|
||||
await self._registration_handler.post_registration_actions(
|
||||
user_id, auth_result, access_token=None
|
||||
)
|
||||
|
||||
await self._auth_handler.complete_sso_login(
|
||||
user_id,
|
||||
request,
|
||||
session.client_redirect_url,
|
||||
session.extra_login_attributes,
|
||||
new_user=True,
|
||||
)
|
||||
|
||||
def _expire_old_sessions(self):
|
||||
|
@ -733,3 +879,14 @@ class SsoHandler:
|
|||
for session_id in to_expire:
|
||||
logger.info("Expiring mapping session %s", session_id)
|
||||
del self._username_mapping_sessions[session_id]
|
||||
|
||||
|
||||
def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
|
||||
"""Extract the session ID from the cookie
|
||||
|
||||
Raises a SynapseError if the cookie isn't found
|
||||
"""
|
||||
session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME)
|
||||
if not session_id:
|
||||
raise SynapseError(code=400, msg="missing session_id")
|
||||
return session_id.decode("ascii", errors="replace")
|
||||
|
|
|
@ -14,15 +14,25 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StateDeltasHandler:
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def _get_key_change(self, prev_event_id, event_id, key_name, public_value):
|
||||
async def _get_key_change(
|
||||
self,
|
||||
prev_event_id: Optional[str],
|
||||
event_id: Optional[str],
|
||||
key_name: str,
|
||||
public_value: str,
|
||||
) -> Optional[bool]:
|
||||
"""Given two events check if the `key_name` field in content changed
|
||||
from not matching `public_value` to doing so.
|
||||
|
||||
|
|
|
@ -12,13 +12,19 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from collections import Counter
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple
|
||||
|
||||
from typing_extensions import Counter as CounterType
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.metrics import event_processing_positions
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -31,7 +37,7 @@ class StatsHandler:
|
|||
Heavily derived from UserDirectoryHandler
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.state = hs.get_state_handler()
|
||||
|
@ -44,7 +50,7 @@ class StatsHandler:
|
|||
self.stats_enabled = hs.config.stats_enabled
|
||||
|
||||
# The current position in the current_state_delta stream
|
||||
self.pos = None
|
||||
self.pos = None # type: Optional[int]
|
||||
|
||||
# Guard to ensure we only process deltas one at a time
|
||||
self._is_processing = False
|
||||
|
@ -56,7 +62,7 @@ class StatsHandler:
|
|||
# we start populating stats
|
||||
self.clock.call_later(0, self.notify_new_event)
|
||||
|
||||
def notify_new_event(self):
|
||||
def notify_new_event(self) -> None:
|
||||
"""Called when there may be more deltas to process
|
||||
"""
|
||||
if not self.stats_enabled or self._is_processing:
|
||||
|
@ -72,7 +78,7 @@ class StatsHandler:
|
|||
|
||||
run_as_background_process("stats.notify_new_event", process)
|
||||
|
||||
async def _unsafe_process(self):
|
||||
async def _unsafe_process(self) -> None:
|
||||
# If self.pos is None then means we haven't fetched it from DB
|
||||
if self.pos is None:
|
||||
self.pos = await self.store.get_stats_positions()
|
||||
|
@ -110,10 +116,10 @@ class StatsHandler:
|
|||
)
|
||||
|
||||
for room_id, fields in room_count.items():
|
||||
room_deltas.setdefault(room_id, {}).update(fields)
|
||||
room_deltas.setdefault(room_id, Counter()).update(fields)
|
||||
|
||||
for user_id, fields in user_count.items():
|
||||
user_deltas.setdefault(user_id, {}).update(fields)
|
||||
user_deltas.setdefault(user_id, Counter()).update(fields)
|
||||
|
||||
logger.debug("room_deltas: %s", room_deltas)
|
||||
logger.debug("user_deltas: %s", user_deltas)
|
||||
|
@ -131,19 +137,20 @@ class StatsHandler:
|
|||
|
||||
self.pos = max_pos
|
||||
|
||||
async def _handle_deltas(self, deltas):
|
||||
async def _handle_deltas(
|
||||
self, deltas: Iterable[JsonDict]
|
||||
) -> Tuple[Dict[str, CounterType[str]], Dict[str, CounterType[str]]]:
|
||||
"""Called with the state deltas to process
|
||||
|
||||
Returns:
|
||||
tuple[dict[str, Counter], dict[str, counter]]
|
||||
Two dicts: the room deltas and the user deltas,
|
||||
mapping from room/user ID to changes in the various fields.
|
||||
"""
|
||||
|
||||
room_to_stats_deltas = {}
|
||||
user_to_stats_deltas = {}
|
||||
room_to_stats_deltas = {} # type: Dict[str, CounterType[str]]
|
||||
user_to_stats_deltas = {} # type: Dict[str, CounterType[str]]
|
||||
|
||||
room_to_state_updates = {}
|
||||
room_to_state_updates = {} # type: Dict[str, Dict[str, Any]]
|
||||
|
||||
for delta in deltas:
|
||||
typ = delta["type"]
|
||||
|
@ -173,7 +180,7 @@ class StatsHandler:
|
|||
)
|
||||
continue
|
||||
|
||||
event_content = {}
|
||||
event_content = {} # type: JsonDict
|
||||
|
||||
sender = None
|
||||
if event_id is not None:
|
||||
|
@ -257,13 +264,13 @@ class StatsHandler:
|
|||
)
|
||||
|
||||
if has_changed_joinedness:
|
||||
delta = +1 if membership == Membership.JOIN else -1
|
||||
membership_delta = +1 if membership == Membership.JOIN else -1
|
||||
|
||||
user_to_stats_deltas.setdefault(user_id, Counter())[
|
||||
"joined_rooms"
|
||||
] += delta
|
||||
] += membership_delta
|
||||
|
||||
room_stats_delta["local_users_in_room"] += delta
|
||||
room_stats_delta["local_users_in_room"] += membership_delta
|
||||
|
||||
elif typ == EventTypes.Create:
|
||||
room_state["is_federatable"] = (
|
||||
|
|
|
@ -15,13 +15,13 @@
|
|||
import logging
|
||||
import random
|
||||
from collections import namedtuple
|
||||
from typing import TYPE_CHECKING, List, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
from synapse.api.errors import AuthError, ShadowBanError, SynapseError
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.tcp.streams import TypingStream
|
||||
from synapse.types import JsonDict, UserID, get_domain_from_id
|
||||
from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
from synapse.util.metrics import Measure
|
||||
from synapse.util.wheel_timer import WheelTimer
|
||||
|
@ -65,17 +65,17 @@ class FollowerTypingHandler:
|
|||
)
|
||||
|
||||
# map room IDs to serial numbers
|
||||
self._room_serials = {}
|
||||
self._room_serials = {} # type: Dict[str, int]
|
||||
# map room IDs to sets of users currently typing
|
||||
self._room_typing = {}
|
||||
self._room_typing = {} # type: Dict[str, Set[str]]
|
||||
|
||||
self._member_last_federation_poke = {}
|
||||
self._member_last_federation_poke = {} # type: Dict[RoomMember, int]
|
||||
self.wheel_timer = WheelTimer(bucket_size=5000)
|
||||
self._latest_room_serial = 0
|
||||
|
||||
self.clock.looping_call(self._handle_timeouts, 5000)
|
||||
|
||||
def _reset(self):
|
||||
def _reset(self) -> None:
|
||||
"""Reset the typing handler's data caches.
|
||||
"""
|
||||
# map room IDs to serial numbers
|
||||
|
@ -86,7 +86,7 @@ class FollowerTypingHandler:
|
|||
self._member_last_federation_poke = {}
|
||||
self.wheel_timer = WheelTimer(bucket_size=5000)
|
||||
|
||||
def _handle_timeouts(self):
|
||||
def _handle_timeouts(self) -> None:
|
||||
logger.debug("Checking for typing timeouts")
|
||||
|
||||
now = self.clock.time_msec()
|
||||
|
@ -96,7 +96,7 @@ class FollowerTypingHandler:
|
|||
for member in members:
|
||||
self._handle_timeout_for_member(now, member)
|
||||
|
||||
def _handle_timeout_for_member(self, now: int, member: RoomMember):
|
||||
def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
|
||||
if not self.is_typing(member):
|
||||
# Nothing to do if they're no longer typing
|
||||
return
|
||||
|
@ -114,10 +114,10 @@ class FollowerTypingHandler:
|
|||
# each person typing.
|
||||
self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
|
||||
|
||||
def is_typing(self, member):
|
||||
def is_typing(self, member: RoomMember) -> bool:
|
||||
return member.user_id in self._room_typing.get(member.room_id, [])
|
||||
|
||||
async def _push_remote(self, member, typing):
|
||||
async def _push_remote(self, member: RoomMember, typing: bool) -> None:
|
||||
if not self.federation:
|
||||
return
|
||||
|
||||
|
@ -148,7 +148,7 @@ class FollowerTypingHandler:
|
|||
|
||||
def process_replication_rows(
|
||||
self, token: int, rows: List[TypingStream.TypingStreamRow]
|
||||
):
|
||||
) -> None:
|
||||
"""Should be called whenever we receive updates for typing stream.
|
||||
"""
|
||||
|
||||
|
@ -178,7 +178,7 @@ class FollowerTypingHandler:
|
|||
|
||||
async def _send_changes_in_typing_to_remotes(
|
||||
self, room_id: str, prev_typing: Set[str], now_typing: Set[str]
|
||||
):
|
||||
) -> None:
|
||||
"""Process a change in typing of a room from replication, sending EDUs
|
||||
for any local users.
|
||||
"""
|
||||
|
@ -194,12 +194,12 @@ class FollowerTypingHandler:
|
|||
if self.is_mine_id(user_id):
|
||||
await self._push_remote(RoomMember(room_id, user_id), False)
|
||||
|
||||
def get_current_token(self):
|
||||
def get_current_token(self) -> int:
|
||||
return self._latest_room_serial
|
||||
|
||||
|
||||
class TypingWriterHandler(FollowerTypingHandler):
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
|
||||
assert hs.config.worker.writers.typing == hs.get_instance_name()
|
||||
|
@ -213,14 +213,15 @@ class TypingWriterHandler(FollowerTypingHandler):
|
|||
|
||||
hs.get_distributor().observe("user_left_room", self.user_left_room)
|
||||
|
||||
self._member_typing_until = {} # clock time we expect to stop
|
||||
# clock time we expect to stop
|
||||
self._member_typing_until = {} # type: Dict[RoomMember, int]
|
||||
|
||||
# caches which room_ids changed at which serials
|
||||
self._typing_stream_change_cache = StreamChangeCache(
|
||||
"TypingStreamChangeCache", self._latest_room_serial
|
||||
)
|
||||
|
||||
def _handle_timeout_for_member(self, now: int, member: RoomMember):
|
||||
def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
|
||||
super()._handle_timeout_for_member(now, member)
|
||||
|
||||
if not self.is_typing(member):
|
||||
|
@ -233,7 +234,9 @@ class TypingWriterHandler(FollowerTypingHandler):
|
|||
self._stopped_typing(member)
|
||||
return
|
||||
|
||||
async def started_typing(self, target_user, requester, room_id, timeout):
|
||||
async def started_typing(
|
||||
self, target_user: UserID, requester: Requester, room_id: str, timeout: int
|
||||
) -> None:
|
||||
target_user_id = target_user.to_string()
|
||||
auth_user_id = requester.user.to_string()
|
||||
|
||||
|
@ -263,11 +266,13 @@ class TypingWriterHandler(FollowerTypingHandler):
|
|||
|
||||
if was_present:
|
||||
# No point sending another notification
|
||||
return None
|
||||
return
|
||||
|
||||
self._push_update(member=member, typing=True)
|
||||
|
||||
async def stopped_typing(self, target_user, requester, room_id):
|
||||
async def stopped_typing(
|
||||
self, target_user: UserID, requester: Requester, room_id: str
|
||||
) -> None:
|
||||
target_user_id = target_user.to_string()
|
||||
auth_user_id = requester.user.to_string()
|
||||
|
||||
|
@ -290,23 +295,23 @@ class TypingWriterHandler(FollowerTypingHandler):
|
|||
|
||||
self._stopped_typing(member)
|
||||
|
||||
def user_left_room(self, user, room_id):
|
||||
def user_left_room(self, user: UserID, room_id: str) -> None:
|
||||
user_id = user.to_string()
|
||||
if self.is_mine_id(user_id):
|
||||
member = RoomMember(room_id=room_id, user_id=user_id)
|
||||
self._stopped_typing(member)
|
||||
|
||||
def _stopped_typing(self, member):
|
||||
def _stopped_typing(self, member: RoomMember) -> None:
|
||||
if member.user_id not in self._room_typing.get(member.room_id, set()):
|
||||
# No point
|
||||
return None
|
||||
return
|
||||
|
||||
self._member_typing_until.pop(member, None)
|
||||
self._member_last_federation_poke.pop(member, None)
|
||||
|
||||
self._push_update(member=member, typing=False)
|
||||
|
||||
def _push_update(self, member, typing):
|
||||
def _push_update(self, member: RoomMember, typing: bool) -> None:
|
||||
if self.hs.is_mine_id(member.user_id):
|
||||
# Only send updates for changes to our own users.
|
||||
run_as_background_process(
|
||||
|
@ -315,7 +320,7 @@ class TypingWriterHandler(FollowerTypingHandler):
|
|||
|
||||
self._push_update_local(member=member, typing=typing)
|
||||
|
||||
async def _recv_edu(self, origin, content):
|
||||
async def _recv_edu(self, origin: str, content: JsonDict) -> None:
|
||||
room_id = content["room_id"]
|
||||
user_id = content["user_id"]
|
||||
|
||||
|
@ -340,7 +345,7 @@ class TypingWriterHandler(FollowerTypingHandler):
|
|||
self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT)
|
||||
self._push_update_local(member=member, typing=content["typing"])
|
||||
|
||||
def _push_update_local(self, member, typing):
|
||||
def _push_update_local(self, member: RoomMember, typing: bool) -> None:
|
||||
room_set = self._room_typing.setdefault(member.room_id, set())
|
||||
if typing:
|
||||
room_set.add(member.user_id)
|
||||
|
@ -386,7 +391,7 @@ class TypingWriterHandler(FollowerTypingHandler):
|
|||
|
||||
changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
|
||||
last_id
|
||||
)
|
||||
) # type: Optional[Iterable[str]]
|
||||
|
||||
if changed_rooms is None:
|
||||
changed_rooms = self._room_serials
|
||||
|
@ -412,13 +417,13 @@ class TypingWriterHandler(FollowerTypingHandler):
|
|||
|
||||
def process_replication_rows(
|
||||
self, token: int, rows: List[TypingStream.TypingStreamRow]
|
||||
):
|
||||
) -> None:
|
||||
# The writing process should never get updates from replication.
|
||||
raise Exception("Typing writer instance got typing info over replication")
|
||||
|
||||
|
||||
class TypingNotificationEventSource:
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.clock = hs.get_clock()
|
||||
# We can't call get_typing_handler here because there's a cycle:
|
||||
|
@ -427,7 +432,7 @@ class TypingNotificationEventSource:
|
|||
#
|
||||
self.get_typing_handler = hs.get_typing_handler
|
||||
|
||||
def _make_event_for(self, room_id):
|
||||
def _make_event_for(self, room_id: str) -> JsonDict:
|
||||
typing = self.get_typing_handler()._room_typing[room_id]
|
||||
return {
|
||||
"type": "m.typing",
|
||||
|
@ -462,7 +467,9 @@ class TypingNotificationEventSource:
|
|||
|
||||
return (events, handler._latest_room_serial)
|
||||
|
||||
async def get_new_events(self, from_key, room_ids, **kwargs):
|
||||
async def get_new_events(
|
||||
self, from_key: int, room_ids: Iterable[str], **kwargs
|
||||
) -> Tuple[List[JsonDict], int]:
|
||||
with Measure(self.clock, "typing.get_new_events"):
|
||||
from_key = int(from_key)
|
||||
handler = self.get_typing_handler()
|
||||
|
@ -478,5 +485,5 @@ class TypingNotificationEventSource:
|
|||
|
||||
return (events, handler._latest_room_serial)
|
||||
|
||||
def get_current_key(self):
|
||||
def get_current_key(self) -> int:
|
||||
return self.get_typing_handler()._latest_room_serial
|
||||
|
|
|
@ -145,10 +145,6 @@ class UserDirectoryHandler(StateDeltasHandler):
|
|||
if self.pos is None:
|
||||
self.pos = await self.store.get_user_directory_stream_pos()
|
||||
|
||||
# If still None then the initial background update hasn't happened yet
|
||||
if self.pos is None:
|
||||
return None
|
||||
|
||||
# Loop round handling deltas until we're up to date
|
||||
while True:
|
||||
with Measure(self.clock, "user_dir_delta"):
|
||||
|
@ -233,6 +229,11 @@ class UserDirectoryHandler(StateDeltasHandler):
|
|||
|
||||
if change: # The user joined
|
||||
event = await self.store.get_event(event_id, allow_none=True)
|
||||
# It isn't expected for this event to not exist, but we
|
||||
# don't want the entire background process to break.
|
||||
if event is None:
|
||||
continue
|
||||
|
||||
profile = ProfileInfo(
|
||||
avatar_url=event.content.get("avatar_url"),
|
||||
display_name=event.content.get("displayname"),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue