Merge branch 'develop' of github.com:matrix-org/synapse into batched_get_pdu

This commit is contained in:
Erik Johnston 2015-03-02 13:53:30 +00:00
commit 0a036944bd
32 changed files with 1212 additions and 147 deletions

1
.gitignore vendored
View File

@ -41,3 +41,4 @@ media_store/
build/ build/
localhost-800*/ localhost-800*/
static/client/register/register_config.js

View File

@ -1,3 +1,12 @@
Changes in synapse vx.x.x (x-x-x)
=================================
* Add support for registration fallback. This is a page hosted on the server
which allows a user to register for an account, regardless of what client
they are using (e.g. mobile devices).
* Application services can now poll on the CS API ``/events`` for their events,
by providing their application service ``access_token``.
Changes in synapse v0.7.1 (2015-02-19) Changes in synapse v0.7.1 (2015-02-19)
====================================== ======================================

View File

@ -1,3 +1,17 @@
Upgrading to vx.xx
==================
Servers which use captchas will need to add their public key to::
static/client/register/register_config.js
window.matrixRegistrationConfig = {
recaptcha_public_key: "YOUR_PUBLIC_KEY"
};
This is required in order to support registration fallback (typically used on
mobile devices).
Upgrading to v0.7.0 Upgrading to v0.7.0
=================== ===================

View File

@ -81,7 +81,7 @@ Your home server configuration file needs the following extra keys:
As an example, here is the relevant section of the config file for As an example, here is the relevant section of the config file for
matrix.org:: matrix.org::
turn_uris: turn:turn.matrix.org:3478?transport=udp,turn:turn.matrix.org:3478?transport=tcp turn_uris: [ "turn:turn.matrix.org:3478?transport=udp", "turn:turn.matrix.org:3478?transport=tcp" ]
turn_shared_secret: n0t4ctuAllymatr1Xd0TorgSshar3d5ecret4obvIousreAsons turn_shared_secret: n0t4ctuAllymatr1Xd0TorgSshar3d5ecret4obvIousreAsons
turn_user_lifetime: 86400000 turn_user_lifetime: 86400000

View File

@ -0,0 +1,32 @@
<html>
<head>
<title> Registration </title>
<meta name='viewport' content='width=device-width, initial-scale=1, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
<link rel="stylesheet" href="style.css">
<script src="js/jquery-2.1.3.min.js"></script>
<script src="js/recaptcha_ajax.js"></script>
<script src="register_config.js"></script>
<script src="js/register.js"></script>
</head>
<body onload="matrixRegistration.onLoad()">
<form id="registrationForm" onsubmit="matrixRegistration.signUp(); return false;">
<div>
Create account:<br/>
<div style="text-align: center">
<input id="desired_user_id" size="32" type="text" placeholder="Matrix ID (e.g. bob)" autocapitalize="off" autocorrect="off" />
<br/>
<input id="pwd1" size="32" type="password" placeholder="Type a password"/>
<br/>
<input id="pwd2" size="32" type="password" placeholder="Confirm your password"/>
<br/>
<span id="feedback" style="color: #f00"></span>
<br/>
<div id="regcaptcha"></div>
<button type="submit" style="margin: 10px">Sign up</button>
</div>
</div>
</form>
</body>
</html>

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,117 @@
window.matrixRegistration = {
endpoint: location.origin + "/_matrix/client/api/v1/register"
};
var setupCaptcha = function() {
if (!window.matrixRegistrationConfig) {
return;
}
$.get(matrixRegistration.endpoint, function(response) {
var serverExpectsCaptcha = false;
for (var i=0; i<response.flows.length; i++) {
var flow = response.flows[i];
if ("m.login.recaptcha" === flow.type) {
serverExpectsCaptcha = true;
break;
}
}
if (!serverExpectsCaptcha) {
console.log("This server does not require a captcha.");
return;
}
console.log("Setting up ReCaptcha for "+matrixRegistration.endpoint);
var public_key = window.matrixRegistrationConfig.recaptcha_public_key;
if (public_key === undefined) {
console.error("No public key defined for captcha!");
setFeedbackString("Misconfigured captcha for server. Contact server admin.");
return;
}
Recaptcha.create(public_key,
"regcaptcha",
{
theme: "red",
callback: Recaptcha.focus_response_field
});
window.matrixRegistration.isUsingRecaptcha = true;
}).error(errorFunc);
};
var submitCaptcha = function(user, pwd) {
var challengeToken = Recaptcha.get_challenge();
var captchaEntry = Recaptcha.get_response();
var data = {
type: "m.login.recaptcha",
challenge: challengeToken,
response: captchaEntry
};
console.log("Submitting captcha");
$.post(matrixRegistration.endpoint, JSON.stringify(data), function(response) {
console.log("Success -> "+JSON.stringify(response));
submitPassword(user, pwd, response.session);
}).error(function(err) {
Recaptcha.reload();
errorFunc(err);
});
};
var submitPassword = function(user, pwd, session) {
console.log("Registering...");
var data = {
type: "m.login.password",
user: user,
password: pwd,
session: session
};
$.post(matrixRegistration.endpoint, JSON.stringify(data), function(response) {
matrixRegistration.onRegistered(
response.home_server, response.user_id, response.access_token
);
}).error(errorFunc);
};
var errorFunc = function(err) {
if (err.responseJSON && err.responseJSON.error) {
setFeedbackString(err.responseJSON.error + " (" + err.responseJSON.errcode + ")");
}
else {
setFeedbackString("Request failed: " + err.status);
}
};
var setFeedbackString = function(text) {
$("#feedback").text(text);
};
matrixRegistration.onLoad = function() {
setupCaptcha();
};
matrixRegistration.signUp = function() {
var user = $("#desired_user_id").val();
if (user.length == 0) {
setFeedbackString("Must specify a username.");
return;
}
var pwd1 = $("#pwd1").val();
var pwd2 = $("#pwd2").val();
if (pwd1.length < 6) {
setFeedbackString("Password: min. 6 characters.");
return;
}
if (pwd1 != pwd2) {
setFeedbackString("Passwords do not match.");
return;
}
if (window.matrixRegistration.isUsingRecaptcha) {
submitCaptcha(user, pwd1);
}
else {
submitPassword(user, pwd1);
}
};
matrixRegistration.onRegistered = function(hs_url, user_id, access_token) {
// clobber this function
console.log("onRegistered - This function should be replaced to proceed.");
};

View File

@ -0,0 +1,3 @@
window.matrixRegistrationConfig = {
recaptcha_public_key: "YOUR_PUBLIC_KEY"
};

View File

