Merge branch 'develop' of github.com:matrix-org/synapse into erikj/createroom_content

This commit is contained in:
Erik Johnston 2017-12-07 14:24:01 +00:00
commit d8a6c734fa
73 changed files with 1778 additions and 822 deletions

View File

@ -5,7 +5,8 @@ To use it, first install prometheus by following the instructions at
http://prometheus.io/ http://prometheus.io/
Then add a new job to the main prometheus.conf file: ### for Prometheus v1
Add a new job to the main prometheus.conf file:
job: { job: {
name: "synapse" name: "synapse"
@ -15,6 +16,22 @@ Then add a new job to the main prometheus.conf file:
} }
} }
### for Prometheus v2
Add a new job to the main prometheus.yml file:
- job_name: "synapse"
metrics_path: "/_synapse/metrics"
# when endpoint uses https:
scheme: "https"
static_configs:
- targets: ['SERVER.LOCATION:PORT']
To use `synapse.rules` add
rule_files:
- "/PATH/TO/synapse-v2.rules"
Metrics are disabled by default when running synapse; they must be enabled Metrics are disabled by default when running synapse; they must be enabled
with the 'enable-metrics' option, either in the synapse config file or as a with the 'enable-metrics' option, either in the synapse config file or as a
command-line option. command-line option.

View File

@ -0,0 +1,60 @@
groups:
- name: synapse
rules:
- record: "synapse_federation_transaction_queue_pendingEdus:total"
expr: "sum(synapse_federation_transaction_queue_pendingEdus or absent(synapse_federation_transaction_queue_pendingEdus)*0)"
- record: "synapse_federation_transaction_queue_pendingPdus:total"
expr: "sum(synapse_federation_transaction_queue_pendingPdus or absent(synapse_federation_transaction_queue_pendingPdus)*0)"
- record: 'synapse_http_server_requests:method'
labels:
servlet: ""
expr: "sum(synapse_http_server_requests) by (method)"
- record: 'synapse_http_server_requests:servlet'
labels:
method: ""
expr: 'sum(synapse_http_server_requests) by (servlet)'
- record: 'synapse_http_server_requests:total'
labels:
servlet: ""
expr: 'sum(synapse_http_server_requests:by_method) by (servlet)'
- record: 'synapse_cache:hit_ratio_5m'
expr: 'rate(synapse_util_caches_cache:hits[5m]) / rate(synapse_util_caches_cache:total[5m])'
- record: 'synapse_cache:hit_ratio_30s'
expr: 'rate(synapse_util_caches_cache:hits[30s]) / rate(synapse_util_caches_cache:total[30s])'
- record: 'synapse_federation_client_sent'
labels:
type: "EDU"
expr: 'synapse_federation_client_sent_edus + 0'
- record: 'synapse_federation_client_sent'
labels:
type: "PDU"
expr: 'synapse_federation_client_sent_pdu_destinations:count + 0'
- record: 'synapse_federation_client_sent'
labels:
type: "Query"
expr: 'sum(synapse_federation_client_sent_queries) by (job)'
- record: 'synapse_federation_server_received'
labels:
type: "EDU"
expr: 'synapse_federation_server_received_edus + 0'
- record: 'synapse_federation_server_received'
labels:
type: "PDU"
expr: 'synapse_federation_server_received_pdus + 0'
- record: 'synapse_federation_server_received'
labels:
type: "Query"
expr: 'sum(synapse_federation_server_received_queries) by (job)'
- record: 'synapse_federation_transaction_queue_pending'
labels:
type: "EDU"
expr: 'synapse_federation_transaction_queue_pending_edus + 0'
- record: 'synapse_federation_transaction_queue_pending'
labels:
type: "PDU"
expr: 'synapse_federation_transaction_queue_pending_pdus + 0'

View File

@ -298,10 +298,6 @@ It can be used like this:
# this will now be logged against the request context # this will now be logged against the request context
logger.debug("Request handling complete") logger.debug("Request handling complete")
XXX: I think ``preserve_context_over_fn`` is supposed to do the first option,
but the fact that it does ``preserve_context_over_deferred`` on its results
means that its use is fraught with difficulty.
Passing synapse deferreds into third-party functions Passing synapse deferreds into third-party functions
---------------------------------------------------- ----------------------------------------------------

View File

@ -1,11 +1,15 @@
Scaling synapse via workers Scaling synapse via workers
--------------------------- ===========================
Synapse has experimental support for splitting out functionality into Synapse has experimental support for splitting out functionality into
multiple separate python processes, helping greatly with scalability. These multiple separate python processes, helping greatly with scalability. These
processes are called 'workers', and are (eventually) intended to scale processes are called 'workers', and are (eventually) intended to scale
horizontally independently. horizontally independently.
All of the below is highly experimental and subject to change as Synapse evolves,
but documenting it here to help folks needing highly scalable Synapses similar
to the one running matrix.org!
All processes continue to share the same database instance, and as such, workers All processes continue to share the same database instance, and as such, workers
only work with postgres based synapse deployments (sharing a single sqlite only work with postgres based synapse deployments (sharing a single sqlite
across multiple processes is a recipe for disaster, plus you should be using across multiple processes is a recipe for disaster, plus you should be using
@ -16,6 +20,16 @@ TCP protocol called 'replication' - analogous to MySQL or Postgres style
database replication; feeding a stream of relevant data to the workers so they database replication; feeding a stream of relevant data to the workers so they
can be kept in sync with the main synapse process and database state. can be kept in sync with the main synapse process and database state.
Configuration
-------------
To make effective use of the workers, you will need to configure an HTTP
reverse-proxy such as nginx or haproxy, which will direct incoming requests to
the correct worker, or to the main synapse instance. Note that this includes
requests made to the federation port. The caveats regarding running a
reverse-proxy on the federation port still apply (see
https://github.com/matrix-org/synapse/blob/master/README.rst#reverse-proxying-the-federation-port).
To enable workers, you need to add a replication listener to the master synapse, e.g.:: To enable workers, you need to add a replication listener to the master synapse, e.g.::
listeners: listeners:
@ -27,26 +41,19 @@ Under **no circumstances** should this replication API listener be exposed to th
public internet; it currently implements no authentication whatsoever and is public internet; it currently implements no authentication whatsoever and is
unencrypted. unencrypted.
You then create a set of configs for the various worker processes. These should be You then create a set of configs for the various worker processes. These
worker configuration files should be stored in a dedicated subdirectory, to allow should be worker configuration files, and should be stored in a dedicated
synctl to manipulate them. subdirectory, to allow synctl to manipulate them.
The current available worker applications are:
* synapse.app.pusher - handles sending push notifications to sygnal and email
* synapse.app.synchrotron - handles /sync endpoints. can scales horizontally through multiple instances.
* synapse.app.appservice - handles output traffic to Application Services
* synapse.app.federation_reader - handles receiving federation traffic (including public_rooms API)
* synapse.app.media_repository - handles the media repository.
* synapse.app.client_reader - handles client API endpoints like /publicRooms
Each worker configuration file inherits the configuration of the main homeserver Each worker configuration file inherits the configuration of the main homeserver
configuration file. You can then override configuration specific to that worker, configuration file. You can then override configuration specific to that worker,
e.g. the HTTP listener that it provides (if any); logging configuration; etc. e.g. the HTTP listener that it provides (if any); logging configuration; etc.
You should minimise the number of overrides though to maintain a usable config. You should minimise the number of overrides though to maintain a usable config.
You must specify the type of worker application (worker_app) and the replication You must specify the type of worker application (``worker_app``). The currently
endpoint that it's talking to on the main synapse process (worker_replication_host available worker applications are listed below. You must also specify the
and worker_replication_port). replication endpoint that it's talking to on the main synapse process
(``worker_replication_host`` and ``worker_replication_port``).
For instance:: For instance::
@ -68,11 +75,11 @@ For instance::
worker_log_config: /home/matrix/synapse/config/synchrotron_log_config.yaml worker_log_config: /home/matrix/synapse/config/synchrotron_log_config.yaml
...is a full configuration for a synchrotron worker instance, which will expose a ...is a full configuration for a synchrotron worker instance, which will expose a
plain HTTP /sync endpoint on port 8083 separately from the /sync endpoint provided plain HTTP ``/sync`` endpoint on port 8083 separately from the ``/sync`` endpoint provided
by the main synapse. by the main synapse.
Obviously you should configure your loadbalancer to route the /sync endpoint to Obviously you should configure your reverse-proxy to route the relevant
the synchrotron instance(s) in this instance. endpoints to the worker (``localhost:8083`` in the above example).
Finally, to actually run your worker-based synapse, you must pass synctl the -a Finally, to actually run your worker-based synapse, you must pass synctl the -a
commandline option to tell it to operate on all the worker configurations found commandline option to tell it to operate on all the worker configurations found
@ -89,6 +96,114 @@ To manipulate a specific worker, you pass the -w option to synctl::
synctl -w $CONFIG/workers/synchrotron.yaml restart synctl -w $CONFIG/workers/synchrotron.yaml restart
All of the above is highly experimental and subject to change as Synapse evolves,
but documenting it here to help folks needing highly scalable Synapses similar Available worker applications
to the one running matrix.org! -----------------------------
``synapse.app.pusher``
~~~~~~~~~~~~~~~~~~~~~~
Handles sending push notifications to sygnal and email. Doesn't handle any
REST endpoints itself, but you should set ``start_pushers: False`` in the
shared configuration file to stop the main synapse sending these notifications.
Note this worker cannot be load-balanced: only one instance should be active.
``synapse.app.synchrotron``
~~~~~~~~~~~~~~~~~~~~~~~~~~~
The synchrotron handles ``sync`` requests from clients. In particular, it can
handle REST endpoints matching the following regular expressions::
^/_matrix/client/(v2_alpha|r0)/sync$
^/_matrix/client/(api/v1|v2_alpha|r0)/events$
^/_matrix/client/(api/v1|r0)/initialSync$
^/_matrix/client/(api/v1|r0)/rooms/[^/]+/initialSync$
The above endpoints should all be routed to the synchrotron worker by the
reverse-proxy configuration.
It is possible to run multiple instances of the synchrotron to scale
horizontally. In this case the reverse-proxy should be configured to
load-balance across the instances, though it will be more efficient if all
requests from a particular user are routed to a single instance. Extracting
a userid from the access token is currently left as an exercise for the reader.
``synapse.app.appservice``
~~~~~~~~~~~~~~~~~~~~~~~~~~
Handles sending output traffic to Application Services. Doesn't handle any
REST endpoints itself, but you should set ``notify_appservices: False`` in the
shared configuration file to stop the main synapse sending these notifications.
Note this worker cannot be load-balanced: only one instance should be active.
``synapse.app.federation_reader``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Handles a subset of federation endpoints. In particular, it can handle REST
endpoints matching the following regular expressions::
^/_matrix/federation/v1/event/
^/_matrix/federation/v1/state/
^/_matrix/federation/v1/state_ids/
^/_matrix/federation/v1/backfill/
^/_matrix/federation/v1/get_missing_events/
^/_matrix/federation/v1/publicRooms
The above endpoints should all be routed to the federation_reader worker by the
reverse-proxy configuration.
``synapse.app.federation_sender``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Handles sending federation traffic to other servers. Doesn't handle any
REST endpoints itself, but you should set ``send_federation: False`` in the
shared configuration file to stop the main synapse sending this traffic.
Note this worker cannot be load-balanced: only one instance should be active.
``synapse.app.media_repository``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Handles the media repository. It can handle all endpoints starting with::
/_matrix/media/
You should also set ``enable_media_repo: False`` in the shared configuration
file to stop the main synapse running background jobs related to managing the
media repository.
Note this worker cannot be load-balanced: only one instance should be active.
``synapse.app.client_reader``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Handles client API endpoints. It can handle REST endpoints matching the
following regular expressions::
^/_matrix/client/(api/v1|r0|unstable)/publicRooms$
``synapse.app.user_dir``
~~~~~~~~~~~~~~~~~~~~~~~~
Handles searches in the user directory. It can handle REST endpoints matching
the following regular expressions::
^/_matrix/client/(api/v1|r0|unstable)/user_directory/search$
``synapse.app.frontend_proxy``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Proxies some frequently-requested client endpoints to add caching and remove
load from the main synapse. It can handle REST endpoints matching the following
regular expressions::
^/_matrix/client/(api/v1|r0|unstable)/keys/upload
It will proxy any requests it cannot handle to the main synapse instance. It
must therefore be configured with the location of the main instance, via
the ``worker_main_http_uri`` setting in the frontend_proxy worker configuration
file. For example::
worker_main_http_uri: http://127.0.0.1:8008

View File

@ -123,15 +123,25 @@ def lookup(destination, path):
except: except:
return "https://%s:%d%s" % (destination, 8448, path) return "https://%s:%d%s" % (destination, 8448, path)
def get_json(origin_name, origin_key, destination, path):
request_json = { def request_json(method, origin_name, origin_key, destination, path, content):
"method": "GET", if method is None:
if content is None:
method = "GET"
else:
method = "POST"
json_to_sign = {
"method": method,
"uri": path, "uri": path,
"origin": origin_name, "origin": origin_name,
"destination": destination, "destination": destination,
} }
signed_json = sign_json(request_json, origin_key, origin_name) if content is not None:
json_to_sign["content"] = json.loads(content)
signed_json = sign_json(json_to_sign, origin_key, origin_name)
authorization_headers = [] authorization_headers = []
@ -145,10 +155,12 @@ def get_json(origin_name, origin_key, destination, path):
dest = lookup(destination, path) dest = lookup(destination, path)
print ("Requesting %s" % dest, file=sys.stderr) print ("Requesting %s" % dest, file=sys.stderr)
result = requests.get( result = requests.request(
dest, method=method,
url=dest,
headers={"Authorization": authorization_headers[0]}, headers={"Authorization": authorization_headers[0]},
verify=False, verify=False,
data=content,
) )
sys.stderr.write("Status Code: %d\n" % (result.status_code,)) sys.stderr.write("Status Code: %d\n" % (result.status_code,))
return result.json() return result.json()
@ -186,6 +198,17 @@ def main():
"connect appropriately.", "connect appropriately.",
) )
parser.add_argument(
"-X", "--method",
help="HTTP method to use for the request. Defaults to GET if --data is"
"unspecified, POST if it is."
)
parser.add_argument(
"--body",
help="Data to send as the body of the HTTP request"
)
parser.add_argument( parser.add_argument(
"path", "path",
help="request path. We will add '/_matrix/federation/v1/' to this." help="request path. We will add '/_matrix/federation/v1/' to this."
@ -199,8 +222,11 @@ def main():
with open(args.signing_key_path) as f: with open(args.signing_key_path) as f:
key = read_signing_keys(f)[0] key = read_signing_keys(f)[0]
result = get_json( result = request_json(
args.server_name, key, args.destination, "/_matrix/federation/v1/" + args.path args.method,
args.server_name, key, args.destination,
"/_matrix/federation/v1/" + args.path,
content=args.body,
) )
json.dump(result, sys.stdout) json.dump(result, sys.stdout)

45
scripts/sync_room_to_group.pl Executable file
View File

@ -0,0 +1,45 @@
#!/usr/bin/env perl
use strict;
use warnings;
use JSON::XS;
use LWP::UserAgent;
use URI::Escape;
if (@ARGV < 4) {
die "usage: $0 <homeserver url> <access_token> <room_id|room_alias> <group_id>\n";
}
my ($hs, $access_token, $room_id, $group_id) = @ARGV;
my $ua = LWP::UserAgent->new();
$ua->timeout(10);
if ($room_id =~ /^#/) {
$room_id = uri_escape($room_id);
$room_id = decode_json($ua->get("${hs}/_matrix/client/r0/directory/room/${room_id}?access_token=${access_token}")->decoded_content)->{room_id};
}
my $room_users = [ keys %{decode_json($ua->get("${hs}/_matrix/client/r0/rooms/${room_id}/joined_members?access_token=${access_token}")->decoded_content)->{joined}} ];
my $group_users = [
(map { $_->{user_id} } @{decode_json($ua->get("${hs}/_matrix/client/unstable/groups/${group_id}/users?access_token=${access_token}" )->decoded_content)->{chunk}}),
(map { $_->{user_id} } @{decode_json($ua->get("${hs}/_matrix/client/unstable/groups/${group_id}/invited_users?access_token=${access_token}" )->decoded_content)->{chunk}}),
];
die "refusing to sync from empty room" unless (@$room_users);
die "refusing to sync to empty group" unless (@$group_users);
my $diff = {};
foreach my $user (@$room_users) { $diff->{$user}++ }
foreach my $user (@$group_users) { $diff->{$user}-- }
foreach my $user (keys %$diff) {
if ($diff->{$user} == 1) {
warn "inviting $user";
print STDERR $ua->put("${hs}/_matrix/client/unstable/groups/${group_id}/admin/users/invite/${user}?access_token=${access_token}", Content=>'{}')->status_line."\n";
}
elsif ($diff->{$user} == -1) {
warn "removing $user";
print STDERR $ua->put("${hs}/_matrix/client/unstable/groups/${group_id}/admin/users/remove/${user}?access_token=${access_token}", Content=>'{}')->status_line."\n";
}
}

View File

@ -270,7 +270,11 @@ class Auth(object):
rights (str): The operation being performed; the access token must rights (str): The operation being performed; the access token must
allow this. allow this.
Returns: Returns:
dict : dict that includes the user and the ID of their access token. Deferred[dict]: dict that includes:
`user` (UserID)
`is_guest` (bool)
`token_id` (int|None): access token id. May be None if guest
`device_id` (str|None): device corresponding to access token
Raises: Raises:
AuthError if no user by that token exists or the token is invalid. AuthError if no user by that token exists or the token is invalid.
""" """

View File

@ -140,6 +140,22 @@ class RegistrationError(SynapseError):
pass pass
class InteractiveAuthIncompleteError(Exception):
"""An error raised when UI auth is not yet complete
(This indicates we should return a 401 with 'result' as the body)
Attributes:
result (dict): the server response to the request, which should be
passed back to the client
"""
def __init__(self, result):
super(InteractiveAuthIncompleteError, self).__init__(
"Interactive auth not yet complete",
)
self.result = result
class UnrecognizedRequestError(SynapseError): class UnrecognizedRequestError(SynapseError):
"""An error indicating we don't understand the request you're trying to make""" """An error indicating we don't understand the request you're trying to make"""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):

View File

@ -43,7 +43,6 @@ from synapse.rest import ClientRestResource
from synapse.rest.key.v1.server_key_resource import LocalKey from synapse.rest.key.v1.server_key_resource import LocalKey
from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage import are_all_users_on_domain from synapse.storage import are_all_users_on_domain
from synapse.storage.engines import IncorrectDatabaseSetup, create_engine from synapse.storage.engines import IncorrectDatabaseSetup, create_engine
@ -195,14 +194,19 @@ class SynapseHomeServer(HomeServer):
}) })
if name in ["media", "federation", "client"]: if name in ["media", "federation", "client"]:
media_repo = MediaRepositoryResource(self) if self.get_config().enable_media_repo:
resources.update({ media_repo = self.get_media_repository_resource()
MEDIA_PREFIX: media_repo, resources.update({
LEGACY_MEDIA_PREFIX: media_repo, MEDIA_PREFIX: media_repo,
CONTENT_REPO_PREFIX: ContentRepoResource( LEGACY_MEDIA_PREFIX: media_repo,
self, self.config.uploads_path CONTENT_REPO_PREFIX: ContentRepoResource(
), self, self.config.uploads_path
}) ),
})
elif name == "media":
raise ConfigError(
"'media' resource conflicts with enable_media_repo=False",
)
if name in ["keys", "federation"]: if name in ["keys", "federation"]:
resources.update({ resources.update({

View File

@ -35,7 +35,6 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage.media_repository import MediaRepositoryStore from synapse.storage.media_repository import MediaRepositoryStore
@ -89,7 +88,7 @@ class MediaRepositoryServer(HomeServer):
if name == "metrics": if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self) resources[METRICS_PREFIX] = MetricsResource(self)
elif name == "media": elif name == "media":
media_repo = MediaRepositoryResource(self) media_repo = self.get_media_repository_resource()
resources.update({ resources.update({
MEDIA_PREFIX: media_repo, MEDIA_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo,
@ -151,6 +150,13 @@ def start(config_options):
assert config.worker_app == "synapse.app.media_repository" assert config.worker_app == "synapse.app.media_repository"
if config.enable_media_repo:
_base.quit_with_error(
"enable_media_repo must be disabled in the main synapse process\n"
"before the media repo can be run in a separate worker.\n"
"Please add ``enable_media_repo: false`` to the main config\n"
)
setup_logging(config, use_worker_options=True) setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts events.USE_FROZEN_DICTS = config.use_frozen_dicts

View File

@ -340,11 +340,10 @@ class SyncReplicationHandler(ReplicationClientHandler):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.typing_handler = hs.get_typing_handler() self.typing_handler = hs.get_typing_handler()
# NB this is a SynchrotronPresence, not a normal PresenceHandler
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.presence_handler.sync_callback = self.send_user_sync
def on_rdata(self, stream_name, token, rows): def on_rdata(self, stream_name, token, rows):
super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows) super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.types import GroupID, get_domain_from_id
from twisted.internet import defer from twisted.internet import defer
@ -81,12 +82,13 @@ class ApplicationService(object):
# values. # values.
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS] NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
def __init__(self, token, url=None, namespaces=None, hs_token=None, def __init__(self, token, hostname, url=None, namespaces=None, hs_token=None,
sender=None, id=None, protocols=None, rate_limited=True): sender=None, id=None, protocols=None, rate_limited=True):
self.token = token self.token = token
self.url = url self.url = url
self.hs_token = hs_token self.hs_token = hs_token
self.sender = sender self.sender = sender
self.server_name = hostname
self.namespaces = self._check_namespaces(namespaces) self.namespaces = self._check_namespaces(namespaces)
self.id = id self.id = id
@ -125,6 +127,24 @@ class ApplicationService(object):
raise ValueError( raise ValueError(
"Expected bool for 'exclusive' in ns '%s'" % ns "Expected bool for 'exclusive' in ns '%s'" % ns
) )
group_id = regex_obj.get("group_id")
if group_id:
if not isinstance(group_id, str):
raise ValueError(
"Expected string for 'group_id' in ns '%s'" % ns
)
try:
GroupID.from_string(group_id)
except Exception:
raise ValueError(
"Expected valid group ID for 'group_id' in ns '%s'" % ns
)
if get_domain_from_id(group_id) != self.server_name:
raise ValueError(
"Expected 'group_id' to be this host in ns '%s'" % ns
)
regex = regex_obj.get("regex") regex = regex_obj.get("regex")
if isinstance(regex, basestring): if isinstance(regex, basestring):
regex_obj["regex"] = re.compile(regex) # Pre-compile regex regex_obj["regex"] = re.compile(regex) # Pre-compile regex
@ -251,6 +271,21 @@ class ApplicationService(object):
if regex_obj["exclusive"] if regex_obj["exclusive"]
] ]
def get_groups_for_user(self, user_id):
"""Get the groups that this user is associated with by this AS
Args:
user_id (str): The ID of the user.
Returns:
iterable[str]: an iterable that yields group_id strings.
"""
return (
regex_obj["group_id"]
for regex_obj in self.namespaces[ApplicationService.NS_USERS]
if "group_id" in regex_obj and regex_obj["regex"].match(user_id)
)
def is_rate_limited(self): def is_rate_limited(self):
return self.rate_limited return self.rate_limited

View File

@ -154,6 +154,7 @@ def _load_appservice(hostname, as_info, config_filename):
) )
return ApplicationService( return ApplicationService(
token=as_info["as_token"], token=as_info["as_token"],
hostname=hostname,
url=as_info["url"], url=as_info["url"],
namespaces=as_info["namespaces"], namespaces=as_info["namespaces"],
hs_token=as_info["hs_token"], hs_token=as_info["hs_token"],

View File

@ -36,6 +36,7 @@ from .workers import WorkerConfig
from .push import PushConfig from .push import PushConfig
from .spam_checker import SpamCheckerConfig from .spam_checker import SpamCheckerConfig
from .groups import GroupsConfig from .groups import GroupsConfig
from .user_directory import UserDirectoryConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
@ -44,7 +45,7 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
JWTConfig, PasswordConfig, EmailConfig, JWTConfig, PasswordConfig, EmailConfig,
WorkerConfig, PasswordAuthProviderConfig, PushConfig, WorkerConfig, PasswordAuthProviderConfig, PushConfig,
SpamCheckerConfig, GroupsConfig,): SpamCheckerConfig, GroupsConfig, UserDirectoryConfig,):
pass pass

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -18,28 +19,43 @@ from ._base import Config
class PushConfig(Config): class PushConfig(Config):
def read_config(self, config): def read_config(self, config):
self.push_redact_content = False push_config = config.get("push", {})
self.push_include_content = push_config.get("include_content", True)
# There was a a 'redact_content' setting but mistakenly read from the
# 'email'section'. Check for the flag in the 'push' section, and log,
# but do not honour it to avoid nasty surprises when people upgrade.
if push_config.get("redact_content") is not None:
print(
"The push.redact_content content option has never worked. "
"Please set push.include_content if you want this behaviour"
)
# Now check for the one in the 'email' section and honour it,
# with a warning.
push_config = config.get("email", {}) push_config = config.get("email", {})
self.push_redact_content = push_config.get("redact_content", False) redact_content = push_config.get("redact_content")
if redact_content is not None:
print(
"The 'email.redact_content' option is deprecated: "
"please set push.include_content instead"
)
self.push_include_content = not redact_content
def default_config(self, config_dir_path, server_name, **kwargs): def default_config(self, config_dir_path, server_name, **kwargs):
return """ return """
# Control how push messages are sent to google/apple to notifications. # Clients requesting push notifications can either have the body of
# Normally every message said in a room with one or more people using # the message sent in the notification poke along with other details
# mobile devices will be posted to a push server hosted by matrix.org # like the sender, or just the event ID and room ID (`event_id_only`).
# which is registered with google and apple in order to allow push # If clients choose the former, this option controls whether the
# notifications to be sent to these mobile devices. # notification request includes the content of the event (other details
# # like the sender are still included). For `event_id_only` push, it
# Setting redact_content to true will make the push messages contain no # has no effect.
# message content which will provide increased privacy. This is a
# temporary solution pending improvements to Android and iPhone apps
# to get content from the app rather than the notification.
#
# For modern android devices the notification content will still appear # For modern android devices the notification content will still appear
# because it is loaded by the app. iPhone, however will send a # because it is loaded by the app. iPhone, however will send a
# notification saying only that a message arrived and who it came from. # notification saying only that a message arrived and who it came from.
# #
#push: #push:
# redact_content: false # include_content: true
""" """

View File

@ -41,6 +41,12 @@ class ServerConfig(Config):
# false only if we are updating the user directory in a worker # false only if we are updating the user directory in a worker
self.update_user_directory = config.get("update_user_directory", True) self.update_user_directory = config.get("update_user_directory", True)
# whether to enable the media repository endpoints. This should be set
# to false if the media repository is running as a separate endpoint;
# doing so ensures that we will not run cache cleanup jobs on the
# master, potentially causing inconsistency.
self.enable_media_repo = config.get("enable_media_repo", True)
self.filter_timeline_limit = config.get("filter_timeline_limit", -1) self.filter_timeline_limit = config.get("filter_timeline_limit", -1)
# Whether we should block invites sent to users on this server # Whether we should block invites sent to users on this server

View File

@ -0,0 +1,44 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class UserDirectoryConfig(Config):
"""User Directory Configuration
Configuration for the behaviour of the /user_directory API
"""
def read_config(self, config):
self.user_directory_search_all_users = False
user_directory_config = config.get("user_directory", None)
if user_directory_config:
self.user_directory_search_all_users = (
user_directory_config.get("search_all_users", False)
)
def default_config(self, config_dir_path, server_name, **kwargs):
return """
# User Directory configuration
#
# 'search_all_users' defines whether to search all users visible to your HS
# when searching the user directory, rather than limiting to users visible
# in public rooms. Defaults to false. If you set it True, you'll have to run
# UPDATE user_directory_stream_pos SET stream_id = NULL;
# on your database to tell it to rebuild the user_directory search indexes.
#
#user_directory:
# search_all_users: false
"""

View File

@ -32,15 +32,22 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
"""Check whether the hash for this PDU matches the contents""" """Check whether the hash for this PDU matches the contents"""
name, expected_hash = compute_content_hash(event, hash_algorithm) name, expected_hash = compute_content_hash(event, hash_algorithm)
logger.debug("Expecting hash: %s", encode_base64(expected_hash)) logger.debug("Expecting hash: %s", encode_base64(expected_hash))
if name not in event.hashes:
# some malformed events lack a 'hashes'. Protect against it being missing
# or a weird type by basically treating it the same as an unhashed event.
hashes = event.get("hashes")
if not isinstance(hashes, dict):
raise SynapseError(400, "Malformed 'hashes'", Codes.UNAUTHORIZED)
if name not in hashes:
raise SynapseError( raise SynapseError(
400, 400,
"Algorithm %s not in hashes %s" % ( "Algorithm %s not in hashes %s" % (
name, list(event.hashes), name, list(hashes),
), ),
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
) )
message_hash_base64 = event.hashes[name] message_hash_base64 = hashes[name]
try: try:
message_hash_bytes = decode_base64(message_hash_base64) message_hash_bytes = decode_base64(message_hash_base64)
except Exception: except Exception:

View File

@ -25,7 +25,7 @@ from synapse.api.errors import (
from synapse.util import unwrapFirstError, logcontext from synapse.util import unwrapFirstError, logcontext
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.events import FrozenEvent, builder from synapse.events import FrozenEvent, builder
import synapse.metrics import synapse.metrics
@ -420,7 +420,7 @@ class FederationClient(FederationBase):
for e_id in batch for e_id in batch
] ]
res = yield preserve_context_over_deferred( res = yield make_deferred_yieldable(
defer.DeferredList(deferreds, consumeErrors=True) defer.DeferredList(deferreds, consumeErrors=True)
) )
for success, result in res: for success, result in res:

