Merge branch 'develop' into babolivier/msc3026

This commit is contained in:
Brendan Abolivier 2021-03-19 16:12:40 +01:00
commit 592d6305fd
No known key found for this signature in database
GPG Key ID: 1E015C145F1916CD
18 changed files with 574 additions and 165 deletions

1
changelog.d/9636.bugfix Normal file
View File

@ -0,0 +1 @@
Checks if passwords are allowed before setting it for the user.

1
changelog.d/9640.misc Normal file
View File

@ -0,0 +1 @@
Improve performance of federation catch up by sending events the latest events in the room to the remote, rather than just the last event sent by the local server.

1
changelog.d/9643.feature Normal file
View File

@ -0,0 +1 @@
Add initial experimental support for a "space summary" API.

1
changelog.d/9645.misc Normal file
View File

@ -0,0 +1 @@
In the `federation_client` commandline client, stop automatically adding the URL prefix, so that servlets on other prefixes can be tested.

1
changelog.d/9647.misc Normal file
View File

@ -0,0 +1 @@
In the `federation_client` commandline client, handle inline `signing_key`s in `homeserver.yaml`.

View File

@ -22,8 +22,8 @@ import sys
from typing import Any, Optional from typing import Any, Optional
from urllib import parse as urlparse from urllib import parse as urlparse
import nacl.signing
import requests import requests
import signedjson.key
import signedjson.types import signedjson.types
import srvlookup import srvlookup
import yaml import yaml
@ -44,18 +44,6 @@ def encode_base64(input_bytes):
return output_string return output_string
def decode_base64(input_string):
"""Decode a base64 string to bytes inferring padding from the length of the
string."""
input_bytes = input_string.encode("ascii")
input_len = len(input_bytes)
padding = b"=" * (3 - ((input_len + 3) % 4))
output_len = 3 * ((input_len + 2) // 4) + (input_len + 2) % 4 - 2
output_bytes = base64.b64decode(input_bytes + padding)
return output_bytes[:output_len]
def encode_canonical_json(value): def encode_canonical_json(value):
return json.dumps( return json.dumps(
value, value,
@ -88,42 +76,6 @@ def sign_json(
return json_object return json_object
NACL_ED25519 = "ed25519"
def decode_signing_key_base64(algorithm, version, key_base64):
"""Decode a base64 encoded signing key
Args:
algorithm (str): The algorithm the key is for (currently "ed25519").
version (str): Identifies this key out of the keys for this entity.
key_base64 (str): Base64 encoded bytes of the key.
Returns:
A SigningKey object.
"""
if algorithm == NACL_ED25519:
key_bytes = decode_base64(key_base64)
key = nacl.signing.SigningKey(key_bytes)
key.version = version
key.alg = NACL_ED25519
return key
else:
raise ValueError("Unsupported algorithm %s" % (algorithm,))
def read_signing_keys(stream):
"""Reads a list of keys from a stream
Args:
stream : A stream to iterate for keys.
Returns:
list of SigningKey objects.
"""
keys = []
for line in stream:
algorithm, version, key_base64 = line.split()
keys.append(decode_signing_key_base64(algorithm, version, key_base64))
return keys
def request( def request(
method: Optional[str], method: Optional[str],
origin_name: str, origin_name: str,
@ -223,23 +175,28 @@ def main():
parser.add_argument("--body", help="Data to send as the body of the HTTP request") parser.add_argument("--body", help="Data to send as the body of the HTTP request")
parser.add_argument( parser.add_argument(
"path", help="request path. We will add '/_matrix/federation/v1/' to this." "path", help="request path, including the '/_matrix/federation/...' prefix."
) )
args = parser.parse_args() args = parser.parse_args()
if not args.server_name or not args.signing_key_path: args.signing_key = None
if args.signing_key_path:
with open(args.signing_key_path) as f:
args.signing_key = f.readline()
if not args.server_name or not args.signing_key:
read_args_from_config(args) read_args_from_config(args)
with open(args.signing_key_path) as f: algorithm, version, key_base64 = args.signing_key.split()
key = read_signing_keys(f)[0] key = signedjson.key.decode_signing_key_base64(algorithm, version, key_base64)
result = request( result = request(
args.method, args.method,
args.server_name, args.server_name,
key, key,
args.destination, args.destination,
"/_matrix/federation/v1/" + args.path, args.path,
content=args.body, content=args.body,
) )
@ -255,10 +212,16 @@ def main():
def read_args_from_config(args): def read_args_from_config(args):
with open(args.config, "r") as fh: with open(args.config, "r") as fh:
config = yaml.safe_load(fh) config = yaml.safe_load(fh)
if not args.server_name: if not args.server_name:
args.server_name = config["server_name"] args.server_name = config["server_name"]
if not args.signing_key_path:
args.signing_key_path = config["signing_key_path"] if not args.signing_key:
if "signing_key" in config:
args.signing_key = config["signing_key"]
else:
with open(config["signing_key_path"]) as f:
args.signing_key = f.readline()
class MatrixConnectionAdapter(HTTPAdapter): class MatrixConnectionAdapter(HTTPAdapter):

View File

@ -101,6 +101,9 @@ class EventTypes:
Dummy = "org.matrix.dummy_event" Dummy = "org.matrix.dummy_event"
MSC1772_SPACE_CHILD = "org.matrix.msc1772.space.child"
MSC1772_SPACE_PARENT = "org.matrix.msc1772.space.parent"
class EduTypes: class EduTypes:
Presence = "m.presence" Presence = "m.presence"
@ -161,6 +164,9 @@ class EventContentFields:
# cf https://github.com/matrix-org/matrix-doc/pull/2228 # cf https://github.com/matrix-org/matrix-doc/pull/2228
SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after" SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after"
# cf https://github.com/matrix-org/matrix-doc/pull/1772
MSC1772_ROOM_TYPE = "org.matrix.msc1772.type"
class RoomEncryptionAlgorithms: class RoomEncryptionAlgorithms:
MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2" MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2"

View File

@ -27,5 +27,7 @@ class ExperimentalConfig(Config):
# MSC2858 (multiple SSO identity providers) # MSC2858 (multiple SSO identity providers)
self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool
# Spaces (MSC1772, MSC2946, etc)
self.spaces_enabled = experimental.get("spaces_enabled", False) # type: bool
# MSC3026 (busy presence state) # MSC3026 (busy presence state)
self.msc3026_enabled = experimental.get("msc3026_enabled", False) # type: bool self.msc3026_enabled = experimental.get("msc3026_enabled", False) # type: bool

View File

@ -35,7 +35,7 @@ from twisted.internet import defer
from twisted.internet.abstract import isIPAddress from twisted.internet.abstract import isIPAddress
from twisted.python import failure from twisted.python import failure
from synapse.api.constants import EduTypes, EventTypes, Membership from synapse.api.constants import EduTypes, EventTypes
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
Codes, Codes,
@ -63,7 +63,7 @@ from synapse.replication.http.federation import (
ReplicationFederationSendEduRestServlet, ReplicationFederationSendEduRestServlet,
ReplicationGetQueryRestServlet, ReplicationGetQueryRestServlet,
) )
from synapse.types import JsonDict, get_domain_from_id from synapse.types import JsonDict
from synapse.util import glob_to_regex, json_decoder, unwrapFirstError from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
@ -727,27 +727,6 @@ class FederationServer(FederationBase):
if the event was unacceptable for any other reason (eg, too large, if the event was unacceptable for any other reason (eg, too large,
too many prev_events, couldn't find the prev_events) too many prev_events, couldn't find the prev_events)
""" """
# check that it's actually being sent from a valid destination to
# workaround bug #1753 in 0.18.5 and 0.18.6
if origin != get_domain_from_id(pdu.sender):
# We continue to accept join events from any server; this is
# necessary for the federation join dance to work correctly.
# (When we join over federation, the "helper" server is
# responsible for sending out the join event, rather than the
# origin. See bug #1893. This is also true for some third party
# invites).
if not (
pdu.type == "m.room.member"
and pdu.content
and pdu.content.get("membership", None)
in (Membership.JOIN, Membership.INVITE)
):
logger.info(
"Discarding PDU %s from invalid origin %s", pdu.event_id, origin
)
return
else:
logger.info("Accepting join PDU %s from %s", pdu.event_id, origin)
# We've already checked that we know the room version by this point # We've already checked that we know the room version by this point
room_version = await self.store.get_room_version(pdu.room_id) room_version = await self.store.get_room_version(pdu.room_id)

View File

@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
import datetime import datetime
import logging import logging
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, cast from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple
import attr import attr
from prometheus_client import Counter from prometheus_client import Counter
@ -77,6 +77,7 @@ class PerDestinationQueue:
self._transaction_manager = transaction_manager self._transaction_manager = transaction_manager
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
self._federation_shard_config = hs.config.worker.federation_shard_config self._federation_shard_config = hs.config.worker.federation_shard_config
self._state = hs.get_state_handler()
self._should_send_on_this_instance = True self._should_send_on_this_instance = True
if not self._federation_shard_config.should_handle( if not self._federation_shard_config.should_handle(
@ -415,22 +416,95 @@ class PerDestinationQueue:
"This should not happen." % event_ids "This should not happen." % event_ids
) )
if logger.isEnabledFor(logging.INFO): # We send transactions with events from one room only, as its likely
rooms = [p.room_id for p in catchup_pdus] # that the remote will have to do additional processing, which may
logger.info("Catching up rooms to %s: %r", self._destination, rooms) # take some time. It's better to give it small amounts of work
# rather than risk the request timing out and repeatedly being
# retried, and not making any progress.
#
# Note: `catchup_pdus` will have exactly one PDU per room.
for pdu in catchup_pdus:
# The PDU from the DB will be the last PDU in the room from
# *this server* that wasn't sent to the remote. However, other
# servers may have sent lots of events since then, and we want
# to try and tell the remote only about the *latest* events in
# the room. This is so that it doesn't get inundated by events
# from various parts of the DAG, which all need to be processed.
#
# Note: this does mean that in large rooms a server coming back
# online will get sent the same events from all the different
# servers, but the remote will correctly deduplicate them and
# handle it only once.
await self._transaction_manager.send_new_transaction( # Step 1, fetch the current extremities
self._destination, catchup_pdus, [] extrems = await self._store.get_prev_events_for_room(pdu.room_id)
)
sent_transactions_counter.inc() if pdu.event_id in extrems:
final_pdu = catchup_pdus[-1] # If the event is in the extremities, then great! We can just
self._last_successful_stream_ordering = cast( # use that without having to do further checks.
int, final_pdu.internal_metadata.stream_ordering room_catchup_pdus = [pdu]
) else:
await self._store.set_destination_last_successful_stream_ordering( # If not, fetch the extremities and figure out which we can
self._destination, self._last_successful_stream_ordering # send.
) extrem_events = await self._store.get_events_as_list(extrems)
new_pdus = []
for p in extrem_events:
# We pulled this from the DB, so it'll be non-null
assert p.internal_metadata.stream_ordering
# Filter out events that happened before the remote went
# offline
if (
p.internal_metadata.stream_ordering
< self._last_successful_stream_ordering
):
continue
# Filter out events where the server is not in the room,
# e.g. it may have left/been kicked. *Ideally* we'd pull
# out the kick and send that, but it's a rare edge case
# so we don't bother for now (the server that sent the
# kick should send it out if its online).
hosts = await self._state.get_hosts_in_room_at_events(
p.room_id, [p.event_id]
)
if self._destination not in hosts:
continue
new_pdus.append(p)
# If we've filtered out all the extremities, fall back to
# sending the original event. This should ensure that the
# server gets at least some of missed events (especially if
# the other sending servers are up).
if new_pdus:
room_catchup_pdus = new_pdus
logger.info(
"Catching up rooms to %s: %r", self._destination, pdu.room_id
)
await self._transaction_manager.send_new_transaction(
self._destination, room_catchup_pdus, []
)
sent_transactions_counter.inc()
# We pulled this from the DB, so it'll be non-null
assert pdu.internal_metadata.stream_ordering
# Note that we mark the last successful stream ordering as that
# from the *original* PDU, rather than the PDU(s) we actually
# send. This is because we use it to mark our position in the
# queue of missed PDUs to process.
self._last_successful_stream_ordering = (
pdu.internal_metadata.stream_ordering
)
await self._store.set_destination_last_successful_stream_ordering(
self._destination, self._last_successful_stream_ordering
)
def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]: def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]:
if not self._pending_rrs: if not self._pending_rrs:

View File

@ -41,7 +41,7 @@ class SetPasswordHandler(BaseHandler):
logout_devices: bool, logout_devices: bool,
requester: Optional[Requester] = None, requester: Optional[Requester] = None,
) -> None: ) -> None:
if not self.hs.config.password_localdb_enabled: if not self._auth_handler.can_change_password():
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN) raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
try: try:

View File

@ -0,0 +1,199 @@
# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import logging
from collections import deque
from typing import TYPE_CHECKING, Iterable, List, Optional, Set
from synapse.api.constants import EventContentFields, EventTypes, HistoryVisibility
from synapse.api.errors import AuthError
from synapse.events import EventBase
from synapse.events.utils import format_event_for_client_v2
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
# number of rooms to return. We'll stop once we hit this limit.
# TODO: allow clients to reduce this with a request param.
MAX_ROOMS = 50
# max number of events to return per room.
MAX_ROOMS_PER_SPACE = 50
class SpaceSummaryHandler:
def __init__(self, hs: "HomeServer"):
self._clock = hs.get_clock()
self._auth = hs.get_auth()
self._room_list_handler = hs.get_room_list_handler()
self._state_handler = hs.get_state_handler()
self._store = hs.get_datastore()
self._event_serializer = hs.get_event_client_serializer()
async def get_space_summary(
self,
requester: str,
room_id: str,
suggested_only: bool = False,
max_rooms_per_space: Optional[int] = None,
) -> JsonDict:
"""
Implementation of the space summary API
Args:
requester: user id of the user making this request
room_id: room id to start the summary at
suggested_only: whether we should only return children with the "suggested"
flag set.
max_rooms_per_space: an optional limit on the number of child rooms we will
return. This does not apply to the root room (ie, room_id), and
is overridden by ROOMS_PER_SPACE_LIMIT.
Returns:
summary dict to return
"""
# first of all, check that the user is in the room in question (or it's
# world-readable)
await self._auth.check_user_in_room_or_world_readable(room_id, requester)
# the queue of rooms to process
room_queue = deque((room_id,))
processed_rooms = set() # type: Set[str]
rooms_result = [] # type: List[JsonDict]
events_result = [] # type: List[JsonDict]
now = self._clock.time_msec()
while room_queue and len(rooms_result) < MAX_ROOMS:
room_id = room_queue.popleft()
logger.debug("Processing room %s", room_id)
processed_rooms.add(room_id)
try:
await self._auth.check_user_in_room_or_world_readable(
room_id, requester
)
except AuthError:
logger.info(
"user %s cannot view room %s, omitting from summary",
requester,
room_id,
)
continue
room_entry = await self._build_room_entry(room_id)
rooms_result.append(room_entry)
# look for child rooms/spaces.
child_events = await self._get_child_events(room_id)
if suggested_only:
# we only care about suggested children
child_events = filter(_is_suggested_child_event, child_events)
# The client-specified max_rooms_per_space limit doesn't apply to the
# room_id specified in the request, so we ignore it if this is the
# first room we are processing. Otherwise, apply any client-specified
# limit, capping to our built-in limit.
if max_rooms_per_space is not None and len(processed_rooms) > 1:
max_rooms = min(MAX_ROOMS_PER_SPACE, max_rooms_per_space)
else:
max_rooms = MAX_ROOMS_PER_SPACE
for edge_event in itertools.islice(child_events, max_rooms):
edge_room_id = edge_event.state_key
events_result.append(
await self._event_serializer.serialize_event(
edge_event,
time_now=now,
event_format=format_event_for_client_v2,
)
)
# if we haven't yet visited the target of this link, add it to the queue
if edge_room_id not in processed_rooms:
room_queue.append(edge_room_id)
return {"rooms": rooms_result, "events": events_result}
async def _build_room_entry(self, room_id: str) -> JsonDict:
"""Generate en entry suitable for the 'rooms' list in the summary response"""
stats = await self._store.get_room_with_stats(room_id)
# currently this should be impossible because we call
# check_user_in_room_or_world_readable on the room before we get here, so
# there should always be an entry
assert stats is not None, "unable to retrieve stats for %s" % (room_id,)
current_state_ids = await self._store.get_current_state_ids(room_id)
create_event = await self._store.get_event(
current_state_ids[(EventTypes.Create, "")]
)
# TODO: update once MSC1772 lands
room_type = create_event.content.get(EventContentFields.MSC1772_ROOM_TYPE)
entry = {
"room_id": stats["room_id"],
"name": stats["name"],
"topic": stats["topic"],
"canonical_alias": stats["canonical_alias"],
"num_joined_members": stats["joined_members"],
"avatar_url": stats["avatar"],
"world_readable": (
stats["history_visibility"] == HistoryVisibility.WORLD_READABLE
),
"guest_can_join": stats["guest_access"] == "can_join",
"room_type": room_type,
}
# Filter out Nones rather omit the field altogether
room_entry = {k: v for k, v in entry.items() if v is not None}
return room_entry
async def _get_child_events(self, room_id: str) -> Iterable[EventBase]:
# look for child rooms/spaces.
current_state_ids = await self._store.get_current_state_ids(room_id)
events = await self._store.get_events_as_list(
[
event_id
for key, event_id in current_state_ids.items()
# TODO: update once MSC1772 lands
if key[0] == EventTypes.MSC1772_SPACE_CHILD
]
)
# filter out any events without a "via" (which implies it has been redacted)
return (e for e in events if e.content.get("via"))
def _is_suggested_child_event(edge_event: EventBase) -> bool:
suggested = edge_event.content.get("suggested")
if isinstance(suggested, bool) and suggested:
return True
logger.debug("Ignorning not-suggested child %s", edge_event.state_key)
return False

View File

@ -271,7 +271,7 @@ class UserRestServletV2(RestServlet):
elif not deactivate and user["deactivated"]: elif not deactivate and user["deactivated"]:
if ( if (
"password" not in body "password" not in body
and self.hs.config.password_localdb_enabled and self.auth_handler.can_change_password()
): ):
raise SynapseError( raise SynapseError(
400, "Must provide a password to re-activate an account." 400, "Must provide a password to re-activate an account."

View File

@ -18,7 +18,7 @@
import logging import logging
import re import re
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional, Tuple
from urllib import parse as urlparse from urllib import parse as urlparse
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
@ -35,16 +35,25 @@ from synapse.events.utils import format_event_for_client_v2
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
assert_params_in_dict, assert_params_in_dict,
parse_boolean,
parse_integer, parse_integer,
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
) )
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import set_tag from synapse.logging.opentracing import set_tag
from synapse.rest.client.transactions import HttpTransactionCache from synapse.rest.client.transactions import HttpTransactionCache
from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID from synapse.types import (
JsonDict,
RoomAlias,
RoomID,
StreamToken,
ThirdPartyInstanceID,
UserID,
)
from synapse.util import json_decoder from synapse.util import json_decoder
from synapse.util.stringutils import parse_and_validate_server_name, random_string from synapse.util.stringutils import parse_and_validate_server_name, random_string
@ -987,7 +996,58 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
) )
def register_servlets(hs, http_server, is_worker=False): class RoomSpaceSummaryRestServlet(RestServlet):
PATTERNS = (
re.compile(
"^/_matrix/client/unstable/org.matrix.msc2946"
"/rooms/(?P<room_id>[^/]*)/spaces$"
),
)
def __init__(self, hs: "HomeServer"):
super().__init__()
self._auth = hs.get_auth()
self._space_summary_handler = hs.get_space_summary_handler()
async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request, allow_guest=True)
return 200, await self._space_summary_handler.get_space_summary(
requester.user.to_string(),
room_id,
suggested_only=parse_boolean(request, "suggested_only", default=False),
max_rooms_per_space=parse_integer(request, "max_rooms_per_space"),
)
async def on_POST(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
suggested_only = content.get("suggested_only", False)
if not isinstance(suggested_only, bool):
raise SynapseError(
400, "'suggested_only' must be a boolean", Codes.BAD_JSON
)
max_rooms_per_space = content.get("max_rooms_per_space")
if max_rooms_per_space is not None and not isinstance(max_rooms_per_space, int):
raise SynapseError(
400, "'max_rooms_per_space' must be an integer", Codes.BAD_JSON
)
return 200, await self._space_summary_handler.get_space_summary(
requester.user.to_string(),
room_id,
suggested_only=suggested_only,
max_rooms_per_space=max_rooms_per_space,
)
def register_servlets(hs: "HomeServer", http_server, is_worker=False):
RoomStateEventRestServlet(hs).register(http_server) RoomStateEventRestServlet(hs).register(http_server)
RoomMemberListRestServlet(hs).register(http_server) RoomMemberListRestServlet(hs).register(http_server)
JoinedRoomMemberListRestServlet(hs).register(http_server) JoinedRoomMemberListRestServlet(hs).register(http_server)
@ -1001,6 +1061,9 @@ def register_servlets(hs, http_server, is_worker=False):
RoomTypingRestServlet(hs).register(http_server) RoomTypingRestServlet(hs).register(http_server)
RoomEventContextServlet(hs).register(http_server) RoomEventContextServlet(hs).register(http_server)
if hs.config.experimental.spaces_enabled:
RoomSpaceSummaryRestServlet(hs).register(http_server)
# Some servlets only get registered for the main process. # Some servlets only get registered for the main process.
if not is_worker: if not is_worker:
RoomCreateRestServlet(hs).register(http_server) RoomCreateRestServlet(hs).register(http_server)

View File

@ -100,6 +100,7 @@ from synapse.handlers.room_member import RoomMemberHandler, RoomMemberMasterHand
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
from synapse.handlers.search import SearchHandler from synapse.handlers.search import SearchHandler
from synapse.handlers.set_password import SetPasswordHandler from synapse.handlers.set_password import SetPasswordHandler
from synapse.handlers.space_summary import SpaceSummaryHandler
from synapse.handlers.sso import SsoHandler from synapse.handlers.sso import SsoHandler
from synapse.handlers.stats import StatsHandler from synapse.handlers.stats import StatsHandler
from synapse.handlers.sync import SyncHandler from synapse.handlers.sync import SyncHandler
@ -732,6 +733,10 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_account_data_handler(self) -> AccountDataHandler: def get_account_data_handler(self) -> AccountDataHandler:
return AccountDataHandler(self) return AccountDataHandler(self)
@cache_in_self
def get_space_summary_handler(self) -> SpaceSummaryHandler:
return SpaceSummaryHandler(self)
@cache_in_self @cache_in_self
def get_external_cache(self) -> ExternalCache: def get_external_cache(self) -> ExternalCache:
return ExternalCache(self) return ExternalCache(self)

View File

@ -1210,6 +1210,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(
txn, self.get_user_deactivated_status, (user_id,) txn, self.get_user_deactivated_status, (user_id,)
) )
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,)) txn.call_after(self.is_guest.invalidate, (user_id,))
@cached() @cached()