@ -0,0 +1,56 @@
html {
height: 100%;
}
body {
height: 100%;
font-family: "Myriad Pro", "Myriad", Helvetica, Arial, sans-serif;
font-size: 12pt;
margin: 0px;
}
h1 {
font-size: 20pt;
}
a:link { color: #666; }
a:visited { color: #666; }
a:hover { color: #000; }
a:active { color: #000; }
input {
width: 100%
}
textarea, input {
font-family: inherit;
font-size: inherit;
}
.smallPrint {
color: #888;
font-size: 9pt ! important;
font-style: italic ! important;
}
#recaptcha_area {
margin: auto
}
#registrationForm {
text-align: left;
padding: 1em;
margin-bottom: 40px;
display: inline-block;
-webkit-border-radius: 10px;
-moz-border-radius: 10px;
border-radius: 10px;
-webkit-box-shadow: 0px 0px 20px 0px rgba(0,0,0,0.15);
-moz-box-shadow: 0px 0px 20px 0px rgba(0,0,0,0.15);
box-shadow: 0px 0px 20px 0px rgba(0,0,0,0.15);
background-color: #f8f8f8;
border: 1px #ccc solid;
}

View File

@ -18,6 +18,7 @@
CLIENT_PREFIX = "/_matrix/client/api/v1" CLIENT_PREFIX = "/_matrix/client/api/v1"
CLIENT_V2_ALPHA_PREFIX = "/_matrix/client/v2_alpha" CLIENT_V2_ALPHA_PREFIX = "/_matrix/client/v2_alpha"
FEDERATION_PREFIX = "/_matrix/federation/v1" FEDERATION_PREFIX = "/_matrix/federation/v1"
STATIC_PREFIX = "/_matrix/static"
WEB_CLIENT_PREFIX = "/_matrix/client" WEB_CLIENT_PREFIX = "/_matrix/client"
CONTENT_REPO_PREFIX = "/_matrix/content" CONTENT_REPO_PREFIX = "/_matrix/content"
SERVER_KEY_PREFIX = "/_matrix/key/v1" SERVER_KEY_PREFIX = "/_matrix/key/v1"

View File

@ -36,7 +36,8 @@ from synapse.http.server_key_resource import LocalKey
from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.api.urls import ( from synapse.api.urls import (
CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX, CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX, APP_SERVICE_PREFIX SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX, APP_SERVICE_PREFIX,
STATIC_PREFIX
) )
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory from synapse.crypto import context_factory
@ -52,6 +53,7 @@ import synapse
import logging import logging
import os import os
import re import re
import resource
import subprocess import subprocess
import sqlite3 import sqlite3
import syweb import syweb
@ -81,6 +83,9 @@ class SynapseHomeServer(HomeServer):
webclient_path = os.path.join(syweb_path, "webclient") webclient_path = os.path.join(syweb_path, "webclient")
return File(webclient_path) # TODO configurable? return File(webclient_path) # TODO configurable?
def build_resource_for_static_content(self):
return File("static")
def build_resource_for_content_repo(self): def build_resource_for_content_repo(self):
return ContentRepoResource( return ContentRepoResource(
self, self.upload_dir, self.auth, self.content_addr self, self.upload_dir, self.auth, self.content_addr
@ -124,7 +129,9 @@ class SynapseHomeServer(HomeServer):
(SERVER_KEY_PREFIX, self.get_resource_for_server_key()), (SERVER_KEY_PREFIX, self.get_resource_for_server_key()),
(MEDIA_PREFIX, self.get_resource_for_media_repository()), (MEDIA_PREFIX, self.get_resource_for_media_repository()),
(APP_SERVICE_PREFIX, self.get_resource_for_app_services()), (APP_SERVICE_PREFIX, self.get_resource_for_app_services()),
(STATIC_PREFIX, self.get_resource_for_static_content()),
] ]
if web_client: if web_client:
logger.info("Adding the web client.") logger.info("Adding the web client.")
desired_tree.append((WEB_CLIENT_PREFIX, desired_tree.append((WEB_CLIENT_PREFIX,
@ -140,8 +147,8 @@ class SynapseHomeServer(HomeServer):
# instead, we'll store a copy of this mapping so we can actually add # instead, we'll store a copy of this mapping so we can actually add
# extra resources to existing nodes. See self._resource_id for the key. # extra resources to existing nodes. See self._resource_id for the key.
resource_mappings = {} resource_mappings = {}
for (full_path, resource) in desired_tree: for full_path, res in desired_tree:
logger.info("Attaching %s to path %s", resource, full_path) logger.info("Attaching %s to path %s", res, full_path)
last_resource = self.root_resource last_resource = self.root_resource
for path_seg in full_path.split('/')[1:-1]: for path_seg in full_path.split('/')[1:-1]:
if path_seg not in last_resource.listNames(): if path_seg not in last_resource.listNames():
@ -172,12 +179,12 @@ class SynapseHomeServer(HomeServer):
child_name) child_name)
child_resource = resource_mappings[child_res_id] child_resource = resource_mappings[child_res_id]
# steal the children # steal the children
resource.putChild(child_name, child_resource) res.putChild(child_name, child_resource)
# finally, insert the desired resource in the right place # finally, insert the desired resource in the right place
last_resource.putChild(last_path_seg, resource) last_resource.putChild(last_path_seg, res)
res_id = self._resource_id(last_resource, last_path_seg) res_id = self._resource_id(last_resource, last_path_seg)
resource_mappings[res_id] = resource resource_mappings[res_id] = res
return self.root_resource return self.root_resource
@ -269,6 +276,20 @@ def get_version_string():
return ("Synapse/%s" % (synapse.__version__,)).encode("ascii") return ("Synapse/%s" % (synapse.__version__,)).encode("ascii")
def change_resource_limit(soft_file_no):
try:
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
if not soft_file_no:
soft_file_no = hard
resource.setrlimit(resource.RLIMIT_NOFILE, (soft_file_no, hard))
logger.info("Set file limit to: %d", soft_file_no)
except (ValueError, resource.error) as e:
logger.warn("Failed to set file limit: %s", e)
def setup(): def setup():
config = HomeServerConfig.load_config( config = HomeServerConfig.load_config(
"Synapse Homeserver", "Synapse Homeserver",
@ -345,10 +366,11 @@ def setup():
if config.daemonize: if config.daemonize:
print config.pid_file print config.pid_file
daemon = Daemonize( daemon = Daemonize(
app="synapse-homeserver", app="synapse-homeserver",
pid=config.pid_file, pid=config.pid_file,
action=run, action=lambda: run(config),
auto_close_fds=False, auto_close_fds=False,
verbose=True, verbose=True,
logger=logger, logger=logger,
@ -356,11 +378,13 @@ def setup():
daemon.start() daemon.start()
else: else:
reactor.run() run(config)
def run(): def run(config):
with LoggingContext("run"): with LoggingContext("run"):
change_resource_limit(config.soft_file_limit)
reactor.run() reactor.run()

View File

@ -22,6 +22,12 @@ class RatelimitConfig(Config):
self.rc_messages_per_second = args.rc_messages_per_second self.rc_messages_per_second = args.rc_messages_per_second
self.rc_message_burst_count = args.rc_message_burst_count self.rc_message_burst_count = args.rc_message_burst_count
self.federation_rc_window_size = args.federation_rc_window_size
self.federation_rc_sleep_limit = args.federation_rc_sleep_limit
self.federation_rc_sleep_delay = args.federation_rc_sleep_delay
self.federation_rc_reject_limit = args.federation_rc_reject_limit
self.federation_rc_concurrent = args.federation_rc_concurrent
@classmethod @classmethod
def add_arguments(cls, parser): def add_arguments(cls, parser):
super(RatelimitConfig, cls).add_arguments(parser) super(RatelimitConfig, cls).add_arguments(parser)
@ -34,3 +40,33 @@ class RatelimitConfig(Config):
"--rc-message-burst-count", type=float, default=10, "--rc-message-burst-count", type=float, default=10,
help="number of message a client can send before being throttled" help="number of message a client can send before being throttled"
) )
rc_group.add_argument(
"--federation-rc-window-size", type=int, default=10000,
help="The federation window size in milliseconds",
)
rc_group.add_argument(
"--federation-rc-sleep-limit", type=int, default=10,
help="The number of federation requests from a single server"
" in a window before the server will delay processing the"
" request.",
)
rc_group.add_argument(
"--federation-rc-sleep-delay", type=int, default=500,
help="The duration in milliseconds to delay processing events from"
" remote servers by if they go over the sleep limit.",
)
rc_group.add_argument(
"--federation-rc-reject-limit", type=int, default=50,
help="The maximum number of concurrent federation requests allowed"
" from a single server",
)
rc_group.add_argument(
"--federation-rc-concurrent", type=int, default=3,
help="The number of federation requests to concurrently process"
" from a single server",
)

View File

@ -31,6 +31,7 @@ class ServerConfig(Config):
self.webclient = True self.webclient = True
self.manhole = args.manhole self.manhole = args.manhole
self.no_tls = args.no_tls self.no_tls = args.no_tls
self.soft_file_limit = args.soft_file_limit
if not args.content_addr: if not args.content_addr:
host = args.server_name host = args.server_name
@ -77,6 +78,12 @@ class ServerConfig(Config):
"content repository") "content repository")
server_group.add_argument("--no-tls", action='store_true', server_group.add_argument("--no-tls", action='store_true',
help="Don't bind to the https port.") help="Don't bind to the https port.")
server_group.add_argument("--soft-file-limit", type=int, default=0,
help="Set the soft limit on the number of "
"file descriptors synapse can use. "
"Zero is used to indicate synapse "
"should set the soft limit to the hard"
"limit.")
def read_signing_key(self, signing_key_path): def read_signing_key(self, signing_key_path):
signing_keys = self.read_file(signing_key_path, "signing_key") signing_keys = self.read_file(signing_key_path, "signing_key")

View File

@ -28,7 +28,7 @@ class VoipConfig(Config):
super(VoipConfig, cls).add_arguments(parser) super(VoipConfig, cls).add_arguments(parser)
group = parser.add_argument_group("voip") group = parser.add_argument_group("voip")
group.add_argument( group.add_argument(
"--turn-uris", type=str, default=None, "--turn-uris", type=str, default=None, action='append',
help="The public URIs of the TURN server to give to clients" help="The public URIs of the TURN server to give to clients"
) )
group.add_argument( group.add_argument(

View File

@ -112,17 +112,20 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id) logger.debug("[%s] Transaction is new", transaction.transaction_id)
with PreserveLoggingContext(): with PreserveLoggingContext():
dl = [] results = []
for pdu in pdu_list: for pdu in pdu_list:
d = self._handle_new_pdu(transaction.origin, pdu) d = self._handle_new_pdu(transaction.origin, pdu)
def handle_failure(failure): try:
failure.trap(FederationError) yield d
self.send_failure(failure.value, transaction.origin) results.append({})
except FederationError as e:
d.addErrback(handle_failure) self.send_failure(e, transaction.origin)
results.append({"error": str(e)})
dl.append(d) except Exception as e:
results.append({"error": str(e)})
logger.exception("Failed to handle PDU")
if hasattr(transaction, "edus"): if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]: for edu in [Edu(**x) for x in transaction.edus]:
@ -135,29 +138,11 @@ class FederationServer(FederationBase):
for failure in getattr(transaction, "pdu_failures", []): for failure in getattr(transaction, "pdu_failures", []):
logger.info("Got failure %r", failure) logger.info("Got failure %r", failure)
results = yield defer.DeferredList(dl, consumeErrors=True) logger.debug("Returning: %s", str(results))
ret = []
for r in results:
if r[0]:
ret.append({})
else:
failure = r[1]
logger.error(
"Failed to handle PDU",
exc_info=(
failure.type,
failure.value,
failure.getTracebackObject()
)
)
ret.append({"error": str(r[1].value)})
logger.debug("Returning: %s", str(ret))
response = { response = {
"pdus": dict(zip( "pdus": dict(zip(
(p.event_id for p in pdu_list), ret (p.event_id for p in pdu_list), results
)), )),
} }

View File

@ -24,6 +24,8 @@ communicate over a different (albeit still reliable) protocol.
from .server import TransportLayerServer from .server import TransportLayerServer
from .client import TransportLayerClient from .client import TransportLayerClient
from synapse.util.ratelimitutils import FederationRateLimiter
class TransportLayer(TransportLayerServer, TransportLayerClient): class TransportLayer(TransportLayerServer, TransportLayerClient):
"""This is a basic implementation of the transport layer that translates """This is a basic implementation of the transport layer that translates
@ -55,8 +57,18 @@ class TransportLayer(TransportLayerServer, TransportLayerClient):
send requests send requests
""" """
self.keyring = homeserver.get_keyring() self.keyring = homeserver.get_keyring()
self.clock = homeserver.get_clock()
self.server_name = server_name self.server_name = server_name
self.server = server self.server = server
self.client = client self.client = client
self.request_handler = None self.request_handler = None
self.received_handler = None self.received_handler = None
self.ratelimiter = FederationRateLimiter(
self.clock,
window_size=homeserver.config.federation_rc_window_size,
sleep_limit=homeserver.config.federation_rc_sleep_limit,
sleep_msec=homeserver.config.federation_rc_sleep_delay,
reject_limit=homeserver.config.federation_rc_reject_limit,
concurrent_requests=homeserver.config.federation_rc_concurrent,
)

View File

@ -98,15 +98,23 @@ class TransportLayerServer(object):
def new_handler(request, *args, **kwargs): def new_handler(request, *args, **kwargs):
try: try:
(origin, content) = yield self._authenticate_request(request) (origin, content) = yield self._authenticate_request(request)
response = yield handler( with self.ratelimiter.ratelimit(origin) as d:
origin, content, request.args, *args, **kwargs yield d
) response = yield handler(
origin, content, request.args, *args, **kwargs
)
except: except:
logger.exception("_authenticate_request failed") logger.exception("_authenticate_request failed")
raise raise
defer.returnValue(response) defer.returnValue(response)
return new_handler return new_handler
def rate_limit_origin(self, handler):
def new_handler(origin, *args, **kwargs):
response = yield handler(origin, *args, **kwargs)
defer.returnValue(response)
return new_handler()
@log_function @log_function
def register_received_handler(self, handler): def register_received_handler(self, handler):
""" Register a handler that will be fired when we receive data. """ Register a handler that will be fired when we receive data.

View File

@ -160,7 +160,7 @@ class DirectoryHandler(BaseHandler):
if not room_id: if not room_id:
raise SynapseError( raise SynapseError(
404, 404,
"Room alias %r not found" % (room_alias.to_string(),), "Room alias %s not found" % (room_alias.to_string(),),
Codes.NOT_FOUND Codes.NOT_FOUND
) )

View File

@ -69,9 +69,6 @@ class EventStreamHandler(BaseHandler):
) )
self._streams_per_user[auth_user] += 1 self._streams_per_user[auth_user] += 1
if pagin_config.from_token is None:
pagin_config.from_token = None
rm_handler = self.hs.get_handlers().room_member_handler rm_handler = self.hs.get_handlers().room_member_handler
room_ids = yield rm_handler.get_rooms_for_user(auth_user) room_ids = yield rm_handler.get_rooms_for_user(auth_user)

View File

@ -510,9 +510,16 @@ class RoomMemberHandler(BaseHandler):
def get_rooms_for_user(self, user, membership_list=[Membership.JOIN]): def get_rooms_for_user(self, user, membership_list=[Membership.JOIN]):
"""Returns a list of roomids that the user has any of the given """Returns a list of roomids that the user has any of the given
membership states in.""" membership states in."""
rooms = yield self.store.get_rooms_for_user_where_membership_is(
user_id=user.to_string(), membership_list=membership_list app_service = yield self.store.get_app_service_by_user_id(
user.to_string()
) )
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
else:
rooms = yield self.store.get_rooms_for_user_where_membership_is(
user_id=user.to_string(), membership_list=membership_list
)
# For some reason the list of events contains duplicates # For some reason the list of events contains duplicates
# TODO(paul): work out why because I really don't think it should # TODO(paul): work out why because I really don't think it should
@ -559,13 +566,24 @@ class RoomEventSource(object):
to_key = yield self.get_current_key() to_key = yield self.get_current_key()
events, end_key = yield self.store.get_room_events_stream( app_service = yield self.store.get_app_service_by_user_id(
user_id=user.to_string(), user.to_string()
from_key=from_key,
to_key=to_key,
room_id=None,
limit=limit,
) )
if app_service:
events, end_key = yield self.store.get_appservice_room_stream(
service=app_service,
from_key=from_key,
to_key=to_key,
limit=limit,
)
else:
events, end_key = yield self.store.get_room_events_stream(
user_id=user.to_string(),
from_key=from_key,
to_key=to_key,
room_id=None,
limit=limit,
)
defer.returnValue((events, end_key)) defer.returnValue((events, end_key))

View File

@ -36,8 +36,10 @@ class _NotificationListener(object):
so that it can remove itself from the indexes in the Notifier class. so that it can remove itself from the indexes in the Notifier class.
""" """
def __init__(self, user, rooms, from_token, limit, timeout, deferred): def __init__(self, user, rooms, from_token, limit, timeout, deferred,
appservice=None):
self.user = user self.user = user
self.appservice = appservice
self.from_token = from_token self.from_token = from_token
self.limit = limit self.limit = limit
self.timeout = timeout self.timeout = timeout
@ -65,6 +67,10 @@ class _NotificationListener(object):
lst.discard(self) lst.discard(self)
notifier.user_to_listeners.get(self.user, set()).discard(self) notifier.user_to_listeners.get(self.user, set()).discard(self)
if self.appservice:
notifier.appservice_to_listeners.get(
self.appservice, set()
).discard(self)
class Notifier(object): class Notifier(object):
@ -79,6 +85,7 @@ class Notifier(object):
self.rooms_to_listeners = {} self.rooms_to_listeners = {}
self.user_to_listeners = {} self.user_to_listeners = {}
self.appservice_to_listeners = {}
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
@ -114,6 +121,17 @@ class Notifier(object):
for user in extra_users: for user in extra_users:
listeners |= self.user_to_listeners.get(user, set()).copy() listeners |= self.user_to_listeners.get(user, set()).copy()
for appservice in self.appservice_to_listeners:
# TODO (kegan): Redundant appservice listener checks?
# App services will already be in the rooms_to_listeners set, but
# that isn't enough. They need to be checked here in order to
# receive *invites* for users they are interested in. Does this
# make the rooms_to_listeners check somewhat obselete?
if appservice.is_interested(event):
listeners |= self.appservice_to_listeners.get(
appservice, set()
).copy()
logger.debug("on_new_room_event listeners %s", listeners) logger.debug("on_new_room_event listeners %s", listeners)
# TODO (erikj): Can we make this more efficient by hitting the # TODO (erikj): Can we make this more efficient by hitting the
@ -280,6 +298,10 @@ class Notifier(object):
if not from_token: if not from_token:
from_token = yield self.event_sources.get_current_token() from_token = yield self.event_sources.get_current_token()
appservice = yield self.hs.get_datastore().get_app_service_by_user_id(
user.to_string()
)
listener = _NotificationListener( listener = _NotificationListener(
user, user,
rooms, rooms,
@ -287,6 +309,7 @@ class Notifier(object):
limit, limit,
timeout, timeout,
deferred, deferred,
appservice=appservice
) )
def _timeout_listener(): def _timeout_listener():
@ -319,6 +342,11 @@ class Notifier(object):
self.user_to_listeners.setdefault(listener.user, set()).add(listener) self.user_to_listeners.setdefault(listener.user, set()).add(listener)
if listener.appservice:
self.appservice_to_listeners.setdefault(
listener.appservice, set()
).add(listener)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def _check_for_updates(self, listener): def _check_for_updates(self, listener):

View File

@ -214,7 +214,7 @@ def _rule_spec_from_path(path):
template = path[0] template = path[0]
path = path[1:] path = path[1:]
if len(path) == 0: if len(path) == 0 or len(path[0]) == 0:
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
rule_id = path[0] rule_id = path[0]

View File

@ -73,6 +73,7 @@ class BaseHomeServer(object):
'resource_for_client', 'resource_for_client',
'resource_for_client_v2_alpha', 'resource_for_client_v2_alpha',
'resource_for_federation', 'resource_for_federation',
'resource_for_static_content',
'resource_for_web_client', 'resource_for_web_client',
'resource_for_content_repo', 'resource_for_content_repo',
'resource_for_server_key', 'resource_for_server_key',

View File

@ -23,7 +23,7 @@ from synapse.util.lrucache import LruCache
from twisted.internet import defer from twisted.internet import defer
import collections from collections import namedtuple, OrderedDict
import simplejson as json import simplejson as json
import sys import sys
import time import time
@ -35,6 +35,52 @@ sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn") transaction_logger = logging.getLogger("synapse.storage.txn")
# TODO(paul):
# * more generic key management
# * export monitoring stats
# * consider other eviction strategies - LRU?
def cached(max_entries=1000):
""" A method decorator that applies a memoizing cache around the function.
The function is presumed to take one additional argument, which is used as
the key for the cache. Cache hits are served directly from the cache;
misses use the function body to generate the value.
The wrapped function has an additional member, a callable called
"invalidate". This can be used to remove individual entries from the cache.
The wrapped function has another additional callable, called "prefill",
which can be used to insert values into the cache specifically, without
calling the calculation function.
"""
def wrap(orig):
cache = OrderedDict()
def prefill(key, value):
while len(cache) > max_entries:
cache.popitem(last=False)
cache[key] = value
@defer.inlineCallbacks
def wrapped(self, key):
if key in cache:
defer.returnValue(cache[key])
ret = yield orig(self, key)
prefill(key, ret)
defer.returnValue(ret)
def invalidate(key):
cache.pop(key, None)
wrapped.invalidate = invalidate
wrapped.prefill = prefill
return wrapped
return wrap
class LoggingTransaction(object): class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object """An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging to the .execute() method.""" passed to the constructor. Adds logging to the .execute() method."""
@ -404,7 +450,8 @@ class SQLBaseStore(object):
Args: Args:
table : string giving the table name table : string giving the table name
keyvalues : dict of column names and values to select the rows with keyvalues : dict of column names and values to select the rows with,
or None to not apply a WHERE clause.
retcols : list of strings giving the names of the columns to return retcols : list of strings giving the names of the columns to return
""" """
return self.runInteraction( return self.runInteraction(
@ -423,13 +470,20 @@ class SQLBaseStore(object):
keyvalues : dict of column names and values to select the rows with keyvalues : dict of column names and values to select the rows with
retcols : list of strings giving the names of the columns to return retcols : list of strings giving the names of the columns to return
""" """
sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % ( if keyvalues:
", ".join(retcols), sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
table, ", ".join(retcols),
" AND ".join("%s = ?" % (k, ) for k in keyvalues) table,
) " AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
txn.execute(sql, keyvalues.values())
else:
sql = "SELECT %s FROM %s ORDER BY rowid asc" % (
", ".join(retcols),
table
)
txn.execute(sql)
txn.execute(sql, keyvalues.values())
return self.cursor_to_dict(txn) return self.cursor_to_dict(txn)
def _simple_update_one(self, table, keyvalues, updatevalues, def _simple_update_one(self, table, keyvalues, updatevalues,
@ -586,8 +640,9 @@ class SQLBaseStore(object):
start_time = time.time() * 1000 start_time = time.time() * 1000
update_counter = self._get_event_counters.update update_counter = self._get_event_counters.update
cache = self._get_event_cache.setdefault(event_id, {})
try: try:
cache = self._get_event_cache.setdefault(event_id, {})
# Separate cache entries for each way to invoke _get_event_txn # Separate cache entries for each way to invoke _get_event_txn
return cache[(check_redacted, get_prev_content, allow_rejected)] return cache[(check_redacted, get_prev_content, allow_rejected)]
except KeyError: except KeyError:
@ -786,7 +841,7 @@ class JoinHelper(object):
for table in self.tables: for table in self.tables:
res += [f for f in table.fields if f not in res] res += [f for f in table.fields if f not in res]
self.EntryType = collections.namedtuple("JoinHelperEntry", res) self.EntryType = namedtuple("JoinHelperEntry", res)
def get_fields(self, **prefixes): def get_fields(self, **prefixes):
"""Get a string representing a list of fields for use in SELECT """Get a string representing a list of fields for use in SELECT

View File

@ -15,31 +15,21 @@
import logging import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.storage.roommember import RoomsForUser
from ._base import SQLBaseStore from ._base import SQLBaseStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ApplicationServiceCache(object):
"""Caches ApplicationServices and provides utility functions on top.
This class is designed to be invoked on incoming events in order to avoid
hammering the database every time to extract a list of application service
regexes.
"""
def __init__(self):
self.services = []
class ApplicationServiceStore(SQLBaseStore): class ApplicationServiceStore(SQLBaseStore):
def __init__(self, hs): def __init__(self, hs):
super(ApplicationServiceStore, self).__init__(hs) super(ApplicationServiceStore, self).__init__(hs)
self.cache = ApplicationServiceCache() self.services_cache = []
self.cache_defer = self._populate_cache() self.cache_defer = self._populate_cache()
@defer.inlineCallbacks @defer.inlineCallbacks
@ -56,7 +46,7 @@ class ApplicationServiceStore(SQLBaseStore):
token, token,
) )
# update cache TODO: Should this be in the txn? # update cache TODO: Should this be in the txn?
for service in self.cache.services: for service in self.services_cache:
if service.token == token: if service.token == token:
service.url = None service.url = None
service.namespaces = None service.namespaces = None
@ -110,13 +100,13 @@ class ApplicationServiceStore(SQLBaseStore):
) )
# update cache TODO: Should this be in the txn? # update cache TODO: Should this be in the txn?
for (index, cache_service) in enumerate(self.cache.services): for (index, cache_service) in enumerate(self.services_cache):
if service.token == cache_service.token: if service.token == cache_service.token:
self.cache.services[index] = service self.services_cache[index] = service
logger.info("Updated: %s", service) logger.info("Updated: %s", service)
return return
# new entry # new entry
self.cache.services.append(service) self.services_cache.append(service)
logger.info("Updated(new): %s", service) logger.info("Updated(new): %s", service)
def _update_app_service_txn(self, txn, service): def _update_app_service_txn(self, txn, service):
@ -160,11 +150,34 @@ class ApplicationServiceStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_app_services(self): def get_app_services(self):
yield self.cache_defer # make sure the cache is ready yield self.cache_defer # make sure the cache is ready
defer.returnValue(self.cache.services) defer.returnValue(self.services_cache)
@defer.inlineCallbacks
def get_app_service_by_user_id(self, user_id):
"""Retrieve an application service from their user ID.
All application services have associated with them a particular user ID.
There is no distinguishing feature on the user ID which indicates it
represents an application service. This function allows you to map from
a user ID to an application service.
Args:
user_id(str): The user ID to see if it is an application service.
Returns:
synapse.appservice.ApplicationService or None.
"""
yield self.cache_defer # make sure the cache is ready
for service in self.services_cache:
if service.sender == user_id:
defer.returnValue(service)
return
defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_app_service_by_token(self, token, from_cache=True): def get_app_service_by_token(self, token, from_cache=True):
"""Get the application service with the given token. """Get the application service with the given appservice token.
Args: Args:
token (str): The application service token. token (str): The application service token.
@ -176,7 +189,7 @@ class ApplicationServiceStore(SQLBaseStore):
yield self.cache_defer # make sure the cache is ready yield self.cache_defer # make sure the cache is ready
if from_cache: if from_cache:
for service in self.cache.services: for service in self.services_cache:
if service.token == token: if service.token == token:
defer.returnValue(service) defer.returnValue(service)
return return
@ -185,6 +198,77 @@ class ApplicationServiceStore(SQLBaseStore):
# TODO: The from_cache=False impl # TODO: The from_cache=False impl
# TODO: This should be JOINed with the application_services_regex table. # TODO: This should be JOINed with the application_services_regex table.
def get_app_service_rooms(self, service):
"""Get a list of RoomsForUser for this application service.
Application services may be "interested" in lots of rooms depending on
the room ID, the room aliases, or the members in the room. This function
takes all of these into account and returns a list of RoomsForUser which
represent the entire list of room IDs that this application service
wants to know about.
Args:
service: The application service to get a room list for.
Returns:
A list of RoomsForUser.
"""
return self.runInteraction(
"get_app_service_rooms",
self._get_app_service_rooms_txn,
service,
)
def _get_app_service_rooms_txn(self, txn, service):
# get all rooms matching the room ID regex.
room_entries = self._simple_select_list_txn(
txn=txn, table="rooms", keyvalues=None, retcols=["room_id"]
)
matching_room_list = set([
r["room_id"] for r in room_entries if
service.is_interested_in_room(r["room_id"])
])
# resolve room IDs for matching room alias regex.
room_alias_mappings = self._simple_select_list_txn(
txn=txn, table="room_aliases", keyvalues=None,
retcols=["room_id", "room_alias"]
)
matching_room_list |= set([
r["room_id"] for r in room_alias_mappings if
service.is_interested_in_alias(r["room_alias"])
])
# get all rooms for every user for this AS. This is scoped to users on
# this HS only.
user_list = self._simple_select_list_txn(
txn=txn, table="users", keyvalues=None, retcols=["name"]
)
user_list = [
u["name"] for u in user_list if
service.is_interested_in_user(u["name"])
]
rooms_for_user_matching_user_id = set() # RoomsForUser list
for user_id in user_list:
# FIXME: This assumes this store is linked with RoomMemberStore :(
rooms_for_user = self._get_rooms_for_user_where_membership_is_txn(
txn=txn,
user_id=user_id,
membership_list=[Membership.JOIN]
)
rooms_for_user_matching_user_id |= set(rooms_for_user)
# make RoomsForUser tuples for room ids and aliases which are not in the
# main rooms_for_user_list - e.g. they are rooms which do not have AS
# registered users in it.
known_room_ids = [r.room_id for r in rooms_for_user_matching_user_id]
missing_rooms_for_user = [
RoomsForUser(r, service.sender, "join") for r in
matching_room_list if r not in known_room_ids
]
rooms_for_user_matching_user_id |= set(missing_rooms_for_user)
return rooms_for_user_matching_user_id
@defer.inlineCallbacks @defer.inlineCallbacks
def _populate_cache(self): def _populate_cache(self):
"""Populates the ApplicationServiceCache from the database.""" """Populates the ApplicationServiceCache from the database."""
@ -235,7 +319,7 @@ class ApplicationServiceStore(SQLBaseStore):
# TODO get last successful txn id f.e. service # TODO get last successful txn id f.e. service
for service in services.values(): for service in services.values():
logger.info("Found application service: %s", service) logger.info("Found application service: %s", service)
self.cache.services.append(ApplicationService( self.services_cache.append(ApplicationService(
token=service["token"], token=service["token"],
url=service["url"], url=service["url"],
namespaces=service["namespaces"], namespaces=service["namespaces"],

View File

@ -17,7 +17,7 @@ from twisted.internet import defer
from collections import namedtuple from collections import namedtuple
from ._base import SQLBaseStore from ._base import SQLBaseStore, cached
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.types import UserID from synapse.types import UserID
@ -35,11 +35,6 @@ RoomsForUser = namedtuple(
class RoomMemberStore(SQLBaseStore): class RoomMemberStore(SQLBaseStore):
def __init__(self, *args, **kw):
super(RoomMemberStore, self).__init__(*args, **kw)
self._user_rooms_cache = {}
def _store_room_member_txn(self, txn, event): def _store_room_member_txn(self, txn, event):
"""Store a room member in the database. """Store a room member in the database.
""" """
@ -103,7 +98,7 @@ class RoomMemberStore(SQLBaseStore):
txn.execute(sql, (event.room_id, domain)) txn.execute(sql, (event.room_id, domain))
self.invalidate_rooms_for_user(target_user_id) self.get_rooms_for_user.invalidate(target_user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_member(self, user_id, room_id): def get_room_member(self, user_id, room_id):
@ -185,6 +180,14 @@ class RoomMemberStore(SQLBaseStore):
if not membership_list: if not membership_list:
return defer.succeed(None) return defer.succeed(None)
return self.runInteraction(
"get_rooms_for_user_where_membership_is",
self._get_rooms_for_user_where_membership_is_txn,
user_id, membership_list
)
def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id,
membership_list):
where_clause = "user_id = ? AND (%s)" % ( where_clause = "user_id = ? AND (%s)" % (
" OR ".join(["membership = ?" for _ in membership_list]), " OR ".join(["membership = ?" for _ in membership_list]),
) )
@ -192,24 +195,18 @@ class RoomMemberStore(SQLBaseStore):
args = [user_id] args = [user_id]
args.extend(membership_list) args.extend(membership_list)
def f(txn): sql = (
sql = ( "SELECT m.room_id, m.sender, m.membership"
"SELECT m.room_id, m.sender, m.membership" " FROM room_memberships as m"
" FROM room_memberships as m" " INNER JOIN current_state_events as c"
" INNER JOIN current_state_events as c" " ON m.event_id = c.event_id"
" ON m.event_id = c.event_id" " WHERE %s"
" WHERE %s" ) % (where_clause,)
) % (where_clause,)
txn.execute(sql, args) txn.execute(sql, args)
return [ return [
RoomsForUser(**r) for r in self.cursor_to_dict(txn) RoomsForUser(**r) for r in self.cursor_to_dict(txn)
] ]
return self.runInteraction(
"get_rooms_for_user_where_membership_is",
f
)
def get_joined_hosts_for_room(self, room_id): def get_joined_hosts_for_room(self, room_id):
return self._simple_select_onecol( return self._simple_select_onecol(
@ -247,33 +244,12 @@ class RoomMemberStore(SQLBaseStore):
results = self._parse_events_txn(txn, rows) results = self._parse_events_txn(txn, rows)
return results return results
# TODO(paul): Create a nice @cached decorator to do this @cached()
# @cached
# def get_foo(...)
# ...
# invalidate_foo = get_foo.invalidator
@defer.inlineCallbacks
def get_rooms_for_user(self, user_id): def get_rooms_for_user(self, user_id):
# TODO(paul): put some performance counters in here so we can easily return self.get_rooms_for_user_where_membership_is(
# track what impact this cache is having
if user_id in self._user_rooms_cache:
defer.returnValue(self._user_rooms_cache[user_id])
rooms = yield self.get_rooms_for_user_where_membership_is(
user_id, membership_list=[Membership.JOIN], user_id, membership_list=[Membership.JOIN],
) )
# TODO(paul): Consider applying a maximum size; just evict things at
# random, or consider LRU?
self._user_rooms_cache[user_id] = rooms
defer.returnValue(rooms)
def invalidate_rooms_for_user(self, user_id):
if user_id in self._user_rooms_cache:
del self._user_rooms_cache[user_id]
@defer.inlineCallbacks @defer.inlineCallbacks
def user_rooms_intersect(self, user_id_list): def user_rooms_intersect(self, user_id_list):
""" Checks whether all the users whose IDs are given in a list share a """ Checks whether all the users whose IDs are given in a list share a

View File

@ -36,6 +36,7 @@ what sort order was used:
from twisted.internet import defer from twisted.internet import defer
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
@ -127,6 +128,85 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")):
class StreamStore(SQLBaseStore): class StreamStore(SQLBaseStore):
@defer.inlineCallbacks
def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
# NB this lives here instead of appservice.py so we can reuse the
# 'private' StreamToken class in this file.
if limit:
limit = max(limit, MAX_STREAM_SIZE)
else:
limit = MAX_STREAM_SIZE
# From and to keys should be integers from ordering.
from_id = _StreamToken.parse_stream_token(from_key)
to_id = _StreamToken.parse_stream_token(to_key)
if from_key == to_key:
defer.returnValue(([], to_key))
return
# select all the events between from/to with a sensible limit
sql = (
"SELECT e.event_id, e.room_id, e.type, s.state_key, "
"e.stream_ordering FROM events AS e LEFT JOIN state_events as s ON "
"e.event_id = s.event_id "
"WHERE e.stream_ordering > ? AND e.stream_ordering <= ? "
"ORDER BY stream_ordering ASC LIMIT %(limit)d "
) % {
"limit": limit
}
def f(txn):
# pull out all the events between the tokens
txn.execute(sql, (from_id.stream, to_id.stream,))
rows = self.cursor_to_dict(txn)
# Logic:
# - We want ALL events which match the AS room_id regex
# - We want ALL events which match the rooms represented by the AS
# room_alias regex
# - We want ALL events for rooms that AS users have joined.
# This is currently supported via get_app_service_rooms (which is
# used for the Notifier listener rooms). We can't reasonably make a
# SQL query for these room IDs, so we'll pull all the events between
# from/to and filter in python.
rooms_for_as = self._get_app_service_rooms_txn(txn, service)
room_ids_for_as = [r.room_id for r in rooms_for_as]
def app_service_interested(row):
if row["room_id"] in room_ids_for_as:
return True
if row["type"] == EventTypes.Member:
if service.is_interested_in_user(row.get("state_key")):
return True
return False
ret = self._get_events_txn(
txn,
# apply the filter on the room id list
[
r["event_id"] for r in rows
if app_service_interested(r)
],
get_prev_content=True
)
self._set_before_and_after(ret, rows)
if rows:
key = "s%d" % max(r["stream_ordering"] for r in rows)
else:
# Assume we didn't get anything because there was nothing to
# get.
key = to_key
return ret, key
results = yield self.runInteraction("get_appservice_room_stream", f)
defer.returnValue(results)
@log_function @log_function
def get_room_events_stream(self, user_id, from_key, to_key, room_id, def get_room_events_stream(self, user_id, from_key, to_key, room_id,
limit=0, with_feedback=False): limit=0, with_feedback=False):
@ -184,8 +264,7 @@ class StreamStore(SQLBaseStore):
self._set_before_and_after(ret, rows) self._set_before_and_after(ret, rows)
if rows: if rows:
key = "s%d" % max([r["stream_ordering"] for r in rows]) key = "s%d" % max(r["stream_ordering"] for r in rows)
else: else:
# Assume we didn't get anything because there was nothing to # Assume we didn't get anything because there was nothing to
# get. # get.

View File

@ -13,12 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, Table from ._base import SQLBaseStore, Table, cached
from collections import namedtuple from collections import namedtuple
from twisted.internet import defer
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -28,10 +26,6 @@ class TransactionStore(SQLBaseStore):
"""A collection of queries for handling PDUs. """A collection of queries for handling PDUs.
""" """
# a write-through cache of DestinationsTable.EntryType indexed by
# destination string
destination_retry_cache = {}
def get_received_txn_response(self, transaction_id, origin): def get_received_txn_response(self, transaction_id, origin):
"""For an incoming transaction from a given origin, check if we have """For an incoming transaction from a given origin, check if we have
already responded to it. If so, return the response code and response already responded to it. If so, return the response code and response
@ -211,6 +205,7 @@ class TransactionStore(SQLBaseStore):
return ReceivedTransactionsTable.decode_results(txn.fetchall()) return ReceivedTransactionsTable.decode_results(txn.fetchall())
@cached()
def get_destination_retry_timings(self, destination): def get_destination_retry_timings(self, destination):
"""Gets the current retry timings (if any) for a given destination. """Gets the current retry timings (if any) for a given destination.
@ -221,9 +216,6 @@ class TransactionStore(SQLBaseStore):
None if not retrying None if not retrying
Otherwise a DestinationsTable.EntryType for the retry scheme Otherwise a DestinationsTable.EntryType for the retry scheme
""" """
if destination in self.destination_retry_cache:
return defer.succeed(self.destination_retry_cache[destination])
return self.runInteraction( return self.runInteraction(
"get_destination_retry_timings", "get_destination_retry_timings",
self._get_destination_retry_timings, destination) self._get_destination_retry_timings, destination)
@ -250,7 +242,9 @@ class TransactionStore(SQLBaseStore):
retry_interval (int) - how long until next retry in ms retry_interval (int) - how long until next retry in ms
""" """
self.destination_retry_cache[destination] = ( # As this is the new value, we might as well prefill the cache
self.get_destination_retry_timings.prefill(
destination,
DestinationsTable.EntryType( DestinationsTable.EntryType(
destination, destination,
retry_last_ts, retry_last_ts,

View File

@ -0,0 +1,216 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import LimitExceededError
from synapse.util.async import sleep
import collections
import contextlib
import logging
logger = logging.getLogger(__name__)
class FederationRateLimiter(object):
def __init__(self, clock, window_size, sleep_limit, sleep_msec,
reject_limit, concurrent_requests):
"""
Args:
clock (Clock)
window_size (int): The window size in milliseconds.
sleep_limit (int): The number of requests received in the last
`window_size` milliseconds before we artificially start
delaying processing of requests.
sleep_msec (int): The number of milliseconds to delay processing
of incoming requests by.
reject_limit (int): The maximum number of requests that are can be
queued for processing before we start rejecting requests with
a 429 Too Many Requests response.
concurrent_requests (int): The number of concurrent requests to
process.
"""
self.clock = clock
self.window_size = window_size
self.sleep_limit = sleep_limit
self.sleep_msec = sleep_msec
self.reject_limit = reject_limit
self.concurrent_requests = concurrent_requests
self.ratelimiters = {}
def ratelimit(self, host):
"""Used to ratelimit an incoming request from given host
Example usage:
with rate_limiter.ratelimit(origin) as wait_deferred:
yield wait_deferred
# Handle request ...
Args:
host (str): Origin of incoming request.
Returns:
_PerHostRatelimiter
"""
return self.ratelimiters.setdefault(
host,
_PerHostRatelimiter(
clock=self.clock,
window_size=self.window_size,
sleep_limit=self.sleep_limit,
sleep_msec=self.sleep_msec,
reject_limit=self.reject_limit,
concurrent_requests=self.concurrent_requests,
)
).ratelimit()
class _PerHostRatelimiter(object):
def __init__(self, clock, window_size, sleep_limit, sleep_msec,
reject_limit, concurrent_requests):
self.clock = clock
self.window_size = window_size
self.sleep_limit = sleep_limit
self.sleep_msec = sleep_msec
self.reject_limit = reject_limit
self.concurrent_requests = concurrent_requests
self.sleeping_requests = set()
self.ready_request_queue = collections.OrderedDict()
self.current_processing = set()
self.request_times = []
def is_empty(self):
time_now = self.clock.time_msec()
self.request_times[:] = [
r for r in self.request_times
if time_now - r < self.window_size
]
return not (
self.ready_request_queue
or self.sleeping_requests
or self.current_processing
or self.request_times
)
@contextlib.contextmanager
def ratelimit(self):
# `contextlib.contextmanager` takes a generator and turns it into a
# context manager. The generator should only yield once with a value
# to be returned by manager.
# Exceptions will be reraised at the yield.
request_id = object()
ret = self._on_enter(request_id)
try:
yield ret
finally:
self._on_exit(request_id)
def _on_enter(self, request_id):
time_now = self.clock.time_msec()
self.request_times[:] = [
r for r in self.request_times
if time_now - r < self.window_size
]
queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
if queue_size > self.reject_limit:
raise LimitExceededError(
retry_after_ms=int(
self.window_size / self.sleep_limit
),
)
self.request_times.append(time_now)
def queue_request():
if len(self.current_processing) > self.concurrent_requests:
logger.debug("Ratelimit [%s]: Queue req", id(request_id))
queue_defer = defer.Deferred()
self.ready_request_queue[request_id] = queue_defer
return queue_defer
else:
return defer.succeed(None)
logger.debug(
"Ratelimit [%s]: len(self.request_times)=%d",
id(request_id), len(self.request_times),
)
if len(self.request_times) > self.sleep_limit:
logger.debug(
"Ratelimit [%s]: sleeping req",
id(request_id),
)
ret_defer = sleep(self.sleep_msec/1000.0)
self.sleeping_requests.add(request_id)
def on_wait_finished(_):
logger.debug(
"Ratelimit [%s]: Finished sleeping",
id(request_id),
)
self.sleeping_requests.discard(request_id)
queue_defer = queue_request()
return queue_defer
ret_defer.addBoth(on_wait_finished)
else:
ret_defer = queue_request()
def on_start(r):
logger.debug(
"Ratelimit [%s]: Processing req",
id(request_id),
)
self.current_processing.add(request_id)
return r
def on_err(r):
self.current_processing.discard(request_id)
return r
def on_both(r):
# Ensure that we've properly cleaned up.
self.sleeping_requests.discard(request_id)
self.ready_request_queue.pop(request_id, None)
return r
ret_defer.addCallbacks(on_start, on_err)
ret_defer.addBoth(on_both)
return ret_defer
def _on_exit(self, request_id):
logger.debug(
"Ratelimit [%s]: Processed req",
id(request_id),
)
self.current_processing.discard(request_id)
try:
request_id, deferred = self.ready_request_queue.popitem()
self.current_processing.add(request_id)
deferred.callback(None)
except KeyError:
pass

View File

@ -295,6 +295,9 @@ class PresenceEventStreamTestCase(unittest.TestCase):
self.mock_datastore = hs.get_datastore() self.mock_datastore = hs.get_datastore()
self.mock_datastore.get_app_service_by_token = Mock(return_value=None) self.mock_datastore.get_app_service_by_token = Mock(return_value=None)
self.mock_datastore.get_app_service_by_user_id = Mock(
return_value=defer.succeed(None)
)
def get_profile_displayname(user_id): def get_profile_displayname(user_id):
return defer.succeed("Frank") return defer.succeed("Frank")

110
tests/storage/test__base.py Normal file
View File

@ -0,0 +1,110 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from twisted.internet import defer
from synapse.storage._base import cached
class CacheDecoratorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_passthrough(self):
@cached()
def func(self, key):
return key
self.assertEquals((yield func(self, "foo")), "foo")
self.assertEquals((yield func(self, "bar")), "bar")
@defer.inlineCallbacks
def test_hit(self):
callcount = [0]
@cached()
def func(self, key):
callcount[0] += 1
return key
yield func(self, "foo")
self.assertEquals(callcount[0], 1)
self.assertEquals((yield func(self, "foo")), "foo")
self.assertEquals(callcount[0], 1)
@defer.inlineCallbacks
def test_invalidate(self):
callcount = [0]
@cached()
def func(self, key):
callcount[0] += 1
return key
yield func(self, "foo")
self.assertEquals(callcount[0], 1)
func.invalidate("foo")
yield func(self, "foo")
self.assertEquals(callcount[0], 2)
def test_invalidate_missing(self):
@cached()
def func(self, key):
return key
func.invalidate("what")
@defer.inlineCallbacks
def test_max_entries(self):
callcount = [0]
@cached(max_entries=10)
def func(self, key):
callcount[0] += 1
return key
for k in range(0,12):
yield func(self, k)
self.assertEquals(callcount[0], 12)
# There must have been at least 2 evictions, meaning if we calculate
# all 12 values again, we must get called at least 2 more times
for k in range(0,12):
yield func(self, k)
self.assertTrue(callcount[0] >= 14,
msg="Expected callcount >= 14, got %d" % (callcount[0]))
@defer.inlineCallbacks
def test_prefill(self):
callcount = [0]
@cached()
def func(self, key):
callcount[0] += 1
return key
func.prefill("foo", 123)
self.assertEquals((yield func(self, "foo")), 123)
self.assertEquals(callcount[0], 0)