View File

@ -20,7 +20,7 @@ from .persistence import TransactionActions
from .units import Transaction, Edu from .units import Transaction, Edu
from synapse.api.errors import HttpResponseException from synapse.api.errors import HttpResponseException
from synapse.util import logcontext from synapse.util import logcontext, PreserveLoggingContext
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
@ -146,7 +146,6 @@ class TransactionQueue(object):
else: else:
return not destination.startswith("localhost") return not destination.startswith("localhost")
@defer.inlineCallbacks
def notify_new_events(self, current_id): def notify_new_events(self, current_id):
"""This gets called when we have some new events we might want to """This gets called when we have some new events we might want to
send out to other servers. send out to other servers.
@ -156,6 +155,13 @@ class TransactionQueue(object):
if self._is_processing: if self._is_processing:
return return
# fire off a processing loop in the background. It's likely it will
# outlast the current request, so run it in the sentinel logcontext.
with PreserveLoggingContext():
self._process_event_queue_loop()
@defer.inlineCallbacks
def _process_event_queue_loop(self):
try: try:
self._is_processing = True self._is_processing = True
while True: while True:

View File

@ -17,7 +17,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
import logging import logging
@ -159,7 +159,7 @@ class ApplicationServicesHandler(object):
def query_3pe(self, kind, protocol, fields): def query_3pe(self, kind, protocol, fields):
services = yield self._get_services_for_3pn(protocol) services = yield self._get_services_for_3pn(protocol)
results = yield preserve_context_over_deferred(defer.DeferredList([ results = yield make_deferred_yieldable(defer.DeferredList([
preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields) preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields)
for service in services for service in services
], consumeErrors=True)) ], consumeErrors=True))

View File