View File

@ -2,6 +2,7 @@ from typing import List, Tuple
from mock import Mock from mock import Mock
from synapse.api.constants import EventTypes
from synapse.events import EventBase from synapse.events import EventBase
from synapse.federation.sender import PerDestinationQueue, TransactionManager from synapse.federation.sender import PerDestinationQueue, TransactionManager
from synapse.federation.units import Edu from synapse.federation.units import Edu
@ -421,3 +422,51 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
self.assertNotIn("zzzerver", woken) self.assertNotIn("zzzerver", woken)
# - all destinations are woken exactly once; they appear once in woken. # - all destinations are woken exactly once; they appear once in woken.
self.assertCountEqual(woken, server_names[:-1]) self.assertCountEqual(woken, server_names[:-1])
@override_config({"send_federation": True})
def test_not_latest_event(self):
"""Test that we send the latest event in the room even if its not ours."""
per_dest_queue, sent_pdus = self.make_fake_destination_queue()
# Make a room with a local user, and two servers. One will go offline
# and one will send some events.
self.register_user("u1", "you the one")
u1_token = self.login("u1", "you the one")
room_1 = self.helper.create_room_as("u1", tok=u1_token)
self.get_success(
event_injection.inject_member_event(self.hs, room_1, "@user:host2", "join")
)
event_1 = self.get_success(
event_injection.inject_member_event(self.hs, room_1, "@user:host3", "join")
)
# First we send something from the local server, so that we notice the
# remote is down and go into catchup mode.
self.helper.send(room_1, "you hear me!!", tok=u1_token)
# Now simulate us receiving an event from the still online remote.
event_2 = self.get_success(
event_injection.inject_event(
self.hs,
type=EventTypes.Message,
sender="@user:host3",
room_id=room_1,
content={"msgtype": "m.text", "body": "Hello"},
)
)
self.get_success(
self.hs.get_datastore().set_destination_last_successful_stream_ordering(
"host2", event_1.internal_metadata.stream_ordering
)
)
self.get_success(per_dest_queue._catch_up_transmission_loop())
# We expect only the last message from the remote, event_2, to have been
# sent, rather than the last *local* event that was sent.
self.assertEqual(len(sent_pdus), 1)
self.assertEqual(sent_pdus[0].event_id, event_2.event_id)
self.assertFalse(per_dest_queue._catching_up)

