Merge remote-tracking branch 'origin/develop' into erikj/generate_presice_thumbnails

This commit is contained in:
Mark Haines 2015-08-13 17:23:39 +01:00
commit b16cd18a86
34 changed files with 614 additions and 369 deletions

View file

@ -44,6 +44,11 @@ class Auth(object):
def check(self, event, auth_events):
""" Checks if this event is correctly authed.
Args:
event: the event being checked.
auth_events (dict: event-key -> event): the existing room state.
Returns:
True if the auth checks pass.
"""
@ -319,7 +324,7 @@ class Auth(object):
Returns:
tuple : of UserID and device string:
User ID object of the user making the request
Client ID object of the client instance the user is using
ClientInfo object of the client instance the user is using
Raises:
AuthError if no user by that token exists or the token is invalid.
"""
@ -352,7 +357,7 @@ class Auth(object):
)
return
except KeyError:
pass # normal users won't have this query parameter set
pass # normal users won't have the user_id query parameter set.
user_info = yield self.get_user_by_token(access_token)
user = user_info["user"]
@ -521,23 +526,22 @@ class Auth(object):
# Check state_key
if hasattr(event, "state_key"):
if not event.state_key.startswith("_"):
if event.state_key.startswith("@"):
if event.state_key != event.user_id:
if event.state_key.startswith("@"):
if event.state_key != event.user_id:
raise AuthError(
403,
"You are not allowed to set others state"
)
else:
sender_domain = UserID.from_string(
event.user_id
).domain
if sender_domain != event.state_key:
raise AuthError(
403,
"You are not allowed to set others state"
)
else:
sender_domain = UserID.from_string(
event.user_id
).domain
if sender_domain != event.state_key:
raise AuthError(
403,
"You are not allowed to set others state"
)
return True

View file

@ -657,7 +657,8 @@ def run(hs):
if hs.config.daemonize:
print hs.config.pid_file
if hs.config.print_pidfile:
print hs.config.pid_file
daemon = Daemonize(
app="synapse-homeserver",

View file

@ -138,12 +138,19 @@ class Config(object):
action="store_true",
help="Generate a config file for the server name"
)
config_parser.add_argument(
"--generate-keys",
action="store_true",
help="Generate any missing key files then exit"
)
config_parser.add_argument(
"-H", "--server-name",
help="The server name to generate a config file for"
)
config_args, remaining_args = config_parser.parse_known_args(argv)
generate_keys = config_args.generate_keys
if config_args.generate_config:
if not config_args.config_path:
config_parser.error(
@ -151,51 +158,40 @@ class Config(object):
" generated using \"--generate-config -H SERVER_NAME"
" -c CONFIG-FILE\""
)
config_dir_path = os.path.dirname(config_args.config_path[0])
config_dir_path = os.path.abspath(config_dir_path)
server_name = config_args.server_name
if not server_name:
print "Must specify a server_name to a generate config for."
sys.exit(1)
(config_path,) = config_args.config_path
if not os.path.exists(config_dir_path):
os.makedirs(config_dir_path)
if os.path.exists(config_path):
print "Config file %r already exists" % (config_path,)
yaml_config = cls.read_config_file(config_path)
yaml_name = yaml_config["server_name"]
if server_name != yaml_name:
print (
"Config file %r has a different server_name: "
" %r != %r" % (config_path, server_name, yaml_name)
)
if not os.path.exists(config_path):
config_dir_path = os.path.dirname(config_path)
config_dir_path = os.path.abspath(config_dir_path)
server_name = config_args.server_name
if not server_name:
print "Must specify a server_name to a generate config for."
sys.exit(1)
config_bytes, config = obj.generate_config(
config_dir_path, server_name
)
config.update(yaml_config)
print "Generating any missing keys for %r" % (server_name,)
obj.invoke_all("generate_files", config)
sys.exit(0)
with open(config_path, "wb") as config_file:
config_bytes, config = obj.generate_config(
config_dir_path, server_name
)
obj.invoke_all("generate_files", config)
config_file.write(config_bytes)
if not os.path.exists(config_dir_path):
os.makedirs(config_dir_path)
with open(config_path, "wb") as config_file:
config_bytes, config = obj.generate_config(
config_dir_path, server_name
)
obj.invoke_all("generate_files", config)
config_file.write(config_bytes)
print (
"A config file has been generated in %s for server name"
" '%s' with corresponding SSL keys and self-signed"
" certificates. Please review this file and customise it to"
" your needs."
"A config file has been generated in %r for server name"
" %r with corresponding SSL keys and self-signed"
" certificates. Please review this file and customise it"
" to your needs."
) % (config_path, server_name)
print (
"If this server name is incorrect, you will need to regenerate"
" the SSL certificates"
)
sys.exit(0)
print (
"If this server name is incorrect, you will need to"
" regenerate the SSL certificates"
)
sys.exit(0)
else:
print (
"Config file %r already exists. Generating any missing key"
" files."
) % (config_path,)
generate_keys = True
parser = argparse.ArgumentParser(
parents=[config_parser],
@ -213,7 +209,7 @@ class Config(object):
" -c CONFIG-FILE\""
)
config_dir_path = os.path.dirname(config_args.config_path[0])
config_dir_path = os.path.dirname(config_args.config_path[-1])
config_dir_path = os.path.abspath(config_dir_path)
specified_config = {}
@ -226,6 +222,10 @@ class Config(object):
config.pop("log_config")
config.update(specified_config)
if generate_keys:
obj.invoke_all("generate_files", config)
sys.exit(0)
obj.invoke_all("read_config", config)
obj.invoke_all("read_arguments", args)

View file

@ -24,6 +24,7 @@ class ServerConfig(Config):
self.web_client = config["web_client"]
self.soft_file_limit = config["soft_file_limit"]
self.daemonize = config.get("daemonize")
self.print_pidfile = config.get("print_pidfile")
self.use_frozen_dicts = config.get("use_frozen_dicts", True)
self.listeners = config.get("listeners", [])
@ -208,12 +209,18 @@ class ServerConfig(Config):
self.manhole = args.manhole
if args.daemonize is not None:
self.daemonize = args.daemonize
if args.print_pidfile is not None:
self.print_pidfile = args.print_pidfile
def add_arguments(self, parser):
server_group = parser.add_argument_group("server")
server_group.add_argument("-D", "--daemonize", action='store_true',
default=None,
help="Daemonize the home server")
server_group.add_argument("--print-pidfile", action='store_true',
default=None,
help="Print the path to the pidfile just"
" before daemonizing")
server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
type=int,
help="Turn on the twisted telnet manhole"

View file

@ -44,7 +44,7 @@ class IdentityHandler(BaseHandler):
http_client = SimpleHttpClient(self.hs)
# XXX: make this configurable!
# trustedIdServers = ['matrix.org', 'localhost:8090']
trustedIdServers = ['matrix.org']
trustedIdServers = ['matrix.org', 'vector.im']
if 'id_server' in creds:
id_server = creds['id_server']

View file

@ -73,7 +73,8 @@ class RegistrationHandler(BaseHandler):
localpart : The local part of the user ID to register. If None,
one will be randomly generated.
password (str) : The password to assign to this user so they can
login again.
login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user).
Returns:
A tuple of (user_id, access_token).
Raises:

View file

@ -16,7 +16,7 @@
from twisted.internet import defer, reactor, protocol
from twisted.internet.error import DNSLookupError
from twisted.web.client import readBody, _AgentBase, _URI, HTTPConnectionPool
from twisted.web.client import readBody, HTTPConnectionPool, Agent
from twisted.web.http_headers import Headers
from twisted.web._newclient import ResponseDone
@ -55,41 +55,17 @@ incoming_responses_counter = metrics.register_counter(
)
class MatrixFederationHttpAgent(_AgentBase):
class MatrixFederationEndpointFactory(object):
def __init__(self, hs):
self.tls_context_factory = hs.tls_context_factory
def __init__(self, reactor, pool=None):
_AgentBase.__init__(self, reactor, pool)
def endpointForURI(self, uri):
destination = uri.netloc
def request(self, destination, endpoint, method, path, params, query,
headers, body_producer):
outgoing_requests_counter.inc(method)
host = b""
port = 0
fragment = b""
parsed_URI = _URI(b"http", destination, host, port, path, params,
query, fragment)
# Set the connection pool key to be the destination.
key = destination
d = self._requestWithEndpoint(key, endpoint, method, parsed_URI,
headers, body_producer,
parsed_URI.originForm)
def _cb(response):
incoming_responses_counter.inc(method, response.code)
return response
def _eb(failure):
incoming_responses_counter.inc(method, "ERR")
return failure
d.addCallbacks(_cb, _eb)
return d
return matrix_federation_endpoint(
reactor, destination, timeout=10,
ssl_context_factory=self.tls_context_factory
)
class MatrixFederationHttpClient(object):
@ -107,12 +83,18 @@ class MatrixFederationHttpClient(object):
self.server_name = hs.hostname
pool = HTTPConnectionPool(reactor)
pool.maxPersistentPerHost = 10
self.agent = MatrixFederationHttpAgent(reactor, pool=pool)
self.agent = Agent.usingEndpointFactory(
reactor, MatrixFederationEndpointFactory(hs), pool=pool
)
self.clock = hs.get_clock()
self.version_string = hs.version_string
self._next_id = 1
def _create_url(self, destination, path_bytes, param_bytes, query_bytes):
return urlparse.urlunparse(
("matrix", destination, path_bytes, param_bytes, query_bytes, "")
)
@defer.inlineCallbacks
def _create_request(self, destination, method, path_bytes,
body_callback, headers_dict={}, param_bytes=b"",
@ -123,8 +105,8 @@ class MatrixFederationHttpClient(object):
headers_dict[b"User-Agent"] = [self.version_string]
headers_dict[b"Host"] = [destination]
url_bytes = urlparse.urlunparse(
("", "", path_bytes, param_bytes, query_bytes, "",)
url_bytes = self._create_url(
destination, path_bytes, param_bytes, query_bytes
)
txn_id = "%s-O-%s" % (method, self._next_id)
@ -139,8 +121,8 @@ class MatrixFederationHttpClient(object):
# (once we have reliable transactions in place)
retries_left = 5
endpoint = preserve_context_over_fn(
self._getEndpoint, reactor, destination
http_url_bytes = urlparse.urlunparse(
("", "", path_bytes, param_bytes, query_bytes, "")
)
log_result = None
@ -148,17 +130,14 @@ class MatrixFederationHttpClient(object):
while True:
producer = None
if body_callback:
producer = body_callback(method, url_bytes, headers_dict)
producer = body_callback(method, http_url_bytes, headers_dict)
try:
def send_request():
request_deferred = self.agent.request(
destination,
endpoint,
request_deferred = preserve_context_over_fn(
self.agent.request,
method,
path_bytes,
param_bytes,
query_bytes,
url_bytes,
Headers(headers_dict),
producer
)
@ -452,12 +431,6 @@ class MatrixFederationHttpClient(object):
defer.returnValue((length, headers))
def _getEndpoint(self, reactor, destination):
return matrix_federation_endpoint(
reactor, destination, timeout=10,
ssl_context_factory=self.hs.tls_context_factory
)
class _ReadBodyToFileProtocol(protocol.Protocol):
def __init__(self, stream, deferred, max_size):

View file

@ -18,8 +18,12 @@ from __future__ import absolute_import
import logging
from resource import getrusage, getpagesize, RUSAGE_SELF
import functools
import os
import stat
import time
from twisted.internet import reactor
from .metric import (
CounterMetric, CallbackMetric, DistributionMetric, CacheMetric
@ -144,3 +148,28 @@ def _process_fds():
return counts
get_metrics_for("process").register_callback("fds", _process_fds, labels=["type"])
reactor_metrics = get_metrics_for("reactor")
tick_time = reactor_metrics.register_distribution("tick_time")
pending_calls_metric = reactor_metrics.register_distribution("pending_calls")
def runUntilCurrentTimer(func):
@functools.wraps(func)
def f(*args, **kwargs):
pending_calls = len(reactor.getDelayedCalls())
start = time.time() * 1000
ret = func(*args, **kwargs)
end = time.time() * 1000
tick_time.inc_by(end - start)
pending_calls_metric.inc_by(pending_calls)
return ret
return f
if hasattr(reactor, "runUntilCurrent"):
# runUntilCurrent is called when we have pending calls. It is called once
# per iteratation after fd polling.
reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent)

View file

@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
REQUIREMENTS = {
"syutil>=0.0.7": ["syutil>=0.0.7"],
"Twisted==14.0.2": ["twisted==14.0.2"],
"Twisted>=15.1.0": ["twisted>=15.1.0"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
"pyyaml": ["yaml"],

View file

@ -19,7 +19,7 @@ from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes
from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern, parse_request_allow_empty
from ._base import client_v2_pattern, parse_json_dict_from_request
import logging
import hmac
@ -55,30 +55,55 @@ class RegisterRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
yield run_on_reactor()
body = parse_json_dict_from_request(request)
body = parse_request_allow_empty(request)
# we do basic sanity checks here because the auth
# layer will store these in sessions
# we do basic sanity checks here because the auth layer will store these
# in sessions. Pull out the username/password provided to us.
desired_password = None
if 'password' in body:
if ((not isinstance(body['password'], str) and
not isinstance(body['password'], unicode)) or
if (not isinstance(body['password'], basestring) or
len(body['password']) > 512):
raise SynapseError(400, "Invalid password")
desired_password = body["password"]
desired_username = None
if 'username' in body:
if ((not isinstance(body['username'], str) and
not isinstance(body['username'], unicode)) or
if (not isinstance(body['username'], basestring) or
len(body['username']) > 512):
raise SynapseError(400, "Invalid username")
desired_username = body['username']
yield self.registration_handler.check_username(desired_username)
is_using_shared_secret = False
is_application_server = False
service = None
appservice = None
if 'access_token' in request.args:
service = yield self.auth.get_appservice_by_req(request)
appservice = yield self.auth.get_appservice_by_req(request)
# fork off as soon as possible for ASes and shared secret auth which
# have completely different registration flows to normal users
# == Application Service Registration ==
if appservice:
result = yield self._do_appservice_registration(
desired_username, request.args["access_token"][0]
)
defer.returnValue((200, result)) # we throw for non 200 responses
return
# == Shared Secret Registration == (e.g. create new user scripts)
if 'mac' in body:
# FIXME: Should we really be determining if this is shared secret
# auth based purely on the 'mac' key?
result = yield self._do_shared_secret_registration(
desired_username, desired_password, body["mac"]
)
defer.returnValue((200, result)) # we throw for non 200 responses
return
# == Normal User Registration == (everyone else)
if self.hs.config.disable_registration:
raise SynapseError(403, "Registration has been disabled")
if desired_username is not None:
yield self.registration_handler.check_username(desired_username)
if self.hs.config.enable_registration_captcha:
flows = [
@ -91,39 +116,20 @@ class RegisterRestServlet(RestServlet):
[LoginType.EMAIL_IDENTITY]
]
result = None
if service:
is_application_server = True
params = body
elif 'mac' in body:
# Check registration-specific shared secret auth
if 'username' not in body:
raise SynapseError(400, "", Codes.MISSING_PARAM)
self._check_shared_secret_auth(
body['username'], body['mac']
)
is_using_shared_secret = True
params = body
else:
authed, result, params = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request)
)
if not authed:
defer.returnValue((401, result))
can_register = (
not self.hs.config.disable_registration
or is_application_server
or is_using_shared_secret
authed, result, params = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request)
)
if not can_register:
raise SynapseError(403, "Registration has been disabled")
if not authed:
defer.returnValue((401, result))
return
# NB: This may be from the auth handler and NOT from the POST
if 'password' not in params:
raise SynapseError(400, "", Codes.MISSING_PARAM)
desired_username = params['username'] if 'username' in params else None
new_password = params['password']
raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM)
desired_username = params.get("username", None)
new_password = params.get("password", None)
(user_id, token) = yield self.registration_handler.register(
localpart=desired_username,
@ -156,18 +162,21 @@ class RegisterRestServlet(RestServlet):
else:
logger.info("bind_email not specified: not binding email")
result = {
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname,
}
result = self._create_registration_details(user_id, token)
defer.returnValue((200, result))
def on_OPTIONS(self, _):
return 200, {}
def _check_shared_secret_auth(self, username, mac):
@defer.inlineCallbacks
def _do_appservice_registration(self, username, as_token):
(user_id, token) = yield self.registration_handler.appservice_register(
username, as_token
)
defer.returnValue(self._create_registration_details(user_id, token))
@defer.inlineCallbacks
def _do_shared_secret_registration(self, username, password, mac):
if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled")
@ -183,13 +192,23 @@ class RegisterRestServlet(RestServlet):
digestmod=sha1,
).hexdigest()
if compare_digest(want_mac, got_mac):
return True
else:
if not compare_digest(want_mac, got_mac):
raise SynapseError(
403, "HMAC incorrect",
)
(user_id, token) = yield self.registration_handler.register(
localpart=username, password=password
)
defer.returnValue(self._create_registration_details(user_id, token))
def _create_registration_details(self, user_id, token):
return {
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname,
}
def register_servlets(hs, http_server):
RegisterRestServlet(hs).register(http_server)

View file

@ -298,43 +298,52 @@ class BaseMediaResource(Resource):
)
return
scales = set()
crops = set()
for r_width, r_height, r_method, r_type in requirements:
if r_method == "scale":
t_width, t_height = thumbnailer.aspect(r_width, r_height)
scales.add((
min(m_width, t_width), min(m_height, t_height), r_type,
local_thumbnails = []
def generate_thumbnails():
scales = set()
crops = set()
for r_width, r_height, r_method, r_type in requirements:
if r_method == "scale":
t_width, t_height = thumbnailer.aspect(r_width, r_height)
scales.add((
min(m_width, t_width), min(m_height, t_height), r_type,
))
elif r_method == "crop":
crops.add((r_width, r_height, r_type))
for t_width, t_height, t_type in scales:
t_method = "scale"
t_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
local_thumbnails.append((
media_id, t_width, t_height, t_type, t_method, t_len
))
elif r_method == "crop":
crops.add((r_width, r_height, r_type))
for t_width, t_height, t_type in scales:
t_method = "scale"
t_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
yield self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
for t_width, t_height, t_type in crops:
if (t_width, t_height, t_type) in scales:
# If the aspect ratio of the cropped thumbnail matches a purely
# scaled one then there is no point in calculating a separate
# thumbnail.
continue
t_method = "crop"
t_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
local_thumbnails.append((
media_id, t_width, t_height, t_type, t_method, t_len
))
for t_width, t_height, t_type in crops:
if (t_width, t_height, t_type) in scales:
# If the aspect ratio of the cropped thumbnail matches a purely
# scaled one then there is no point in calculating a separate
# thumbnail.
continue
t_method = "crop"
t_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
yield self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
yield threads.deferToThread(generate_thumbnails)
for l in local_thumbnails:
yield self.store.store_local_thumbnail(*l)
defer.returnValue({
"width": m_width,

View file

@ -99,7 +99,7 @@ class DataStore(RoomMemberStore, RoomStore,
key = (user.to_string(), access_token, device_id, ip)
try:
last_seen = self.client_ip_last_seen.get(*key)
last_seen = self.client_ip_last_seen.get(key)
except KeyError:
last_seen = None
@ -107,7 +107,7 @@ class DataStore(RoomMemberStore, RoomStore,
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
defer.returnValue(None)
self.client_ip_last_seen.prefill(*key + (now,))
self.client_ip_last_seen.prefill(key, now)
# It's safe not to lock here: a) no unique constraint,
# b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
@ -354,6 +354,11 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
)
logger.debug("Running script %s", relative_path)
module.run_upgrade(cur, database_engine)
elif ext == ".pyc":
# Sometimes .pyc files turn up anyway even though we've
# disabled their generation; e.g. from distribution package
# installers. Silently skip it
pass
elif ext == ".sql":
# A plain old .sql file, just read and execute it
logger.debug("Applying schema %s", relative_path)

View file

@ -15,6 +15,7 @@
import logging
from synapse.api.errors import StoreError
from synapse.util.async import ObservableDeferred
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.lrucache import LruCache
@ -27,6 +28,7 @@ from twisted.internet import defer
from collections import namedtuple, OrderedDict
import functools
import inspect
import sys
import time
import threading
@ -55,9 +57,12 @@ cache_counter = metrics.register_cache(
)
_CacheSentinel = object()
class Cache(object):
def __init__(self, name, max_entries=1000, keylen=1, lru=False):
def __init__(self, name, max_entries=1000, keylen=1, lru=True):
if lru:
self.cache = LruCache(max_size=max_entries)
self.max_entries = None
@ -81,45 +86,44 @@ class Cache(object):
"Cache objects can only be accessed from the main thread"
)
def get(self, *keyargs):
if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen)
if keyargs in self.cache:
def get(self, key, default=_CacheSentinel):
val = self.cache.get(key, _CacheSentinel)
if val is not _CacheSentinel:
cache_counter.inc_hits(self.name)
return self.cache[keyargs]
return val
cache_counter.inc_misses(self.name)
raise KeyError()
def update(self, sequence, *args):
if default is _CacheSentinel:
raise KeyError()
else:
return default
def update(self, sequence, key, value):
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
self.prefill(*args)
def prefill(self, *args): # because I can't *keyargs, value
keyargs = args[:-1]
value = args[-1]
if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen)
self.prefill(key, value)
def prefill(self, key, value):
if self.max_entries is not None:
while len(self.cache) >= self.max_entries:
self.cache.popitem(last=False)
self.cache[keyargs] = value
self.cache[key] = value
def invalidate(self, *keyargs):
def invalidate(self, key):
self.check_thread()
if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen)
if not isinstance(key, tuple):
raise TypeError(
"The cache key must be a tuple not %r" % (type(key),)
)
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1
self.cache.pop(keyargs, None)
self.cache.pop(key, None)
def invalidate_all(self):
self.check_thread()
@ -130,6 +134,9 @@ class Cache(object):
class CacheDescriptor(object):
""" A method decorator that applies a memoizing cache around the function.
This caches deferreds, rather than the results themselves. Deferreds that
fail are removed from the cache.
The function is presumed to take zero or more arguments, which are used in
a tuple as the key for the cache. Hits are served directly from the cache;
misses use the function body to generate the value.
@ -141,58 +148,92 @@ class CacheDescriptor(object):
which can be used to insert values into the cache specifically, without
calling the calculation function.
"""
def __init__(self, orig, max_entries=1000, num_args=1, lru=False):
def __init__(self, orig, max_entries=1000, num_args=1, lru=True,
inlineCallbacks=False):
self.orig = orig
if inlineCallbacks:
self.function_to_call = defer.inlineCallbacks(orig)
else:
self.function_to_call = orig
self.max_entries = max_entries
self.num_args = num_args
self.lru = lru
def __get__(self, obj, objtype=None):
cache = Cache(
self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
if len(self.arg_names) < self.num_args:
raise Exception(
"Not enough explicit positional arguments to key off of for %r."
" (@cached cannot key off of *args or **kwars)"
% (orig.__name__,)
)
self.cache = Cache(
name=self.orig.__name__,
max_entries=self.max_entries,
keylen=self.num_args,
lru=self.lru,
)
def __get__(self, obj, objtype=None):
@functools.wraps(self.orig)
@defer.inlineCallbacks
def wrapped(*keyargs):
def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
try:
cached_result = cache.get(*keyargs[:self.num_args])
cached_result_d = self.cache.get(cache_key)
observer = cached_result_d.observe()
if DEBUG_CACHES:
actual_result = yield self.orig(obj, *keyargs)
if actual_result != cached_result:
logger.error(
"Stale cache entry %s%r: cached: %r, actual %r",
self.orig.__name__, keyargs,
cached_result, actual_result,
)
raise ValueError("Stale cache entry")
defer.returnValue(cached_result)
@defer.inlineCallbacks
def check_result(cached_result):
actual_result = yield self.function_to_call(obj, *args, **kwargs)
if actual_result != cached_result:
logger.error(
"Stale cache entry %s%r: cached: %r, actual %r",
self.orig.__name__, cache_key,
cached_result, actual_result,
)
raise ValueError("Stale cache entry")
defer.returnValue(cached_result)
observer.addCallback(check_result)
return observer
except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369)
sequence = cache.sequence
sequence = self.cache.sequence
ret = yield self.orig(obj, *keyargs)
ret = defer.maybeDeferred(
self.function_to_call,
obj, *args, **kwargs
)
cache.update(sequence, *keyargs[:self.num_args] + (ret,))
def onErr(f):
self.cache.invalidate(cache_key)
return f
defer.returnValue(ret)
ret.addErrback(onErr)
wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all
wrapped.prefill = cache.prefill
ret = ObservableDeferred(ret, consumeErrors=True)
self.cache.update(sequence, cache_key, ret)
return ret.observe()
wrapped.invalidate = self.cache.invalidate
wrapped.invalidate_all = self.cache.invalidate_all
wrapped.prefill = self.cache.prefill
obj.__dict__[self.orig.__name__] = wrapped
return wrapped
def cached(max_entries=1000, num_args=1, lru=False):
def cached(max_entries=1000, num_args=1, lru=True):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
@ -201,6 +242,16 @@ def cached(max_entries=1000, num_args=1, lru=False):
)
def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
lru=lru,
inlineCallbacks=True,
)
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()

View file

@ -104,7 +104,7 @@ class DirectoryStore(SQLBaseStore):
},
desc="create_room_alias_association",
)
self.get_aliases_for_room.invalidate(room_id)
self.get_aliases_for_room.invalidate((room_id,))
@defer.inlineCallbacks
def delete_room_alias(self, room_alias):
@ -114,7 +114,7 @@ class DirectoryStore(SQLBaseStore):
room_alias,
)
self.get_aliases_for_room.invalidate(room_id)
self.get_aliases_for_room.invalidate((room_id,))
defer.returnValue(room_id)
def _delete_room_alias_txn(self, txn, room_alias):

View file

@ -362,7 +362,7 @@ class EventFederationStore(SQLBaseStore):
for room_id in events_by_room:
txn.call_after(
self.get_latest_event_ids_in_room.invalidate, room_id
self.get_latest_event_ids_in_room.invalidate, (room_id,)
)
def get_backfill_events(self, room_id, event_list, limit):
@ -505,4 +505,4 @@ class EventFederationStore(SQLBaseStore):
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
txn.execute(query, (room_id,))
txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id)
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))

View file

@ -162,8 +162,8 @@ class EventsStore(SQLBaseStore):
if current_state:
txn.call_after(self.get_current_state_for_key.invalidate_all)
txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, event.room_id)
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_room_name_and_aliases, event.room_id)
self._simple_delete_txn(
@ -430,13 +430,13 @@ class EventsStore(SQLBaseStore):
if not context.rejected:
txn.call_after(
self.get_current_state_for_key.invalidate,
event.room_id, event.type, event.state_key
)
(event.room_id, event.type, event.state_key,)
)
if event.type in [EventTypes.Name, EventTypes.Aliases]:
txn.call_after(
self.get_room_name_and_aliases.invalidate,
event.room_id
(event.room_id,)
)
self._simple_upsert_txn(
@ -567,8 +567,9 @@ class EventsStore(SQLBaseStore):
def _invalidate_get_event_cache(self, event_id):
for check_redacted in (False, True):
for get_prev_content in (False, True):
self._get_event_cache.invalidate(event_id, check_redacted,
get_prev_content)
self._get_event_cache.invalidate(
(event_id, check_redacted, get_prev_content)
)
def _get_event_txn(self, txn, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False):
@ -589,7 +590,7 @@ class EventsStore(SQLBaseStore):
for event_id in events:
try:
ret = self._get_event_cache.get(
event_id, check_redacted, get_prev_content
(event_id, check_redacted, get_prev_content,)
)
if allow_rejected or not ret.rejected_reason:
@ -822,7 +823,7 @@ class EventsStore(SQLBaseStore):
ev.unsigned["prev_content"] = prev.get_dict()["content"]
self._get_event_cache.prefill(
ev.event_id, check_redacted, get_prev_content, ev
(ev.event_id, check_redacted, get_prev_content), ev
)
defer.returnValue(ev)
@ -879,7 +880,7 @@ class EventsStore(SQLBaseStore):
ev.unsigned["prev_content"] = prev.get_dict()["content"]
self._get_event_cache.prefill(
ev.event_id, check_redacted, get_prev_content, ev
(ev.event_id, check_redacted, get_prev_content), ev
)
return ev

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from _base import SQLBaseStore, cached
from _base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer
@ -71,8 +71,7 @@ class KeyStore(SQLBaseStore):
desc="store_server_certificate",
)
@cached()
@defer.inlineCallbacks
@cachedInlineCallbacks()
def get_all_server_verify_keys(self, server_name):
rows = yield self._simple_select_list(
table="server_signature_keys",
@ -132,7 +131,7 @@ class KeyStore(SQLBaseStore):
desc="store_server_verify_key",
)
self.get_all_server_verify_keys.invalidate(server_name)
self.get_all_server_verify_keys.invalidate((server_name,))
def store_server_keys_json(self, server_name, key_id, from_server,
ts_now_ms, ts_expires_ms, key_json_bytes):

View file

@ -98,7 +98,7 @@ class PresenceStore(SQLBaseStore):
updatevalues={"accepted": True},
desc="set_presence_list_accepted",
)
self.get_presence_list_accepted.invalidate(observer_localpart)
self.get_presence_list_accepted.invalidate((observer_localpart,))
defer.returnValue(result)
def get_presence_list(self, observer_localpart, accepted=None):
@ -133,4 +133,4 @@ class PresenceStore(SQLBaseStore):
"observed_user_id": observed_userid},
desc="del_presence_list",
)
self.get_presence_list_accepted.invalidate(observer_localpart)
self.get_presence_list_accepted.invalidate((observer_localpart,))

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore, cached
from ._base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer
import logging
@ -23,8 +23,7 @@ logger = logging.getLogger(__name__)
class PushRuleStore(SQLBaseStore):
@cached()
@defer.inlineCallbacks
@cachedInlineCallbacks()
def get_push_rules_for_user(self, user_name):
rows = yield self._simple_select_list(
table=PushRuleTable.table_name,
@ -41,8 +40,7 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(rows)
@cached()
@defer.inlineCallbacks
@cachedInlineCallbacks()
def get_push_rules_enabled_for_user(self, user_name):
results = yield self._simple_select_list(
table=PushRuleEnableTable.table_name,
@ -153,11 +151,11 @@ class PushRuleStore(SQLBaseStore):
txn.execute(sql, (user_name, priority_class, new_rule_priority))
txn.call_after(
self.get_push_rules_for_user.invalidate, user_name
self.get_push_rules_for_user.invalidate, (user_name,)
)
txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, user_name
self.get_push_rules_enabled_for_user.invalidate, (user_name,)
)
self._simple_insert_txn(
@ -189,10 +187,10 @@ class PushRuleStore(SQLBaseStore):
new_rule['priority'] = new_prio
txn.call_after(
self.get_push_rules_for_user.invalidate, user_name
self.get_push_rules_for_user.invalidate, (user_name,)
)
txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, user_name
self.get_push_rules_enabled_for_user.invalidate, (user_name,)
)
self._simple_insert_txn(
@ -218,8 +216,8 @@ class PushRuleStore(SQLBaseStore):
desc="delete_push_rule",
)
self.get_push_rules_for_user.invalidate(user_name)
self.get_push_rules_enabled_for_user.invalidate(user_name)
self.get_push_rules_for_user.invalidate((user_name,))
self.get_push_rules_enabled_for_user.invalidate((user_name,))
@defer.inlineCallbacks
def set_push_rule_enabled(self, user_name, rule_id, enabled):
@ -240,10 +238,10 @@ class PushRuleStore(SQLBaseStore):
{'id': new_id},
)
txn.call_after(
self.get_push_rules_for_user.invalidate, user_name
self.get_push_rules_for_user.invalidate, (user_name,)
)
txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, user_name
self.get_push_rules_enabled_for_user.invalidate, (user_name,)
)

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore, cached
from ._base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer
@ -128,8 +128,7 @@ class ReceiptsStore(SQLBaseStore):
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_max_token(self)
@cached
@defer.inlineCallbacks
@cachedInlineCallbacks()
def get_graph_receipts_for_room(self, room_id):
"""Get receipts for sending to remote servers.
"""

View file

@ -131,7 +131,7 @@ class RegistrationStore(SQLBaseStore):
user_id
)
for r in rows:
self.get_user_by_token.invalidate(r)
self.get_user_by_token.invalidate((r,))
@cached()
def get_user_by_token(self, token):

View file

@ -17,7 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
from ._base import SQLBaseStore, cached
from ._base import SQLBaseStore, cachedInlineCallbacks
import collections
import logging
@ -186,8 +186,7 @@ class RoomStore(SQLBaseStore):
}
)
@cached()
@defer.inlineCallbacks
@cachedInlineCallbacks()
def get_room_name_and_aliases(self, room_id):
def f(txn):
sql = (

View file

@ -54,9 +54,9 @@ class RoomMemberStore(SQLBaseStore):
)
for event in events:
txn.call_after(self.get_rooms_for_user.invalidate, event.state_key)
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
txn.call_after(self.get_users_in_room.invalidate, event.room_id)
txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member.
@ -78,7 +78,7 @@ class RoomMemberStore(SQLBaseStore):
lambda events: events[0] if events else None
)
@cached()
@cached(max_entries=5000)
def get_users_in_room(self, room_id):
def f(txn):
@ -154,7 +154,7 @@ class RoomMemberStore(SQLBaseStore):
RoomsForUser(**r) for r in self.cursor_to_dict(txn)
]
@cached()
@cached(max_entries=5000)
def get_joined_hosts_for_room(self, room_id):
return self.runInteraction(
"get_joined_hosts_for_room",

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore, cached
from ._base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer
@ -91,7 +91,6 @@ class StateStore(SQLBaseStore):
defer.returnValue(dict(state_list))
@cached(num_args=1)
def _fetch_events_for_group(self, key, events):
return self._get_events(
events, get_prev_content=False
@ -189,8 +188,7 @@ class StateStore(SQLBaseStore):
events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events)
@cached(num_args=3)
@defer.inlineCallbacks
@cachedInlineCallbacks(num_args=3)
def get_current_state_for_key(self, room_id, event_type, state_key):
def f(txn):
sql = (

View file

@ -178,7 +178,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
Live tokens start with an "s" followed by the "stream_ordering" id of the
event it comes after. Historic tokens start with a "t" followed by the
"topological_ordering" id of the event it comes after, follewed by "-",
"topological_ordering" id of the event it comes after, followed by "-",
followed by the "stream_ordering" id of the event it comes after.
"""
__slots__ = []
@ -211,4 +211,5 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
return "s%d" % (self.stream,)
# token_id is the primary key ID of the access token, not the access token itself.
ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id"))

View file

@ -51,7 +51,7 @@ class ObservableDeferred(object):
object.__setattr__(self, "_observers", set())
def callback(r):
self._result = (True, r)
object.__setattr__(self, "_result", (True, r))
while self._observers:
try:
self._observers.pop().callback(r)
@ -60,7 +60,7 @@ class ObservableDeferred(object):
return r
def errback(f):
self._result = (False, f)
object.__setattr__(self, "_result", (False, f))
while self._observers:
try:
self._observers.pop().errback(f)
@ -97,3 +97,8 @@ class ObservableDeferred(object):
def __setattr__(self, name, value):
setattr(self._deferred, name, value)
def __repr__(self):
return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
id(self), self._result, self._deferred,
)