@ -17,7 +17,10 @@ from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError from synapse.api.errors import (
AuthError, Codes, InteractiveAuthIncompleteError, LoginError, StoreError,
SynapseError,
)
from synapse.module_api import ModuleApi from synapse.module_api import ModuleApi
from synapse.types import UserID from synapse.types import UserID
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
@ -46,7 +49,6 @@ class AuthHandler(BaseHandler):
""" """
super(AuthHandler, self).__init__(hs) super(AuthHandler, self).__init__(hs)
self.checkers = { self.checkers = {
LoginType.PASSWORD: self._check_password_auth,
LoginType.RECAPTCHA: self._check_recaptcha, LoginType.RECAPTCHA: self._check_recaptcha,
LoginType.EMAIL_IDENTITY: self._check_email_identity, LoginType.EMAIL_IDENTITY: self._check_email_identity,
LoginType.MSISDN: self._check_msisdn, LoginType.MSISDN: self._check_msisdn,
@ -75,15 +77,76 @@ class AuthHandler(BaseHandler):
self.macaroon_gen = hs.get_macaroon_generator() self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.password_enabled self._password_enabled = hs.config.password_enabled
login_types = set() # we keep this as a list despite the O(N^2) implication so that we can
# keep PASSWORD first and avoid confusing clients which pick the first
# type in the list. (NB that the spec doesn't require us to do so and
# clients which favour types that they don't understand over those that
# they do are technically broken)
login_types = []
if self._password_enabled: if self._password_enabled:
login_types.add(LoginType.PASSWORD) login_types.append(LoginType.PASSWORD)
for provider in self.password_providers: for provider in self.password_providers:
if hasattr(provider, "get_supported_login_types"): if hasattr(provider, "get_supported_login_types"):
login_types.update( for t in provider.get_supported_login_types().keys():
provider.get_supported_login_types().keys() if t not in login_types:
) login_types.append(t)
self._supported_login_types = frozenset(login_types) self._supported_login_types = login_types
@defer.inlineCallbacks
def validate_user_via_ui_auth(self, requester, request_body, clientip):
"""
Checks that the user is who they claim to be, via a UI auth.
This is used for things like device deletion and password reset where
the user already has a valid access token, but we want to double-check
that it isn't stolen by re-authenticating them.
Args:
requester (Requester): The user, as given by the access token
request_body (dict): The body of the request sent by the client
clientip (str): The IP address of the client.
Returns:
defer.Deferred[dict]: the parameters for this request (which may
have been given only in a previous call).
Raises:
InteractiveAuthIncompleteError if the client has not yet completed
any of the permitted login flows
AuthError if the client has completed a login flow, and it gives
a different user to `requester`
"""
# build a list of supported flows
flows = [
[login_type] for login_type in self._supported_login_types
]
result, params, _ = yield self.check_auth(
flows, request_body, clientip,
)
# find the completed login type
for login_type in self._supported_login_types:
if login_type not in result:
continue
user_id = result[login_type]
break
else:
# this can't happen
raise Exception(
"check_auth returned True but no successful login type",
)
# check that the UI auth matched the access token
if user_id != requester.user.to_string():
raise AuthError(403, "Invalid auth")
defer.returnValue(params)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip): def check_auth(self, flows, clientdict, clientip):
@ -95,26 +158,36 @@ class AuthHandler(BaseHandler):
session with a map, which maps each auth-type (str) to the relevant session with a map, which maps each auth-type (str) to the relevant
identity authenticated by that auth-type (mostly str, but for captcha, bool). identity authenticated by that auth-type (mostly str, but for captcha, bool).
If no auth flows have been completed successfully, raises an
InteractiveAuthIncompleteError. To handle this, you can use
synapse.rest.client.v2_alpha._base.interactive_auth_handler as a
decorator.
Args: Args:
flows (list): A list of login flows. Each flow is an ordered list of flows (list): A list of login flows. Each flow is an ordered list of
strings representing auth-types. At least one full strings representing auth-types. At least one full
flow must be completed in order for auth to be successful. flow must be completed in order for auth to be successful.
clientdict: The dictionary from the client root level, not the clientdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent. 'auth' key: this method prompts for auth if none is sent.
clientip (str): The IP address of the client. clientip (str): The IP address of the client.
Returns: Returns:
A tuple of (authed, dict, dict, session_id) where authed is true if defer.Deferred[dict, dict, str]: a deferred tuple of
the client has successfully completed an auth flow. If it is true (creds, params, session_id).
the first dict contains the authenticated credentials of each stage.
If authed is false, the first dictionary is the server response to 'creds' contains the authenticated credentials of each stage.
the login request and should be passed back to the client.
In either case, the second dict contains the parameters for this 'params' contains the parameters for this request (which may
request (which may have been given only in a previous call). have been given only in a previous call).
session_id is the ID of this session, either passed in by the client 'session_id' is the ID of this session, either passed in by the
or assigned by the call to check_auth client or assigned by this call
Raises:
InteractiveAuthIncompleteError if the client has not yet completed
all the stages in any of the permitted flows.
""" """
authdict = None authdict = None
@ -142,11 +215,8 @@ class AuthHandler(BaseHandler):
clientdict = session['clientdict'] clientdict = session['clientdict']
if not authdict: if not authdict:
defer.returnValue( raise InteractiveAuthIncompleteError(
( self._auth_dict_for_flows(flows, session),
False, self._auth_dict_for_flows(flows, session),
clientdict, session['id']
)
) )
if 'creds' not in session: if 'creds' not in session:
@ -157,14 +227,12 @@ class AuthHandler(BaseHandler):
errordict = {} errordict = {}
if 'type' in authdict: if 'type' in authdict:
login_type = authdict['type'] login_type = authdict['type']
if login_type not in self.checkers:
raise LoginError(400, "", Codes.UNRECOGNIZED)
try: try:
result = yield self.checkers[login_type](authdict, clientip) result = yield self._check_auth_dict(authdict, clientip)
if result: if result:
creds[login_type] = result creds[login_type] = result
self._save_session(session) self._save_session(session)
except LoginError, e: except LoginError as e:
if login_type == LoginType.EMAIL_IDENTITY: if login_type == LoginType.EMAIL_IDENTITY:
# riot used to have a bug where it would request a new # riot used to have a bug where it would request a new
# validation token (thus sending a new email) each time it # validation token (thus sending a new email) each time it
@ -173,7 +241,7 @@ class AuthHandler(BaseHandler):
# #
# Grandfather in the old behaviour for now to avoid # Grandfather in the old behaviour for now to avoid
# breaking old riot deployments. # breaking old riot deployments.
raise e raise
# this step failed. Merge the error dict into the response # this step failed. Merge the error dict into the response
# so that the client can have another go. # so that the client can have another go.
@ -190,12 +258,14 @@ class AuthHandler(BaseHandler):
"Auth completed with creds: %r. Client dict has keys: %r", "Auth completed with creds: %r. Client dict has keys: %r",
creds, clientdict.keys() creds, clientdict.keys()
) )
defer.returnValue((True, creds, clientdict, session['id'])) defer.returnValue((creds, clientdict, session['id']))
ret = self._auth_dict_for_flows(flows, session) ret = self._auth_dict_for_flows(flows, session)
ret['completed'] = creds.keys() ret['completed'] = creds.keys()
ret.update(errordict) ret.update(errordict)
defer.returnValue((False, ret, clientdict, session['id'])) raise InteractiveAuthIncompleteError(
ret,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_oob_auth(self, stagetype, authdict, clientip): def add_oob_auth(self, stagetype, authdict, clientip):
@ -268,17 +338,35 @@ class AuthHandler(BaseHandler):
return sess.setdefault('serverdict', {}).get(key, default) return sess.setdefault('serverdict', {}).get(key, default)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_password_auth(self, authdict, _): def _check_auth_dict(self, authdict, clientip):
if "user" not in authdict or "password" not in authdict: """Attempt to validate the auth dict provided by a client
raise LoginError(400, "", Codes.MISSING_PARAM)
user_id = authdict["user"] Args:
password = authdict["password"] authdict (object): auth dict provided by the client
clientip (str): IP address of the client
(canonical_id, callback) = yield self.validate_login(user_id, { Returns:
"type": LoginType.PASSWORD, Deferred: result of the stage verification.
"password": password,
}) Raises:
StoreError if there was a problem accessing the database
SynapseError if there was a problem with the request
LoginError if there was an authentication problem.
"""
login_type = authdict['type']
checker = self.checkers.get(login_type)
if checker is not None:
res = yield checker(authdict, clientip)
defer.returnValue(res)
# build a v1-login-style dict out of the authdict and fall back to the
# v1 code
user_id = authdict.get("user")
if user_id is None:
raise SynapseError(400, "", Codes.MISSING_PARAM)
(canonical_id, callback) = yield self.validate_login(user_id, authdict)
defer.returnValue(canonical_id) defer.returnValue(canonical_id)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -649,41 +737,6 @@ class AuthHandler(BaseHandler):
except Exception: except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
def set_password(self, user_id, newpassword, requester=None):
password_hash = self.hash(newpassword)
except_access_token_id = requester.access_token_id if requester else None
try:
yield self.store.user_set_password_hash(user_id, password_hash)
except StoreError as e:
if e.code == 404:
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
raise e
yield self.delete_access_tokens_for_user(
user_id, except_token_id=except_access_token_id,
)
yield self.hs.get_pusherpool().remove_pushers_by_user(
user_id, except_access_token_id
)
@defer.inlineCallbacks
def deactivate_account(self, user_id):
"""Deactivate a user's account
Args:
user_id (str): ID of user to be deactivated
Returns:
Deferred
"""
# FIXME: Theoretically there is a race here wherein user resets
# password using threepid.
yield self.delete_access_tokens_for_user(user_id)
yield self.store.user_delete_threepids(user_id)
yield self.store.user_set_password_hash(user_id, None)
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_access_token(self, access_token): def delete_access_token(self, access_token):
"""Invalidate a single access token """Invalidate a single access token
@ -706,6 +759,12 @@ class AuthHandler(BaseHandler):
access_token=access_token, access_token=access_token,
) )
# delete pushers associated with this access token
if user_info["token_id"] is not None:
yield self.hs.get_pusherpool().remove_pushers_by_access_token(
str(user_info["user"]), (user_info["token_id"], )
)
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_access_tokens_for_user(self, user_id, except_token_id=None, def delete_access_tokens_for_user(self, user_id, except_token_id=None,
device_id=None): device_id=None):
@ -728,13 +787,18 @@ class AuthHandler(BaseHandler):
# see if any of our auth providers want to know about this # see if any of our auth providers want to know about this
for provider in self.password_providers: for provider in self.password_providers:
if hasattr(provider, "on_logged_out"): if hasattr(provider, "on_logged_out"):
for token, device_id in tokens_and_devices: for token, token_id, device_id in tokens_and_devices:
yield provider.on_logged_out( yield provider.on_logged_out(
user_id=user_id, user_id=user_id,
device_id=device_id, device_id=device_id,
access_token=token, access_token=token,
) )
# delete pushers associated with the access tokens
yield self.hs.get_pusherpool().remove_pushers_by_access_token(
user_id, (token_id for _, token_id, _ in tokens_and_devices),
)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_threepid(self, user_id, medium, address, validated_at): def add_threepid(self, user_id, medium, address, validated_at):
# 'Canonicalise' email addresses down to lower case. # 'Canonicalise' email addresses down to lower case.

View File

@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseHandler
import logging
logger = logging.getLogger(__name__)
class DeactivateAccountHandler(BaseHandler):
"""Handler which deals with deactivating user accounts."""
def __init__(self, hs):
super(DeactivateAccountHandler, self).__init__(hs)
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def deactivate_account(self, user_id):
"""Deactivate a user's account
Args:
user_id (str): ID of user to be deactivated
Returns:
Deferred
"""
# FIXME: Theoretically there is a race here wherein user resets
# password using threepid.
# first delete any devices belonging to the user, which will also
# delete corresponding access tokens.
yield self._device_handler.delete_all_devices_for_user(user_id)
# then delete any remaining access tokens which weren't associated with
# a device.
yield self._auth_handler.delete_access_tokens_for_user(user_id)
yield self.store.user_delete_threepids(user_id)
yield self.store.user_set_password_hash(user_id, None)

View File

@ -170,13 +170,31 @@ class DeviceHandler(BaseHandler):
yield self.notify_device_update(user_id, [device_id]) yield self.notify_device_update(user_id, [device_id])
@defer.inlineCallbacks
def delete_all_devices_for_user(self, user_id, except_device_id=None):
"""Delete all of the user's devices
Args:
user_id (str):
except_device_id (str|None): optional device id which should not
be deleted
Returns:
defer.Deferred:
"""
device_map = yield self.store.get_devices_by_user(user_id)
device_ids = device_map.keys()
if except_device_id is not None:
device_ids = [d for d in device_ids if d != except_device_id]
yield self.delete_devices(user_id, device_ids)
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_devices(self, user_id, device_ids): def delete_devices(self, user_id, device_ids):
""" Delete several devices """ Delete several devices
Args: Args:
user_id (str): user_id (str):
device_ids (str): The list of device IDs to delete device_ids (List[str]): The list of device IDs to delete
Returns: Returns:
defer.Deferred: defer.Deferred:

View File

@ -375,6 +375,12 @@ class GroupsLocalHandler(object):
def get_publicised_groups_for_user(self, user_id): def get_publicised_groups_for_user(self, user_id):
if self.hs.is_mine_id(user_id): if self.hs.is_mine_id(user_id):
result = yield self.store.get_publicised_groups_for_user(user_id) result = yield self.store.get_publicised_groups_for_user(user_id)
# Check AS associated groups for this user - this depends on the
# RegExps in the AS registration file (under `users`)
for app_service in self.store.get_app_services():
result.extend(app_service.get_groups_for_user(user_id))
defer.returnValue({"groups": result}) defer.returnValue({"groups": result})
else: else:
result = yield self.transport_client.get_publicised_groups_for_user( result = yield self.transport_client.get_publicised_groups_for_user(
@ -415,4 +421,9 @@ class GroupsLocalHandler(object):
uid uid
) )
# Check AS associated groups for this user - this depends on the
# RegExps in the AS registration file (under `users`)
for app_service in self.store.get_app_services():
results[uid].extend(app_service.get_groups_for_user(uid))
defer.returnValue({"users": results}) defer.returnValue({"users": results})

View File

@ -27,7 +27,7 @@ from synapse.types import (
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute from synapse.util.async import concurrently_execute
from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from ._base import BaseHandler from ._base import BaseHandler
@ -163,7 +163,7 @@ class InitialSyncHandler(BaseHandler):
lambda states: states[event.event_id] lambda states: states[event.event_id]
) )
(messages, token), current_state = yield preserve_context_over_deferred( (messages, token), current_state = yield make_deferred_yieldable(
defer.gatherResults( defer.gatherResults(
[ [
preserve_fn(self.store.get_recent_events_for_room)( preserve_fn(self.store.get_recent_events_for_room)(

View File

@ -1199,7 +1199,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now):
) )
changed = True changed = True
else: else:
# We expect to be poked occaisonally by the other side. # We expect to be poked occasionally by the other side.
# This is to protect against forgetful/buggy servers, so that # This is to protect against forgetful/buggy servers, so that
# no one gets stuck online forever. # no one gets stuck online forever.
if now - state.last_federation_update_ts > FEDERATION_TIMEOUT: if now - state.last_federation_update_ts > FEDERATION_TIMEOUT:

View File

@ -36,6 +36,8 @@ class ProfileHandler(BaseHandler):
"profile", self.on_profile_query "profile", self.on_profile_query
) )
self.user_directory_handler = hs.get_user_directory_handler()
self.clock.looping_call(self._update_remote_profile_cache, self.PROFILE_UPDATE_MS) self.clock.looping_call(self._update_remote_profile_cache, self.PROFILE_UPDATE_MS)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -139,6 +141,12 @@ class ProfileHandler(BaseHandler):
target_user.localpart, new_displayname target_user.localpart, new_displayname
) )
if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(target_user.localpart)
yield self.user_directory_handler.handle_local_profile_change(
target_user.to_string(), profile
)
yield self._update_join_states(requester, target_user) yield self._update_join_states(requester, target_user)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -183,6 +191,12 @@ class ProfileHandler(BaseHandler):
target_user.localpart, new_avatar_url target_user.localpart, new_avatar_url
) )
if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(target_user.localpart)
yield self.user_directory_handler.handle_local_profile_change(
target_user.to_string(), profile
)
yield self._update_join_states(requester, target_user) yield self._update_join_states(requester, target_user)
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -38,6 +38,7 @@ class RegistrationHandler(BaseHandler):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
self.captcha_client = CaptchaServerHttpClient(hs) self.captcha_client = CaptchaServerHttpClient(hs)
self._next_generated_user_id = None self._next_generated_user_id = None
@ -165,6 +166,13 @@ class RegistrationHandler(BaseHandler):
), ),
admin=admin, admin=admin,
) )
if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(localpart)
yield self.user_directory_handler.handle_local_profile_change(
user_id, profile
)
else: else:
# autogen a sequential user ID # autogen a sequential user ID
attempts = 0 attempts = 0

View File

@ -154,6 +154,8 @@ class RoomListHandler(BaseHandler):
# We want larger rooms to be first, hence negating num_joined_users # We want larger rooms to be first, hence negating num_joined_users
rooms_to_order_value[room_id] = (-num_joined_users, room_id) rooms_to_order_value[room_id] = (-num_joined_users, room_id)
logger.info("Getting ordering for %i rooms since %s",
len(room_ids), stream_token)
yield concurrently_execute(get_order_for_room, room_ids, 10) yield concurrently_execute(get_order_for_room, room_ids, 10)
sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1]) sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1])
@ -181,34 +183,42 @@ class RoomListHandler(BaseHandler):
rooms_to_scan = rooms_to_scan[:since_token.current_limit] rooms_to_scan = rooms_to_scan[:since_token.current_limit]
rooms_to_scan.reverse() rooms_to_scan.reverse()
# Actually generate the entries. _append_room_entry_to_chunk will append to logger.info("After sorting and filtering, %i rooms remain",
# chunk but will stop if len(chunk) > limit len(rooms_to_scan))
chunk = []
if limit and not search_filter: # _append_room_entry_to_chunk will append to chunk but will stop if
# len(chunk) > limit
#
# Normally we will generate enough results on the first iteration here,
# but if there is a search filter, _append_room_entry_to_chunk may
# filter some results out, in which case we loop again.
#
# We don't want to scan over the entire range either as that
# would potentially waste a lot of work.
#
# XXX if there is no limit, we may end up DoSing the server with
# calls to get_current_state_ids for every single room on the
# server. Surely we should cap this somehow?
#
if limit:
step = limit + 1 step = limit + 1
for i in xrange(0, len(rooms_to_scan), step):
# We iterate here because the vast majority of cases we'll stop
# at first iteration, but occaisonally _append_room_entry_to_chunk
# won't append to the chunk and so we need to loop again.
# We don't want to scan over the entire range either as that
# would potentially waste a lot of work.
yield concurrently_execute(
lambda r: self._append_room_entry_to_chunk(
r, rooms_to_num_joined[r],
chunk, limit, search_filter
),
rooms_to_scan[i:i + step], 10
)
if len(chunk) >= limit + 1:
break
else: else:
step = len(rooms_to_scan)
chunk = []
for i in xrange(0, len(rooms_to_scan), step):
batch = rooms_to_scan[i:i + step]
logger.info("Processing %i rooms for result", len(batch))
yield concurrently_execute( yield concurrently_execute(
lambda r: self._append_room_entry_to_chunk( lambda r: self._append_room_entry_to_chunk(
r, rooms_to_num_joined[r], r, rooms_to_num_joined[r],
chunk, limit, search_filter chunk, limit, search_filter
), ),
rooms_to_scan, 5 batch, 5,
) )
logger.info("Now %i rooms in result", len(chunk))
if len(chunk) >= limit + 1:
break
chunk.sort(key=lambda e: (-e["num_joined_members"], e["room_id"])) chunk.sort(key=lambda e: (-e["num_joined_members"], e["room_id"]))

View File

@ -0,0 +1,56 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.api.errors import Codes, StoreError, SynapseError
from ._base import BaseHandler
logger = logging.getLogger(__name__)
class SetPasswordHandler(BaseHandler):
"""Handler which deals with changing user account passwords"""
def __init__(self, hs):
super(SetPasswordHandler, self).__init__(hs)
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def set_password(self, user_id, newpassword, requester=None):
password_hash = self._auth_handler.hash(newpassword)
except_device_id = requester.device_id if requester else None
except_access_token_id = requester.access_token_id if requester else None
try:
yield self.store.user_set_password_hash(user_id, password_hash)
except StoreError as e:
if e.code == 404:
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
raise e
# we want to log out all of the user's other sessions. First delete
# all his other devices.
yield self._device_handler.delete_all_devices_for_user(
user_id, except_device_id=except_device_id,
)
# and now delete any access tokens which weren't associated with
# devices (or were associated with this device).
yield self._auth_handler.delete_access_tokens_for_user(
user_id, except_token_id=except_access_token_id,
)

