mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2024-12-18 16:44:17 -05:00
Merge branch 'develop' of github.com:matrix-org/synapse into mysql
This commit is contained in:
commit
f6583796fe
@ -1,3 +1,12 @@
|
||||
Changes in synapse v0.8.1 (2015-03-18)
|
||||
======================================
|
||||
|
||||
* Disable registration by default. New users can be added using the command
|
||||
``register_new_matrix_user`` or by enabling registration in the config.
|
||||
* Add metrics to synapse. To enable metrics use config options
|
||||
``enable_metrics`` and ``metrics_port``.
|
||||
* Fix bug where banning only kicked the user.
|
||||
|
||||
Changes in synapse v0.8.0 (2015-03-06)
|
||||
======================================
|
||||
|
||||
|
11
README.rst
11
README.rst
@ -128,6 +128,17 @@ To set up your homeserver, run (in your virtualenv, as before)::
|
||||
|
||||
Substituting your host and domain name as appropriate.
|
||||
|
||||
By default, registration of new users is disabled. You can either enable
|
||||
registration in the config (it is then recommended to also set up CAPTCHA), or
|
||||
you can use the command line to register new users::
|
||||
|
||||
$ source ~/.synapse/bin/activate
|
||||
$ register_new_matrix_user -c homeserver.yaml https://localhost:8448
|
||||
New user localpart: erikj
|
||||
Password:
|
||||
Confirm password:
|
||||
Success!
|
||||
|
||||
For reliable VoIP calls to be routed via this homeserver, you MUST configure
|
||||
a TURN server. See docs/turn-howto.rst for details.
|
||||
|
||||
|
149
register_new_matrix_user
Executable file
149
register_new_matrix_user
Executable file
@ -0,0 +1,149 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import getpass
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import sys
|
||||
import urllib2
|
||||
import yaml
|
||||
|
||||
|
||||
def request_registration(user, password, server_location, shared_secret):
|
||||
mac = hmac.new(
|
||||
key=shared_secret,
|
||||
msg=user,
|
||||
digestmod=hashlib.sha1,
|
||||
).hexdigest()
|
||||
|
||||
data = {
|
||||
"user": user,
|
||||
"password": password,
|
||||
"mac": mac,
|
||||
"type": "org.matrix.login.shared_secret",
|
||||
}
|
||||
|
||||
server_location = server_location.rstrip("/")
|
||||
|
||||
print "Sending registration request..."
|
||||
|
||||
req = urllib2.Request(
|
||||
"%s/_matrix/client/api/v1/register" % (server_location,),
|
||||
data=json.dumps(data),
|
||||
headers={'Content-Type': 'application/json'}
|
||||
)
|
||||
try:
|
||||
f = urllib2.urlopen(req)
|
||||
f.read()
|
||||
f.close()
|
||||
print "Success."
|
||||
except urllib2.HTTPError as e:
|
||||
print "ERROR! Received %d %s" % (e.code, e.reason,)
|
||||
if 400 <= e.code < 500:
|
||||
if e.info().type == "application/json":
|
||||
resp = json.load(e)
|
||||
if "error" in resp:
|
||||
print resp["error"]
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def register_new_user(user, password, server_location, shared_secret):
|
||||
if not user:
|
||||
try:
|
||||
default_user = getpass.getuser()
|
||||
except:
|
||||
default_user = None
|
||||
|
||||
if default_user:
|
||||
user = raw_input("New user localpart [%s]: " % (default_user,))
|
||||
if not user:
|
||||
user = default_user
|
||||
else:
|
||||
user = raw_input("New user localpart: ")
|
||||
|
||||
if not user:
|
||||
print "Invalid user name"
|
||||
sys.exit(1)
|
||||
|
||||
if not password:
|
||||
password = getpass.getpass("Password: ")
|
||||
|
||||
if not password:
|
||||
print "Password cannot be blank."
|
||||
sys.exit(1)
|
||||
|
||||
confirm_password = getpass.getpass("Confirm password: ")
|
||||
|
||||
if password != confirm_password:
|
||||
print "Passwords do not match"
|
||||
sys.exit(1)
|
||||
|
||||
request_registration(user, password, server_location, shared_secret)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Used to register new users with a given home server when"
|
||||
" registration has been disabled. The home server must be"
|
||||
" configured with the 'registration_shared_secret' option"
|
||||
" set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-u", "--user",
|
||||
default=None,
|
||||
help="Local part of the new user. Will prompt if omitted.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p", "--password",
|
||||
default=None,
|
||||
help="New password for user. Will prompt if omitted.",
|
||||
)
|
||||
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument(
|
||||
"-c", "--config",
|
||||
type=argparse.FileType('r'),
|
||||
help="Path to server config file. Used to read in shared secret.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"-k", "--shared-secret",
|
||||
help="Shared secret as defined in server config file.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"server_url",
|
||||
default="https://localhost:8448",
|
||||
nargs='?',
|
||||
help="URL to use to talk to the home server. Defaults to "
|
||||
" 'https://localhost:8448'.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if "config" in args and args.config:
|
||||
config = yaml.safe_load(args.config)
|
||||
secret = config.get("registration_shared_secret", None)
|
||||
if not secret:
|
||||
print "No 'registration_shared_secret' defined in config."
|
||||
sys.exit(1)
|
||||
else:
|
||||
secret = args.shared_secret
|
||||
|
||||
register_new_user(args.user, args.password, args.server_url, secret)
|
4
setup.py
4
setup.py
@ -45,7 +45,7 @@ setup(
|
||||
version=version,
|
||||
packages=find_packages(exclude=["tests", "tests.*"]),
|
||||
description="Reference Synapse Home Server",
|
||||
install_requires=dependencies["REQUIREMENTS"].keys(),
|
||||
install_requires=dependencies['requirements'](include_conditional=True).keys(),
|
||||
setup_requires=[
|
||||
"Twisted==14.0.2", # Here to override setuptools_trial's dependency on Twisted>=2.4.0
|
||||
"setuptools_trial",
|
||||
@ -55,5 +55,5 @@ setup(
|
||||
include_package_data=True,
|
||||
zip_safe=False,
|
||||
long_description=long_description,
|
||||
scripts=["synctl"],
|
||||
scripts=["synctl", "register_new_matrix_user"],
|
||||
)
|
||||
|
@ -16,4 +16,4 @@
|
||||
""" This is a reference implementation of a Matrix home server.
|
||||
"""
|
||||
|
||||
__version__ = "0.8.0"
|
||||
__version__ = "0.8.1-r2"
|
||||
|
@ -388,7 +388,7 @@ class Auth(object):
|
||||
AuthError if no user by that token exists or the token is invalid.
|
||||
"""
|
||||
try:
|
||||
ret = yield self.store.get_user_by_token(token=token)
|
||||
ret = yield self.store.get_user_by_token(token)
|
||||
if not ret:
|
||||
raise StoreError(400, "Unknown token")
|
||||
user_info = {
|
||||
|
@ -60,6 +60,7 @@ class LoginType(object):
|
||||
EMAIL_IDENTITY = u"m.login.email.identity"
|
||||
RECAPTCHA = u"m.login.recaptcha"
|
||||
APPLICATION_SERVICE = u"m.login.application_service"
|
||||
SHARED_SECRET = u"org.matrix.login.shared_secret"
|
||||
|
||||
|
||||
class EventTypes(object):
|
||||
|
@ -60,9 +60,9 @@ import re
|
||||
import resource
|
||||
import subprocess
|
||||
import sqlite3
|
||||
import syweb
|
||||
import yaml
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -84,6 +84,7 @@ class SynapseHomeServer(HomeServer):
|
||||
return AppServiceRestResource(self)
|
||||
|
||||
def build_resource_for_web_client(self):
|
||||
import syweb
|
||||
syweb_path = os.path.dirname(syweb.__file__)
|
||||
webclient_path = os.path.join(syweb_path, "webclient")
|
||||
return File(webclient_path) # TODO configurable?
|
||||
@ -131,7 +132,7 @@ class SynapseHomeServer(HomeServer):
|
||||
True.
|
||||
"""
|
||||
config = self.get_config()
|
||||
web_client = config.webclient
|
||||
web_client = config.web_client
|
||||
|
||||
# list containing (path_str, Resource) e.g:
|
||||
# [ ("/aaa/bbb/cc", Resource1), ("/aaa/dummy", Resource2) ]
|
||||
@ -344,7 +345,8 @@ def setup(config_options):
|
||||
|
||||
config.setup_logging()
|
||||
|
||||
check_requirements()
|
||||
# check any extra requirements we have now we have a config
|
||||
check_requirements(config)
|
||||
|
||||
version_string = get_version_string()
|
||||
|
||||
@ -472,6 +474,7 @@ def run(hs):
|
||||
|
||||
def main():
|
||||
with LoggingContext("main"):
|
||||
# check base requirements
|
||||
check_requirements()
|
||||
hs = setup(sys.argv[1:])
|
||||
run(hs)
|
||||
|
@ -15,19 +15,46 @@
|
||||
|
||||
from ._base import Config
|
||||
|
||||
from synapse.util.stringutils import random_string_with_symbols
|
||||
|
||||
import distutils.util
|
||||
|
||||
|
||||
class RegistrationConfig(Config):
|
||||
|
||||
def __init__(self, args):
|
||||
super(RegistrationConfig, self).__init__(args)
|
||||
self.disable_registration = args.disable_registration
|
||||
|
||||
# `args.disable_registration` may either be a bool or a string depending
|
||||
# on if the option was given a value (e.g. --disable-registration=false
|
||||
# would set `args.disable_registration` to "false" not False.)
|
||||
self.disable_registration = bool(
|
||||
distutils.util.strtobool(str(args.disable_registration))
|
||||
)
|
||||
self.registration_shared_secret = args.registration_shared_secret
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
super(RegistrationConfig, cls).add_arguments(parser)
|
||||
reg_group = parser.add_argument_group("registration")
|
||||
|
||||
reg_group.add_argument(
|
||||
"--disable-registration",
|
||||
action='store_true',
|
||||
help="Disable registration of new users."
|
||||
const=True,
|
||||
default=True,
|
||||
nargs='?',
|
||||
help="Disable registration of new users.",
|
||||
)
|
||||
reg_group.add_argument(
|
||||
"--registration-shared-secret", type=str,
|
||||
help="If set, allows registration by anyone who also has the shared"
|
||||
" secret, even if registration is otherwise disabled.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_config(cls, args, config_dir_path):
|
||||
if args.disable_registration is None:
|
||||
args.disable_registration = True
|
||||
|
||||
if args.registration_shared_secret is None:
|
||||
args.registration_shared_secret = random_string_with_symbols(50)
|
||||
|
@ -28,7 +28,7 @@ class ServerConfig(Config):
|
||||
self.unsecure_port = args.unsecure_port
|
||||
self.daemonize = args.daemonize
|
||||
self.pid_file = self.abspath(args.pid_file)
|
||||
self.webclient = True
|
||||
self.web_client = args.web_client
|
||||
self.manhole = args.manhole
|
||||
self.soft_file_limit = args.soft_file_limit
|
||||
|
||||
@ -68,6 +68,8 @@ class ServerConfig(Config):
|
||||
server_group.add_argument('--pid-file', default="homeserver.pid",
|
||||
help="When running as a daemon, the file to"
|
||||
" store the pid in")
|
||||
server_group.add_argument('--web_client', default=True, type=bool,
|
||||
help="Whether or not to serve a web client")
|
||||
server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
|
||||
type=int,
|
||||
help="Turn on the twisted telnet manhole"
|
||||
|
@ -361,4 +361,5 @@ SERVLET_CLASSES = (
|
||||
FederationInviteServlet,
|
||||
FederationQueryAuthServlet,
|
||||
FederationGetMissingEventsServlet,
|
||||
FederationEventAuthServlet,
|
||||
)
|
||||
|
@ -290,6 +290,8 @@ class FederationHandler(BaseHandler):
|
||||
"""
|
||||
logger.debug("Joining %s to %s", joinee, room_id)
|
||||
|
||||
yield self.store.clean_room_for_join(room_id)
|
||||
|
||||
origin, pdu = yield self.replication_layer.make_join(
|
||||
target_hosts,
|
||||
room_id,
|
||||
|
@ -31,6 +31,7 @@ import base64
|
||||
import bcrypt
|
||||
import json
|
||||
import logging
|
||||
import urllib
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -63,6 +64,13 @@ class RegistrationHandler(BaseHandler):
|
||||
password_hash = bcrypt.hashpw(password, bcrypt.gensalt())
|
||||
|
||||
if localpart:
|
||||
if localpart and urllib.quote(localpart) != localpart:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"User ID must only contain characters which do not"
|
||||
" require URL encoding."
|
||||
)
|
||||
|
||||
user = UserID(localpart, self.hs.hostname)
|
||||
user_id = user.to_string()
|
||||
|
||||
|
@ -51,8 +51,8 @@ class RestServlet(object):
|
||||
pattern = self.PATTERN
|
||||
|
||||
for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"):
|
||||
if hasattr(self, "on_%s" % (method)):
|
||||
method_handler = getattr(self, "on_%s" % (method))
|
||||
if hasattr(self, "on_%s" % (method,)):
|
||||
method_handler = getattr(self, "on_%s" % (method,))
|
||||
http_server.register_path(method, pattern, method_handler)
|
||||
else:
|
||||
raise NotImplementedError("RestServlet must register something.")
|
||||
|
@ -5,7 +5,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
REQUIREMENTS = {
|
||||
"syutil>=0.0.3": ["syutil"],
|
||||
"matrix_angular_sdk>=0.6.5": ["syweb>=0.6.5"],
|
||||
"Twisted==14.0.2": ["twisted==14.0.2"],
|
||||
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
|
||||
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
|
||||
@ -18,6 +17,19 @@ REQUIREMENTS = {
|
||||
"pillow": ["PIL"],
|
||||
"pydenticon": ["pydenticon"],
|
||||
}
|
||||
CONDITIONAL_REQUIREMENTS = {
|
||||
"web_client": {
|
||||
"matrix_angular_sdk>=0.6.5": ["syweb>=0.6.5"],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def requirements(config=None, include_conditional=False):
|
||||
reqs = REQUIREMENTS.copy()
|
||||
for key, req in CONDITIONAL_REQUIREMENTS.items():
|
||||
if (config and getattr(config, key)) or include_conditional:
|
||||
reqs.update(req)
|
||||
return reqs
|
||||
|
||||
|
||||
def github_link(project, version, egg):
|
||||
@ -46,10 +58,11 @@ class MissingRequirementError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def check_requirements():
|
||||
def check_requirements(config=None):
|
||||
"""Checks that all the modules needed by synapse have been correctly
|
||||
installed and are at the correct version"""
|
||||
for dependency, module_requirements in REQUIREMENTS.items():
|
||||
for dependency, module_requirements in (
|
||||
requirements(config, include_conditional=False).items()):
|
||||
for module_requirement in module_requirements:
|
||||
if ">=" in module_requirement:
|
||||
module_name, required_version = module_requirement.split(">=")
|
||||
@ -110,7 +123,7 @@ def list_requirements():
|
||||
egg = link.split("#egg=")[1]
|
||||
linked.append(egg.split('-')[0])
|
||||
result.append(link)
|
||||
for requirement in REQUIREMENTS:
|
||||
for requirement in requirements(include_conditional=True):
|
||||
is_linked = False
|
||||
for link in linked:
|
||||
if requirement.replace('-', '_').startswith(link):
|
||||
|
@ -27,7 +27,6 @@ from hashlib import sha1
|
||||
import hmac
|
||||
import simplejson as json
|
||||
import logging
|
||||
import urllib
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -110,14 +109,22 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||
login_type = register_json["type"]
|
||||
|
||||
is_application_server = login_type == LoginType.APPLICATION_SERVICE
|
||||
if self.disable_registration and not is_application_server:
|
||||
is_using_shared_secret = login_type == LoginType.SHARED_SECRET
|
||||
|
||||
can_register = (
|
||||
not self.disable_registration
|
||||
or is_application_server
|
||||
or is_using_shared_secret
|
||||
)
|
||||
if not can_register:
|
||||
raise SynapseError(403, "Registration has been disabled")
|
||||
|
||||
stages = {
|
||||
LoginType.RECAPTCHA: self._do_recaptcha,
|
||||
LoginType.PASSWORD: self._do_password,
|
||||
LoginType.EMAIL_IDENTITY: self._do_email_identity,
|
||||
LoginType.APPLICATION_SERVICE: self._do_app_service
|
||||
LoginType.APPLICATION_SERVICE: self._do_app_service,
|
||||
LoginType.SHARED_SECRET: self._do_shared_secret,
|
||||
}
|
||||
|
||||
session_info = self._get_session_info(request, session)
|
||||
@ -255,14 +262,11 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||
)
|
||||
|
||||
password = register_json["password"].encode("utf-8")
|
||||
desired_user_id = (register_json["user"].encode("utf-8")
|
||||
if "user" in register_json else None)
|
||||
if (desired_user_id
|
||||
and urllib.quote(desired_user_id) != desired_user_id):
|
||||
raise SynapseError(
|
||||
400,
|
||||
"User ID must only contain characters which do not " +
|
||||
"require URL encoding.")
|
||||
desired_user_id = (
|
||||
register_json["user"].encode("utf-8")
|
||||
if "user" in register_json else None
|
||||
)
|
||||
|
||||
handler = self.handlers.registration_handler
|
||||
(user_id, token) = yield handler.register(
|
||||
localpart=desired_user_id,
|
||||
@ -304,6 +308,51 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||
"home_server": self.hs.hostname,
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_shared_secret(self, request, register_json, session):
|
||||
yield run_on_reactor()
|
||||
|
||||
if not isinstance(register_json.get("mac", None), basestring):
|
||||
raise SynapseError(400, "Expected mac.")
|
||||
if not isinstance(register_json.get("user", None), basestring):
|
||||
raise SynapseError(400, "Expected 'user' key.")
|
||||
if not isinstance(register_json.get("password", None), basestring):
|
||||
raise SynapseError(400, "Expected 'password' key.")
|
||||
|
||||
if not self.hs.config.registration_shared_secret:
|
||||
raise SynapseError(400, "Shared secret registration is not enabled")
|
||||
|
||||
user = register_json["user"].encode("utf-8")
|
||||
|
||||
# str() because otherwise hmac complains that 'unicode' does not
|
||||
# have the buffer interface
|
||||
got_mac = str(register_json["mac"])
|
||||
|
||||
want_mac = hmac.new(
|
||||
key=self.hs.config.registration_shared_secret,
|
||||
msg=user,
|
||||
digestmod=sha1,
|
||||
).hexdigest()
|
||||
|
||||
password = register_json["password"].encode("utf-8")
|
||||
|
||||
if compare_digest(want_mac, got_mac):
|
||||
handler = self.handlers.registration_handler
|
||||
user_id, token = yield handler.register(
|
||||
localpart=user,
|
||||
password=password,
|
||||
)
|
||||
self._remove_session(session)
|
||||
defer.returnValue({
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"home_server": self.hs.hostname,
|
||||
})
|
||||
else:
|
||||
raise SynapseError(
|
||||
403, "HMAC incorrect",
|
||||
)
|
||||
|
||||
|
||||
def _parse_json(request):
|
||||
try:
|
||||
|
@ -91,7 +91,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
"user_agent": user_agent,
|
||||
"last_seen": int(self._clock.time_msec()),
|
||||
},
|
||||
or_replace=True,
|
||||
desc="insert_client_ip",
|
||||
)
|
||||
|
||||
def get_user_ip_and_agents(self, user):
|
||||
@ -101,6 +101,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
retcols=[
|
||||
"device_id", "access_token", "ip", "user_agent", "last_seen"
|
||||
],
|
||||
desc="get_user_ip_and_agents",
|
||||
)
|
||||
|
||||
|
||||
|
@ -25,6 +25,7 @@ import synapse.metrics
|
||||
from twisted.internet import defer
|
||||
|
||||
from collections import namedtuple, OrderedDict
|
||||
import functools
|
||||
import simplejson as json
|
||||
import sys
|
||||
import time
|
||||
@ -53,13 +54,12 @@ cache_counter = metrics.register_cache(
|
||||
|
||||
|
||||
# TODO(paul):
|
||||
# * more generic key management
|
||||
# * consider other eviction strategies - LRU?
|
||||
def cached(max_entries=1000):
|
||||
def cached(max_entries=1000, num_args=1):
|
||||
""" A method decorator that applies a memoizing cache around the function.
|
||||
|
||||
The function is presumed to take one additional argument, which is used as
|
||||
the key for the cache. Cache hits are served directly from the cache;
|
||||
The function is presumed to take zero or more arguments, which are used in
|
||||
a tuple as the key for the cache. Hits are served directly from the cache;
|
||||
misses use the function body to generate the value.
|
||||
|
||||
The wrapped function has an additional member, a callable called
|
||||
@ -75,25 +75,41 @@ def cached(max_entries=1000):
|
||||
|
||||
caches_by_name[name] = cache
|
||||
|
||||
def prefill(key, value):
|
||||
def prefill(*args): # because I can't *keyargs, value
|
||||
keyargs = args[:-1]
|
||||
value = args[-1]
|
||||
|
||||
if len(keyargs) != num_args:
|
||||
raise ValueError("Expected a call to have %d arguments", num_args)
|
||||
|
||||
while len(cache) > max_entries:
|
||||
cache.popitem(last=False)
|
||||
|
||||
cache[key] = value
|
||||
cache[keyargs] = value
|
||||
|
||||
@functools.wraps(orig)
|
||||
@defer.inlineCallbacks
|
||||
def wrapped(self, key):
|
||||
if key in cache:
|
||||
def wrapped(self, *keyargs):
|
||||
if len(keyargs) != num_args:
|
||||
raise ValueError("Expected a call to have %d arguments", num_args)
|
||||
|
||||
if keyargs in cache:
|
||||
cache_counter.inc_hits(name)
|
||||
defer.returnValue(cache[key])
|
||||
defer.returnValue(cache[keyargs])
|
||||
|
||||
cache_counter.inc_misses(name)
|
||||
ret = yield orig(self, key)
|
||||
prefill(key, ret)
|
||||
ret = yield orig(self, *keyargs)
|
||||
|
||||
prefill_args = keyargs + (ret,)
|
||||
prefill(*prefill_args)
|
||||
|
||||
defer.returnValue(ret)
|
||||
|
||||
def invalidate(key):
|
||||
cache.pop(key, None)
|
||||
def invalidate(*keyargs):
|
||||
if len(keyargs) != num_args:
|
||||
raise ValueError("Expected a call to have %d arguments", num_args)
|
||||
|
||||
cache.pop(keyargs, None)
|
||||
|
||||
wrapped.invalidate = invalidate
|
||||
wrapped.prefill = prefill
|
||||
@ -325,7 +341,8 @@ class SQLBaseStore(object):
|
||||
# "Simple" SQL API methods that operate on a single table with no JOINs,
|
||||
# no complex WHERE clauses, just a dict of values for columns.
|
||||
|
||||
def _simple_insert(self, table, values, or_replace=False, or_ignore=False):
|
||||
def _simple_insert(self, table, values, or_replace=False, or_ignore=False,
|
||||
desc="_simple_insert"):
|
||||
"""Executes an INSERT query on the named table.
|
||||
|
||||
Args:
|
||||
@ -334,7 +351,7 @@ class SQLBaseStore(object):
|
||||
or_replace : bool; if True performs an INSERT OR REPLACE
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"_simple_insert",
|
||||
desc,
|
||||
self._simple_insert_txn, table, values, or_replace=or_replace,
|
||||
or_ignore=or_ignore,
|
||||
)
|
||||
@ -357,7 +374,7 @@ class SQLBaseStore(object):
|
||||
txn.execute(sql, values.values())
|
||||
return txn.lastrowid
|
||||
|
||||
def _simple_upsert(self, table, keyvalues, values):
|
||||
def _simple_upsert(self, table, keyvalues, values, desc="_simple_upsert"):
|
||||
"""
|
||||
Args:
|
||||
table (str): The table to upsert into
|
||||
@ -366,7 +383,7 @@ class SQLBaseStore(object):
|
||||
Returns: A deferred
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"_simple_upsert",
|
||||
desc,
|
||||
self._simple_upsert_txn, table, keyvalues, values
|
||||
)
|
||||
|
||||
@ -402,7 +419,7 @@ class SQLBaseStore(object):
|
||||
txn.execute(sql, allvalues.values())
|
||||
|
||||
def _simple_select_one(self, table, keyvalues, retcols,
|
||||
allow_none=False):
|
||||
allow_none=False, desc="_simple_select_one"):
|
||||
"""Executes a SELECT query on the named table, which is expected to
|
||||
return a single row, returning a single column from it.
|
||||
|
||||
@ -414,12 +431,15 @@ class SQLBaseStore(object):
|
||||
allow_none : If true, return None instead of failing if the SELECT
|
||||
statement returns no rows
|
||||
"""
|
||||
return self._simple_selectupdate_one(
|
||||
table, keyvalues, retcols=retcols, allow_none=allow_none
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_select_one_txn,
|
||||
table, keyvalues, retcols, allow_none,
|
||||
)
|
||||
|
||||
def _simple_select_one_onecol(self, table, keyvalues, retcol,
|
||||
allow_none=False):
|
||||
allow_none=False,
|
||||
desc="_simple_select_one_onecol"):
|
||||
"""Executes a SELECT query on the named table, which is expected to
|
||||
return a single row, returning a single column from it."
|
||||
|
||||
@ -429,7 +449,7 @@ class SQLBaseStore(object):
|
||||
retcol : string giving the name of the column to return
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"_simple_select_one_onecol",
|
||||
desc,
|
||||
self._simple_select_one_onecol_txn,
|
||||
table, keyvalues, retcol, allow_none=allow_none,
|
||||
)
|
||||
@ -464,7 +484,8 @@ class SQLBaseStore(object):
|
||||
|
||||
return [r[0] for r in txn.fetchall()]
|
||||
|
||||
def _simple_select_onecol(self, table, keyvalues, retcol):
|
||||
def _simple_select_onecol(self, table, keyvalues, retcol,
|
||||
desc="_simple_select_onecol"):
|
||||
"""Executes a SELECT query on the named table, which returns a list
|
||||
comprising of the values of the named column from the selected rows.
|
||||
|
||||
@ -477,12 +498,13 @@ class SQLBaseStore(object):
|
||||
Deferred: Results in a list
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"_simple_select_onecol",
|
||||
desc,
|
||||
self._simple_select_onecol_txn,
|
||||
table, keyvalues, retcol
|
||||
)
|
||||
|
||||
def _simple_select_list(self, table, keyvalues, retcols):
|
||||
def _simple_select_list(self, table, keyvalues, retcols,
|
||||
desc="_simple_select_list"):
|
||||
"""Executes a SELECT query on the named table, which may return zero or
|
||||
more rows, returning the result as a list of dicts.
|
||||
|
||||
@ -493,7 +515,7 @@ class SQLBaseStore(object):
|
||||
retcols : list of strings giving the names of the columns to return
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"_simple_select_list",
|
||||
desc,
|
||||
self._simple_select_list_txn,
|
||||
table, keyvalues, retcols
|
||||
)
|
||||
@ -525,7 +547,7 @@ class SQLBaseStore(object):
|
||||
return self.cursor_to_dict(txn)
|
||||
|
||||
def _simple_update_one(self, table, keyvalues, updatevalues,
|
||||
retcols=None):
|
||||
desc="_simple_update_one"):
|
||||
"""Executes an UPDATE query on the named table, setting new values for
|
||||
columns in a row matching the key values.
|
||||
|
||||
@ -543,45 +565,70 @@ class SQLBaseStore(object):
|
||||
get-and-set. This can be used to implement compare-and-set by putting
|
||||
the update column in the 'keyvalues' dict as well.
|
||||
"""
|
||||
return self._simple_selectupdate_one(table, keyvalues, updatevalues,
|
||||
retcols=retcols)
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_update_one_txn,
|
||||
table, keyvalues, updatevalues,
|
||||
)
|
||||
|
||||
def _simple_update_one_txn(self, txn, table, keyvalues, updatevalues):
|
||||
update_sql = "UPDATE %s SET %s WHERE %s" % (
|
||||
table,
|
||||
", ".join("%s = ?" % (k,) for k in updatevalues),
|
||||
" AND ".join("%s = ?" % (k,) for k in keyvalues)
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
update_sql,
|
||||
updatevalues.values() + keyvalues.values()
|
||||
)
|
||||
|
||||
if txn.rowcount == 0:
|
||||
raise StoreError(404, "No row found")
|
||||
if txn.rowcount > 1:
|
||||
raise StoreError(500, "More than one row matched")
|
||||
|
||||
def _simple_select_one_txn(self, txn, table, keyvalues, retcols,
|
||||
allow_none=False):
|
||||
select_sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
|
||||
", ".join(retcols),
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k) for k in keyvalues)
|
||||
)
|
||||
|
||||
txn.execute(select_sql, keyvalues.values())
|
||||
|
||||
row = txn.fetchone()
|
||||
if not row:
|
||||
if allow_none:
|
||||
return None
|
||||
raise StoreError(404, "No row found")
|
||||
if txn.rowcount > 1:
|
||||
raise StoreError(500, "More than one row matched")
|
||||
|
||||
return dict(zip(retcols, row))
|
||||
|
||||
def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None,
|
||||
retcols=None, allow_none=False):
|
||||
retcols=None, allow_none=False,
|
||||
desc="_simple_selectupdate_one"):
|
||||
""" Combined SELECT then UPDATE."""
|
||||
if retcols:
|
||||
select_sql = "SELECT %s FROM %s WHERE %s" % (
|
||||
", ".join(retcols),
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k) for k in keyvalues)
|
||||
)
|
||||
|
||||
if updatevalues:
|
||||
update_sql = "UPDATE %s SET %s WHERE %s" % (
|
||||
table,
|
||||
", ".join("%s = ?" % (k,) for k in updatevalues),
|
||||
" AND ".join("%s = ?" % (k,) for k in keyvalues)
|
||||
)
|
||||
|
||||
def func(txn):
|
||||
ret = None
|
||||
if retcols:
|
||||
txn.execute(select_sql, keyvalues.values())
|
||||
|
||||
row = txn.fetchone()
|
||||
if not row:
|
||||
if allow_none:
|
||||
return None
|
||||
raise StoreError(404, "No row found")
|
||||
if txn.rowcount > 1:
|
||||
raise StoreError(500, "More than one row matched")
|
||||
|
||||
ret = dict(zip(retcols, row))
|
||||
ret = self._simple_select_one_txn(
|
||||
txn,
|
||||
table=table,
|
||||
keyvalues=keyvalues,
|
||||
retcols=retcols,
|
||||
allow_none=allow_none,
|
||||
)
|
||||
|
||||
if updatevalues:
|
||||
txn.execute(
|
||||
update_sql,
|
||||
updatevalues.values() + keyvalues.values()
|
||||
self._simple_update_one_txn(
|
||||
txn,
|
||||
table=table,
|
||||
keyvalues=keyvalues,
|
||||
updatevalues=updatevalues,
|
||||
)
|
||||
|
||||
# if txn.rowcount == 0:
|
||||
@ -590,9 +637,9 @@ class SQLBaseStore(object):
|
||||
raise StoreError(500, "More than one row matched")
|
||||
|
||||
return ret
|
||||
return self.runInteraction("_simple_selectupdate_one", func)
|
||||
return self.runInteraction(desc, func)
|
||||
|
||||
def _simple_delete_one(self, table, keyvalues):
|
||||
def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"):
|
||||
"""Executes a DELETE query on the named table, expecting to delete a
|
||||
single row.
|
||||
|
||||
@ -611,9 +658,9 @@ class SQLBaseStore(object):
|
||||
raise StoreError(404, "No row found")
|
||||
if txn.rowcount > 1:
|
||||
raise StoreError(500, "more than one row matched")
|
||||
return self.runInteraction("_simple_delete_one", func)
|
||||
return self.runInteraction(desc, func)
|
||||
|
||||
def _simple_delete(self, table, keyvalues):
|
||||
def _simple_delete(self, table, keyvalues, desc="_simple_delete"):
|
||||
"""Executes a DELETE query on the named table.
|
||||
|
||||
Args:
|
||||
@ -621,7 +668,7 @@ class SQLBaseStore(object):
|
||||
keyvalues : dict of column names and values to select the row with
|
||||
"""
|
||||
|
||||
return self.runInteraction("_simple_delete", self._simple_delete_txn)
|
||||
return self.runInteraction(desc, self._simple_delete_txn)
|
||||
|
||||
def _simple_delete_txn(self, txn, table, keyvalues):
|
||||
sql = "DELETE FROM %s WHERE %s" % (
|
||||
|
@ -48,6 +48,7 @@ class DirectoryStore(SQLBaseStore):
|
||||
{"room_alias": room_alias.to_string()},
|
||||
"room_id",
|
||||
allow_none=True,
|
||||
desc="get_association_from_room_alias",
|
||||
)
|
||||
|
||||
if not room_id:
|
||||
@ -58,6 +59,7 @@ class DirectoryStore(SQLBaseStore):
|
||||
"room_alias_servers",
|
||||
{"room_alias": room_alias.to_string()},
|
||||
"server",
|
||||
desc="get_association_from_room_alias",
|
||||
)
|
||||
|
||||
if not servers:
|
||||
@ -87,6 +89,7 @@ class DirectoryStore(SQLBaseStore):
|
||||
"room_alias": room_alias.to_string(),
|
||||
"room_id": room_id,
|
||||
},
|
||||
desc="create_room_alias_association",
|
||||
)
|
||||
except sqlite3.IntegrityError:
|
||||
raise SynapseError(
|
||||
@ -100,7 +103,8 @@ class DirectoryStore(SQLBaseStore):
|
||||
{
|
||||
"room_alias": room_alias.to_string(),
|
||||
"server": server,
|
||||
}
|
||||
},
|
||||
desc="create_room_alias_association",
|
||||
)
|
||||
|
||||
def delete_room_alias(self, room_alias):
|
||||
@ -139,4 +143,5 @@ class DirectoryStore(SQLBaseStore):
|
||||
"room_aliases",
|
||||
{"room_id": room_id},
|
||||
"room_alias",
|
||||
desc="get_aliases_for_room",
|
||||
)
|
||||
|
@ -426,3 +426,15 @@ class EventFederationStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
return events[:limit]
|
||||
|
||||
def clean_room_for_join(self, room_id):
|
||||
return self.runInteraction(
|
||||
"clean_room_for_join",
|
||||
self._clean_room_for_join_txn,
|
||||
room_id,
|
||||
)
|
||||
|
||||
def _clean_room_for_join_txn(self, txn, room_id):
|
||||
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
|
||||
|
||||
txn.execute(query, (room_id,))
|
||||
|
@ -52,6 +52,7 @@ class EventsStore(SQLBaseStore):
|
||||
is_new_state=is_new_state,
|
||||
current_state=current_state,
|
||||
)
|
||||
self.get_room_events_max_id.invalidate()
|
||||
except _RollbackButIsFineException:
|
||||
pass
|
||||
|
||||
@ -242,7 +243,6 @@ class EventsStore(SQLBaseStore):
|
||||
if stream_ordering is None:
|
||||
stream_ordering = self.get_next_stream_id()
|
||||
|
||||
|
||||
unrec = {
|
||||
k: v
|
||||
for k, v in event.get_dict().items()
|
||||
|
@ -31,6 +31,7 @@ class FilteringStore(SQLBaseStore):
|
||||
},
|
||||
retcol="filter_json",
|
||||
allow_none=False,
|
||||
desc="get_user_filter",
|
||||
)
|
||||
|
||||
defer.returnValue(json.loads(def_json))
|
||||
|
@ -32,6 +32,7 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||
{"media_id": media_id},
|
||||
("media_type", "media_length", "upload_name", "created_ts"),
|
||||
allow_none=True,
|
||||
desc="get_local_media",
|
||||
)
|
||||
|
||||
def store_local_media(self, media_id, media_type, time_now_ms, upload_name,
|
||||
@ -45,7 +46,8 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||
"upload_name": upload_name,
|
||||
"media_length": media_length,
|
||||
"user_id": user_id.to_string(),
|
||||
}
|
||||
},
|
||||
desc="store_local_media",
|
||||
)
|
||||
|
||||
def get_local_media_thumbnails(self, media_id):
|
||||
@ -55,7 +57,8 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||
(
|
||||
"thumbnail_width", "thumbnail_height", "thumbnail_method",
|
||||
"thumbnail_type", "thumbnail_length",
|
||||
)
|
||||
),
|
||||
desc="get_local_media_thumbnails",
|
||||
)
|
||||
|
||||
def store_local_thumbnail(self, media_id, thumbnail_width,
|
||||
@ -70,7 +73,8 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||
"thumbnail_method": thumbnail_method,
|
||||
"thumbnail_type": thumbnail_type,
|
||||
"thumbnail_length": thumbnail_length,
|
||||
}
|
||||
},
|
||||
desc="store_local_thumbnail",
|
||||
)
|
||||
|
||||
def get_cached_remote_media(self, origin, media_id):
|
||||
@ -82,6 +86,7 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||
"filesystem_id",
|
||||
),
|
||||
allow_none=True,
|
||||
desc="get_cached_remote_media",
|
||||
)
|
||||
|
||||
def store_cached_remote_media(self, origin, media_id, media_type,
|
||||
@ -97,7 +102,8 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||
"created_ts": time_now_ms,
|
||||
"upload_name": upload_name,
|
||||
"filesystem_id": filesystem_id,
|
||||
}
|
||||
},
|
||||
desc="store_cached_remote_media",
|
||||
)
|
||||
|
||||
def get_remote_media_thumbnails(self, origin, media_id):
|
||||
@ -107,7 +113,8 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||
(
|
||||
"thumbnail_width", "thumbnail_height", "thumbnail_method",
|
||||
"thumbnail_type", "thumbnail_length", "filesystem_id",
|
||||
)
|
||||
),
|
||||
desc="get_remote_media_thumbnails",
|
||||
)
|
||||
|
||||
def store_remote_media_thumbnail(self, origin, media_id, filesystem_id,
|
||||
@ -125,5 +132,6 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||
"thumbnail_type": thumbnail_type,
|
||||
"thumbnail_length": thumbnail_length,
|
||||
"filesystem_id": filesystem_id,
|
||||
}
|
||||
},
|
||||
desc="store_remote_media_thumbnail",
|
||||
)
|
||||
|
@ -21,6 +21,7 @@ class PresenceStore(SQLBaseStore):
|
||||
return self._simple_insert(
|
||||
table="presence",
|
||||
values={"user_id": user_localpart},
|
||||
desc="create_presence",
|
||||
)
|
||||
|
||||
def has_presence_state(self, user_localpart):
|
||||
@ -29,6 +30,7 @@ class PresenceStore(SQLBaseStore):
|
||||
keyvalues={"user_id": user_localpart},
|
||||
retcols=["user_id"],
|
||||
allow_none=True,
|
||||
desc="has_presence_state",
|
||||
)
|
||||
|
||||
def get_presence_state(self, user_localpart):
|
||||
@ -36,6 +38,7 @@ class PresenceStore(SQLBaseStore):
|
||||
table="presence",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
retcols=["state", "status_msg", "mtime"],
|
||||
desc="get_presence_state",
|
||||
)
|
||||
|
||||
def set_presence_state(self, user_localpart, new_state):
|
||||
@ -45,6 +48,7 @@ class PresenceStore(SQLBaseStore):
|
||||
updatevalues={"state": new_state["state"],
|
||||
"status_msg": new_state["status_msg"],
|
||||
"mtime": self._clock.time_msec()},
|
||||
desc="set_presence_state",
|
||||
)
|
||||
|
||||
def allow_presence_visible(self, observed_localpart, observer_userid):
|
||||
@ -52,6 +56,7 @@ class PresenceStore(SQLBaseStore):
|
||||
table="presence_allow_inbound",
|
||||
values={"observed_user_id": observed_localpart,
|
||||
"observer_user_id": observer_userid},
|
||||
desc="allow_presence_visible",
|
||||
)
|
||||
|
||||
def disallow_presence_visible(self, observed_localpart, observer_userid):
|
||||
@ -59,6 +64,7 @@ class PresenceStore(SQLBaseStore):
|
||||
table="presence_allow_inbound",
|
||||
keyvalues={"observed_user_id": observed_localpart,
|
||||
"observer_user_id": observer_userid},
|
||||
desc="disallow_presence_visible",
|
||||
)
|
||||
|
||||
def is_presence_visible(self, observed_localpart, observer_userid):
|
||||
@ -68,6 +74,7 @@ class PresenceStore(SQLBaseStore):
|
||||
"observer_user_id": observer_userid},
|
||||
retcols=["observed_user_id"],
|
||||
allow_none=True,
|
||||
desc="is_presence_visible",
|
||||
)
|
||||
|
||||
def add_presence_list_pending(self, observer_localpart, observed_userid):
|
||||
@ -76,6 +83,7 @@ class PresenceStore(SQLBaseStore):
|
||||
values={"user_id": observer_localpart,
|
||||
"observed_user_id": observed_userid,
|
||||
"accepted": False},
|
||||
desc="add_presence_list_pending",
|
||||
)
|
||||
|
||||
def set_presence_list_accepted(self, observer_localpart, observed_userid):
|
||||
@ -84,6 +92,7 @@ class PresenceStore(SQLBaseStore):
|
||||
keyvalues={"user_id": observer_localpart,
|
||||
"observed_user_id": observed_userid},
|
||||
updatevalues={"accepted": True},
|
||||
desc="set_presence_list_accepted",
|
||||
)
|
||||
|
||||
def get_presence_list(self, observer_localpart, accepted=None):
|
||||
@ -95,6 +104,7 @@ class PresenceStore(SQLBaseStore):
|
||||
table="presence_list",
|
||||
keyvalues=keyvalues,
|
||||
retcols=["observed_user_id", "accepted"],
|
||||
desc="get_presence_list",
|
||||
)
|
||||
|
||||
def del_presence_list(self, observer_localpart, observed_userid):
|
||||
@ -102,4 +112,5 @@ class PresenceStore(SQLBaseStore):
|
||||
table="presence_list",
|
||||
keyvalues={"user_id": observer_localpart,
|
||||
"observed_user_id": observed_userid},
|
||||
desc="del_presence_list",
|
||||
)
|
||||
|
@ -21,6 +21,7 @@ class ProfileStore(SQLBaseStore):
|
||||
return self._simple_insert(
|
||||
table="profiles",
|
||||
values={"user_id": user_localpart},
|
||||
desc="create_profile",
|
||||
)
|
||||
|
||||
def get_profile_displayname(self, user_localpart):
|
||||
@ -28,6 +29,7 @@ class ProfileStore(SQLBaseStore):
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
retcol="displayname",
|
||||
desc="get_profile_displayname",
|
||||
)
|
||||
|
||||
def set_profile_displayname(self, user_localpart, new_displayname):
|
||||
@ -35,6 +37,7 @@ class ProfileStore(SQLBaseStore):
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
updatevalues={"displayname": new_displayname},
|
||||
desc="set_profile_displayname",
|
||||
)
|
||||
|
||||
def get_profile_avatar_url(self, user_localpart):
|
||||
@ -42,6 +45,7 @@ class ProfileStore(SQLBaseStore):
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
retcol="avatar_url",
|
||||
desc="get_profile_avatar_url",
|
||||
)
|
||||
|
||||
def set_profile_avatar_url(self, user_localpart, new_avatar_url):
|
||||
@ -49,4 +53,5 @@ class ProfileStore(SQLBaseStore):
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
updatevalues={"avatar_url": new_avatar_url},
|
||||
desc="set_profile_avatar_url",
|
||||
)
|
||||
|
@ -50,7 +50,8 @@ class PushRuleStore(SQLBaseStore):
|
||||
results = yield self._simple_select_list(
|
||||
PushRuleEnableTable.table_name,
|
||||
{'user_name': user_name},
|
||||
PushRuleEnableTable.fields
|
||||
PushRuleEnableTable.fields,
|
||||
desc="get_push_rules_enabled_for_user",
|
||||
)
|
||||
defer.returnValue(
|
||||
{r['rule_id']: False if r['enabled'] == 0 else True for r in results}
|
||||
@ -201,7 +202,8 @@ class PushRuleStore(SQLBaseStore):
|
||||
"""
|
||||
yield self._simple_delete_one(
|
||||
PushRuleTable.table_name,
|
||||
{'user_name': user_name, 'rule_id': rule_id}
|
||||
{'user_name': user_name, 'rule_id': rule_id},
|
||||
desc="delete_push_rule",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -209,7 +211,8 @@ class PushRuleStore(SQLBaseStore):
|
||||
yield self._simple_upsert(
|
||||
PushRuleEnableTable.table_name,
|
||||
{'user_name': user_name, 'rule_id': rule_id},
|
||||
{'enabled': enabled}
|
||||
{'enabled': enabled},
|
||||
desc="set_push_rule_enabled",
|
||||
)
|
||||
|
||||
|
||||
|
@ -114,7 +114,9 @@ class PusherStore(SQLBaseStore):
|
||||
ts=pushkey_ts,
|
||||
lang=lang,
|
||||
data=data
|
||||
))
|
||||
),
|
||||
desc="add_pusher",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("create_pusher with failed: %s", e)
|
||||
raise StoreError(500, "Problem creating pusher.")
|
||||
@ -123,7 +125,8 @@ class PusherStore(SQLBaseStore):
|
||||
def delete_pusher_by_app_id_pushkey(self, app_id, pushkey):
|
||||
yield self._simple_delete_one(
|
||||
PushersTable.table_name,
|
||||
dict(app_id=app_id, pushkey=pushkey)
|
||||
{"app_id": app_id, "pushkey": pushkey},
|
||||
desc="delete_pusher_by_app_id_pushkey",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -131,7 +134,8 @@ class PusherStore(SQLBaseStore):
|
||||
yield self._simple_update_one(
|
||||
PushersTable.table_name,
|
||||
{'app_id': app_id, 'pushkey': pushkey},
|
||||
{'last_token': last_token}
|
||||
{'last_token': last_token},
|
||||
desc="update_pusher_last_token",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -140,7 +144,8 @@ class PusherStore(SQLBaseStore):
|
||||
yield self._simple_update_one(
|
||||
PushersTable.table_name,
|
||||
{'app_id': app_id, 'pushkey': pushkey},
|
||||
{'last_token': last_token, 'last_success': last_success}
|
||||
{'last_token': last_token, 'last_success': last_success},
|
||||
desc="update_pusher_last_token_and_success",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -148,7 +153,8 @@ class PusherStore(SQLBaseStore):
|
||||
yield self._simple_update_one(
|
||||
PushersTable.table_name,
|
||||
{'app_id': app_id, 'pushkey': pushkey},
|
||||
{'failing_since': failing_since}
|
||||
{'failing_since': failing_since},
|
||||
desc="update_pusher_failing_since",
|
||||
)
|
||||
|
||||
|
||||
|
@ -19,7 +19,7 @@ from sqlite3 import IntegrityError
|
||||
|
||||
from synapse.api.errors import StoreError, Codes
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from ._base import SQLBaseStore, cached
|
||||
|
||||
|
||||
class RegistrationStore(SQLBaseStore):
|
||||
@ -44,7 +44,8 @@ class RegistrationStore(SQLBaseStore):
|
||||
{
|
||||
"user_id": user_id,
|
||||
"token": token
|
||||
}
|
||||
},
|
||||
desc="add_access_token_to_user",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -87,6 +88,11 @@ class RegistrationStore(SQLBaseStore):
|
||||
"get_user_by_id", self.cursor_to_dict, query, user_id
|
||||
)
|
||||
|
||||
@cached()
|
||||
# TODO(paul): Currently there's no code to invalidate this cache. That
|
||||
# means if/when we ever add internal ways to invalidate access tokens or
|
||||
# change whether a user is a server admin, those will need to invoke
|
||||
# store.get_user_by_token.invalidate(token)
|
||||
def get_user_by_token(self, token):
|
||||
"""Get a user from the given access token.
|
||||
|
||||
@ -111,6 +117,7 @@ class RegistrationStore(SQLBaseStore):
|
||||
keyvalues={"name": user.to_string()},
|
||||
retcol="admin",
|
||||
allow_none=True,
|
||||
desc="is_server_admin",
|
||||
)
|
||||
|
||||
defer.returnValue(res if res else False)
|
||||
|
@ -29,7 +29,7 @@ class RejectionsStore(SQLBaseStore):
|
||||
"event_id": event_id,
|
||||
"reason": reason,
|
||||
"last_check": self._clock.time_msec(),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def get_rejection_reason(self, event_id):
|
||||
@ -40,4 +40,5 @@ class RejectionsStore(SQLBaseStore):
|
||||
"event_id": event_id,
|
||||
},
|
||||
allow_none=True,
|
||||
desc="get_rejection_reason",
|
||||
)
|
||||
|
@ -15,11 +15,9 @@
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from sqlite3 import IntegrityError
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
|
||||
from ._base import SQLBaseStore, Table
|
||||
from ._base import SQLBaseStore
|
||||
|
||||
import collections
|
||||
import logging
|
||||
@ -27,8 +25,9 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
OpsLevel = collections.namedtuple("OpsLevel", (
|
||||
"ban_level", "kick_level", "redact_level")
|
||||
OpsLevel = collections.namedtuple(
|
||||
"OpsLevel",
|
||||
("ban_level", "kick_level", "redact_level",)
|
||||
)
|
||||
|
||||
|
||||
@ -47,13 +46,15 @@ class RoomStore(SQLBaseStore):
|
||||
StoreError if the room could not be stored.
|
||||
"""
|
||||
try:
|
||||
yield self._simple_insert(RoomsTable.table_name, dict(
|
||||
room_id=room_id,
|
||||
creator=room_creator_user_id,
|
||||
is_public=is_public
|
||||
))
|
||||
except IntegrityError:
|
||||
raise StoreError(409, "Room ID in use.")
|
||||
yield self._simple_insert(
|
||||
RoomsTable.table_name,
|
||||
{
|
||||
"room_id": room_id,
|
||||
"creator": room_creator_user_id,
|
||||
"is_public": is_public,
|
||||
},
|
||||
desc="store_room",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("store_room with room_id=%s failed: %s", room_id, e)
|
||||
raise StoreError(500, "Problem creating room.")
|
||||
@ -66,12 +67,11 @@ class RoomStore(SQLBaseStore):
|
||||
Returns:
|
||||
A namedtuple containing the room information, or an empty list.
|
||||
"""
|
||||
query = RoomsTable.select_statement("room_id=?")
|
||||
return self._execute(
|
||||
"get_room",
|
||||
lambda txn: RoomsTable.decode_single_result(txn.fetchall()),
|
||||
query,
|
||||
room_id,
|
||||
return self._simple_select_one(
|
||||
table=RoomsTable.table_name,
|
||||
keyvalues={"room_id": room_id},
|
||||
retcols=RoomsTable.fields,
|
||||
desc="get_room",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -146,7 +146,7 @@ class RoomStore(SQLBaseStore):
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"topic": event.content["topic"],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def _store_room_name_txn(self, txn, event):
|
||||
@ -199,7 +199,7 @@ class RoomStore(SQLBaseStore):
|
||||
defer.returnValue((name, aliases))
|
||||
|
||||
|
||||
class RoomsTable(Table):
|
||||
class RoomsTable(object):
|
||||
table_name = "rooms"
|
||||
|
||||
fields = [
|
||||
@ -207,5 +207,3 @@ class RoomsTable(Table):
|
||||
"is_public",
|
||||
"creator"
|
||||
]
|
||||
|
||||
EntryType = collections.namedtuple("RoomEntry", fields)
|
||||
|
@ -212,7 +212,8 @@ class RoomMemberStore(SQLBaseStore):
|
||||
return self._simple_select_onecol(
|
||||
"room_hosts",
|
||||
{"room_id": room_id},
|
||||
"host"
|
||||
"host",
|
||||
desc="get_joined_hosts_for_room",
|
||||
)
|
||||
|
||||
def _get_members_by_dict(self, where_dict):
|
||||
|
@ -160,3 +160,4 @@ class StateStore(SQLBaseStore):
|
||||
|
||||
def _make_group_id(clock):
|
||||
return str(int(clock.time_msec())) + random_string(5)
|
||||
|
||||
|
@ -35,7 +35,7 @@ what sort order was used:
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from ._base import SQLBaseStore, cached
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.util.logutils import log_function
|
||||
@ -413,6 +413,7 @@ class StreamStore(SQLBaseStore):
|
||||
"get_recent_events_for_room", get_recent_events_for_room_txn
|
||||
)
|
||||
|
||||
@cached(num_args=0)
|
||||
def get_room_events_max_id(self):
|
||||
return self.runInteraction(
|
||||
"get_room_events_max_id",
|
||||
|
@ -46,15 +46,19 @@ class TransactionStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
def _get_received_txn_response(self, txn, transaction_id, origin):
|
||||
where_clause = "transaction_id = ? AND origin = ?"
|
||||
query = ReceivedTransactionsTable.select_statement(where_clause)
|
||||
result = self._simple_select_one_txn(
|
||||
txn,
|
||||
table=ReceivedTransactionsTable.table_name,
|
||||
keyvalues={
|
||||
"transaction_id": transaction_id,
|
||||
"origin": origin,
|
||||
},
|
||||
retcols=ReceivedTransactionsTable.fields,
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
txn.execute(query, (transaction_id, origin))
|
||||
|
||||
results = ReceivedTransactionsTable.decode_results(txn.fetchall())
|
||||
|
||||
if results and results[0].response_code:
|
||||
return (results[0].response_code, results[0].response_json)
|
||||
if result and result.response_code:
|
||||
return result["response_code"], result["response_json"]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
@ -16,6 +16,10 @@
|
||||
import random
|
||||
import string
|
||||
|
||||
_string_with_symbols = (
|
||||
string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
|
||||
)
|
||||
|
||||
|
||||
def origin_from_ucid(ucid):
|
||||
return ucid.split("@", 1)[1]
|
||||
@ -23,3 +27,9 @@ def origin_from_ucid(ucid):
|
||||
|
||||
def random_string(length):
|
||||
return ''.join(random.choice(string.ascii_letters) for _ in xrange(length))
|
||||
|
||||
|
||||
def random_string_with_symbols(length):
|
||||
return ''.join(
|
||||
random.choice(_string_with_symbols) for _ in xrange(length)
|
||||
)
|
||||
|
@ -180,7 +180,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||
self.mock_txn.rowcount = 1
|
||||
self.mock_txn.fetchone.return_value = ("Old Value",)
|
||||
|
||||
ret = yield self.datastore._simple_update_one(
|
||||
ret = yield self.datastore._simple_selectupdate_one(
|
||||
table="tablename",
|
||||
keyvalues={"keycol": "TheKey"},
|
||||
updatevalues={"columname": "New Value"},
|
||||
|
@ -44,7 +44,7 @@ class RoomStoreTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_room(self):
|
||||
self.assertObjectHasAttributes(
|
||||
self.assertDictContainsSubset(
|
||||
{"room_id": self.room.to_string(),
|
||||
"creator": self.u_creator.to_string(),
|
||||
"is_public": True},
|
||||
|
Loading…
Reference in New Issue
Block a user