View File

@ -1003,12 +1003,23 @@ class UserRestTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth_handler = hs.get_auth_handler()
# create users and get access tokens
# regardless of whether password login or SSO is allowed
self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass") self.admin_user_tok = self.get_success(
self.auth_handler.get_access_token_for_user_id(
self.admin_user, device_id=None, valid_until_ms=None
)
)
self.other_user = self.register_user("user", "pass", displayname="User") self.other_user = self.register_user("user", "pass", displayname="User")
self.other_user_token = self.login("user", "pass") self.other_user_token = self.get_success(
self.auth_handler.get_access_token_for_user_id(
self.other_user, device_id=None, valid_until_ms=None
)
)
self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote( self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote(
self.other_user self.other_user
) )
@ -1081,7 +1092,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertEqual(True, channel.json_body["admin"]) self.assertTrue(channel.json_body["admin"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
# Get user # Get user
@ -1096,9 +1107,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertEqual(True, channel.json_body["admin"]) self.assertTrue(channel.json_body["admin"])
self.assertEqual(False, channel.json_body["is_guest"]) self.assertFalse(channel.json_body["is_guest"])
self.assertEqual(False, channel.json_body["deactivated"]) self.assertFalse(channel.json_body["deactivated"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
def test_create_user(self): def test_create_user(self):
@ -1130,7 +1141,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertEqual(False, channel.json_body["admin"]) self.assertFalse(channel.json_body["admin"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
# Get user # Get user
@ -1145,10 +1156,10 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
self.assertEqual(False, channel.json_body["admin"]) self.assertFalse(channel.json_body["admin"])
self.assertEqual(False, channel.json_body["is_guest"]) self.assertFalse(channel.json_body["is_guest"])
self.assertEqual(False, channel.json_body["deactivated"]) self.assertFalse(channel.json_body["deactivated"])
self.assertEqual(False, channel.json_body["shadow_banned"]) self.assertFalse(channel.json_body["shadow_banned"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
@override_config( @override_config(
@ -1197,7 +1208,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["admin"]) self.assertFalse(channel.json_body["admin"])
@override_config( @override_config(
{"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0} {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
@ -1237,7 +1248,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Admin user is not blocked by mau anymore # Admin user is not blocked by mau anymore
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["admin"]) self.assertFalse(channel.json_body["admin"])
@override_config( @override_config(
{ {
@ -1429,24 +1440,23 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"]) self.assertFalse(channel.json_body["deactivated"])
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User", channel.json_body["displayname"]) self.assertEqual("User", channel.json_body["displayname"])
# Deactivate user # Deactivate user
body = json.dumps({"deactivated": True})
channel = self.make_request( channel = self.make_request(
"PUT", "PUT",
self.url_other_user, self.url_other_user,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"), content={"deactivated": True},
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"]) self.assertTrue(channel.json_body["deactivated"])
self.assertIsNone(channel.json_body["password_hash"])
self.assertEqual(0, len(channel.json_body["threepids"])) self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User", channel.json_body["displayname"]) self.assertEqual("User", channel.json_body["displayname"])
@ -1461,7 +1471,8 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"]) self.assertTrue(channel.json_body["deactivated"])
self.assertIsNone(channel.json_body["password_hash"])
self.assertEqual(0, len(channel.json_body["threepids"])) self.assertEqual(0, len(channel.json_body["threepids"]))
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
self.assertEqual("User", channel.json_body["displayname"]) self.assertEqual("User", channel.json_body["displayname"])
@ -1478,41 +1489,37 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertTrue(profile["display_name"] == "User") self.assertTrue(profile["display_name"] == "User")
# Deactivate user # Deactivate user
body = json.dumps({"deactivated": True})
channel = self.make_request( channel = self.make_request(
"PUT", "PUT",
self.url_other_user, self.url_other_user,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"), content={"deactivated": True},
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"]) self.assertTrue(channel.json_body["deactivated"])
# is not in user directory # is not in user directory
profile = self.get_success(self.store.get_user_in_directory(self.other_user)) profile = self.get_success(self.store.get_user_in_directory(self.other_user))
self.assertTrue(profile is None) self.assertIsNone(profile)
# Set new displayname user # Set new displayname user
body = json.dumps({"displayname": "Foobar"})
channel = self.make_request( channel = self.make_request(
"PUT", "PUT",
self.url_other_user, self.url_other_user,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"), content={"displayname": "Foobar"},
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"]) self.assertTrue(channel.json_body["deactivated"])
self.assertEqual("Foobar", channel.json_body["displayname"]) self.assertEqual("Foobar", channel.json_body["displayname"])
# is not in user directory # is not in user directory
profile = self.get_success(self.store.get_user_in_directory(self.other_user)) profile = self.get_success(self.store.get_user_in_directory(self.other_user))
self.assertTrue(profile is None) self.assertIsNone(profile)
def test_reactivate_user(self): def test_reactivate_user(self):
""" """
@ -1520,24 +1527,14 @@ class UserRestTestCase(unittest.HomeserverTestCase):
""" """
# Deactivate the user. # Deactivate the user.
channel = self.make_request( self._deactivate_user("@user:test")
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content=json.dumps({"deactivated": True}).encode(encoding="utf_8"),
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self._is_erased("@user:test", False)
d = self.store.mark_user_erased("@user:test")
self.assertIsNone(self.get_success(d))
self._is_erased("@user:test", True)
# Attempt to reactivate the user (without a password). # Attempt to reactivate the user (without a password).
channel = self.make_request( channel = self.make_request(
"PUT", "PUT",
self.url_other_user, self.url_other_user,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content=json.dumps({"deactivated": False}).encode(encoding="utf_8"), content={"deactivated": False},
) )
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@ -1546,22 +1543,76 @@ class UserRestTestCase(unittest.HomeserverTestCase):
"PUT", "PUT",
self.url_other_user, self.url_other_user,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content=json.dumps({"deactivated": False, "password": "foo"}).encode( content={"deactivated": False, "password": "foo"},
encoding="utf_8"
),
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Get user
channel = self.make_request(
"GET",
self.url_other_user,
access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"]) self.assertFalse(channel.json_body["deactivated"])
self.assertIsNotNone(channel.json_body["password_hash"])
self._is_erased("@user:test", False)
@override_config({"password_config": {"localdb_enabled": False}})
def test_reactivate_user_localdb_disabled(self):
"""
Test reactivating another user when using SSO.
"""
# Deactivate the user.
self._deactivate_user("@user:test")
# Reactivate the user with a password
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content={"deactivated": False, "password": "foo"},
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Reactivate the user without a password.
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content={"deactivated": False},
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
self.assertIsNone(channel.json_body["password_hash"])
self._is_erased("@user:test", False)
@override_config({"password_config": {"enabled": False}})
def test_reactivate_user_password_disabled(self):
"""
Test reactivating another user when using SSO.
"""
# Deactivate the user.
self._deactivate_user("@user:test")
# Reactivate the user with a password
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content={"deactivated": False, "password": "foo"},
)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Reactivate the user without a password.
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content={"deactivated": False},
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
self.assertFalse(channel.json_body["deactivated"])
self.assertIsNone(channel.json_body["password_hash"])
self._is_erased("@user:test", False) self._is_erased("@user:test", False)
def test_set_user_as_admin(self): def test_set_user_as_admin(self):
@ -1570,18 +1621,16 @@ class UserRestTestCase(unittest.HomeserverTestCase):
""" """
# Set a user as an admin # Set a user as an admin
body = json.dumps({"admin": True})
channel = self.make_request( channel = self.make_request(
"PUT", "PUT",
self.url_other_user, self.url_other_user,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"), content={"admin": True},
) )
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["admin"]) self.assertTrue(channel.json_body["admin"])
# Get user # Get user
channel = self.make_request( channel = self.make_request(
@ -1592,7 +1641,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"]) self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["admin"]) self.assertTrue(channel.json_body["admin"])
def test_accidental_deactivation_prevention(self): def test_accidental_deactivation_prevention(self):
""" """
@ -1602,13 +1651,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
url = "/_synapse/admin/v2/users/@bob:test" url = "/_synapse/admin/v2/users/@bob:test"
# Create user # Create user
body = json.dumps({"password": "abc123"})
channel = self.make_request( channel = self.make_request(
"PUT", "PUT",
url, url,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"), content={"password": "abc123"},
) )
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
@ -1628,13 +1675,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(0, channel.json_body["deactivated"]) self.assertEqual(0, channel.json_body["deactivated"])
# Change password (and use a str for deactivate instead of a bool) # Change password (and use a str for deactivate instead of a bool)
body = json.dumps({"password": "abc123", "deactivated": "false"}) # oops!
channel = self.make_request( channel = self.make_request(
"PUT", "PUT",
url, url,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"), content={"password": "abc123", "deactivated": "false"},
) )
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@ -1653,7 +1698,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Ensure they're still alive # Ensure they're still alive
self.assertEqual(0, channel.json_body["deactivated"]) self.assertEqual(0, channel.json_body["deactivated"])
def _is_erased(self, user_id, expect): def _is_erased(self, user_id: str, expect: bool) -> None:
"""Assert that the user is erased or not""" """Assert that the user is erased or not"""
d = self.store.is_user_erased(user_id) d = self.store.is_user_erased(user_id)
if expect: if expect:
@ -1661,6 +1706,24 @@ class UserRestTestCase(unittest.HomeserverTestCase):
else: else:
self.assertFalse(self.get_success(d)) self.assertFalse(self.get_success(d))
def _deactivate_user(self, user_id: str) -> None:
"""Deactivate user and set as erased"""
# Deactivate the user.
channel = self.make_request(
"PUT",
"/_synapse/admin/v2/users/%s" % urllib.parse.quote(user_id),
access_token=self.admin_user_tok,
content={"deactivated": True},
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertTrue(channel.json_body["deactivated"])
self.assertIsNone(channel.json_body["password_hash"])
self._is_erased(user_id, False)
d = self.store.mark_user_erased(user_id)
self.assertIsNone(self.get_success(d))
self._is_erased(user_id, True)
class UserMembershipRestTestCase(unittest.HomeserverTestCase): class UserMembershipRestTestCase(unittest.HomeserverTestCase):