View File

@ -20,12 +20,13 @@ from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.storage.roommember import ProfileInfo from synapse.storage.roommember import ProfileInfo
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.async import sleep from synapse.util.async import sleep
from synapse.types import get_localpart_from_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UserDirectoyHandler(object): class UserDirectoryHandler(object):
"""Handles querying of and keeping updated the user_directory. """Handles querying of and keeping updated the user_directory.
N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY
@ -41,9 +42,10 @@ class UserDirectoyHandler(object):
one public room. one public room.
""" """
INITIAL_SLEEP_MS = 50 INITIAL_ROOM_SLEEP_MS = 50
INITIAL_SLEEP_COUNT = 100 INITIAL_ROOM_SLEEP_COUNT = 100
INITIAL_BATCH_SIZE = 100 INITIAL_ROOM_BATCH_SIZE = 100
INITIAL_USER_SLEEP_MS = 10
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -53,6 +55,7 @@ class UserDirectoyHandler(object):
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.update_user_directory = hs.config.update_user_directory self.update_user_directory = hs.config.update_user_directory
self.search_all_users = hs.config.user_directory_search_all_users
# When start up for the first time we need to populate the user_directory. # When start up for the first time we need to populate the user_directory.
# This is a set of user_id's we've inserted already # This is a set of user_id's we've inserted already
@ -110,6 +113,15 @@ class UserDirectoyHandler(object):
finally: finally:
self._is_processing = False self._is_processing = False
@defer.inlineCallbacks
def handle_local_profile_change(self, user_id, profile):
"""Called to update index of our local user profiles when they change
irrespective of any rooms the user may be in.
"""
yield self.store.update_profile_in_user_dir(
user_id, profile.display_name, profile.avatar_url, None,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _unsafe_process(self): def _unsafe_process(self):
# If self.pos is None then means we haven't fetched it from DB # If self.pos is None then means we haven't fetched it from DB
@ -148,16 +160,30 @@ class UserDirectoyHandler(object):
room_ids = yield self.store.get_all_rooms() room_ids = yield self.store.get_all_rooms()
logger.info("Doing initial update of user directory. %d rooms", len(room_ids)) logger.info("Doing initial update of user directory. %d rooms", len(room_ids))
num_processed_rooms = 1 num_processed_rooms = 0
for room_id in room_ids: for room_id in room_ids:
logger.info("Handling room %d/%d", num_processed_rooms, len(room_ids)) logger.info("Handling room %d/%d", num_processed_rooms + 1, len(room_ids))
yield self._handle_initial_room(room_id) yield self._handle_initial_room(room_id)
num_processed_rooms += 1 num_processed_rooms += 1
yield sleep(self.INITIAL_SLEEP_MS / 1000.) yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
logger.info("Processed all rooms.") logger.info("Processed all rooms.")
if self.search_all_users:
num_processed_users = 0
user_ids = yield self.store.get_all_local_users()
logger.info("Doing initial update of user directory. %d users", len(user_ids))
for user_id in user_ids:
# We add profiles for all users even if they don't match the
# include pattern, just in case we want to change it in future
logger.info("Handling user %d/%d", num_processed_users + 1, len(user_ids))
yield self._handle_local_user(user_id)
num_processed_users += 1
yield sleep(self.INITIAL_USER_SLEEP_MS / 1000.)
logger.info("Processed all users")
self.initially_handled_users = None self.initially_handled_users = None
self.initially_handled_users_in_public = None self.initially_handled_users_in_public = None
self.initially_handled_users_share = None self.initially_handled_users_share = None
@ -201,8 +227,8 @@ class UserDirectoyHandler(object):
to_update = set() to_update = set()
count = 0 count = 0
for user_id in user_ids: for user_id in user_ids:
if count % self.INITIAL_SLEEP_COUNT == 0: if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
yield sleep(self.INITIAL_SLEEP_MS / 1000.) yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
if not self.is_mine_id(user_id): if not self.is_mine_id(user_id):
count += 1 count += 1
@ -216,8 +242,8 @@ class UserDirectoyHandler(object):
if user_id == other_user_id: if user_id == other_user_id:
continue continue
if count % self.INITIAL_SLEEP_COUNT == 0: if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
yield sleep(self.INITIAL_SLEEP_MS / 1000.) yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
count += 1 count += 1
user_set = (user_id, other_user_id) user_set = (user_id, other_user_id)
@ -237,13 +263,13 @@ class UserDirectoyHandler(object):
else: else:
self.initially_handled_users_share_private_room.add(user_set) self.initially_handled_users_share_private_room.add(user_set)
if len(to_insert) > self.INITIAL_BATCH_SIZE: if len(to_insert) > self.INITIAL_ROOM_BATCH_SIZE:
yield self.store.add_users_who_share_room( yield self.store.add_users_who_share_room(
room_id, not is_public, to_insert, room_id, not is_public, to_insert,
) )
to_insert.clear() to_insert.clear()
if len(to_update) > self.INITIAL_BATCH_SIZE: if len(to_update) > self.INITIAL_ROOM_BATCH_SIZE:
yield self.store.update_users_who_share_room( yield self.store.update_users_who_share_room(
room_id, not is_public, to_update, room_id, not is_public, to_update,
) )
@ -384,15 +410,29 @@ class UserDirectoyHandler(object):
for user_id in users: for user_id in users:
yield self._handle_remove_user(room_id, user_id) yield self._handle_remove_user(room_id, user_id)
@defer.inlineCallbacks
def _handle_local_user(self, user_id):
"""Adds a new local roomless user into the user_directory_search table.
Used to populate up the user index when we have an
user_directory_search_all_users specified.
"""
logger.debug("Adding new local user to dir, %r", user_id)
profile = yield self.store.get_profileinfo(get_localpart_from_id(user_id))
row = yield self.store.get_user_in_directory(user_id)
if not row:
yield self.store.add_profiles_to_user_dir(None, {user_id: profile})
@defer.inlineCallbacks @defer.inlineCallbacks
def _handle_new_user(self, room_id, user_id, profile): def _handle_new_user(self, room_id, user_id, profile):
"""Called when we might need to add user to directory """Called when we might need to add user to directory
Args: Args:
room_id (str): room_id that user joined or started being public that room_id (str): room_id that user joined or started being public
user_id (str) user_id (str)
""" """
logger.debug("Adding user to dir, %r", user_id) logger.debug("Adding new user to dir, %r", user_id)
row = yield self.store.get_user_in_directory(user_id) row = yield self.store.get_user_in_directory(user_id)
if not row: if not row:
@ -407,7 +447,7 @@ class UserDirectoyHandler(object):
if not row: if not row:
yield self.store.add_users_to_public_room(room_id, [user_id]) yield self.store.add_users_to_public_room(room_id, [user_id])
else: else:
logger.debug("Not adding user to public dir, %r", user_id) logger.debug("Not adding new user to public dir, %r", user_id)
# Now we update users who share rooms with users. We do this by getting # Now we update users who share rooms with users. We do this by getting
# all the current users in the room and seeing which aren't already # all the current users in the room and seeing which aren't already

View File

@ -362,8 +362,10 @@ def _get_hosts_for_srv_record(dns_client, host):
return res return res
# no logcontexts here, so we can safely fire these off and gatherResults # no logcontexts here, so we can safely fire these off and gatherResults
d1 = dns_client.lookupAddress(host).addCallbacks(cb, eb) d1 = dns_client.lookupAddress(host).addCallbacks(
d2 = dns_client.lookupIPV6Address(host).addCallbacks(cb, eb) cb, eb, errbackArgs=("A", ))
d2 = dns_client.lookupIPV6Address(host).addCallbacks(
cb, eb, errbackArgs=("AAAA", ))
results = yield defer.DeferredList( results = yield defer.DeferredList(
[d1, d2], consumeErrors=True) [d1, d2], consumeErrors=True)

View File

@ -28,6 +28,7 @@ from canonicaljson import (
) )
from twisted.internet import defer from twisted.internet import defer
from twisted.python import failure
from twisted.web import server, resource from twisted.web import server, resource
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
from twisted.web.util import redirectTo from twisted.web.util import redirectTo
@ -131,12 +132,17 @@ def wrap_request_handler(request_handler, include_metrics=False):
version_string=self.version_string, version_string=self.version_string,
) )
except Exception: except Exception:
logger.exception( # failure.Failure() fishes the original Failure out
"Failed handle request %s.%s on %r: %r", # of our stack, and thus gives us a sensible stack
# trace.
f = failure.Failure()
logger.error(
"Failed handle request %s.%s on %r: %r: %s",
request_handler.__module__, request_handler.__module__,
request_handler.__name__, request_handler.__name__,
self, self,
request request,
f.getTraceback().rstrip(),
) )
respond_with_json( respond_with_json(
request, request,

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.types import UserID from synapse.types import UserID
@ -81,6 +82,7 @@ class ModuleApi(object):
reg = self.hs.get_handlers().registration_handler reg = self.hs.get_handlers().registration_handler
return reg.register(localpart=localpart) return reg.register(localpart=localpart)
@defer.inlineCallbacks
def invalidate_access_token(self, access_token): def invalidate_access_token(self, access_token):
"""Invalidate an access token for a user """Invalidate an access token for a user
@ -94,8 +96,16 @@ class ModuleApi(object):
Raises: Raises:
synapse.api.errors.AuthError: the access token is invalid synapse.api.errors.AuthError: the access token is invalid
""" """
# see if the access token corresponds to a device
return self._auth_handler.delete_access_token(access_token) user_info = yield self._auth.get_user_by_access_token(access_token)
device_id = user_info.get("device_id")
user_id = user_info["user"].to_string()
if device_id:
# delete the device, which will also delete its access tokens
yield self.hs.get_device_handler().delete_device(user_id, device_id)
else:
# no associated device. Just delete the access token.
yield self._auth_handler.delete_access_token(access_token)
def run_db_interaction(self, desc, func, *args, **kwargs): def run_db_interaction(self, desc, func, *args, **kwargs):
"""Run a function with a database connection """Run a function with a database connection

View File

@ -255,9 +255,7 @@ class Notifier(object):
) )
if self.federation_sender: if self.federation_sender:
preserve_fn(self.federation_sender.notify_new_events)( self.federation_sender.notify_new_events(room_stream_id)
room_stream_id
)
if event.type == EventTypes.Member and event.membership == Membership.JOIN: if event.type == EventTypes.Member and event.membership == Membership.JOIN:
self._user_joined_room(event.state_key, event.room_id) self._user_joined_room(event.state_key, event.room_id)
@ -297,8 +295,7 @@ class Notifier(object):
def on_new_replication_data(self): def on_new_replication_data(self):
"""Used to inform replication listeners that something has happend """Used to inform replication listeners that something has happend
without waking up any of the normal user event streams""" without waking up any of the normal user event streams"""
with PreserveLoggingContext(): self.notify_replication()
self.notify_replication()
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_events(self, user_id, timeout, callback, room_ids=None, def wait_for_events(self, user_id, timeout, callback, room_ids=None,
@ -516,8 +513,14 @@ class Notifier(object):
self.replication_deferred = ObservableDeferred(defer.Deferred()) self.replication_deferred = ObservableDeferred(defer.Deferred())
deferred.callback(None) deferred.callback(None)
for cb in self.replication_callbacks: # the callbacks may well outlast the current request, so we run
preserve_fn(cb)() # them in the sentinel logcontext.
#
# (ideally it would be up to the callbacks to know if they were
# starting off background processes and drop the logcontext
# accordingly, but that requires more changes)
for cb in self.replication_callbacks:
cb()
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_replication(self, callback, timeout): def wait_for_replication(self, callback, timeout):

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -295,7 +296,7 @@ class HttpPusher(object):
if event.type == 'm.room.member': if event.type == 'm.room.member':
d['notification']['membership'] = event.content['membership'] d['notification']['membership'] = event.content['membership']
d['notification']['user_is_target'] = event.state_key == self.user_id d['notification']['user_is_target'] = event.state_key == self.user_id
if not self.hs.config.push_redact_content and 'content' in event: if self.hs.config.push_include_content and 'content' in event:
d['notification']['content'] = event.content d['notification']['content'] = event.content
# We no longer send aliases separately, instead, we send the human # We no longer send aliases separately, instead, we send the human

View File

@ -17,7 +17,7 @@
from twisted.internet import defer from twisted.internet import defer
from .pusher import PusherFactory from .pusher import PusherFactory
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
import logging import logging
@ -103,19 +103,25 @@ class PusherPool:
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks @defer.inlineCallbacks
def remove_pushers_by_user(self, user_id, except_access_token_id=None): def remove_pushers_by_access_token(self, user_id, access_tokens):
all = yield self.store.get_all_pushers() """Remove the pushers for a given user corresponding to a set of
logger.info( access_tokens.
"Removing all pushers for user %s except access tokens id %r",
user_id, except_access_token_id Args:
) user_id (str): user to remove pushers for
for p in all: access_tokens (Iterable[int]): access token *ids* to remove pushers
if p['user_name'] == user_id and p['access_token'] != except_access_token_id: for
"""
tokens = set(access_tokens)
for p in (yield self.store.get_pushers_by_user_id(user_id)):
if p['access_token'] in tokens:
logger.info( logger.info(
"Removing pusher for app id %s, pushkey %s, user %s", "Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name'] p['app_id'], p['pushkey'], p['user_name']
) )
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) yield self.remove_pusher(
p['app_id'], p['pushkey'], p['user_name'],
)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_new_notifications(self, min_stream_id, max_stream_id): def on_new_notifications(self, min_stream_id, max_stream_id):
@ -136,7 +142,7 @@ class PusherPool:
) )
) )
yield preserve_context_over_deferred(defer.gatherResults(deferreds)) yield make_deferred_yieldable(defer.gatherResults(deferreds))
except Exception: except Exception:
logger.exception("Exception in pusher on_new_notifications") logger.exception("Exception in pusher on_new_notifications")
@ -161,7 +167,7 @@ class PusherPool:
preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id) preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id)
) )
yield preserve_context_over_deferred(defer.gatherResults(deferreds)) yield make_deferred_yieldable(defer.gatherResults(deferreds))
except Exception: except Exception:
logger.exception("Exception in pusher on_new_receipts") logger.exception("Exception in pusher on_new_receipts")

View File

@ -12,20 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import BaseSlavedStore import logging
from ._slaved_id_tracker import SlavedIdTracker
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.storage.roommember import RoomMemberStore
from synapse.storage.event_federation import EventFederationStore from synapse.storage.event_federation import EventFederationStore
from synapse.storage.event_push_actions import EventPushActionsStore from synapse.storage.event_push_actions import EventPushActionsStore
from synapse.storage.state import StateStore from synapse.storage.roommember import RoomMemberStore
from synapse.storage.state import StateGroupReadStore
from synapse.storage.stream import StreamStore from synapse.storage.stream import StreamStore
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore
import logging from ._slaved_id_tracker import SlavedIdTracker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,7 +37,7 @@ logger = logging.getLogger(__name__)
# the method descriptor on the DataStore and chuck them into our class. # the method descriptor on the DataStore and chuck them into our class.
class SlavedEventStore(BaseSlavedStore): class SlavedEventStore(StateGroupReadStore, BaseSlavedStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(SlavedEventStore, self).__init__(db_conn, hs) super(SlavedEventStore, self).__init__(db_conn, hs)
@ -90,25 +88,9 @@ class SlavedEventStore(BaseSlavedStore):
_get_unread_counts_by_pos_txn = ( _get_unread_counts_by_pos_txn = (
DataStore._get_unread_counts_by_pos_txn.__func__ DataStore._get_unread_counts_by_pos_txn.__func__
) )
_get_state_group_for_events = (
StateStore.__dict__["_get_state_group_for_events"]
)
_get_state_group_for_event = (
StateStore.__dict__["_get_state_group_for_event"]
)
_get_state_groups_from_groups = (
StateStore.__dict__["_get_state_groups_from_groups"]
)
_get_state_groups_from_groups_txn = (
DataStore._get_state_groups_from_groups_txn.__func__
)
get_recent_event_ids_for_room = ( get_recent_event_ids_for_room = (
StreamStore.__dict__["get_recent_event_ids_for_room"] StreamStore.__dict__["get_recent_event_ids_for_room"]
) )
get_current_state_ids = (
StateStore.__dict__["get_current_state_ids"]
)
get_state_group_delta = StateStore.__dict__["get_state_group_delta"]
_get_joined_hosts_cache = RoomMemberStore.__dict__["_get_joined_hosts_cache"] _get_joined_hosts_cache = RoomMemberStore.__dict__["_get_joined_hosts_cache"]
has_room_changed_since = DataStore.has_room_changed_since.__func__ has_room_changed_since = DataStore.has_room_changed_since.__func__
@ -134,12 +116,6 @@ class SlavedEventStore(BaseSlavedStore):
DataStore.get_room_events_stream_for_room.__func__ DataStore.get_room_events_stream_for_room.__func__
) )
get_events_around = DataStore.get_events_around.__func__ get_events_around = DataStore.get_events_around.__func__
get_state_for_event = DataStore.get_state_for_event.__func__
get_state_for_events = DataStore.get_state_for_events.__func__
get_state_groups = DataStore.get_state_groups.__func__
get_state_groups_ids = DataStore.get_state_groups_ids.__func__
get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__
get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__
get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__ get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__
get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__ get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__
_get_joined_users_from_context = ( _get_joined_users_from_context = (
@ -169,10 +145,7 @@ class SlavedEventStore(BaseSlavedStore):
_get_rooms_for_user_where_membership_is_txn = ( _get_rooms_for_user_where_membership_is_txn = (
DataStore._get_rooms_for_user_where_membership_is_txn.__func__ DataStore._get_rooms_for_user_where_membership_is_txn.__func__
) )
_get_state_for_groups = DataStore._get_state_for_groups.__func__
_get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__
_get_events_around_txn = DataStore._get_events_around_txn.__func__ _get_events_around_txn = DataStore._get_events_around_txn.__func__
_get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__
get_backfill_events = DataStore.get_backfill_events.__func__ get_backfill_events = DataStore.get_backfill_events.__func__
_get_backfill_events = DataStore._get_backfill_events.__func__ _get_backfill_events = DataStore._get_backfill_events.__func__

View File

@ -216,11 +216,12 @@ class ReplicationStreamer(object):
self.federation_sender.federation_ack(token) self.federation_sender.federation_ack(token)
@measure_func("repl.on_user_sync") @measure_func("repl.on_user_sync")
@defer.inlineCallbacks
def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms): def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
"""A client has started/stopped syncing on a worker. """A client has started/stopped syncing on a worker.
""" """
user_sync_counter.inc() user_sync_counter.inc()
self.presence_handler.update_external_syncs_row( yield self.presence_handler.update_external_syncs_row(
conn_id, user_id, is_syncing, last_sync_ms, conn_id, user_id, is_syncing, last_sync_ms,
) )
@ -244,11 +245,12 @@ class ReplicationStreamer(object):
getattr(self.store, cache_func).invalidate(tuple(keys)) getattr(self.store, cache_func).invalidate(tuple(keys))
@measure_func("repl.on_user_ip") @measure_func("repl.on_user_ip")
@defer.inlineCallbacks
def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen): def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
"""The client saw a user request """The client saw a user request
""" """
user_ip_cache_counter.inc() user_ip_cache_counter.inc()
self.store.insert_client_ip( yield self.store.insert_client_ip(
user_id, access_token, ip, user_agent, device_id, last_seen, user_id, access_token, ip, user_agent, device_id, last_seen,
) )

View File

@ -137,8 +137,8 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)") PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)")
def __init__(self, hs): def __init__(self, hs):
self._auth_handler = hs.get_auth_handler()
super(DeactivateAccountRestServlet, self).__init__(hs) super(DeactivateAccountRestServlet, self).__init__(hs)
self._deactivate_account_handler = hs.get_deactivate_account_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, target_user_id): def on_POST(self, request, target_user_id):
@ -149,7 +149,7 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
if not is_admin: if not is_admin:
raise AuthError(403, "You are not a server admin") raise AuthError(403, "You are not a server admin")
yield self._auth_handler.deactivate_account(target_user_id) yield self._deactivate_account_handler.deactivate_account(target_user_id)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -309,7 +309,7 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
super(ResetPasswordRestServlet, self).__init__(hs) super(ResetPasswordRestServlet, self).__init__(hs)
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler() self._set_password_handler = hs.get_set_password_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, target_user_id): def on_POST(self, request, target_user_id):
@ -330,7 +330,7 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
logger.info("new_password: %r", new_password) logger.info("new_password: %r", new_password)
yield self.auth_handler.set_password( yield self._set_password_handler.set_password(
target_user_id, new_password, requester target_user_id, new_password, requester
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))

View File

@ -16,6 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.auth import get_access_token_from_request from synapse.api.auth import get_access_token_from_request
from synapse.api.errors import AuthError
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
@ -30,15 +31,30 @@ class LogoutRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(LogoutRestServlet, self).__init__(hs) super(LogoutRestServlet, self).__init__(hs)
self._auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
def on_OPTIONS(self, request): def on_OPTIONS(self, request):
return (200, {}) return (200, {})
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
access_token = get_access_token_from_request(request) try:
yield self._auth_handler.delete_access_token(access_token) requester = yield self.auth.get_user_by_req(request)
except AuthError:
# this implies the access token has already been deleted.
pass
else:
if requester.device_id is None:
# the acccess token wasn't associated with a device.
# Just delete the access token
access_token = get_access_token_from_request(request)
yield self._auth_handler.delete_access_token(access_token)
else:
yield self._device_handler.delete_device(
requester.user.to_string(), requester.device_id)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -49,6 +65,7 @@ class LogoutAllRestServlet(ClientV1RestServlet):
super(LogoutAllRestServlet, self).__init__(hs) super(LogoutAllRestServlet, self).__init__(hs)
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
def on_OPTIONS(self, request): def on_OPTIONS(self, request):
return (200, {}) return (200, {})
@ -57,6 +74,12 @@ class LogoutAllRestServlet(ClientV1RestServlet):
def on_POST(self, request): def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
# first delete all of the user's devices
yield self._device_handler.delete_all_devices_for_user(user_id)
# .. and then delete any access tokens which weren't associated with
# devices.
yield self._auth_handler.delete_access_tokens_for_user(user_id) yield self._auth_handler.delete_access_tokens_for_user(user_id)
defer.returnValue((200, {})) defer.returnValue((200, {}))

View File

@ -15,12 +15,13 @@
"""This module contains base REST classes for constructing client v1 servlets. """This module contains base REST classes for constructing client v1 servlets.
""" """
import logging
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
import re import re
import logging from twisted.internet import defer
from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -57,3 +58,37 @@ def set_timeline_upper_limit(filter_json, filter_timeline_limit):
filter_json['room']['timeline']["limit"] = min( filter_json['room']['timeline']["limit"] = min(
filter_json['room']['timeline']['limit'], filter_json['room']['timeline']['limit'],
filter_timeline_limit) filter_timeline_limit)
def interactive_auth_handler(orig):
"""Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
Takes a on_POST method which returns a deferred (errcode, body) response
and adds exception handling to turn a InteractiveAuthIncompleteError into
a 401 response.
Normal usage is:
@interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
# ...
yield self.auth_handler.check_auth
"""
def wrapped(*args, **kwargs):
res = defer.maybeDeferred(orig, *args, **kwargs)
res.addErrback(_catch_incomplete_interactive_auth)
return res
return wrapped
def _catch_incomplete_interactive_auth(f):
"""helper for interactive_auth_handler
Catches InteractiveAuthIncompleteErrors and turns them into 401 responses
Args:
f (failure.Failure):
"""
f.trap(InteractiveAuthIncompleteError)
return 401, f.value.result

View File

@ -19,14 +19,14 @@ from twisted.internet import defer
from synapse.api.auth import has_access_token from synapse.api.auth import has_access_token
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, assert_params_in_request, RestServlet, assert_params_in_request,
parse_json_object_from_request, parse_json_object_from_request,
) )
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.msisdn import phone_number_to_msisdn
from ._base import client_v2_patterns from ._base import client_v2_patterns, interactive_auth_handler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -98,56 +98,61 @@ class PasswordRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.datastore = self.hs.get_datastore() self.datastore = self.hs.get_datastore()
self._set_password_handler = hs.get_set_password_handler()
@interactive_auth_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
yield run_on_reactor()
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
authed, result, params, _ = yield self.auth_handler.check_auth([ # there are two possibilities here. Either the user does not have an
[LoginType.PASSWORD], # access token, and needs to do a password reset; or they have one and
[LoginType.EMAIL_IDENTITY], # need to validate their identity.
[LoginType.MSISDN], #
], body, self.hs.get_ip_from_request(request)) # In the first case, we offer a couple of means of identifying
# themselves (email and msisdn, though it's unclear if msisdn actually
# works).
#
# In the second case, we require a password to confirm their identity.
if not authed: if has_access_token(request):
defer.returnValue((401, result))
user_id = None
requester = None
if LoginType.PASSWORD in result:
# if using password, they should also be logged in
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() params = yield self.auth_handler.validate_user_via_ui_auth(
if user_id != result[LoginType.PASSWORD]: requester, body, self.hs.get_ip_from_request(request),
raise LoginError(400, "", Codes.UNKNOWN)
elif LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY]
if 'medium' not in threepid or 'address' not in threepid:
raise SynapseError(500, "Malformed threepid")
if threepid['medium'] == 'email':
# For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
threepid['address'] = threepid['address'].lower()
# if using email, we must know about the email they're authing with!
threepid_user_id = yield self.datastore.get_user_id_by_threepid(
threepid['medium'], threepid['address']
) )
if not threepid_user_id: user_id = requester.user.to_string()
raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
user_id = threepid_user_id
else: else:
logger.error("Auth succeeded but no known type!", result.keys()) requester = None
raise SynapseError(500, "", Codes.UNKNOWN) result, params, _ = yield self.auth_handler.check_auth(
[[LoginType.EMAIL_IDENTITY], [LoginType.MSISDN]],
body, self.hs.get_ip_from_request(request),
)
if LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY]
if 'medium' not in threepid or 'address' not in threepid:
raise SynapseError(500, "Malformed threepid")
if threepid['medium'] == 'email':
# For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
threepid['address'] = threepid['address'].lower()
# if using email, we must know about the email they're authing with!
threepid_user_id = yield self.datastore.get_user_id_by_threepid(
threepid['medium'], threepid['address']
)
if not threepid_user_id:
raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
user_id = threepid_user_id
else:
logger.error("Auth succeeded but no known type!", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
if 'new_password' not in params: if 'new_password' not in params:
raise SynapseError(400, "", Codes.MISSING_PARAM) raise SynapseError(400, "", Codes.MISSING_PARAM)
new_password = params['new_password'] new_password = params['new_password']
yield self.auth_handler.set_password( yield self._set_password_handler.set_password(
user_id, new_password, requester user_id, new_password, requester
) )
@ -161,52 +166,32 @@ class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/deactivate$") PATTERNS = client_v2_patterns("/account/deactivate$")
def __init__(self, hs): def __init__(self, hs):
super(DeactivateAccountRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
super(DeactivateAccountRestServlet, self).__init__() self._deactivate_account_handler = hs.get_deactivate_account_handler()
@interactive_auth_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
# if the caller provides an access token, it ought to be valid. requester = yield self.auth.get_user_by_req(request)
requester = None
if has_access_token(request):
requester = yield self.auth.get_user_by_req(
request,
) # type: synapse.types.Requester
# allow ASes to dectivate their own users # allow ASes to dectivate their own users
if requester and requester.app_service: if requester.app_service:
yield self.auth_handler.deactivate_account( yield self._deactivate_account_handler.deactivate_account(
requester.user.to_string() requester.user.to_string()
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
authed, result, params, _ = yield self.auth_handler.check_auth([ yield self.auth_handler.validate_user_via_ui_auth(
[LoginType.PASSWORD], requester, body, self.hs.get_ip_from_request(request),
], body, self.hs.get_ip_from_request(request)) )
yield self._deactivate_account_handler.deactivate_account(
if not authed: requester.user.to_string(),
defer.returnValue((401, result)) )
if LoginType.PASSWORD in result:
user_id = result[LoginType.PASSWORD]
# if using password, they should also be logged in
if requester is None:
raise SynapseError(
400,
"Deactivate account requires an access_token",
errcode=Codes.MISSING_TOKEN
)
if requester.user.to_string() != user_id:
raise LoginError(400, "", Codes.UNKNOWN)
else:
logger.error("Auth succeeded but no known type!", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
yield self.auth_handler.deactivate_account(user_id)
defer.returnValue((200, {})) defer.returnValue((200, {}))

View File

@ -17,9 +17,9 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api import constants, errors from synapse.api import errors
from synapse.http import servlet from synapse.http import servlet
from ._base import client_v2_patterns from ._base import client_v2_patterns, interactive_auth_handler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -60,8 +60,11 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
@interactive_auth_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request)
try: try:
body = servlet.parse_json_object_from_request(request) body = servlet.parse_json_object_from_request(request)
except errors.SynapseError as e: except errors.SynapseError as e:
@ -77,14 +80,10 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM 400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
) )
authed, result, params, _ = yield self.auth_handler.check_auth([ yield self.auth_handler.validate_user_via_ui_auth(
[constants.LoginType.PASSWORD], requester, body, self.hs.get_ip_from_request(request),
], body, self.hs.get_ip_from_request(request)) )
if not authed:
defer.returnValue((401, result))
requester = yield self.auth.get_user_by_req(request)
yield self.device_handler.delete_devices( yield self.device_handler.delete_devices(
requester.user.to_string(), requester.user.to_string(),
body['devices'], body['devices'],
@ -115,6 +114,7 @@ class DeviceRestServlet(servlet.RestServlet):
) )
defer.returnValue((200, device)) defer.returnValue((200, device))
@interactive_auth_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_DELETE(self, request, device_id): def on_DELETE(self, request, device_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
@ -130,19 +130,13 @@ class DeviceRestServlet(servlet.RestServlet):
else: else:
raise raise
authed, result, params, _ = yield self.auth_handler.check_auth([ yield self.auth_handler.validate_user_via_ui_auth(
[constants.LoginType.PASSWORD], requester, body, self.hs.get_ip_from_request(request),
], body, self.hs.get_ip_from_request(request)) )
if not authed: yield self.device_handler.delete_device(
defer.returnValue((401, result)) requester.user.to_string(), device_id,
)
# check that the UI auth matched the access token
user_id = result[constants.LoginType.PASSWORD]
if user_id != requester.user.to_string():
raise errors.AuthError(403, "Invalid auth")
yield self.device_handler.delete_device(user_id, device_id)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -38,7 +38,7 @@ class GroupServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, group_id): def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
group_description = yield self.groups_handler.get_group_profile( group_description = yield self.groups_handler.get_group_profile(
@ -74,7 +74,7 @@ class GroupSummaryServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, group_id): def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
get_group_summary = yield self.groups_handler.get_group_summary( get_group_summary = yield self.groups_handler.get_group_summary(
@ -148,7 +148,7 @@ class GroupCategoryServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, group_id, category_id): def on_GET(self, request, group_id, category_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_category( category = yield self.groups_handler.get_group_category(
@ -200,7 +200,7 @@ class GroupCategoriesServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, group_id): def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_categories( category = yield self.groups_handler.get_group_categories(
@ -225,7 +225,7 @@ class GroupRoleServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, group_id, role_id): def on_GET(self, request, group_id, role_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_role( category = yield self.groups_handler.get_group_role(
@ -277,7 +277,7 @@ class GroupRolesServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, group_id): def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_roles( category = yield self.groups_handler.get_group_roles(
@ -348,7 +348,7 @@ class GroupRoomServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, group_id): def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
result = yield self.groups_handler.get_rooms_in_group(group_id, requester_user_id) result = yield self.groups_handler.get_rooms_in_group(group_id, requester_user_id)
@ -369,7 +369,7 @@ class GroupUsersServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, group_id): def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
result = yield self.groups_handler.get_users_in_group(group_id, requester_user_id) result = yield self.groups_handler.get_users_in_group(group_id, requester_user_id)
@ -672,7 +672,7 @@ class PublicisedGroupsForUserServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request, allow_guest=True)
result = yield self.groups_handler.get_publicised_groups_for_user( result = yield self.groups_handler.get_publicised_groups_for_user(
user_id user_id
@ -697,7 +697,7 @@ class PublicisedGroupsForUsersServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
user_ids = content["user_ids"] user_ids = content["user_ids"]
@ -724,7 +724,7 @@ class GroupsForUserServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
result = yield self.groups_handler.get_joined_groups(requester_user_id) result = yield self.groups_handler.get_joined_groups(requester_user_id)

View File

@ -27,7 +27,7 @@ from synapse.http.servlet import (
) )
from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.msisdn import phone_number_to_msisdn
from ._base import client_v2_patterns from ._base import client_v2_patterns, interactive_auth_handler
import logging import logging
import hmac import hmac
@ -176,6 +176,7 @@ class RegisterRestServlet(RestServlet):
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator() self.macaroon_gen = hs.get_macaroon_generator()
@interactive_auth_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
yield run_on_reactor() yield run_on_reactor()
@ -325,14 +326,10 @@ class RegisterRestServlet(RestServlet):
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY], [LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
]) ])
authed, auth_result, params, session_id = yield self.auth_handler.check_auth( auth_result, params, session_id = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request) flows, body, self.hs.get_ip_from_request(request)
) )
if not authed:
defer.returnValue((401, auth_result))
return
if registered_user_id is not None: if registered_user_id is not None:
logger.info( logger.info(
"Already registered user ID %r for this session", "Already registered user ID %r for this session",

View File

@ -30,6 +30,7 @@ class VersionsRestServlet(RestServlet):
"r0.0.1", "r0.0.1",
"r0.1.0", "r0.1.0",
"r0.2.0", "r0.2.0",
"r0.3.0",
] ]
}) })

View File

@ -25,7 +25,8 @@ from synapse.util.stringutils import random_string
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.http.client import SpiderHttpClient from synapse.http.client import SpiderHttpClient
from synapse.http.server import ( from synapse.http.server import (
request_handler, respond_with_json_bytes request_handler, respond_with_json_bytes,
respond_with_json,
) )
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.stringutils import is_ascii from synapse.util.stringutils import is_ascii
@ -78,6 +79,9 @@ class PreviewUrlResource(Resource):
self._expire_url_cache_data, 10 * 1000 self._expire_url_cache_data, 10 * 1000
) )
def render_OPTIONS(self, request):
return respond_with_json(request, 200, {}, send_cors=True)
def render_GET(self, request): def render_GET(self, request):
self._async_render_GET(request) self._async_render_GET(request)
return NOT_DONE_YET return NOT_DONE_YET
@ -348,11 +352,16 @@ class PreviewUrlResource(Resource):
def _expire_url_cache_data(self): def _expire_url_cache_data(self):
"""Clean up expired url cache content, media and thumbnails. """Clean up expired url cache content, media and thumbnails.
""" """
# TODO: Delete from backup media store # TODO: Delete from backup media store
now = self.clock.time_msec() now = self.clock.time_msec()
logger.info("Running url preview cache expiry")
if not (yield self.store.has_completed_background_updates()):
logger.info("Still running DB updates; skipping expiry")
return
# First we delete expired url cache entries # First we delete expired url cache entries
media_ids = yield self.store.get_expired_url_cache(now) media_ids = yield self.store.get_expired_url_cache(now)
@ -426,8 +435,7 @@ class PreviewUrlResource(Resource):
yield self.store.delete_url_cache_media(removed_media) yield self.store.delete_url_cache_media(removed_media)
if removed_media: logger.info("Deleted %d media from url cache", len(removed_media))
logger.info("Deleted %d media from url cache", len(removed_media))
def decode_and_calc_og(body, media_uri, request_encoding=None): def decode_and_calc_og(body, media_uri, request_encoding=None):

View File

@ -39,18 +39,20 @@ from synapse.federation.transaction_queue import TransactionQueue
from synapse.handlers import Handlers from synapse.handlers import Handlers
from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler, MacaroonGeneartor from synapse.handlers.auth import AuthHandler, MacaroonGeneartor
from synapse.handlers.deactivate_account import DeactivateAccountHandler
from synapse.handlers.devicemessage import DeviceMessageHandler from synapse.handlers.devicemessage import DeviceMessageHandler
from synapse.handlers.device import DeviceHandler from synapse.handlers.device import DeviceHandler
from synapse.handlers.e2e_keys import E2eKeysHandler from synapse.handlers.e2e_keys import E2eKeysHandler
from synapse.handlers.presence import PresenceHandler from synapse.handlers.presence import PresenceHandler
from synapse.handlers.room_list import RoomListHandler from synapse.handlers.room_list import RoomListHandler
from synapse.handlers.set_password import SetPasswordHandler
from synapse.handlers.sync import SyncHandler from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler from synapse.handlers.typing import TypingHandler
from synapse.handlers.events import EventHandler, EventStreamHandler from synapse.handlers.events import EventHandler, EventStreamHandler
from synapse.handlers.initial_sync import InitialSyncHandler from synapse.handlers.initial_sync import InitialSyncHandler
from synapse.handlers.receipts import ReceiptsHandler from synapse.handlers.receipts import ReceiptsHandler
from synapse.handlers.read_marker import ReadMarkerHandler from synapse.handlers.read_marker import ReadMarkerHandler
from synapse.handlers.user_directory import UserDirectoyHandler from synapse.handlers.user_directory import UserDirectoryHandler
from synapse.handlers.groups_local import GroupsLocalHandler from synapse.handlers.groups_local import GroupsLocalHandler
from synapse.handlers.profile import ProfileHandler from synapse.handlers.profile import ProfileHandler
from synapse.groups.groups_server import GroupsServerHandler from synapse.groups.groups_server import GroupsServerHandler
@ -60,7 +62,10 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.notifier import Notifier from synapse.notifier import Notifier
from synapse.push.action_generator import ActionGenerator from synapse.push.action_generator import ActionGenerator
from synapse.push.pusherpool import PusherPool from synapse.push.pusherpool import PusherPool
from synapse.rest.media.v1.media_repository import MediaRepository from synapse.rest.media.v1.media_repository import (
MediaRepository,
MediaRepositoryResource,
)
from synapse.state import StateHandler from synapse.state import StateHandler
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.streams.events import EventSources from synapse.streams.events import EventSources
@ -90,17 +95,12 @@ class HomeServer(object):
""" """
DEPENDENCIES = [ DEPENDENCIES = [
'config',
'clock',
'http_client', 'http_client',
'db_pool', 'db_pool',
'persistence_service',
'replication_layer', 'replication_layer',
'datastore',
'handlers', 'handlers',
'v1auth', 'v1auth',
'auth', 'auth',
'rest_servlet_factory',
'state_handler', 'state_handler',
'presence_handler', 'presence_handler',
'sync_handler', 'sync_handler',
@ -117,19 +117,10 @@ class HomeServer(object):
'application_service_handler', 'application_service_handler',
'device_message_handler', 'device_message_handler',
'profile_handler', 'profile_handler',
'deactivate_account_handler',
'set_password_handler',
'notifier', 'notifier',
'distributor',
'client_resource',
'resource_for_federation',
'resource_for_static_content',
'resource_for_web_client',
'resource_for_content_repo',
'resource_for_server_key',
'resource_for_server_key_v2',
'resource_for_media_repository',
'resource_for_metrics',
'event_sources', 'event_sources',
'ratelimiter',
'keyring', 'keyring',
'pusherpool', 'pusherpool',
'event_builder_factory', 'event_builder_factory',
@ -137,6 +128,7 @@ class HomeServer(object):
'http_client_context_factory', 'http_client_context_factory',
'simple_http_client', 'simple_http_client',
'media_repository', 'media_repository',
'media_repository_resource',
'federation_transport_client', 'federation_transport_client',
'federation_sender', 'federation_sender',
'receipts_handler', 'receipts_handler',
@ -183,6 +175,21 @@ class HomeServer(object):
def is_mine_id(self, string): def is_mine_id(self, string):
return string.split(":", 1)[1] == self.hostname return string.split(":", 1)[1] == self.hostname
def get_clock(self):
return self.clock
def get_datastore(self):
return self.datastore
def get_config(self):
return self.config
def get_distributor(self):
return self.distributor
def get_ratelimiter(self):
return self.ratelimiter
def build_replication_layer(self): def build_replication_layer(self):
return initialize_http_replication(self) return initialize_http_replication(self)
@ -265,6 +272,12 @@ class HomeServer(object):
def build_profile_handler(self): def build_profile_handler(self):
return ProfileHandler(self) return ProfileHandler(self)
def build_deactivate_account_handler(self):
return DeactivateAccountHandler(self)
def build_set_password_handler(self):
return SetPasswordHandler(self)
def build_event_sources(self): def build_event_sources(self):
return EventSources(self) return EventSources(self)
@ -294,6 +307,11 @@ class HomeServer(object):
**self.db_config.get("args", {}) **self.db_config.get("args", {})
) )
def build_media_repository_resource(self):
# build the media repo resource. This indirects through the HomeServer
# to ensure that we only have a single instance of
return MediaRepositoryResource(self)
def build_media_repository(self): def build_media_repository(self):
return MediaRepository(self) return MediaRepository(self)
@ -321,7 +339,7 @@ class HomeServer(object):
return ActionGenerator(self) return ActionGenerator(self)
def build_user_directory_handler(self): def build_user_directory_handler(self):
return UserDirectoyHandler(self) return UserDirectoryHandler(self)
def build_groups_local_handler(self): def build_groups_local_handler(self):
return GroupsLocalHandler(self) return GroupsLocalHandler(self)

View File

@ -3,10 +3,14 @@ import synapse.federation.transaction_queue
import synapse.federation.transport.client import synapse.federation.transport.client
import synapse.handlers import synapse.handlers
import synapse.handlers.auth import synapse.handlers.auth
import synapse.handlers.deactivate_account
import synapse.handlers.device import synapse.handlers.device
import synapse.handlers.e2e_keys import synapse.handlers.e2e_keys
import synapse.storage import synapse.handlers.set_password
import synapse.rest.media.v1.media_repository
import synapse.state import synapse.state
import synapse.storage
class HomeServer(object): class HomeServer(object):
def get_auth(self) -> synapse.api.auth.Auth: def get_auth(self) -> synapse.api.auth.Auth:
@ -30,8 +34,20 @@ class HomeServer(object):
def get_state_handler(self) -> synapse.state.StateHandler: def get_state_handler(self) -> synapse.state.StateHandler:
pass pass
def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
pass
def get_set_password_handler(self) -> synapse.handlers.set_password.SetPasswordHandler:
pass
def get_federation_sender(self) -> synapse.federation.transaction_queue.TransactionQueue: def get_federation_sender(self) -> synapse.federation.transaction_queue.TransactionQueue:
pass pass
def get_federation_transport_client(self) -> synapse.federation.transport.client.TransportLayerClient: def get_federation_transport_client(self) -> synapse.federation.transport.client.TransportLayerClient:
pass pass
def get_media_repository_resource(self) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource:
pass
def get_media_repository(self) -> synapse.rest.media.v1.media_repository.MediaRepository:
pass

View File

@ -16,8 +16,6 @@ import logging
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache from synapse.util.caches.descriptors import Cache
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
import synapse.metrics import synapse.metrics
@ -180,10 +178,6 @@ class SQLBaseStore(object):
self._get_event_cache = Cache("*getEvent*", keylen=3, self._get_event_cache = Cache("*getEvent*", keylen=3,
max_entries=hs.config.event_cache_size) max_entries=hs.config.event_cache_size)
self._state_group_cache = DictionaryCache(
"*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
)
self._event_fetch_lock = threading.Condition() self._event_fetch_lock = threading.Condition()
self._event_fetch_list = [] self._event_fetch_list = []
self._event_fetch_ongoing = 0 self._event_fetch_ongoing = 0
@ -475,23 +469,53 @@ class SQLBaseStore(object):
txn.executemany(sql, vals) txn.executemany(sql, vals)
@defer.inlineCallbacks
def _simple_upsert(self, table, keyvalues, values, def _simple_upsert(self, table, keyvalues, values,
insertion_values={}, desc="_simple_upsert", lock=True): insertion_values={}, desc="_simple_upsert", lock=True):
""" """
`lock` should generally be set to True (the default), but can be set
to False if either of the following are true:
* there is a UNIQUE INDEX on the key columns. In this case a conflict
will cause an IntegrityError in which case this function will retry
the update.
* we somehow know that we are the only thread which will be updating
this table.
Args: Args:
table (str): The table to upsert into table (str): The table to upsert into
keyvalues (dict): The unique key tables and their new values keyvalues (dict): The unique key tables and their new values
values (dict): The nonunique columns and their new values values (dict): The nonunique columns and their new values
insertion_values (dict): key/values to use when inserting insertion_values (dict): additional key/values to use only when
inserting
lock (bool): True to lock the table when doing the upsert.
Returns: Returns:
Deferred(bool): True if a new entry was created, False if an Deferred(bool): True if a new entry was created, False if an
existing one was updated. existing one was updated.
""" """
return self.runInteraction( attempts = 0
desc, while True:
self._simple_upsert_txn, table, keyvalues, values, insertion_values, try:
lock result = yield self.runInteraction(
) desc,
self._simple_upsert_txn, table, keyvalues, values, insertion_values,
lock=lock
)
defer.returnValue(result)
except self.database_engine.module.IntegrityError as e:
attempts += 1
if attempts >= 5:
# don't retry forever, because things other than races
# can cause IntegrityErrors
raise
# presumably we raced with another transaction: let's retry.
logger.warn(
"IntegrityError when upserting into %s; retrying: %s",
table, e
)
def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={}, def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={},
lock=True): lock=True):
@ -499,7 +523,7 @@ class SQLBaseStore(object):
if lock: if lock:
self.database_engine.lock_table(txn, table) self.database_engine.lock_table(txn, table)
# Try to update # First try to update.
sql = "UPDATE %s SET %s WHERE %s" % ( sql = "UPDATE %s SET %s WHERE %s" % (
table, table,
", ".join("%s = ?" % (k,) for k in values), ", ".join("%s = ?" % (k,) for k in values),
@ -508,28 +532,29 @@ class SQLBaseStore(object):
sqlargs = values.values() + keyvalues.values() sqlargs = values.values() + keyvalues.values()
txn.execute(sql, sqlargs) txn.execute(sql, sqlargs)
if txn.rowcount == 0: if txn.rowcount > 0:
# We didn't update and rows so insert a new one # successfully updated at least one row.
allvalues = {}
allvalues.update(keyvalues)
allvalues.update(values)
allvalues.update(insertion_values)
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table,
", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues)
)
txn.execute(sql, allvalues.values())
return True
else:
return False return False
# We didn't update any rows so insert a new one
allvalues = {}
allvalues.update(keyvalues)
allvalues.update(values)
allvalues.update(insertion_values)
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table,
", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues)
)
txn.execute(sql, allvalues.values())
# successfully inserted
return True
def _simple_select_one(self, table, keyvalues, retcols, def _simple_select_one(self, table, keyvalues, retcols,
allow_none=False, desc="_simple_select_one"): allow_none=False, desc="_simple_select_one"):
"""Executes a SELECT query on the named table, which is expected to """Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it. return a single row, returning multiple columns from it.
Args: Args:
table : string giving the table name table : string giving the table name
@ -582,20 +607,18 @@ class SQLBaseStore(object):
@staticmethod @staticmethod
def _simple_select_onecol_txn(txn, table, keyvalues, retcol): def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
if keyvalues:
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
else:
where = ""
sql = ( sql = (
"SELECT %(retcol)s FROM %(table)s %(where)s" "SELECT %(retcol)s FROM %(table)s"
) % { ) % {
"retcol": retcol, "retcol": retcol,
"table": table, "table": table,
"where": where,
} }
txn.execute(sql, keyvalues.values()) if keyvalues:
sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
txn.execute(sql, keyvalues.values())
else:
txn.execute(sql)
return [r[0] for r in txn] return [r[0] for r in txn]
@ -606,7 +629,7 @@ class SQLBaseStore(object):
Args: Args:
table (str): table name table (str): table name
keyvalues (dict): column names and values to select the rows with keyvalues (dict|None): column names and values to select the rows with
retcol (str): column whos value we wish to retrieve. retcol (str): column whos value we wish to retrieve.
Returns: Returns:

View File

@ -222,9 +222,12 @@ class AccountDataStore(SQLBaseStore):
""" """
content_json = json.dumps(content) content_json = json.dumps(content)
def add_account_data_txn(txn, next_id): with self._account_data_id_gen.get_next() as next_id:
self._simple_upsert_txn( # no need to lock here as room_account_data has a unique constraint
txn, # on (user_id, room_id, account_data_type) so _simple_upsert will
# retry if there is a conflict.
yield self._simple_upsert(
desc="add_room_account_data",
table="room_account_data", table="room_account_data",
keyvalues={ keyvalues={
"user_id": user_id, "user_id": user_id,
@ -234,19 +237,20 @@ class AccountDataStore(SQLBaseStore):
values={ values={
"stream_id": next_id, "stream_id": next_id,
"content": content_json, "content": content_json,
} },
lock=False,
) )
txn.call_after(
self._account_data_stream_cache.entity_has_changed,
user_id, next_id,
)
txn.call_after(self.get_account_data_for_user.invalidate, (user_id,))
self._update_max_stream_id(txn, next_id)
with self._account_data_id_gen.get_next() as next_id: # it's theoretically possible for the above to succeed and the
yield self.runInteraction( # below to fail - in which case we might reuse a stream id on
"add_room_account_data", add_account_data_txn, next_id # restart, and the above update might not get propagated. That
) # doesn't sound any worse than the whole update getting lost,
# which is what would happen if we combined the two into one
# transaction.
yield self._update_max_stream_id(next_id)
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
self.get_account_data_for_user.invalidate((user_id,))
result = self._account_data_id_gen.get_current_token() result = self._account_data_id_gen.get_current_token()
defer.returnValue(result) defer.returnValue(result)
@ -263,9 +267,12 @@ class AccountDataStore(SQLBaseStore):
""" """
content_json = json.dumps(content) content_json = json.dumps(content)
def add_account_data_txn(txn, next_id): with self._account_data_id_gen.get_next() as next_id:
self._simple_upsert_txn( # no need to lock here as account_data has a unique constraint on
txn, # (user_id, account_data_type) so _simple_upsert will retry if
# there is a conflict.
yield self._simple_upsert(
desc="add_user_account_data",
table="account_data", table="account_data",
keyvalues={ keyvalues={
"user_id": user_id, "user_id": user_id,
@ -274,40 +281,46 @@ class AccountDataStore(SQLBaseStore):
values={ values={
"stream_id": next_id, "stream_id": next_id,
"content": content_json, "content": content_json,
} },
lock=False,
) )
txn.call_after(
self._account_data_stream_cache.entity_has_changed, # it's theoretically possible for the above to succeed and the
# below to fail - in which case we might reuse a stream id on
# restart, and the above update might not get propagated. That
# doesn't sound any worse than the whole update getting lost,
# which is what would happen if we combined the two into one
# transaction.
yield self._update_max_stream_id(next_id)
self._account_data_stream_cache.entity_has_changed(
user_id, next_id, user_id, next_id,
) )
txn.call_after(self.get_account_data_for_user.invalidate, (user_id,)) self.get_account_data_for_user.invalidate((user_id,))
txn.call_after( self.get_global_account_data_by_type_for_user.invalidate(
self.get_global_account_data_by_type_for_user.invalidate,
(account_data_type, user_id,) (account_data_type, user_id,)
) )
self._update_max_stream_id(txn, next_id)
with self._account_data_id_gen.get_next() as next_id:
yield self.runInteraction(
"add_user_account_data", add_account_data_txn, next_id
)
result = self._account_data_id_gen.get_current_token() result = self._account_data_id_gen.get_current_token()
defer.returnValue(result) defer.returnValue(result)
def _update_max_stream_id(self, txn, next_id): def _update_max_stream_id(self, next_id):
"""Update the max stream_id """Update the max stream_id
Args: Args:
txn: The database cursor
next_id(int): The the revision to advance to. next_id(int): The the revision to advance to.
""" """
update_max_id_sql = ( def _update(txn):
"UPDATE account_data_max_stream_id" update_max_id_sql = (
" SET stream_id = ?" "UPDATE account_data_max_stream_id"
" WHERE stream_id < ?" " SET stream_id = ?"
" WHERE stream_id < ?"
)
txn.execute(update_max_id_sql, (next_id, next_id))
return self.runInteraction(
"update_account_data_max_stream_id",
_update,
) )
txn.execute(update_max_id_sql, (next_id, next_id))
@cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000) @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context): def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):

View File

@ -85,6 +85,7 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_performance = {} self._background_update_performance = {}
self._background_update_queue = [] self._background_update_queue = []
self._background_update_handlers = {} self._background_update_handlers = {}
self._all_done = False
@defer.inlineCallbacks @defer.inlineCallbacks
def start_doing_background_updates(self): def start_doing_background_updates(self):
@ -106,8 +107,40 @@ class BackgroundUpdateStore(SQLBaseStore):
"No more background updates to do." "No more background updates to do."
" Unscheduling background update task." " Unscheduling background update task."
) )
self._all_done = True
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks
def has_completed_background_updates(self):
"""Check if all the background updates have completed
Returns:
Deferred[bool]: True if all background updates have completed
"""
# if we've previously determined that there is nothing left to do, that
# is easy
if self._all_done:
defer.returnValue(True)
# obviously, if we have things in our queue, we're not done.
if self._background_update_queue:
defer.returnValue(False)
# otherwise, check if there are updates to be run. This is important,
# as we may be running on a worker which doesn't perform the bg updates
# itself, but still wants to wait for them to happen.
updates = yield self._simple_select_onecol(
"background_updates",
keyvalues=None,
retcol="1",
desc="check_background_updates",
)
if not updates:
self._all_done = True
defer.returnValue(True)
defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def do_next_background_update(self, desired_duration_ms): def do_next_background_update(self, desired_duration_ms):
"""Does some amount of work on the next queued background update """Does some amount of work on the next queued background update
@ -269,7 +302,7 @@ class BackgroundUpdateStore(SQLBaseStore):
# Sqlite doesn't support concurrent creation of indexes. # Sqlite doesn't support concurrent creation of indexes.
# #
# We don't use partial indices on SQLite as it wasn't introduced # We don't use partial indices on SQLite as it wasn't introduced
# until 3.8, and wheezy has 3.7 # until 3.8, and wheezy and CentOS 7 have 3.7
# #
# We assume that sqlite doesn't give us invalid indices; however # We assume that sqlite doesn't give us invalid indices; however
# we may still end up with the index existing but the # we may still end up with the index existing but the

View File

@ -12,13 +12,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage.background_updates import BackgroundUpdateStore
from ._base import SQLBaseStore
class MediaRepositoryStore(SQLBaseStore): class MediaRepositoryStore(BackgroundUpdateStore):
"""Persistence for attachments and avatars""" """Persistence for attachments and avatars"""
def __init__(self, db_conn, hs):
super(MediaRepositoryStore, self).__init__(db_conn, hs)
self.register_background_index_update(
update_name='local_media_repository_url_idx',
index_name='local_media_repository_url_idx',
table='local_media_repository',
columns=['created_ts'],
where_clause='url_cache IS NOT NULL',
)
def get_default_thumbnails(self, top_level_type, sub_type): def get_default_thumbnails(self, top_level_type, sub_type):
return [] return []

View File

@ -15,6 +15,9 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.storage.roommember import ProfileInfo
from synapse.api.errors import StoreError
from ._base import SQLBaseStore from ._base import SQLBaseStore
@ -26,6 +29,30 @@ class ProfileStore(SQLBaseStore):
desc="create_profile", desc="create_profile",
) )
@defer.inlineCallbacks
def get_profileinfo(self, user_localpart):
try:
profile = yield self._simple_select_one(
table="profiles",
keyvalues={"user_id": user_localpart},
retcols=("displayname", "avatar_url"),
desc="get_profileinfo",
)
except StoreError as e:
if e.code == 404:
# no match
defer.returnValue(ProfileInfo(None, None))
return
else:
raise
defer.returnValue(
ProfileInfo(
avatar_url=profile['avatar_url'],
display_name=profile['displayname'],
)
)
def get_profile_displayname(self, user_localpart): def get_profile_displayname(self, user_localpart):
return self._simple_select_one_onecol( return self._simple_select_one_onecol(
table="profiles", table="profiles",

View File

@ -204,34 +204,35 @@ class PusherStore(SQLBaseStore):
pushkey, pushkey_ts, lang, data, last_stream_ordering, pushkey, pushkey_ts, lang, data, last_stream_ordering,
profile_tag=""): profile_tag=""):
with self._pushers_id_gen.get_next() as stream_id: with self._pushers_id_gen.get_next() as stream_id:
def f(txn): # no need to lock because `pushers` has a unique key on
newly_inserted = self._simple_upsert_txn( # (app_id, pushkey, user_name) so _simple_upsert will retry
txn, newly_inserted = yield self._simple_upsert(
"pushers", table="pushers",
{ keyvalues={
"app_id": app_id, "app_id": app_id,
"pushkey": pushkey, "pushkey": pushkey,
"user_name": user_id, "user_name": user_id,
}, },
{ values={
"access_token": access_token, "access_token": access_token,
"kind": kind, "kind": kind,
"app_display_name": app_display_name, "app_display_name": app_display_name,
"device_display_name": device_display_name, "device_display_name": device_display_name,
"ts": pushkey_ts, "ts": pushkey_ts,
"lang": lang, "lang": lang,
"data": encode_canonical_json(data), "data": encode_canonical_json(data),
"last_stream_ordering": last_stream_ordering, "last_stream_ordering": last_stream_ordering,
"profile_tag": profile_tag, "profile_tag": profile_tag,
"id": stream_id, "id": stream_id,
}, },
) desc="add_pusher",
if newly_inserted: lock=False,
# get_if_user_has_pusher only cares if the user has )
# at least *one* pusher.
txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,))
yield self.runInteraction("add_pusher", f) if newly_inserted:
# get_if_user_has_pusher only cares if the user has
# at least *one* pusher.
self.get_if_user_has_pusher.invalidate(user_id,)
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id): def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
@ -243,11 +244,19 @@ class PusherStore(SQLBaseStore):
"pushers", "pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id} {"app_id": app_id, "pushkey": pushkey, "user_name": user_id}
) )
self._simple_upsert_txn(
# it's possible for us to end up with duplicate rows for
# (app_id, pushkey, user_id) at different stream_ids, but that
# doesn't really matter.
self._simple_insert_txn(
txn, txn,
"deleted_pushers", table="deleted_pushers",
{"app_id": app_id, "pushkey": pushkey, "user_id": user_id}, values={
{"stream_id": stream_id}, "stream_id": stream_id,
"app_id": app_id,
"pushkey": pushkey,
"user_id": user_id,
},
) )
with self._pushers_id_gen.get_next() as stream_id: with self._pushers_id_gen.get_next() as stream_id:
@ -310,9 +319,12 @@ class PusherStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def set_throttle_params(self, pusher_id, room_id, params): def set_throttle_params(self, pusher_id, room_id, params):
# no need to lock because `pusher_throttle` has a primary key on
# (pusher, room_id) so _simple_upsert will retry
yield self._simple_upsert( yield self._simple_upsert(
"pusher_throttle", "pusher_throttle",
{"pusher": pusher_id, "room_id": room_id}, {"pusher": pusher_id, "room_id": room_id},
params, params,
desc="set_throttle_params" desc="set_throttle_params",
lock=False,
) )

View File

@ -254,8 +254,8 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
If None, tokens associated with any device (or no device) will If None, tokens associated with any device (or no device) will
be deleted be deleted
Returns: Returns:
defer.Deferred[list[str, str|None]]: a list of the deleted tokens defer.Deferred[list[str, int, str|None, int]]: a list of
and device IDs (token, token id, device id) for each of the deleted tokens
""" """
def f(txn): def f(txn):
keyvalues = { keyvalues = {
@ -272,12 +272,12 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
values.append(except_token_id) values.append(except_token_id)
txn.execute( txn.execute(
"SELECT token, device_id FROM access_tokens WHERE %s" % where_clause, "SELECT token, id, device_id FROM access_tokens WHERE %s" % where_clause,
values values
) )
tokens_and_devices = [(r[0], r[1]) for r in txn] tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
for token, _ in tokens_and_devices: for token, _, _ in tokens_and_devices:
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(
txn, self.get_user_by_access_token, (token,) txn, self.get_user_by_access_token, (token,)
) )

View File

@ -13,7 +13,10 @@
* limitations under the License. * limitations under the License.
*/ */
CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL; -- this didn't work on SQLite 3.7 (because of lack of partial indexes), so was
-- removed and replaced with 46/local_media_repository_url_idx.sql.
--
-- CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL;
-- we need to change `expires` to `expires_ts` so that we can index on it. SQLite doesn't support -- we need to change `expires` to `expires_ts` so that we can index on it. SQLite doesn't support
-- indices on expressions until 3.9. -- indices on expressions until 3.9.

View File

@ -0,0 +1,35 @@
/* Copyright 2017 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- drop the unique constraint on deleted_pushers so that we can just insert
-- into it rather than upserting.
CREATE TABLE deleted_pushers2 (
stream_id BIGINT NOT NULL,
app_id TEXT NOT NULL,
pushkey TEXT NOT NULL,
user_id TEXT NOT NULL
);
INSERT INTO deleted_pushers2 (stream_id, app_id, pushkey, user_id)
SELECT stream_id, app_id, pushkey, user_id from deleted_pushers;
DROP TABLE deleted_pushers;
ALTER TABLE deleted_pushers2 RENAME TO deleted_pushers;
-- create the index after doing the inserts because that's more efficient.
-- it also means we can give it the same name as the old one without renaming.
CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id);

View File

@ -0,0 +1,24 @@
/* Copyright 2017 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- register a background update which will recreate the
-- local_media_repository_url_idx index.
--
-- We do this as a bg update not because it is a particularly onerous
-- operation, but because we'd like it to be a partial index if possible, and
-- the background_index_update code will understand whether we are on
-- postgres or sqlite and behave accordingly.
INSERT INTO background_updates (update_name, progress_json) VALUES
('local_media_repository_url_idx', '{}');

View File

@ -0,0 +1,35 @@
/* Copyright 2017 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- change the user_directory table to also cover global local user profiles
-- rather than just profiles within specific rooms.
CREATE TABLE user_directory2 (
user_id TEXT NOT NULL,
room_id TEXT,
display_name TEXT,
avatar_url TEXT
);
INSERT INTO user_directory2(user_id, room_id, display_name, avatar_url)
SELECT user_id, room_id, display_name, avatar_url from user_directory;
DROP TABLE user_directory;
ALTER TABLE user_directory2 RENAME TO user_directory;
-- create indexes after doing the inserts because that's more efficient.
-- it also means we can give it the same name as the old one without renaming.
CREATE INDEX user_directory_room_idx ON user_directory(room_id);
CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id);

View File

@ -13,16 +13,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore from collections import namedtuple
from synapse.util.caches.descriptors import cached, cachedList import logging
from synapse.util.caches import intern_string
from synapse.util.stringutils import to_ascii
from synapse.storage.engines import PostgresEngine
from twisted.internet import defer from twisted.internet import defer
from collections import namedtuple
import logging from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.engines import PostgresEngine
from synapse.util.caches import intern_string, CACHE_SIZE_FACTOR
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.stringutils import to_ascii
from ._base import SQLBaseStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -40,23 +42,11 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt
return len(self.delta_ids) if self.delta_ids else 0 return len(self.delta_ids) if self.delta_ids else 0
class StateStore(SQLBaseStore): class StateGroupReadStore(SQLBaseStore):
""" Keeps track of the state at a given event. """The read-only parts of StateGroupStore
This is done by the concept of `state groups`. Every event is a assigned None of these functions write to the state tables, so are suitable for
a state group (identified by an arbitrary string), which references a including in the SlavedStores.
collection of state events. The current state of an event is then the
collection of state events referenced by the event's state group.
Hence, every change in the current state causes a new state group to be
generated. However, if no change happens (e.g., if we get a message event
with only one parent it inherits the state group from its parent.)
There are three tables:
* `state_groups`: Stores group name, first event with in the group and
room id.
* `event_to_state_groups`: Maps events to state groups.
* `state_groups_state`: Maps state group to state events.
""" """
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
@ -64,21 +54,10 @@ class StateStore(SQLBaseStore):
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(StateStore, self).__init__(db_conn, hs) super(StateGroupReadStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, self._state_group_cache = DictionaryCache(
self._background_deduplicate_state, "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
)
self.register_background_update_handler(
self.STATE_GROUP_INDEX_UPDATE_NAME,
self._background_index_state,
)
self.register_background_index_update(
self.CURRENT_STATE_INDEX_UPDATE_NAME,
index_name="current_state_events_member_index",
table="current_state_events",
columns=["state_key"],
where_clause="type='m.room.member'",
) )
@cached(max_entries=100000, iterable=True) @cached(max_entries=100000, iterable=True)
@ -190,178 +169,6 @@ class StateStore(SQLBaseStore):
for group, event_id_map in group_to_ids.iteritems() for group, event_id_map in group_to_ids.iteritems()
}) })
def _have_persisted_state_group_txn(self, txn, state_group):
txn.execute(
"SELECT count(*) FROM state_groups WHERE id = ?",
(state_group,)
)
row = txn.fetchone()
return row and row[0]
def _store_mult_state_groups_txn(self, txn, events_and_contexts):
state_groups = {}
for event, context in events_and_contexts:
if event.internal_metadata.is_outlier():
continue
if context.current_state_ids is None:
# AFAIK, this can never happen
logger.error(
"Non-outlier event %s had current_state_ids==None",
event.event_id)
continue
# if the event was rejected, just give it the same state as its
# predecessor.
if context.rejected:
state_groups[event.event_id] = context.prev_group
continue
state_groups[event.event_id] = context.state_group
if self._have_persisted_state_group_txn(txn, context.state_group):
continue
self._simple_insert_txn(
txn,
table="state_groups",
values={
"id": context.state_group,
"room_id": event.room_id,
"event_id": event.event_id,
},
)
# We persist as a delta if we can, while also ensuring the chain
# of deltas isn't tooo long, as otherwise read performance degrades.
if context.prev_group:
is_in_db = self._simple_select_one_onecol_txn(
txn,
table="state_groups",
keyvalues={"id": context.prev_group},
retcol="id",
allow_none=True,
)
if not is_in_db:
raise Exception(
"Trying to persist state with unpersisted prev_group: %r"
% (context.prev_group,)
)
potential_hops = self._count_state_group_hops_txn(
txn, context.prev_group
)
if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
self._simple_insert_txn(
txn,
table="state_group_edges",
values={
"state_group": context.state_group,
"prev_state_group": context.prev_group,
},
)
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
"state_group": context.state_group,
"room_id": event.room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in context.delta_ids.iteritems()
],
)
else:
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
"state_group": context.state_group,
"room_id": event.room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in context.current_state_ids.iteritems()
],
)
# Prefill the state group cache with this group.
# It's fine to use the sequence like this as the state group map
# is immutable. (If the map wasn't immutable then this prefill could
# race with another update)
txn.call_after(
self._state_group_cache.update,
self._state_group_cache.sequence,
key=context.state_group,
value=dict(context.current_state_ids),
full=True,
)
self._simple_insert_many_txn(
txn,
table="event_to_state_groups",
values=[
{
"state_group": state_group_id,
"event_id": event_id,
}
for event_id, state_group_id in state_groups.iteritems()
],
)
for event_id, state_group_id in state_groups.iteritems():
txn.call_after(
self._get_state_group_for_event.prefill,
(event_id,), state_group_id
)
def _count_state_group_hops_txn(self, txn, state_group):
"""Given a state group, count how many hops there are in the tree.
This is used to ensure the delta chains don't get too long.
"""
if isinstance(self.database_engine, PostgresEngine):
sql = ("""
WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
SELECT prev_state_group FROM state_group_edges e, state s
WHERE s.state_group = e.state_group
)
SELECT count(*) FROM state;
""")
txn.execute(sql, (state_group,))
row = txn.fetchone()
if row and row[0]:
return row[0]
else:
return 0
else:
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
next_group = state_group
count = 0
while next_group:
next_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
retcol="prev_state_group",
allow_none=True,
)
if next_group:
count += 1
return count
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_state_groups_from_groups(self, groups, types): def _get_state_groups_from_groups(self, groups, types):
"""Returns dictionary state_group -> (dict of (type, state_key) -> event id) """Returns dictionary state_group -> (dict of (type, state_key) -> event id)
@ -742,6 +549,220 @@ class StateStore(SQLBaseStore):
defer.returnValue(results) defer.returnValue(results)
class StateStore(StateGroupReadStore, BackgroundUpdateStore):
""" Keeps track of the state at a given event.
This is done by the concept of `state groups`. Every event is a assigned
a state group (identified by an arbitrary string), which references a
collection of state events. The current state of an event is then the
collection of state events referenced by the event's state group.
Hence, every change in the current state causes a new state group to be
generated. However, if no change happens (e.g., if we get a message event
with only one parent it inherits the state group from its parent.)
There are three tables:
* `state_groups`: Stores group name, first event with in the group and
room id.
* `event_to_state_groups`: Maps events to state groups.
* `state_groups_state`: Maps state group to state events.
"""
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
def __init__(self, db_conn, hs):
super(StateStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
self._background_deduplicate_state,
)
self.register_background_update_handler(
self.STATE_GROUP_INDEX_UPDATE_NAME,
self._background_index_state,
)
self.register_background_index_update(
self.CURRENT_STATE_INDEX_UPDATE_NAME,
index_name="current_state_events_member_index",
table="current_state_events",
columns=["state_key"],
where_clause="type='m.room.member'",
)
def _have_persisted_state_group_txn(self, txn, state_group):
txn.execute(
"SELECT count(*) FROM state_groups WHERE id = ?",
(state_group,)
)
row = txn.fetchone()
return row and row[0]
def _store_mult_state_groups_txn(self, txn, events_and_contexts):
state_groups = {}
for event, context in events_and_contexts:
if event.internal_metadata.is_outlier():
continue
if context.current_state_ids is None:
# AFAIK, this can never happen
logger.error(
"Non-outlier event %s had current_state_ids==None",
event.event_id)
continue
# if the event was rejected, just give it the same state as its
# predecessor.
if context.rejected:
state_groups[event.event_id] = context.prev_group
continue
state_groups[event.event_id] = context.state_group
if self._have_persisted_state_group_txn(txn, context.state_group):
continue
self._simple_insert_txn(
txn,
table="state_groups",
values={
"id": context.state_group,
"room_id": event.room_id,
"event_id": event.event_id,
},
)
# We persist as a delta if we can, while also ensuring the chain
# of deltas isn't tooo long, as otherwise read performance degrades.
if context.prev_group:
is_in_db = self._simple_select_one_onecol_txn(
txn,
table="state_groups",
keyvalues={"id": context.prev_group},
retcol="id",
allow_none=True,
)
if not is_in_db:
raise Exception(
"Trying to persist state with unpersisted prev_group: %r"
% (context.prev_group,)
)
potential_hops = self._count_state_group_hops_txn(
txn, context.prev_group
)
if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
self._simple_insert_txn(
txn,
table="state_group_edges",
values={
"state_group": context.state_group,
"prev_state_group": context.prev_group,
},
)
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
"state_group": context.state_group,
"room_id": event.room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in context.delta_ids.iteritems()
],
)
else:
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
"state_group": context.state_group,
"room_id": event.room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in context.current_state_ids.iteritems()
],
)
# Prefill the state group cache with this group.
# It's fine to use the sequence like this as the state group map
# is immutable. (If the map wasn't immutable then this prefill could
# race with another update)
txn.call_after(
self._state_group_cache.update,
self._state_group_cache.sequence,
key=context.state_group,
value=dict(context.current_state_ids),
full=True,
)
self._simple_insert_many_txn(
txn,
table="event_to_state_groups",
values=[
{
"state_group": state_group_id,
"event_id": event_id,
}
for event_id, state_group_id in state_groups.iteritems()
],
)
for event_id, state_group_id in state_groups.iteritems():
txn.call_after(
self._get_state_group_for_event.prefill,
(event_id,), state_group_id
)
def _count_state_group_hops_txn(self, txn, state_group):
"""Given a state group, count how many hops there are in the tree.
This is used to ensure the delta chains don't get too long.
"""
if isinstance(self.database_engine, PostgresEngine):
sql = ("""
WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
SELECT prev_state_group FROM state_group_edges e, state s
WHERE s.state_group = e.state_group
)
SELECT count(*) FROM state;
""")
txn.execute(sql, (state_group,))
row = txn.fetchone()
if row and row[0]:
return row[0]
else:
return 0
else:
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
next_group = state_group
count = 0
while next_group:
next_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
retcol="prev_state_group",
allow_none=True,
)
if next_group:
count += 1
return count
def get_next_state_group(self): def get_next_state_group(self):
return self._state_groups_id_gen.get_next() return self._state_groups_id_gen.get_next()

View File

@ -39,7 +39,7 @@ from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import PostgresEngine, Sqlite3Engine
import logging import logging
@ -234,7 +234,7 @@ class StreamStore(SQLBaseStore):
results = {} results = {}
room_ids = list(room_ids) room_ids = list(room_ids)
for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)): for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
res = yield preserve_context_over_deferred(defer.gatherResults([ res = yield make_deferred_yieldable(defer.gatherResults([
preserve_fn(self.get_room_events_stream_for_room)( preserve_fn(self.get_room_events_stream_for_room)(
room_id, from_key, to_key, limit, order=order, room_id, from_key, to_key, limit, order=order,
) )

View File

@ -164,7 +164,7 @@ class UserDirectoryStore(SQLBaseStore):
) )
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
# We weight the loclpart most highly, then display name and finally # We weight the localpart most highly, then display name and finally
# server name # server name
if new_entry: if new_entry:
sql = """ sql = """
@ -317,6 +317,16 @@ class UserDirectoryStore(SQLBaseStore):
rows = yield self._execute("get_all_rooms", None, sql) rows = yield self._execute("get_all_rooms", None, sql)
defer.returnValue([room_id for room_id, in rows]) defer.returnValue([room_id for room_id, in rows])
@defer.inlineCallbacks
def get_all_local_users(self):
"""Get all local users
"""
sql = """
SELECT name FROM users
"""
rows = yield self._execute("get_all_local_users", None, sql)
defer.returnValue([name for name, in rows])
def add_users_who_share_room(self, room_id, share_private, user_id_tuples): def add_users_who_share_room(self, room_id, share_private, user_id_tuples):
"""Insert entries into the users_who_share_rooms table. The first """Insert entries into the users_who_share_rooms table. The first
user should be a local user. user should be a local user.
@ -629,6 +639,20 @@ class UserDirectoryStore(SQLBaseStore):
] ]
} }
""" """
if self.hs.config.user_directory_search_all_users:
join_clause = ""
where_clause = "?<>''" # naughty hack to keep the same number of binds
else:
join_clause = """
LEFT JOIN users_in_public_rooms AS p USING (user_id)
LEFT JOIN (
SELECT other_user_id AS user_id FROM users_who_share_rooms
WHERE user_id = ? AND share_private
) AS s USING (user_id)
"""
where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)"
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
full_query, exact_query, prefix_query = _parse_query_postgres(search_term) full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
@ -641,13 +665,9 @@ class UserDirectoryStore(SQLBaseStore):
SELECT d.user_id, display_name, avatar_url SELECT d.user_id, display_name, avatar_url
FROM user_directory_search FROM user_directory_search
INNER JOIN user_directory AS d USING (user_id) INNER JOIN user_directory AS d USING (user_id)
LEFT JOIN users_in_public_rooms AS p USING (user_id) %s
LEFT JOIN (
SELECT other_user_id AS user_id FROM users_who_share_rooms
WHERE user_id = ? AND share_private
) AS s USING (user_id)
WHERE WHERE
(s.user_id IS NOT NULL OR p.user_id IS NOT NULL) %s
AND vector @@ to_tsquery('english', ?) AND vector @@ to_tsquery('english', ?)
ORDER BY ORDER BY
(CASE WHEN s.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END) (CASE WHEN s.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
@ -671,7 +691,7 @@ class UserDirectoryStore(SQLBaseStore):
display_name IS NULL, display_name IS NULL,
avatar_url IS NULL avatar_url IS NULL
LIMIT ? LIMIT ?
""" """ % (join_clause, where_clause)
args = (user_id, full_query, exact_query, prefix_query, limit + 1,) args = (user_id, full_query, exact_query, prefix_query, limit + 1,)
elif isinstance(self.database_engine, Sqlite3Engine): elif isinstance(self.database_engine, Sqlite3Engine):
search_query = _parse_query_sqlite(search_term) search_query = _parse_query_sqlite(search_term)
@ -680,20 +700,16 @@ class UserDirectoryStore(SQLBaseStore):
SELECT d.user_id, display_name, avatar_url SELECT d.user_id, display_name, avatar_url
FROM user_directory_search FROM user_directory_search
INNER JOIN user_directory AS d USING (user_id) INNER JOIN user_directory AS d USING (user_id)
LEFT JOIN users_in_public_rooms AS p USING (user_id) %s
LEFT JOIN (
SELECT other_user_id AS user_id FROM users_who_share_rooms
WHERE user_id = ? AND share_private
) AS s USING (user_id)
WHERE WHERE
(s.user_id IS NOT NULL OR p.user_id IS NOT NULL) %s
AND value MATCH ? AND value MATCH ?
ORDER BY ORDER BY
rank(matchinfo(user_directory_search)) DESC, rank(matchinfo(user_directory_search)) DESC,
display_name IS NULL, display_name IS NULL,
avatar_url IS NULL avatar_url IS NULL
LIMIT ? LIMIT ?
""" """ % (join_clause, where_clause)
args = (user_id, search_query, limit + 1) args = (user_id, search_query, limit + 1)
else: else:
# This should be unreachable. # This should be unreachable.
@ -723,7 +739,7 @@ def _parse_query_sqlite(search_term):
# Pull out the individual words, discarding any non-word characters. # Pull out the individual words, discarding any non-word characters.
results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
return " & ".join("(%s* | %s)" % (result, result,) for result in results) return " & ".join("(%s* OR %s)" % (result, result,) for result in results)
def _parse_query_postgres(search_term): def _parse_query_postgres(search_term):

View File

@ -17,7 +17,7 @@
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from .logcontext import ( from .logcontext import (
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred, PreserveLoggingContext, make_deferred_yieldable, preserve_fn
) )
from synapse.util import logcontext, unwrapFirstError from synapse.util import logcontext, unwrapFirstError
@ -351,7 +351,7 @@ class ReadWriteLock(object):
# We wait for the latest writer to finish writing. We can safely ignore # We wait for the latest writer to finish writing. We can safely ignore
# any existing readers... as they're readers. # any existing readers... as they're readers.
yield curr_writer yield make_deferred_yieldable(curr_writer)
@contextmanager @contextmanager
def _ctx_manager(): def _ctx_manager():
@ -380,7 +380,7 @@ class ReadWriteLock(object):
curr_readers.clear() curr_readers.clear()
self.key_to_current_writer[key] = new_defer self.key_to_current_writer[key] = new_defer
yield preserve_context_over_deferred(defer.gatherResults(to_wait_on)) yield make_deferred_yieldable(defer.gatherResults(to_wait_on))
@contextmanager @contextmanager
def _ctx_manager(): def _ctx_manager():

View File

@ -13,32 +13,24 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.util.logcontext import (
PreserveLoggingContext, preserve_context_over_fn
)
from synapse.util import unwrapFirstError
import logging import logging
from twisted.internet import defer
from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def user_left_room(distributor, user, room_id): def user_left_room(distributor, user, room_id):
return preserve_context_over_fn( with PreserveLoggingContext():
distributor.fire, distributor.fire("user_left_room", user=user, room_id=room_id)
"user_left_room", user=user, room_id=room_id
)
def user_joined_room(distributor, user, room_id): def user_joined_room(distributor, user, room_id):
return preserve_context_over_fn( with PreserveLoggingContext():
distributor.fire, distributor.fire("user_joined_room", user=user, room_id=room_id)
"user_joined_room", user=user, room_id=room_id
)
class Distributor(object): class Distributor(object):

View File

@ -261,67 +261,6 @@ class PreserveLoggingContext(object):
) )
class _PreservingContextDeferred(defer.Deferred):
"""A deferred that ensures that all callbacks and errbacks are called with
the given logging context.
"""
def __init__(self, context):
self._log_context = context
defer.Deferred.__init__(self)
def addCallbacks(self, callback, errback=None,
callbackArgs=None, callbackKeywords=None,
errbackArgs=None, errbackKeywords=None):
callback = self._wrap_callback(callback)
errback = self._wrap_callback(errback)
return defer.Deferred.addCallbacks(
self, callback,
errback=errback,
callbackArgs=callbackArgs,
callbackKeywords=callbackKeywords,
errbackArgs=errbackArgs,
errbackKeywords=errbackKeywords,
)
def _wrap_callback(self, f):
def g(res, *args, **kwargs):
with PreserveLoggingContext(self._log_context):
res = f(res, *args, **kwargs)
return res
return g
def preserve_context_over_fn(fn, *args, **kwargs):
"""Takes a function and invokes it with the given arguments, but removes
and restores the current logging context while doing so.
If the result is a deferred, call preserve_context_over_deferred before
returning it.
"""
with PreserveLoggingContext():
res = fn(*args, **kwargs)
if isinstance(res, defer.Deferred):
return preserve_context_over_deferred(res)
else:
return res
def preserve_context_over_deferred(deferred, context=None):
"""Given a deferred wrap it such that any callbacks added later to it will
be invoked with the current context.
Deprecated: this almost certainly doesn't do want you want, ie make
the deferred follow the synapse logcontext rules: try
``make_deferred_yieldable`` instead.
"""
if context is None:
context = LoggingContext.current_context()
d = _PreservingContextDeferred(context)
deferred.chainDeferred(d)
return d
def preserve_fn(f): def preserve_fn(f):
"""Wraps a function, to ensure that the current context is restored after """Wraps a function, to ensure that the current context is restored after
return from the function, and that the sentinel context is set once the return from the function, and that the sentinel context is set once the

View File

@ -17,7 +17,7 @@ from twisted.internet import defer
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
import logging import logging
@ -58,7 +58,7 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state,
always_include_ids (set(event_id)): set of event ids to specifically always_include_ids (set(event_id)): set of event ids to specifically
include (unless sender is ignored) include (unless sender is ignored)
""" """
forgotten = yield preserve_context_over_deferred(defer.gatherResults([ forgotten = yield make_deferred_yieldable(defer.gatherResults([
defer.maybeDeferred( defer.maybeDeferred(
preserve_fn(store.who_forgot_in_room), preserve_fn(store.who_forgot_in_room),
room_id, room_id,

View File

@ -36,6 +36,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
id="unique_identifier", id="unique_identifier",
url="some_url", url="some_url",
token="some_token", token="some_token",
hostname="matrix.org", # only used by get_groups_for_user
namespaces={ namespaces={
ApplicationService.NS_USERS: [], ApplicationService.NS_USERS: [],
ApplicationService.NS_ROOMS: [], ApplicationService.NS_ROOMS: [],

View File

@ -58,7 +58,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.mock_federation_resource = MockHttpResource() self.mock_federation_resource = MockHttpResource()
mock_notifier = Mock(spec=["on_new_event"]) mock_notifier = Mock()
self.on_new_event = mock_notifier.on_new_event self.on_new_event = mock_notifier.on_new_event
self.auth = Mock(spec=[]) self.auth = Mock(spec=[])
@ -76,6 +76,9 @@ class TypingNotificationsTestCase(unittest.TestCase):
"set_received_txn_response", "set_received_txn_response",
"get_destination_retry_timings", "get_destination_retry_timings",
"get_devices_by_remote", "get_devices_by_remote",
# Bits that user_directory needs
"get_user_directory_stream_pos",
"get_current_state_deltas",
]), ]),
state_handler=self.state_handler, state_handler=self.state_handler,
handlers=None, handlers=None,
@ -122,6 +125,15 @@ class TypingNotificationsTestCase(unittest.TestCase):
return set(str(u) for u in self.room_members) return set(str(u) for u in self.room_members)
self.state_handler.get_current_user_in_room = get_current_user_in_room self.state_handler.get_current_user_in_room = get_current_user_in_room
self.datastore.get_user_directory_stream_pos.return_value = (
# we deliberately return a non-None stream pos to avoid doing an initial_spam
defer.succeed(1)
)
self.datastore.get_current_state_deltas.return_value = (
None
)
self.auth.check_joined_room = check_joined_room self.auth.check_joined_room = check_joined_room
self.datastore.get_to_device_stream_token = lambda: 0 self.datastore.get_to_device_stream_token = lambda: 0

View File

@ -1,5 +1,7 @@
from twisted.python import failure
from synapse.rest.client.v2_alpha.register import RegisterRestServlet from synapse.rest.client.v2_alpha.register import RegisterRestServlet
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError, InteractiveAuthIncompleteError
from twisted.internet import defer from twisted.internet import defer
from mock import Mock from mock import Mock
from tests import unittest from tests import unittest
@ -24,7 +26,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
side_effect=lambda x: self.appservice) side_effect=lambda x: self.appservice)
) )
self.auth_result = (False, None, None, None) self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None))
self.auth_handler = Mock( self.auth_handler = Mock(
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result), check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
get_session_data=Mock(return_value=None) get_session_data=Mock(return_value=None)
@ -86,6 +88,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.request.args = { self.request.args = {
"access_token": "i_am_an_app_service" "access_token": "i_am_an_app_service"
} }
self.request_data = json.dumps({ self.request_data = json.dumps({
"username": "kermit" "username": "kermit"
}) })
@ -120,7 +123,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
"device_id": device_id, "device_id": device_id,
}) })
self.registration_handler.check_username = Mock(return_value=True) self.registration_handler.check_username = Mock(return_value=True)
self.auth_result = (True, None, { self.auth_result = (None, {
"username": "kermit", "username": "kermit",
"password": "monkey" "password": "monkey"
}, None) }, None)
@ -150,7 +153,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
"password": "monkey" "password": "monkey"
}) })
self.registration_handler.check_username = Mock(return_value=True) self.registration_handler.check_username = Mock(return_value=True)
self.auth_result = (True, None, { self.auth_result = (None, {
"username": "kermit", "username": "kermit",
"password": "monkey" "password": "monkey"
}, None) }, None)