Merge branch 'release-v0.8.0' of github.com:matrix-org/synapse

This commit is contained in:
Erik Johnston 2015-03-09 14:25:06 +00:00
commit d5174065af
91 changed files with 2906 additions and 922 deletions

1
.gitignore vendored
View File

@ -41,3 +41,4 @@ media_store/
build/ build/
localhost-800*/ localhost-800*/
static/client/register/register_config.js

View File

@ -1,3 +1,40 @@
Changes in synapse v0.8.0 (2015-03-06)
======================================
General:
* Add support for registration fallback. This is a page hosted on the server
which allows a user to register for an account, regardless of what client
they are using (e.g. mobile devices).
* Added new default push rules and made them configurable by clients:
* Suppress all notice messages.
* Notify when invited to a new room.
* Notify for messages that don't match any rule.
* Notify on incoming call.
Federation:
* Added per host server side rate-limiting of incoming federation requests.
* Added a ``/get_missing_events/`` API to federation to reduce number of
``/events/`` requests.
Configuration:
* Added configuration option to disable registration:
``disable_registration``.
* Added configuration option to change soft limit of number of open file
descriptors: ``soft_file_limit``.
* Make ``tls_private_key_path`` optional when running with ``no_tls``.
Application services:
* Application services can now poll on the CS API ``/events`` for their events,
by providing their application service ``access_token``.
* Added exclusive namespace support to application services API.
Changes in synapse v0.7.1 (2015-02-19) Changes in synapse v0.7.1 (2015-02-19)
====================================== ======================================

View File

@ -1,3 +1,18 @@
Upgrading to v0.8.0
===================
Servers which use captchas will need to add their public key to::
static/client/register/register_config.js
window.matrixRegistrationConfig = {
recaptcha_public_key: "YOUR_PUBLIC_KEY"
};
This is required in order to support registration fallback (typically used on
mobile devices).
Upgrading to v0.7.0 Upgrading to v0.7.0
=================== ===================

489
contrib/vertobot/bridge.pl Executable file
View File

@ -0,0 +1,489 @@
#!/usr/bin/env perl
use strict;
use warnings;
use 5.010; # //
use IO::Socket::SSL qw(SSL_VERIFY_NONE);
use IO::Async::Loop;
use Net::Async::WebSocket::Client;
use Net::Async::HTTP;
use Net::Async::HTTP::Server;
use JSON;
use YAML;
use Data::UUID;
use Getopt::Long;
use Data::Dumper;
use URI::Encode qw(uri_encode uri_decode);
binmode STDOUT, ":encoding(UTF-8)";
binmode STDERR, ":encoding(UTF-8)";
my $msisdn_to_matrix = {
'447417892400' => '@matthew:matrix.org',
};
my $matrix_to_msisdn = {};
foreach (keys %$msisdn_to_matrix) {
$matrix_to_msisdn->{$msisdn_to_matrix->{$_}} = $_;
}
my $loop = IO::Async::Loop->new;
# Net::Async::HTTP + SSL + IO::Poll doesn't play well. See
# https://rt.cpan.org/Ticket/Display.html?id=93107
# ref $loop eq "IO::Async::Loop::Poll" and
# warn "Using SSL with IO::Poll causes known memory-leaks!!\n";
GetOptions(
'C|config=s' => \my $CONFIG,
'eval-from=s' => \my $EVAL_FROM,
) or exit 1;
if( defined $EVAL_FROM ) {
# An emergency 'eval() this file' hack
$SIG{HUP} = sub {
my $code = do {
open my $fh, "<", $EVAL_FROM or warn( "Cannot read - $!" ), return;
local $/; <$fh>
};
eval $code or warn "Cannot eval() - $@";
};
}
defined $CONFIG or die "Must supply --config\n";
my %CONFIG = %{ YAML::LoadFile( $CONFIG ) };
my %MATRIX_CONFIG = %{ $CONFIG{matrix} };
# No harm in always applying this
$MATRIX_CONFIG{SSL_verify_mode} = SSL_VERIFY_NONE;
my $bridgestate = {};
my $roomid_by_callid = {};
my $sessid = lc new Data::UUID->create_str();
my $as_token = $CONFIG{"matrix-bot"}->{as_token};
my $hs_domain = $CONFIG{"matrix-bot"}->{domain};
my $http = Net::Async::HTTP->new();
$loop->add( $http );
sub create_virtual_user
{
my ($localpart) = @_;
my ( $response ) = $http->do_request(
method => "POST",
uri => URI->new(
$CONFIG{"matrix"}->{server}.
"/_matrix/client/api/v1/register?".
"access_token=$as_token&user_id=$localpart"
),
content_type => "application/json",
content => <<EOT
{
"type": "m.login.application_service",
"user": "$localpart"
}
EOT
)->get;
warn $response->as_string if ($response->code != 200);
}
my $http_server = Net::Async::HTTP::Server->new(
on_request => sub {
my $self = shift;
my ( $req ) = @_;
my $response;
my $path = uri_decode($req->path);
warn("request: $path");
if ($path =~ m#/users/\@(\+.*)#) {
# when queried about virtual users, auto-create them in the HS
my $localpart = $1;
create_virtual_user($localpart);
$response = HTTP::Response->new( 200 );
$response->add_content('{}');
$response->content_type( "application/json" );
}
elsif ($path =~ m#/transactions/(.*)#) {
my $event = JSON->new->decode($req->body);
print Dumper($event);
my $room_id = $event->{room_id};
my %dp = %{$CONFIG{'verto-dialog-params'}};
$dp{callID} = $bridgestate->{$room_id}->{callid};
if ($event->{type} eq 'm.room.membership') {
my $membership = $event->{content}->{membership};
my $state_key = $event->{state_key};
my $room_id = $event->{state_id};
if ($membership eq 'invite') {
# autojoin invites
my ( $response ) = $http->do_request(
method => "POST",
uri => URI->new(
$CONFIG{"matrix"}->{server}.
"/_matrix/client/api/v1/rooms/$room_id/join?".
"access_token=$as_token&user_id=$state_key"
),
content_type => "application/json",
content => "{}",
)->get;
warn $response->as_string if ($response->code != 200);
}
}
elsif ($event->{type} eq 'm.call.invite') {
my $room_id = $event->{room_id};
$bridgestate->{$room_id}->{matrix_callid} = $event->{content}->{call_id};
$bridgestate->{$room_id}->{callid} = lc new Data::UUID->create_str();
$bridgestate->{$room_id}->{sessid} = $sessid;
# $bridgestate->{$room_id}->{offer} = $event->{content}->{offer}->{sdp};
my $offer = $event->{content}->{offer}->{sdp};
# $bridgestate->{$room_id}->{gathered_candidates} = 0;
$roomid_by_callid->{ $bridgestate->{$room_id}->{callid} } = $room_id;
# no trickle ICE in verto apparently
my $f = send_verto_json_request("verto.invite", {
"sdp" => $offer,
"dialogParams" => \%dp,
"sessid" => $bridgestate->{$room_id}->{sessid},
});
$self->adopt_future($f);
}
# elsif ($event->{type} eq 'm.call.candidates') {
# # XXX: this could fire for both matrix->verto and verto->matrix calls
# # and races as it collects candidates. much better to just turn off
# # candidate gathering in the webclient entirely for now
#
# my $room_id = $event->{room_id};
# # XXX: compare call IDs
# if (!$bridgestate->{$room_id}->{gathered_candidates}) {
# $bridgestate->{$room_id}->{gathered_candidates} = 1;
# my $offer = $bridgestate->{$room_id}->{offer};
# my $candidate_block = "";
# foreach (@{$event->{content}->{candidates}}) {
# $candidate_block .= "a=" . $_->{candidate} . "\r\n";
# }
# # XXX: collate using the right m= line - for now assume audio call
# $offer =~ s/(a=rtcp.*[\r\n]+)/$1$candidate_block/;
#
# my $f = send_verto_json_request("verto.invite", {
# "sdp" => $offer,
# "dialogParams" => \%dp,
# "sessid" => $bridgestate->{$room_id}->{sessid},
# });
# $self->adopt_future($f);
# }
# else {
# # ignore them, as no trickle ICE, although we might as well
# # batch them up
# # foreach (@{$event->{content}->{candidates}}) {
# # push @{$bridgestate->{$room_id}->{candidates}}, $_;
# # }
# }
# }
elsif ($event->{type} eq 'm.call.answer') {
# grab the answer and relay it to verto as a verto.answer
my $room_id = $event->{room_id};
my $answer = $event->{content}->{answer}->{sdp};
my $f = send_verto_json_request("verto.answer", {
"sdp" => $answer,
"dialogParams" => \%dp,
"sessid" => $bridgestate->{$room_id}->{sessid},
});
$self->adopt_future($f);
}
elsif ($event->{type} eq 'm.call.hangup') {
my $room_id = $event->{room_id};
if ($bridgestate->{$room_id}->{matrix_callid} eq $event->{content}->{call_id}) {
my $f = send_verto_json_request("verto.bye", {
"dialogParams" => \%dp,
"sessid" => $bridgestate->{$room_id}->{sessid},
});
$self->adopt_future($f);
}
else {
warn "Ignoring unrecognised callid: ".$event->{content}->{call_id};
}
}
else {
warn "Unhandled event: $event->{type}";
}
$response = HTTP::Response->new( 200 );
$response->add_content('{}');
$response->content_type( "application/json" );
}
else {
warn "Unhandled path: $path";
$response = HTTP::Response->new( 404 );
}
$req->respond( $response );
},
);
$loop->add( $http_server );
$http_server->listen(
addr => { family => "inet", socktype => "stream", port => 8009 },
on_listen_error => sub { die "Cannot listen - $_[-1]\n" },
);
my $bot_verto = Net::Async::WebSocket::Client->new(
on_frame => sub {
my ( $self, $frame ) = @_;
warn "[Verto] receiving $frame";
on_verto_json($frame);
},
);
$loop->add( $bot_verto );
my $verto_connecting = $loop->new_future;
$bot_verto->connect(
%{ $CONFIG{"verto-bot"} },
on_connected => sub {
warn("[Verto] connected to websocket");
if (not $verto_connecting->is_done) {
$verto_connecting->done($bot_verto);
send_verto_json_request("login", {
'login' => $CONFIG{'verto-dialog-params'}{'login'},
'passwd' => $CONFIG{'verto-config'}{'passwd'},
'sessid' => $sessid,
});
}
},
on_connect_error => sub { die "Cannot connect to verto - $_[-1]" },
on_resolve_error => sub { die "Cannot resolve to verto - $_[-1]" },
);
# die Dumper($verto_connecting);
my $as_url = $CONFIG{"matrix-bot"}->{as_url};
Future->needs_all(
$http->do_request(
method => "POST",
uri => URI->new( $CONFIG{"matrix"}->{server}."/_matrix/appservice/v1/register" ),
content_type => "application/json",
content => <<EOT
{
"as_token": "$as_token",
"url": "$as_url",
"namespaces": { "users": ["\@\\\\+.*"] }
}
EOT
),
$verto_connecting,
)->get;
$loop->attach_signal(
PIPE => sub { warn "pipe\n" }
);
$loop->attach_signal(
INT => sub { $loop->stop },
);
$loop->attach_signal(
TERM => sub { $loop->stop },
);
eval {
$loop->run;
} or my $e = $@;
die $e if $e;
exit 0;
{
my $json_id;
my $requests;
sub send_verto_json_request
{
$json_id ||= 1;
my ($method, $params) = @_;
my $json = {
jsonrpc => "2.0",
method => $method,
params => $params,
id => $json_id,
};
my $text = JSON->new->encode( $json );
warn "[Verto] sending $text";
$bot_verto->send_frame ( $text );
my $request = $loop->new_future;
$requests->{$json_id} = $request;
$json_id++;
return $request;
}
sub send_verto_json_response
{
my ($result, $id) = @_;
my $json = {
jsonrpc => "2.0",
result => $result,
id => $id,
};
my $text = JSON->new->encode( $json );
warn "[Verto] sending $text";
$bot_verto->send_frame ( $text );
}
sub on_verto_json
{
my $json = JSON->new->decode( $_[0] );
if ($json->{method}) {
if (($json->{method} eq 'verto.answer' && $json->{params}->{sdp}) ||
$json->{method} eq 'verto.media') {
my $caller = $json->{dialogParams}->{caller_id_number};
my $callee = $json->{dialogParams}->{destination_number};
my $caller_user = '@+' . $caller . ':' . $hs_domain;
my $callee_user = $msisdn_to_matrix->{$callee} || warn "unrecogised callee: $callee";
my $room_id = $roomid_by_callid->{$json->{params}->{callID}};
if ($json->{params}->{sdp}) {
$http->do_request(
method => "POST",
uri => URI->new(
$CONFIG{"matrix"}->{server}.
"/_matrix/client/api/v1/send/m.call.answer?".
"access_token=$as_token&user_id=$caller_user"
),
content_type => "application/json",
content => JSON->new->encode({
call_id => $bridgestate->{$room_id}->{matrix_callid},
version => 0,
answer => {
sdp => $json->{params}->{sdp},
type => "answer",
},
}),
)->then( sub {
send_verto_json_response( {
method => $json->{method},
}, $json->{id});
})->get;
}
}
elsif ($json->{method} eq 'verto.invite') {
my $caller = $json->{dialogParams}->{caller_id_number};
my $callee = $json->{dialogParams}->{destination_number};
my $caller_user = '@+' . $caller . ':' . $hs_domain;
my $callee_user = $msisdn_to_matrix->{$callee} || warn "unrecogised callee: $callee";
my $alias = ($caller lt $callee) ? ($caller.'-'.$callee) : ($callee.'-'.$caller);
my $room_id;
# create a virtual user for the caller if needed.
create_virtual_user($caller);
# create a room of form #peer-peer and invite the callee
$http->do_request(
method => "POST",
uri => URI->new(
$CONFIG{"matrix"}->{server}.
"/_matrix/client/api/v1/createRoom?".
"access_token=$as_token&user_id=$caller_user"
),
content_type => "application/json",
content => JSON->new->encode({
room_alias_name => $alias,
invite => [ $callee_user ],
}),
)->then( sub {
my ( $response ) = @_;
my $resp = JSON->new->decode($response->content);
$room_id = $resp->{room_id};
$roomid_by_callid->{$json->{params}->{callID}} = $room_id;
})->get;
# join it
my ($response) = $http->do_request(
method => "POST",
uri => URI->new(
$CONFIG{"matrix"}->{server}.
"/_matrix/client/api/v1/join/$room_id?".
"access_token=$as_token&user_id=$caller_user"
),
content_type => "application/json",
content => '{}',
)->get;
$bridgestate->{$room_id}->{matrix_callid} = lc new Data::UUID->create_str();
$bridgestate->{$room_id}->{callid} = $json->{dialogParams}->{callID};
$bridgestate->{$room_id}->{sessid} = $sessid;
# put the m.call.invite in there
$http->do_request(
method => "POST",
uri => URI->new(
$CONFIG{"matrix"}->{server}.
"/_matrix/client/api/v1/send/m.call.invite?".
"access_token=$as_token&user_id=$caller_user"
),
content_type => "application/json",
content => JSON->new->encode({
call_id => $bridgestate->{$room_id}->{matrix_callid},
version => 0,
answer => {
sdp => $json->{params}->{sdp},
type => "offer",
},
}),
)->then( sub {
# acknowledge the verto
send_verto_json_response( {
method => $json->{method},
}, $json->{id});
})->get;
}
elsif ($json->{method} eq 'verto.bye') {
my $caller = $json->{dialogParams}->{caller_id_number};
my $callee = $json->{dialogParams}->{destination_number};
my $caller_user = '@+' . $caller . ':' . $hs_domain;
my $callee_user = $msisdn_to_matrix->{$callee} || warn "unrecogised callee: $callee";
my $room_id = $roomid_by_callid->{$json->{params}->{callID}};
# put the m.call.hangup into the room
$http->do_request(
method => "POST",
uri => URI->new(
$CONFIG{"matrix"}->{server}.
"/_matrix/client/api/v1/send/m.call.hangup?".
"access_token=$as_token&user_id=$caller_user"
),
content_type => "application/json",
content => JSON->new->encode({
call_id => $bridgestate->{$room_id}->{matrix_callid},
version => 0,
}),
)->then( sub {
# acknowledge the verto
send_verto_json_response( {
method => $json->{method},
}, $json->{id});
})->get;
}
else {
warn ("[Verto] unhandled method: " . $json->{method});
send_verto_json_response( {
method => $json->{method},
}, $json->{id});
}
}
elsif ($json->{result}) {
$requests->{$json->{id}}->done($json->{result});
}
elsif ($json->{error}) {
$requests->{$json->{id}}->fail($json->{error}->{message}, $json->{error});
}
}
}

View File

@ -81,7 +81,7 @@ Your home server configuration file needs the following extra keys:
As an example, here is the relevant section of the config file for As an example, here is the relevant section of the config file for
matrix.org:: matrix.org::
turn_uris: turn:turn.matrix.org:3478?transport=udp,turn:turn.matrix.org:3478?transport=tcp turn_uris: [ "turn:turn.matrix.org:3478?transport=udp", "turn:turn.matrix.org:3478?transport=tcp" ]
turn_shared_secret: n0t4ctuAllymatr1Xd0TorgSshar3d5ecret4obvIousreAsons turn_shared_secret: n0t4ctuAllymatr1Xd0TorgSshar3d5ecret4obvIousreAsons
turn_user_lifetime: 86400000 turn_user_lifetime: 86400000

View File

@ -0,0 +1,32 @@
<html>
<head>
<title> Registration </title>
<meta name='viewport' content='width=device-width, initial-scale=1, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
<link rel="stylesheet" href="style.css">
<script src="js/jquery-2.1.3.min.js"></script>
<script src="js/recaptcha_ajax.js"></script>
<script src="register_config.js"></script>
<script src="js/register.js"></script>
</head>
<body onload="matrixRegistration.onLoad()">
<form id="registrationForm" onsubmit="matrixRegistration.signUp(); return false;">
<div>
Create account:<br/>
<div style="text-align: center">
<input id="desired_user_id" size="32" type="text" placeholder="Matrix ID (e.g. bob)" autocapitalize="off" autocorrect="off" />
<br/>
<input id="pwd1" size="32" type="password" placeholder="Type a password"/>
<br/>
<input id="pwd2" size="32" type="password" placeholder="Confirm your password"/>
<br/>
<span id="feedback" style="color: #f00"></span>
<br/>
<div id="regcaptcha"></div>
<button type="submit" style="margin: 10px">Sign up</button>
</div>
</div>
</form>
</body>
</html>

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,117 @@
window.matrixRegistration = {
endpoint: location.origin + "/_matrix/client/api/v1/register"
};
var setupCaptcha = function() {
if (!window.matrixRegistrationConfig) {
return;
}
$.get(matrixRegistration.endpoint, function(response) {
var serverExpectsCaptcha = false;
for (var i=0; i<response.flows.length; i++) {
var flow = response.flows[i];
if ("m.login.recaptcha" === flow.type) {
serverExpectsCaptcha = true;
break;
}
}
if (!serverExpectsCaptcha) {
console.log("This server does not require a captcha.");
return;
}
console.log("Setting up ReCaptcha for "+matrixRegistration.endpoint);
var public_key = window.matrixRegistrationConfig.recaptcha_public_key;
if (public_key === undefined) {
console.error("No public key defined for captcha!");
setFeedbackString("Misconfigured captcha for server. Contact server admin.");
return;
}
Recaptcha.create(public_key,
"regcaptcha",
{
theme: "red",
callback: Recaptcha.focus_response_field
});
window.matrixRegistration.isUsingRecaptcha = true;
}).error(errorFunc);
};
var submitCaptcha = function(user, pwd) {
var challengeToken = Recaptcha.get_challenge();
var captchaEntry = Recaptcha.get_response();
var data = {
type: "m.login.recaptcha",
challenge: challengeToken,
response: captchaEntry
};
console.log("Submitting captcha");
$.post(matrixRegistration.endpoint, JSON.stringify(data), function(response) {
console.log("Success -> "+JSON.stringify(response));
submitPassword(user, pwd, response.session);
}).error(function(err) {
Recaptcha.reload();
errorFunc(err);
});
};
var submitPassword = function(user, pwd, session) {
console.log("Registering...");
var data = {
type: "m.login.password",
user: user,
password: pwd,
session: session
};
$.post(matrixRegistration.endpoint, JSON.stringify(data), function(response) {
matrixRegistration.onRegistered(
response.home_server, response.user_id, response.access_token
);
}).error(errorFunc);
};
var errorFunc = function(err) {
if (err.responseJSON && err.responseJSON.error) {
setFeedbackString(err.responseJSON.error + " (" + err.responseJSON.errcode + ")");
}
else {
setFeedbackString("Request failed: " + err.status);
}
};
var setFeedbackString = function(text) {
$("#feedback").text(text);
};
matrixRegistration.onLoad = function() {
setupCaptcha();
};
matrixRegistration.signUp = function() {
var user = $("#desired_user_id").val();
if (user.length == 0) {
setFeedbackString("Must specify a username.");
return;
}
var pwd1 = $("#pwd1").val();
var pwd2 = $("#pwd2").val();
if (pwd1.length < 6) {
setFeedbackString("Password: min. 6 characters.");
return;
}
if (pwd1 != pwd2) {
setFeedbackString("Passwords do not match.");
return;
}
if (window.matrixRegistration.isUsingRecaptcha) {
submitCaptcha(user, pwd1);
}
else {
submitPassword(user, pwd1);
}
};
matrixRegistration.onRegistered = function(hs_url, user_id, access_token) {
// clobber this function
console.log("onRegistered - This function should be replaced to proceed.");
};

View File

@ -0,0 +1,3 @@
window.matrixRegistrationConfig = {
recaptcha_public_key: "YOUR_PUBLIC_KEY"
};

View File

@ -0,0 +1,56 @@
html {
height: 100%;
}
body {
height: 100%;
font-family: "Myriad Pro", "Myriad", Helvetica, Arial, sans-serif;
font-size: 12pt;
margin: 0px;
}
h1 {
font-size: 20pt;
}
a:link { color: #666; }
a:visited { color: #666; }
a:hover { color: #000; }
a:active { color: #000; }
input {
width: 100%
}
textarea, input {
font-family: inherit;
font-size: inherit;
}
.smallPrint {
color: #888;
font-size: 9pt ! important;
font-style: italic ! important;
}
#recaptcha_area {
margin: auto
}
#registrationForm {
text-align: left;
padding: 1em;
margin-bottom: 40px;
display: inline-block;
-webkit-border-radius: 10px;
-moz-border-radius: 10px;
border-radius: 10px;
-webkit-box-shadow: 0px 0px 20px 0px rgba(0,0,0,0.15);
-moz-box-shadow: 0px 0px 20px 0px rgba(0,0,0,0.15);
box-shadow: 0px 0px 20px 0px rgba(0,0,0,0.15);
background-color: #f8f8f8;
border: 1px #ccc solid;
}

View File

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.7.1-r4" __version__ = "0.8.0"

View File

@ -18,6 +18,7 @@
CLIENT_PREFIX = "/_matrix/client/api/v1" CLIENT_PREFIX = "/_matrix/client/api/v1"
CLIENT_V2_ALPHA_PREFIX = "/_matrix/client/v2_alpha" CLIENT_V2_ALPHA_PREFIX = "/_matrix/client/v2_alpha"
FEDERATION_PREFIX = "/_matrix/federation/v1" FEDERATION_PREFIX = "/_matrix/federation/v1"
STATIC_PREFIX = "/_matrix/static"
WEB_CLIENT_PREFIX = "/_matrix/client" WEB_CLIENT_PREFIX = "/_matrix/client"
CONTENT_REPO_PREFIX = "/_matrix/content" CONTENT_REPO_PREFIX = "/_matrix/content"
SERVER_KEY_PREFIX = "/_matrix/key/v1" SERVER_KEY_PREFIX = "/_matrix/key/v1"

View File

@ -17,7 +17,9 @@
import sys import sys
sys.dont_write_bytecode = True sys.dont_write_bytecode = True
from synapse.storage import prepare_database, UpgradeDatabaseException from synapse.storage import (
prepare_database, prepare_sqlite3_database, UpgradeDatabaseException,
)
from synapse.server import HomeServer from synapse.server import HomeServer
@ -36,7 +38,8 @@ from synapse.http.server_key_resource import LocalKey
from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.api.urls import ( from synapse.api.urls import (
CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX, CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX, APP_SERVICE_PREFIX SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX, APP_SERVICE_PREFIX,
STATIC_PREFIX
) )
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory from synapse.crypto import context_factory
@ -52,6 +55,7 @@ import synapse
import logging import logging
import os import os
import re import re
import resource
import subprocess import subprocess
import sqlite3 import sqlite3
import syweb import syweb
@ -81,6 +85,9 @@ class SynapseHomeServer(HomeServer):
webclient_path = os.path.join(syweb_path, "webclient") webclient_path = os.path.join(syweb_path, "webclient")
return File(webclient_path) # TODO configurable? return File(webclient_path) # TODO configurable?
def build_resource_for_static_content(self):
return File("static")
def build_resource_for_content_repo(self): def build_resource_for_content_repo(self):
return ContentRepoResource( return ContentRepoResource(
self, self.upload_dir, self.auth, self.content_addr self, self.upload_dir, self.auth, self.content_addr
@ -124,7 +131,9 @@ class SynapseHomeServer(HomeServer):
(SERVER_KEY_PREFIX, self.get_resource_for_server_key()), (SERVER_KEY_PREFIX, self.get_resource_for_server_key()),
(MEDIA_PREFIX, self.get_resource_for_media_repository()), (MEDIA_PREFIX, self.get_resource_for_media_repository()),
(APP_SERVICE_PREFIX, self.get_resource_for_app_services()), (APP_SERVICE_PREFIX, self.get_resource_for_app_services()),
(STATIC_PREFIX, self.get_resource_for_static_content()),
] ]
if web_client: if web_client:
logger.info("Adding the web client.") logger.info("Adding the web client.")
desired_tree.append((WEB_CLIENT_PREFIX, desired_tree.append((WEB_CLIENT_PREFIX,
@ -140,8 +149,8 @@ class SynapseHomeServer(HomeServer):
# instead, we'll store a copy of this mapping so we can actually add # instead, we'll store a copy of this mapping so we can actually add
# extra resources to existing nodes. See self._resource_id for the key. # extra resources to existing nodes. See self._resource_id for the key.
resource_mappings = {} resource_mappings = {}
for (full_path, resource) in desired_tree: for full_path, res in desired_tree:
logger.info("Attaching %s to path %s", resource, full_path) logger.info("Attaching %s to path %s", res, full_path)
last_resource = self.root_resource last_resource = self.root_resource
for path_seg in full_path.split('/')[1:-1]: for path_seg in full_path.split('/')[1:-1]:
if path_seg not in last_resource.listNames(): if path_seg not in last_resource.listNames():
@ -172,12 +181,12 @@ class SynapseHomeServer(HomeServer):
child_name) child_name)
child_resource = resource_mappings[child_res_id] child_resource = resource_mappings[child_res_id]
# steal the children # steal the children
resource.putChild(child_name, child_resource) res.putChild(child_name, child_resource)
# finally, insert the desired resource in the right place # finally, insert the desired resource in the right place
last_resource.putChild(last_path_seg, resource) last_resource.putChild(last_path_seg, res)
res_id = self._resource_id(last_resource, last_path_seg) res_id = self._resource_id(last_resource, last_path_seg)
resource_mappings[res_id] = resource resource_mappings[res_id] = res
return self.root_resource return self.root_resource
@ -272,6 +281,20 @@ def get_version_string():
return ("Synapse/%s" % (synapse.__version__,)).encode("ascii") return ("Synapse/%s" % (synapse.__version__,)).encode("ascii")
def change_resource_limit(soft_file_no):
try:
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
if not soft_file_no:
soft_file_no = hard
resource.setrlimit(resource.RLIMIT_NOFILE, (soft_file_no, hard))
logger.info("Set file limit to: %d", soft_file_no)
except (ValueError, resource.error) as e:
logger.warn("Failed to set file limit: %s", e)
def setup(): def setup():
config = HomeServerConfig.load_config( config = HomeServerConfig.load_config(
"Synapse Homeserver", "Synapse Homeserver",
@ -317,6 +340,7 @@ def setup():
try: try:
with sqlite3.connect(db_name) as db_conn: with sqlite3.connect(db_name) as db_conn:
prepare_sqlite3_database(db_conn)
prepare_database(db_conn) prepare_database(db_conn)
except UpgradeDatabaseException: except UpgradeDatabaseException:
sys.stderr.write( sys.stderr.write(
@ -348,10 +372,11 @@ def setup():
if config.daemonize: if config.daemonize:
print config.pid_file print config.pid_file
daemon = Daemonize( daemon = Daemonize(
app="synapse-homeserver", app="synapse-homeserver",
pid=config.pid_file, pid=config.pid_file,
action=run, action=lambda: run(config),
auto_close_fds=False, auto_close_fds=False,
verbose=True, verbose=True,
logger=logger, logger=logger,
@ -359,11 +384,13 @@ def setup():
daemon.start() daemon.start()
else: else:
reactor.run() run(config)
def run(): def run(config):
with LoggingContext("run"): with LoggingContext("run"):
change_resource_limit(config.soft_file_limit)
reactor.run() reactor.run()

View File

@ -46,22 +46,34 @@ class ApplicationService(object):
def _check_namespaces(self, namespaces): def _check_namespaces(self, namespaces):
# Sanity check that it is of the form: # Sanity check that it is of the form:
# { # {
# users: ["regex",...], # users: [ {regex: "[A-z]+.*", exclusive: true}, ...],
# aliases: ["regex",...], # aliases: [ {regex: "[A-z]+.*", exclusive: true}, ...],
# rooms: ["regex",...], # rooms: [ {regex: "[A-z]+.*", exclusive: true}, ...],
# } # }
if not namespaces: if not namespaces:
return None return None
for ns in ApplicationService.NS_LIST: for ns in ApplicationService.NS_LIST:
if ns not in namespaces:
namespaces[ns] = []
continue
if type(namespaces[ns]) != list: if type(namespaces[ns]) != list:
raise ValueError("Bad namespace value for '%s'", ns) raise ValueError("Bad namespace value for '%s'" % ns)
for regex in namespaces[ns]: for regex_obj in namespaces[ns]:
if not isinstance(regex, basestring): if not isinstance(regex_obj, dict):
raise ValueError("Expected string regex for ns '%s'", ns) raise ValueError("Expected dict regex for ns '%s'" % ns)
if not isinstance(regex_obj.get("exclusive"), bool):
raise ValueError(
"Expected bool for 'exclusive' in ns '%s'" % ns
)
if not isinstance(regex_obj.get("regex"), basestring):
raise ValueError(
"Expected string for 'regex' in ns '%s'" % ns
)
return namespaces return namespaces
def _matches_regex(self, test_string, namespace_key): def _matches_regex(self, test_string, namespace_key, return_obj=False):
if not isinstance(test_string, basestring): if not isinstance(test_string, basestring):
logger.error( logger.error(
"Expected a string to test regex against, but got %s", "Expected a string to test regex against, but got %s",
@ -69,11 +81,19 @@ class ApplicationService(object):
) )
return False return False
for regex in self.namespaces[namespace_key]: for regex_obj in self.namespaces[namespace_key]:
if re.match(regex, test_string): if re.match(regex_obj["regex"], test_string):
if return_obj:
return regex_obj
return True return True
return False return False
def _is_exclusive(self, ns_key, test_string):
regex_obj = self._matches_regex(test_string, ns_key, return_obj=True)
if regex_obj:
return regex_obj["exclusive"]
return False
def _matches_user(self, event, member_list): def _matches_user(self, event, member_list):
if (hasattr(event, "sender") and if (hasattr(event, "sender") and
self.is_interested_in_user(event.sender)): self.is_interested_in_user(event.sender)):
@ -143,5 +163,14 @@ class ApplicationService(object):
def is_interested_in_room(self, room_id): def is_interested_in_room(self, room_id):
return self._matches_regex(room_id, ApplicationService.NS_ROOMS) return self._matches_regex(room_id, ApplicationService.NS_ROOMS)
def is_exclusive_user(self, user_id):
return self._is_exclusive(ApplicationService.NS_USERS, user_id)
def is_exclusive_alias(self, alias):
return self._is_exclusive(ApplicationService.NS_ALIASES, alias)
def is_exclusive_room(self, room_id):
return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
def __str__(self): def __str__(self):
return "ApplicationService: %s" % (self.__dict__,) return "ApplicationService: %s" % (self.__dict__,)

View File

@ -22,11 +22,12 @@ from .repository import ContentRepositoryConfig
from .captcha import CaptchaConfig from .captcha import CaptchaConfig
from .email import EmailConfig from .email import EmailConfig
from .voip import VoipConfig from .voip import VoipConfig
from .registration import RegistrationConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
EmailConfig, VoipConfig): EmailConfig, VoipConfig, RegistrationConfig,):
pass pass

View File

@ -22,6 +22,12 @@ class RatelimitConfig(Config):
self.rc_messages_per_second = args.rc_messages_per_second self.rc_messages_per_second = args.rc_messages_per_second
self.rc_message_burst_count = args.rc_message_burst_count self.rc_message_burst_count = args.rc_message_burst_count
self.federation_rc_window_size = args.federation_rc_window_size
self.federation_rc_sleep_limit = args.federation_rc_sleep_limit
self.federation_rc_sleep_delay = args.federation_rc_sleep_delay
self.federation_rc_reject_limit = args.federation_rc_reject_limit
self.federation_rc_concurrent = args.federation_rc_concurrent
@classmethod @classmethod
def add_arguments(cls, parser): def add_arguments(cls, parser):
super(RatelimitConfig, cls).add_arguments(parser) super(RatelimitConfig, cls).add_arguments(parser)
@ -34,3 +40,33 @@ class RatelimitConfig(Config):
"--rc-message-burst-count", type=float, default=10, "--rc-message-burst-count", type=float, default=10,
help="number of message a client can send before being throttled" help="number of message a client can send before being throttled"
) )
rc_group.add_argument(
"--federation-rc-window-size", type=int, default=10000,
help="The federation window size in milliseconds",
)
rc_group.add_argument(
"--federation-rc-sleep-limit", type=int, default=10,
help="The number of federation requests from a single server"
" in a window before the server will delay processing the"
" request.",
)
rc_group.add_argument(
"--federation-rc-sleep-delay", type=int, default=500,
help="The duration in milliseconds to delay processing events from"
" remote servers by if they go over the sleep limit.",
)
rc_group.add_argument(
"--federation-rc-reject-limit", type=int, default=50,
help="The maximum number of concurrent federation requests allowed"
" from a single server",
)
rc_group.add_argument(
"--federation-rc-concurrent", type=int, default=3,
help="The number of federation requests to concurrently process"
" from a single server",
)

View File

@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class RegistrationConfig(Config):
def __init__(self, args):
super(RegistrationConfig, self).__init__(args)
self.disable_registration = args.disable_registration
@classmethod
def add_arguments(cls, parser):
super(RegistrationConfig, cls).add_arguments(parser)
reg_group = parser.add_argument_group("registration")
reg_group.add_argument(
"--disable-registration",
action='store_true',
help="Disable registration of new users."
)

View File

@ -30,7 +30,7 @@ class ServerConfig(Config):
self.pid_file = self.abspath(args.pid_file) self.pid_file = self.abspath(args.pid_file)
self.webclient = True self.webclient = True
self.manhole = args.manhole self.manhole = args.manhole
self.no_tls = args.no_tls self.soft_file_limit = args.soft_file_limit
if not args.content_addr: if not args.content_addr:
host = args.server_name host = args.server_name
@ -75,8 +75,12 @@ class ServerConfig(Config):
server_group.add_argument("--content-addr", default=None, server_group.add_argument("--content-addr", default=None,
help="The host and scheme to use for the " help="The host and scheme to use for the "
"content repository") "content repository")
server_group.add_argument("--no-tls", action='store_true', server_group.add_argument("--soft-file-limit", type=int, default=0,
help="Don't bind to the https port.") help="Set the soft limit on the number of "
"file descriptors synapse can use. "
"Zero is used to indicate synapse "
"should set the soft limit to the hard"
"limit.")
def read_signing_key(self, signing_key_path): def read_signing_key(self, signing_key_path):
signing_keys = self.read_file(signing_key_path, "signing_key") signing_keys = self.read_file(signing_key_path, "signing_key")

View File

@ -28,9 +28,16 @@ class TlsConfig(Config):
self.tls_certificate = self.read_tls_certificate( self.tls_certificate = self.read_tls_certificate(
args.tls_certificate_path args.tls_certificate_path
) )
self.no_tls = args.no_tls
if self.no_tls:
self.tls_private_key = None
else:
self.tls_private_key = self.read_tls_private_key( self.tls_private_key = self.read_tls_private_key(
args.tls_private_key_path args.tls_private_key_path
) )
self.tls_dh_params_path = self.check_file( self.tls_dh_params_path = self.check_file(
args.tls_dh_params_path, "tls_dh_params" args.tls_dh_params_path, "tls_dh_params"
) )
@ -45,6 +52,8 @@ class TlsConfig(Config):
help="PEM encoded private key for TLS") help="PEM encoded private key for TLS")
tls_group.add_argument("--tls-dh-params-path", tls_group.add_argument("--tls-dh-params-path",
help="PEM dh parameters for ephemeral keys") help="PEM dh parameters for ephemeral keys")
tls_group.add_argument("--no-tls", action='store_true',
help="Don't bind to the https port.")
def read_tls_certificate(self, cert_path): def read_tls_certificate(self, cert_path):
cert_pem = self.read_file(cert_path, "tls_certificate") cert_pem = self.read_file(cert_path, "tls_certificate")

View File

@ -28,7 +28,7 @@ class VoipConfig(Config):
super(VoipConfig, cls).add_arguments(parser) super(VoipConfig, cls).add_arguments(parser)
group = parser.add_argument_group("voip") group = parser.add_argument_group("voip")
group.add_argument( group.add_argument(
"--turn-uris", type=str, default=None, "--turn-uris", type=str, default=None, action='append',
help="The public URIs of the TURN server to give to clients" help="The public URIs of the TURN server to give to clients"
) )
group.add_argument( group.add_argument(

View File

@ -38,7 +38,10 @@ class ServerContextFactory(ssl.ContextFactory):
logger.exception("Failed to enable eliptic curve for TLS") logger.exception("Failed to enable eliptic curve for TLS")
context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3) context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)
context.use_certificate(config.tls_certificate) context.use_certificate(config.tls_certificate)
if not config.no_tls:
context.use_privatekey(config.tls_private_key) context.use_privatekey(config.tls_private_key)
context.load_tmp_dh(config.tls_dh_params_path) context.load_tmp_dh(config.tls_dh_params_path)
context.set_cipher_list("!ADH:HIGH+kEDH:!AECDH:HIGH+kEECDH") context.set_cipher_list("!ADH:HIGH+kEDH:!AECDH:HIGH+kEECDH")

View File

@ -50,18 +50,27 @@ class Keyring(object):
) )
try: try:
verify_key = yield self.get_server_verify_key(server_name, key_ids) verify_key = yield self.get_server_verify_key(server_name, key_ids)
except IOError: except IOError as e:
logger.warn(
"Got IOError when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError( raise SynapseError(
502, 502,
"Error downloading keys for %s" % (server_name,), "Error downloading keys for %s" % (server_name,),
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
) )
except: except Exception as e:
logger.warn(
"Got Exception when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError( raise SynapseError(
401, 401,
"No key for %s with id %s" % (server_name, key_ids), "No key for %s with id %s" % (server_name, key_ids),
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
) )
try: try:
verify_signed_json(json_object, server_name, verify_key) verify_signed_json(json_object, server_name, verify_key)
except: except:

View File

@ -19,14 +19,18 @@ from twisted.internet import defer
from .federation_base import FederationBase from .federation_base import FederationBase
from .units import Edu from .units import Edu
from synapse.api.errors import CodeMessageException, SynapseError from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError,
)
from synapse.util.expiringcache import ExpiringCache from synapse.util.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
import itertools
import logging import logging
import random
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -439,6 +443,116 @@ class FederationClient(FederationBase):
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks
def get_missing_events(self, destination, room_id, earliest_events_ids,
latest_events, limit, min_depth):
"""Tries to fetch events we are missing. This is called when we receive
an event without having received all of its ancestors.
Args:
destination (str)
room_id (str)
earliest_events_ids (list): List of event ids. Effectively the
events we expected to receive, but haven't. `get_missing_events`
should only return events that didn't happen before these.
latest_events (list): List of events we have received that we don't
have all previous events for.
limit (int): Maximum number of events to return.
min_depth (int): Minimum depth of events tor return.
"""
try:
content = yield self.transport_layer.get_missing_events(
destination=destination,
room_id=room_id,
earliest_events=earliest_events_ids,
latest_events=[e.event_id for e in latest_events],
limit=limit,
min_depth=min_depth,
)
events = [
self.event_from_pdu_json(e)
for e in content.get("events", [])
]
signed_events = yield self._check_sigs_and_hash_and_fetch(
destination, events, outlier=True
)
have_gotten_all_from_destination = True
except HttpResponseException as e:
if not e.code == 400:
raise
# We are probably hitting an old server that doesn't support
# get_missing_events
signed_events = []
have_gotten_all_from_destination = False
if len(signed_events) >= limit:
defer.returnValue(signed_events)
servers = yield self.store.get_joined_hosts_for_room(room_id)
servers = set(servers)
servers.discard(self.server_name)
failed_to_fetch = set()
while len(signed_events) < limit:
# Are we missing any?
seen_events = set(earliest_events_ids)
seen_events.update(e.event_id for e in signed_events)
missing_events = {}
for e in itertools.chain(latest_events, signed_events):
if e.depth > min_depth:
missing_events.update({
e_id: e.depth for e_id, _ in e.prev_events
if e_id not in seen_events
and e_id not in failed_to_fetch
})
if not missing_events:
break
have_seen = yield self.store.have_events(missing_events)
for k in have_seen:
missing_events.pop(k, None)
if not missing_events:
break
# Okay, we haven't gotten everything yet. Lets get them.
ordered_missing = sorted(missing_events.items(), key=lambda x: x[0])
if have_gotten_all_from_destination:
servers.discard(destination)
def random_server_list():
srvs = list(servers)
random.shuffle(srvs)
return srvs
deferreds = [
self.get_pdu(
destinations=random_server_list(),
event_id=e_id,
)
for e_id, depth in ordered_missing[:limit - len(signed_events)]
]
res = yield defer.DeferredList(deferreds, consumeErrors=True)
for (result, val), (e_id, _) in zip(res, ordered_missing):
if result:
signed_events.append(val)
else:
failed_to_fetch.add(e_id)
defer.returnValue(signed_events)
def event_from_pdu_json(self, pdu_json, outlier=False): def event_from_pdu_json(self, pdu_json, outlier=False):
event = FrozenEvent( event = FrozenEvent(
pdu_json pdu_json

View File

@ -112,17 +112,20 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id) logger.debug("[%s] Transaction is new", transaction.transaction_id)
with PreserveLoggingContext(): with PreserveLoggingContext():
dl = [] results = []
for pdu in pdu_list: for pdu in pdu_list:
d = self._handle_new_pdu(transaction.origin, pdu) d = self._handle_new_pdu(transaction.origin, pdu)
def handle_failure(failure): try:
failure.trap(FederationError) yield d
self.send_failure(failure.value, transaction.origin) results.append({})
except FederationError as e:
d.addErrback(handle_failure) self.send_failure(e, transaction.origin)
results.append({"error": str(e)})
dl.append(d) except Exception as e:
results.append({"error": str(e)})
logger.exception("Failed to handle PDU")
if hasattr(transaction, "edus"): if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]: for edu in [Edu(**x) for x in transaction.edus]:
@ -135,21 +138,11 @@ class FederationServer(FederationBase):
for failure in getattr(transaction, "pdu_failures", []): for failure in getattr(transaction, "pdu_failures", []):
logger.info("Got failure %r", failure) logger.info("Got failure %r", failure)
results = yield defer.DeferredList(dl, consumeErrors=True) logger.debug("Returning: %s", str(results))
ret = []
for r in results:
if r[0]:
ret.append({})
else:
logger.exception(r[1])
ret.append({"error": str(r[1].value)})
logger.debug("Returning: %s", str(ret))
response = { response = {
"pdus": dict(zip( "pdus": dict(zip(
(p.event_id for p in pdu_list), ret (p.event_id for p in pdu_list), results
)), )),
} }
@ -305,6 +298,20 @@ class FederationServer(FederationBase):
(200, send_content) (200, send_content)
) )
@defer.inlineCallbacks
@log_function
def on_get_missing_events(self, origin, room_id, earliest_events,
latest_events, limit, min_depth):
missing_events = yield self.handler.on_get_missing_events(
origin, room_id, earliest_events, latest_events, limit, min_depth
)
time_now = self._clock.time_msec()
defer.returnValue({
"events": [ev.get_pdu_json(time_now) for ev in missing_events],
})
@log_function @log_function
def _get_persisted_pdu(self, origin, event_id, do_auth=True): def _get_persisted_pdu(self, origin, event_id, do_auth=True):
""" Get a PDU from the database with given origin and id. """ Get a PDU from the database with given origin and id.
@ -331,7 +338,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def _handle_new_pdu(self, origin, pdu, max_recursion=10): def _handle_new_pdu(self, origin, pdu, get_missing=True):
# We reprocess pdus when we have seen them only as outliers # We reprocess pdus when we have seen them only as outliers
existing = yield self._get_persisted_pdu( existing = yield self._get_persisted_pdu(
origin, pdu.event_id, do_auth=False origin, pdu.event_id, do_auth=False
@ -383,48 +390,54 @@ class FederationServer(FederationBase):
pdu.room_id, min_depth pdu.room_id, min_depth
) )
prevs = {e_id for e_id, _ in pdu.prev_events}
seen = set(have_seen.keys())
if min_depth and pdu.depth < min_depth: if min_depth and pdu.depth < min_depth:
# This is so that we don't notify the user about this # This is so that we don't notify the user about this
# message, to work around the fact that some events will # message, to work around the fact that some events will
# reference really really old events we really don't want to # reference really really old events we really don't want to
# send to the clients. # send to the clients.
pdu.internal_metadata.outlier = True pdu.internal_metadata.outlier = True
elif min_depth and pdu.depth > min_depth and max_recursion > 0: elif min_depth and pdu.depth > min_depth:
for event_id, hashes in pdu.prev_events: if get_missing and prevs - seen:
if event_id not in have_seen: latest_tuples = yield self.store.get_latest_events_in_room(
logger.debug( pdu.room_id
"_handle_new_pdu requesting pdu %s",
event_id
) )
try: # We add the prev events that we have seen to the latest
new_pdu = yield self.federation_client.get_pdu( # list to ensure the remote server doesn't give them to us
[origin, pdu.origin], latest = set(e_id for e_id, _, _ in latest_tuples)
event_id=event_id, latest |= seen
missing_events = yield self.get_missing_events(
origin,
pdu.room_id,
earliest_events_ids=list(latest),
latest_events=[pdu],
limit=10,
min_depth=min_depth,
) )
if new_pdu: # We want to sort these by depth so we process them and
# tell clients about them in order.
missing_events.sort(key=lambda x: x.depth)
for e in missing_events:
yield self._handle_new_pdu( yield self._handle_new_pdu(
origin, origin,
new_pdu, e,
max_recursion=max_recursion-1 get_missing=False
)
have_seen = yield self.store.have_events(
[ev for ev, _ in pdu.prev_events]
) )
logger.debug("Processed pdu %s", event_id)
else:
logger.warn("Failed to get PDU %s", event_id)
fetch_state = True
except:
# TODO(erikj): Do some more intelligent retries.
logger.exception("Failed to get PDU")
fetch_state = True
else:
prevs = {e_id for e_id, _ in pdu.prev_events} prevs = {e_id for e_id, _ in pdu.prev_events}
seen = set(have_seen.keys()) seen = set(have_seen.keys())
if prevs - seen: if prevs - seen:
fetch_state = True fetch_state = True
else:
fetch_state = True
if fetch_state: if fetch_state:
# We need to get the state at this event, since we haven't # We need to get the state at this event, since we haven't

View File

@ -224,6 +224,8 @@ class TransactionQueue(object):
] ]
try: try:
self.pending_transactions[destination] = 1
limiter = yield get_retry_limiter( limiter = yield get_retry_limiter(
destination, destination,
self._clock, self._clock,
@ -239,8 +241,6 @@ class TransactionQueue(object):
len(pending_failures) len(pending_failures)
) )
self.pending_transactions[destination] = 1
logger.debug("TX [%s] Persisting transaction...", destination) logger.debug("TX [%s] Persisting transaction...", destination)
transaction = Transaction.create_new( transaction = Transaction.create_new(
@ -287,7 +287,7 @@ class TransactionQueue(object):
code = 200 code = 200
if response: if response:
for e_id, r in getattr(response, "pdus", {}).items(): for e_id, r in response.get("pdus", {}).items():
if "error" in r: if "error" in r:
logger.warn( logger.warn(
"Transaction returned error for %s: %s", "Transaction returned error for %s: %s",

View File

@ -24,6 +24,8 @@ communicate over a different (albeit still reliable) protocol.
from .server import TransportLayerServer from .server import TransportLayerServer
from .client import TransportLayerClient from .client import TransportLayerClient
from synapse.util.ratelimitutils import FederationRateLimiter
class TransportLayer(TransportLayerServer, TransportLayerClient): class TransportLayer(TransportLayerServer, TransportLayerClient):
"""This is a basic implementation of the transport layer that translates """This is a basic implementation of the transport layer that translates
@ -55,8 +57,18 @@ class TransportLayer(TransportLayerServer, TransportLayerClient):
send requests send requests
""" """
self.keyring = homeserver.get_keyring() self.keyring = homeserver.get_keyring()
self.clock = homeserver.get_clock()
self.server_name = server_name self.server_name = server_name
self.server = server self.server = server
self.client = client self.client = client
self.request_handler = None self.request_handler = None
self.received_handler = None self.received_handler = None
self.ratelimiter = FederationRateLimiter(
self.clock,
window_size=homeserver.config.federation_rc_window_size,
sleep_limit=homeserver.config.federation_rc_sleep_limit,
sleep_msec=homeserver.config.federation_rc_sleep_delay,
reject_limit=homeserver.config.federation_rc_reject_limit,
concurrent_requests=homeserver.config.federation_rc_concurrent,
)

View File

@ -219,3 +219,22 @@ class TransportLayerClient(object):
) )
defer.returnValue(content) defer.returnValue(content)
@defer.inlineCallbacks
@log_function
def get_missing_events(self, destination, room_id, earliest_events,
latest_events, limit, min_depth):
path = PREFIX + "/get_missing_events/%s" % (room_id,)
content = yield self.client.post_json(
destination=destination,
path=path,
data={
"limit": int(limit),
"min_depth": int(min_depth),
"earliest_events": earliest_events,
"latest_events": latest_events,
}
)
defer.returnValue(content)

View File

@ -98,6 +98,8 @@ class TransportLayerServer(object):
def new_handler(request, *args, **kwargs): def new_handler(request, *args, **kwargs):
try: try:
(origin, content) = yield self._authenticate_request(request) (origin, content) = yield self._authenticate_request(request)
with self.ratelimiter.ratelimit(origin) as d:
yield d
response = yield handler( response = yield handler(
origin, content, request.args, *args, **kwargs origin, content, request.args, *args, **kwargs
) )
@ -107,6 +109,12 @@ class TransportLayerServer(object):
defer.returnValue(response) defer.returnValue(response)
return new_handler return new_handler
def rate_limit_origin(self, handler):
def new_handler(origin, *args, **kwargs):
response = yield handler(origin, *args, **kwargs)
defer.returnValue(response)
return new_handler()
@log_function @log_function
def register_received_handler(self, handler): def register_received_handler(self, handler):
""" Register a handler that will be fired when we receive data. """ Register a handler that will be fired when we receive data.
@ -234,6 +242,7 @@ class TransportLayerServer(object):
) )
) )
) )
self.server.register_path( self.server.register_path(
"POST", "POST",
re.compile("^" + PREFIX + "/query_auth/([^/]*)/([^/]*)$"), re.compile("^" + PREFIX + "/query_auth/([^/]*)/([^/]*)$"),
@ -245,6 +254,17 @@ class TransportLayerServer(object):
) )
) )
self.server.register_path(
"POST",
re.compile("^" + PREFIX + "/get_missing_events/([^/]*)/?$"),
self._with_authentication(
lambda origin, content, query, room_id:
self._get_missing_events(
origin, content, room_id,
)
)
)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def _on_send_request(self, origin, content, query, transaction_id): def _on_send_request(self, origin, content, query, transaction_id):
@ -344,3 +364,22 @@ class TransportLayerServer(object):
) )
defer.returnValue((200, new_content)) defer.returnValue((200, new_content))
@defer.inlineCallbacks
@log_function
def _get_missing_events(self, origin, content, room_id):
limit = int(content.get("limit", 10))
min_depth = int(content.get("min_depth", 0))
earliest_events = content.get("earliest_events", [])
latest_events = content.get("latest_events", [])
content = yield self.request_handler.on_get_missing_events(
origin,
room_id=room_id,
earliest_events=earliest_events,
latest_events=latest_events,
min_depth=min_depth,
limit=limit,
)
defer.returnValue((200, content))

View File

@ -160,7 +160,7 @@ class DirectoryHandler(BaseHandler):
if not room_id: if not room_id:
raise SynapseError( raise SynapseError(
404, 404,
"Room alias %r not found" % (room_alias.to_string(),), "Room alias %s not found" % (room_alias.to_string(),),
Codes.NOT_FOUND Codes.NOT_FOUND
) )
@ -232,13 +232,23 @@ class DirectoryHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def can_modify_alias(self, alias, user_id=None): def can_modify_alias(self, alias, user_id=None):
# Any application service "interested" in an alias they are regexing on
# can modify the alias.
# Users can only modify the alias if ALL the interested services have
# non-exclusive locks on the alias (or there are no interested services)
services = yield self.store.get_app_services() services = yield self.store.get_app_services()
interested_services = [ interested_services = [
s for s in services if s.is_interested_in_alias(alias.to_string()) s for s in services if s.is_interested_in_alias(alias.to_string())
] ]
for service in interested_services: for service in interested_services:
if user_id == service.sender: if user_id == service.sender:
# this user IS the app service # this user IS the app service so they can do whatever they like
defer.returnValue(True) defer.returnValue(True)
return return
defer.returnValue(len(interested_services) == 0) elif service.is_exclusive_alias(alias.to_string()):
# another service has an exclusive lock on this alias.
defer.returnValue(False)
return
# either no interested services, or no service with an exclusive lock
defer.returnValue(True)

View File

@ -23,6 +23,7 @@ from synapse.events.utils import serialize_event
from ._base import BaseHandler from ._base import BaseHandler
import logging import logging
import random
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -69,12 +70,17 @@ class EventStreamHandler(BaseHandler):
) )
self._streams_per_user[auth_user] += 1 self._streams_per_user[auth_user] += 1
if pagin_config.from_token is None:
pagin_config.from_token = None
rm_handler = self.hs.get_handlers().room_member_handler rm_handler = self.hs.get_handlers().room_member_handler
room_ids = yield rm_handler.get_rooms_for_user(auth_user) room_ids = yield rm_handler.get_rooms_for_user(auth_user)
if timeout:
# If they've set a timeout set a minimum limit.
timeout = max(timeout, 500)
# Add some randomness to this value to try and mitigate against
# thundering herds on restart.
timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
with PreserveLoggingContext(): with PreserveLoggingContext():
events, tokens = yield self.notifier.get_events_for( events, tokens = yield self.notifier.get_events_for(
auth_user, room_ids, pagin_config, timeout auth_user, room_ids, pagin_config, timeout

View File

@ -581,9 +581,10 @@ class FederationHandler(BaseHandler):
defer.returnValue(event) defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_for_pdu(self, origin, room_id, event_id): def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True):
yield run_on_reactor() yield run_on_reactor()
if do_auth:
in_room = yield self.auth.check_host_in_room(room_id, origin) in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room: if not in_room:
raise AuthError(403, "Host not in room.") raise AuthError(403, "Host not in room.")
@ -788,6 +789,29 @@ class FederationHandler(BaseHandler):
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks
def on_get_missing_events(self, origin, room_id, earliest_events,
latest_events, limit, min_depth):
in_room = yield self.auth.check_host_in_room(
room_id,
origin
)
if not in_room:
raise AuthError(403, "Host not in room.")
limit = min(limit, 20)
min_depth = max(min_depth, 0)
missing_events = yield self.store.get_missing_events(
room_id=room_id,
earliest_events=earliest_events,
latest_events=latest_events,
limit=limit,
min_depth=min_depth,
)
defer.returnValue(missing_events)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def do_auth(self, origin, event, context, auth_events): def do_auth(self, origin, event, context, auth_events):

View File

@ -212,6 +212,7 @@ class ProfileHandler(BaseHandler):
) )
msg_handler = self.hs.get_handlers().message_handler msg_handler = self.hs.get_handlers().message_handler
try:
yield msg_handler.create_and_send_event({ yield msg_handler.create_and_send_event({
"type": EventTypes.Member, "type": EventTypes.Member,
"room_id": j.room_id, "room_id": j.room_id,
@ -219,3 +220,8 @@ class ProfileHandler(BaseHandler):
"content": content, "content": content,
"sender": user.to_string() "sender": user.to_string()
}, ratelimit=False) }, ratelimit=False)
except Exception as e:
logger.warn(
"Failed to update join event for room %s - %s",
j.room_id, str(e.message)
)

View File

@ -201,7 +201,8 @@ class RegistrationHandler(BaseHandler):
interested_services = [ interested_services = [
s for s in services if s.is_interested_in_user(user_id) s for s in services if s.is_interested_in_user(user_id)
] ]
if len(interested_services) > 0: for service in interested_services:
if service.is_exclusive_user(user_id):
raise SynapseError( raise SynapseError(
400, "This user ID is reserved by an application service.", 400, "This user ID is reserved by an application service.",
errcode=Codes.EXCLUSIVE errcode=Codes.EXCLUSIVE

View File

@ -510,6 +510,13 @@ class RoomMemberHandler(BaseHandler):
def get_rooms_for_user(self, user, membership_list=[Membership.JOIN]): def get_rooms_for_user(self, user, membership_list=[Membership.JOIN]):
"""Returns a list of roomids that the user has any of the given """Returns a list of roomids that the user has any of the given
membership states in.""" membership states in."""
app_service = yield self.store.get_app_service_by_user_id(
user.to_string()
)
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
else:
rooms = yield self.store.get_rooms_for_user_where_membership_is( rooms = yield self.store.get_rooms_for_user_where_membership_is(
user_id=user.to_string(), membership_list=membership_list user_id=user.to_string(), membership_list=membership_list
) )
@ -559,6 +566,17 @@ class RoomEventSource(object):
to_key = yield self.get_current_key() to_key = yield self.get_current_key()
app_service = yield self.store.get_app_service_by_user_id(
user.to_string()
)
if app_service:
events, end_key = yield self.store.get_appservice_room_stream(
service=app_service,
from_key=from_key,
to_key=to_key,
limit=limit,
)
else:
events, end_key = yield self.store.get_room_events_stream( events, end_key = yield self.store.get_room_events_stream(
user_id=user.to_string(), user_id=user.to_string(),
from_key=from_key, from_key=from_key,

View File

@ -143,7 +143,7 @@ class SimpleHttpClient(object):
query_bytes = urllib.urlencode(args, True) query_bytes = urllib.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes) uri = "%s?%s" % (uri, query_bytes)
json_str = json.dumps(json_body) json_str = encode_canonical_json(json_body)
response = yield self.agent.request( response = yield self.agent.request(
"PUT", "PUT",

View File

@ -124,7 +124,9 @@ class JsonResource(HttpServer, resource.Resource):
# and path regex match # and path regex match
for path_entry in self.path_regexs.get(request.method, []): for path_entry in self.path_regexs.get(request.method, []):
m = path_entry.pattern.match(request.path) m = path_entry.pattern.match(request.path)
if m: if not m:
continue
# We found a match! Trigger callback and then return the # We found a match! Trigger callback and then return the
# returned response. We pass both the request and any # returned response. We pass both the request and any
# matched groups from the regex to the callback. # matched groups from the regex to the callback.

View File

@ -36,8 +36,10 @@ class _NotificationListener(object):
so that it can remove itself from the indexes in the Notifier class. so that it can remove itself from the indexes in the Notifier class.
""" """
def __init__(self, user, rooms, from_token, limit, timeout, deferred): def __init__(self, user, rooms, from_token, limit, timeout, deferred,
appservice=None):
self.user = user self.user = user
self.appservice = appservice
self.from_token = from_token self.from_token = from_token
self.limit = limit self.limit = limit
self.timeout = timeout self.timeout = timeout
@ -61,10 +63,14 @@ class _NotificationListener(object):
pass pass
for room in self.rooms: for room in self.rooms:
lst = notifier.rooms_to_listeners.get(room, set()) lst = notifier.room_to_listeners.get(room, set())
lst.discard(self) lst.discard(self)
notifier.user_to_listeners.get(self.user, set()).discard(self) notifier.user_to_listeners.get(self.user, set()).discard(self)
if self.appservice:
notifier.appservice_to_listeners.get(
self.appservice, set()
).discard(self)
class Notifier(object): class Notifier(object):
@ -77,8 +83,9 @@ class Notifier(object):
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
self.rooms_to_listeners = {} self.room_to_listeners = {}
self.user_to_listeners = {} self.user_to_listeners = {}
self.appservice_to_listeners = {}
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
@ -109,11 +116,22 @@ class Notifier(object):
room_source = self.event_sources.sources["room"] room_source = self.event_sources.sources["room"]
listeners = self.rooms_to_listeners.get(room_id, set()).copy() listeners = self.room_to_listeners.get(room_id, set()).copy()
for user in extra_users: for user in extra_users:
listeners |= self.user_to_listeners.get(user, set()).copy() listeners |= self.user_to_listeners.get(user, set()).copy()
for appservice in self.appservice_to_listeners:
# TODO (kegan): Redundant appservice listener checks?
# App services will already be in the room_to_listeners set, but
# that isn't enough. They need to be checked here in order to
# receive *invites* for users they are interested in. Does this
# make the room_to_listeners check somewhat obselete?
if appservice.is_interested(event):
listeners |= self.appservice_to_listeners.get(
appservice, set()
).copy()
logger.debug("on_new_room_event listeners %s", listeners) logger.debug("on_new_room_event listeners %s", listeners)
# TODO (erikj): Can we make this more efficient by hitting the # TODO (erikj): Can we make this more efficient by hitting the
@ -166,7 +184,7 @@ class Notifier(object):
listeners |= self.user_to_listeners.get(user, set()).copy() listeners |= self.user_to_listeners.get(user, set()).copy()
for room in rooms: for room in rooms:
listeners |= self.rooms_to_listeners.get(room, set()).copy() listeners |= self.room_to_listeners.get(room, set()).copy()
@defer.inlineCallbacks @defer.inlineCallbacks
def notify(listener): def notify(listener):
@ -280,6 +298,10 @@ class Notifier(object):
if not from_token: if not from_token:
from_token = yield self.event_sources.get_current_token() from_token = yield self.event_sources.get_current_token()
appservice = yield self.hs.get_datastore().get_app_service_by_user_id(
user.to_string()
)
listener = _NotificationListener( listener = _NotificationListener(
user, user,
rooms, rooms,
@ -287,6 +309,7 @@ class Notifier(object):
limit, limit,
timeout, timeout,
deferred, deferred,
appservice=appservice
) )
def _timeout_listener(): def _timeout_listener():
@ -314,11 +337,16 @@ class Notifier(object):
@log_function @log_function
def _register_with_keys(self, listener): def _register_with_keys(self, listener):
for room in listener.rooms: for room in listener.rooms:
s = self.rooms_to_listeners.setdefault(room, set()) s = self.room_to_listeners.setdefault(room, set())
s.add(listener) s.add(listener)
self.user_to_listeners.setdefault(listener.user, set()).add(listener) self.user_to_listeners.setdefault(listener.user, set()).add(listener)
if listener.appservice:
self.appservice_to_listeners.setdefault(
listener.appservice, set()
).add(listener)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def _check_for_updates(self, listener): def _check_for_updates(self, listener):
@ -352,5 +380,5 @@ class Notifier(object):
def _user_joined_room(self, user, room_id): def _user_joined_room(self, user, room_id):
new_listeners = self.user_to_listeners.get(user, set()) new_listeners = self.user_to_listeners.get(user, set())
listeners = self.rooms_to_listeners.setdefault(room_id, set()) listeners = self.room_to_listeners.setdefault(room_id, set())
listeners |= new_listeners listeners |= new_listeners

View File

@ -32,7 +32,7 @@ class Pusher(object):
INITIAL_BACKOFF = 1000 INITIAL_BACKOFF = 1000
MAX_BACKOFF = 60 * 60 * 1000 MAX_BACKOFF = 60 * 60 * 1000
GIVE_UP_AFTER = 24 * 60 * 60 * 1000 GIVE_UP_AFTER = 24 * 60 * 60 * 1000
DEFAULT_ACTIONS = ['notify'] DEFAULT_ACTIONS = ['dont-notify']
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
@ -72,16 +72,14 @@ class Pusher(object):
# let's assume you probably know about messages you sent yourself # let's assume you probably know about messages you sent yourself
defer.returnValue(['dont_notify']) defer.returnValue(['dont_notify'])
if ev['type'] == 'm.room.member': rawrules = yield self.store.get_push_rules_for_user(self.user_name)
if ev['state_key'] != self.user_name:
defer.returnValue(['dont_notify'])
rawrules = yield self.store.get_push_rules_for_user_name(self.user_name)
for r in rawrules: for r in rawrules:
r['conditions'] = json.loads(r['conditions']) r['conditions'] = json.loads(r['conditions'])
r['actions'] = json.loads(r['actions']) r['actions'] = json.loads(r['actions'])
enabled_map = yield self.store.get_push_rules_enabled_for_user(self.user_name)
user = UserID.from_string(self.user_name) user = UserID.from_string(self.user_name)
rules = baserules.list_with_base_rules(rawrules, user) rules = baserules.list_with_base_rules(rawrules, user)
@ -107,6 +105,8 @@ class Pusher(object):
room_member_count += 1 room_member_count += 1
for r in rules: for r in rules:
if r['rule_id'] in enabled_map and not enabled_map[r['rule_id']]:
continue
matches = True matches = True
conditions = r['conditions'] conditions = r['conditions']
@ -117,7 +117,11 @@ class Pusher(object):
ev, c, display_name=my_display_name, ev, c, display_name=my_display_name,
room_member_count=room_member_count room_member_count=room_member_count
) )
# ignore rules with no actions (we have an explict 'dont_notify' logger.debug(
"Rule %s %s",
r['rule_id'], "matches" if matches else "doesn't match"
)
# ignore rules with no actions (we have an explict 'dont_notify')
if len(actions) == 0: if len(actions) == 0:
logger.warn( logger.warn(
"Ignoring rule id %s with no actions for user %s" % "Ignoring rule id %s with no actions for user %s" %

View File

@ -32,12 +32,14 @@ def make_base_rules(user, kind):
if kind == 'override': if kind == 'override':
rules = make_base_override_rules() rules = make_base_override_rules()
elif kind == 'underride':
rules = make_base_underride_rules(user)
elif kind == 'content': elif kind == 'content':
rules = make_base_content_rules(user) rules = make_base_content_rules(user)
for r in rules: for r in rules:
r['priority_class'] = PRIORITY_CLASS_MAP[kind] r['priority_class'] = PRIORITY_CLASS_MAP[kind]
r['default'] = True r['default'] = True # Deprecated, left for backwards compat
return rules return rules
@ -45,6 +47,7 @@ def make_base_rules(user, kind):
def make_base_content_rules(user): def make_base_content_rules(user):
return [ return [
{ {
'rule_id': 'global/content/.m.rule.contains_user_name',
'conditions': [ 'conditions': [
{ {
'kind': 'event_match', 'kind': 'event_match',
@ -57,6 +60,8 @@ def make_base_content_rules(user):
{ {
'set_tweak': 'sound', 'set_tweak': 'sound',
'value': 'default', 'value': 'default',
}, {
'set_tweak': 'highlight'
} }
] ]
}, },
@ -66,6 +71,40 @@ def make_base_content_rules(user):
def make_base_override_rules(): def make_base_override_rules():
return [ return [
{ {
'rule_id': 'global/override/.m.rule.call',
'conditions': [
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.call.invite',
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'ring'
}, {
'set_tweak': 'highlight',
'value': 'false'
}
]
},
{
'rule_id': 'global/override/.m.rule.suppress_notices',
'conditions': [
{
'kind': 'event_match',
'key': 'content.msgtype',
'pattern': 'm.notice',
}
],
'actions': [
'dont_notify',
]
},
{
'rule_id': 'global/override/.m.rule.contains_display_name',
'conditions': [ 'conditions': [
{ {
'kind': 'contains_display_name' 'kind': 'contains_display_name'
@ -76,10 +115,13 @@ def make_base_override_rules():
{ {
'set_tweak': 'sound', 'set_tweak': 'sound',
'value': 'default' 'value': 'default'
}, {
'set_tweak': 'highlight'
} }
] ]
}, },
{ {
'rule_id': 'global/override/.m.rule.room_one_to_one',
'conditions': [ 'conditions': [
{ {
'kind': 'room_member_count', 'kind': 'room_member_count',
@ -91,6 +133,76 @@ def make_base_override_rules():
{ {
'set_tweak': 'sound', 'set_tweak': 'sound',
'value': 'default' 'value': 'default'
}, {
'set_tweak': 'highlight',
'value': 'false'
}
]
}
]
def make_base_underride_rules(user):
return [
{
'rule_id': 'global/underride/.m.rule.invite_for_me',
'conditions': [
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.room.member',
},
{
'kind': 'event_match',
'key': 'content.membership',
'pattern': 'invite',
},
{
'kind': 'event_match',
'key': 'state_key',
'pattern': user.to_string(),
},
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}, {
'set_tweak': 'highlight',
'value': 'false'
}
]
},
{
'rule_id': 'global/underride/.m.rule.member_event',
'conditions': [
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.room.member',
}
],
'actions': [
'notify', {
'set_tweak': 'highlight',
'value': 'false'
}
]
},
{
'rule_id': 'global/underride/.m.rule.message',
'conditions': [
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.room.message',
}
],
'actions': [
'notify', {
'set_tweak': 'highlight',
'value': 'false'
} }
] ]
} }

View File

@ -88,6 +88,7 @@ class HttpPusher(Pusher):
} }
if event['type'] == 'm.room.member': if event['type'] == 'm.room.member':
d['notification']['membership'] = event['content']['membership'] d['notification']['membership'] = event['content']['membership']
d['notification']['user_is_target'] = event['state_key'] == self.user_name
if 'content' in event: if 'content' in event:
d['notification']['content'] = event['content'] d['notification']['content'] = event['content']
@ -108,7 +109,7 @@ class HttpPusher(Pusher):
try: try:
resp = yield self.httpCli.post_json_get_json(self.url, notification_dict) resp = yield self.httpCli.post_json_get_json(self.url, notification_dict)
except: except:
logger.exception("Failed to push %s ", self.url) logger.warn("Failed to push %s ", self.url)
defer.returnValue(False) defer.returnValue(False)
rejected = [] rejected = []
if 'rejected' in resp: if 'rejected' in resp:

View File

@ -48,18 +48,12 @@ class RegisterRestServlet(AppServiceRestServlet):
400, "Missed required keys: as_token(str) / url(str)." 400, "Missed required keys: as_token(str) / url(str)."
) )
namespaces = { try:
"users": [], app_service = ApplicationService(
"rooms": [], as_token, as_url, params["namespaces"]
"aliases": [] )
} except ValueError as e:
raise SynapseError(400, e.message)
if "namespaces" in params:
self._parse_namespace(namespaces, params["namespaces"], "users")
self._parse_namespace(namespaces, params["namespaces"], "rooms")
self._parse_namespace(namespaces, params["namespaces"], "aliases")
app_service = ApplicationService(as_token, as_url, namespaces)
app_service = yield self.handler.register(app_service) app_service = yield self.handler.register(app_service)
hs_token = app_service.hs_token hs_token = app_service.hs_token
@ -68,23 +62,6 @@ class RegisterRestServlet(AppServiceRestServlet):
"hs_token": hs_token "hs_token": hs_token
})) }))
def _parse_namespace(self, target_ns, origin_ns, ns):
if ns not in target_ns or ns not in origin_ns:
return # nothing to parse / map through to.
possible_regex_list = origin_ns[ns]
if not type(possible_regex_list) == list:
raise SynapseError(400, "Namespace %s isn't an array." % ns)
for regex in possible_regex_list:
if not isinstance(regex, basestring):
raise SynapseError(
400, "Regex '%s' isn't a string in namespace %s" %
(regex, ns)
)
target_ns[ns] = origin_ns[ns]
class UnregisterRestServlet(AppServiceRestServlet): class UnregisterRestServlet(AppServiceRestServlet):
"""Handles AS registration with the home server. """Handles AS registration with the home server.

View File

@ -50,6 +50,10 @@ class PushRuleRestServlet(ClientV1RestServlet):
content = _parse_json(request) content = _parse_json(request)
if 'attr' in spec:
self.set_rule_attr(user.to_string(), spec, content)
defer.returnValue((200, {}))
try: try:
(conditions, actions) = _rule_tuple_from_request_object( (conditions, actions) = _rule_tuple_from_request_object(
spec['template'], spec['template'],
@ -110,7 +114,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
# we build up the full structure and then decide which bits of it # we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is # to send which means doing unnecessary work sometimes but is
# is probably not going to make a whole lot of difference # is probably not going to make a whole lot of difference
rawrules = yield self.hs.get_datastore().get_push_rules_for_user_name( rawrules = yield self.hs.get_datastore().get_push_rules_for_user(
user.to_string() user.to_string()
) )
@ -124,6 +128,9 @@ class PushRuleRestServlet(ClientV1RestServlet):
rules['global'] = _add_empty_priority_class_arrays(rules['global']) rules['global'] = _add_empty_priority_class_arrays(rules['global'])
enabled_map = yield self.hs.get_datastore().\
get_push_rules_enabled_for_user(user.to_string())
for r in ruleslist: for r in ruleslist:
rulearray = None rulearray = None
@ -149,6 +156,9 @@ class PushRuleRestServlet(ClientV1RestServlet):
template_rule = _rule_to_template(r) template_rule = _rule_to_template(r)
if template_rule: if template_rule:
template_rule['enabled'] = True
if r['rule_id'] in enabled_map:
template_rule['enabled'] = enabled_map[r['rule_id']]
rulearray.append(template_rule) rulearray.append(template_rule)
path = request.postpath[1:] path = request.postpath[1:]
@ -189,6 +199,25 @@ class PushRuleRestServlet(ClientV1RestServlet):
def on_OPTIONS(self, _): def on_OPTIONS(self, _):
return 200, {} return 200, {}
def set_rule_attr(self, user_name, spec, val):
if spec['attr'] == 'enabled':
if not isinstance(val, bool):
raise SynapseError(400, "Value for 'enabled' must be boolean")
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
self.hs.get_datastore().set_push_rule_enabled(
user_name, namespaced_rule_id, val
)
else:
raise UnrecognizedRequestError()
def get_rule_attr(self, user_name, namespaced_rule_id, attr):
if attr == 'enabled':
return self.hs.get_datastore().get_push_rule_enabled_by_user_rule_id(
user_name, namespaced_rule_id
)
else:
raise UnrecognizedRequestError()
def _rule_spec_from_path(path): def _rule_spec_from_path(path):
if len(path) < 2: if len(path) < 2:
@ -214,7 +243,7 @@ def _rule_spec_from_path(path):
template = path[0] template = path[0]
path = path[1:] path = path[1:]
if len(path) == 0: if len(path) == 0 or len(path[0]) == 0:
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
rule_id = path[0] rule_id = path[0]
@ -226,6 +255,12 @@ def _rule_spec_from_path(path):
} }
if device: if device:
spec['profile_tag'] = device spec['profile_tag'] = device
path = path[1:]
if len(path) > 0 and len(path[0]) > 0:
spec['attr'] = path[0]
return spec return spec
@ -275,7 +310,7 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None
for a in actions: for a in actions:
if a in ['notify', 'dont_notify', 'coalesce']: if a in ['notify', 'dont_notify', 'coalesce']:
pass pass
elif isinstance(a, dict) and 'set_sound' in a: elif isinstance(a, dict) and 'set_tweak' in a:
pass pass
else: else:
raise InvalidRuleException("Unrecognised action") raise InvalidRuleException("Unrecognised action")
@ -319,11 +354,24 @@ def _filter_ruleset_with_path(ruleset, path):
if path[0] == '': if path[0] == '':
return ruleset[template_kind] return ruleset[template_kind]
rule_id = path[0] rule_id = path[0]
the_rule = None
for r in ruleset[template_kind]: for r in ruleset[template_kind]:
if r['rule_id'] == rule_id: if r['rule_id'] == rule_id:
return r the_rule = r
if the_rule is None:
raise NotFoundError raise NotFoundError
path = path[1:]
if len(path) == 0:
return the_rule
attr = path[0]
if attr in the_rule:
return the_rule[attr]
else:
raise UnrecognizedRequestError()
def _priority_class_from_spec(spec): def _priority_class_from_spec(spec):
if spec['template'] not in PRIORITY_CLASS_MAP.keys(): if spec['template'] not in PRIORITY_CLASS_MAP.keys():
@ -339,7 +387,7 @@ def _priority_class_from_spec(spec):
def _priority_class_to_template_name(pc): def _priority_class_to_template_name(pc):
if pc > PRIORITY_CLASS_MAP['override']: if pc > PRIORITY_CLASS_MAP['override']:
# per-device # per-device
prio_class_index = pc - len(PushRuleRestServlet.PRIORITY_CLASS_MAP) prio_class_index = pc - len(PRIORITY_CLASS_MAP)
return PRIORITY_CLASS_INVERSE_MAP[prio_class_index] return PRIORITY_CLASS_INVERSE_MAP[prio_class_index]
else: else:
return PRIORITY_CLASS_INVERSE_MAP[pc] return PRIORITY_CLASS_INVERSE_MAP[pc]
@ -399,9 +447,6 @@ class InvalidRuleException(Exception):
def _parse_json(request): def _parse_json(request):
try: try:
content = json.loads(request.content.read()) content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.",
errcode=Codes.NOT_JSON)
return content return content
except ValueError: except ValueError:
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)

View File

@ -59,6 +59,7 @@ class RegisterRestServlet(ClientV1RestServlet):
# } # }
# TODO: persistent storage # TODO: persistent storage
self.sessions = {} self.sessions = {}
self.disable_registration = hs.config.disable_registration
def on_GET(self, request): def on_GET(self, request):
if self.hs.config.enable_registration_captcha: if self.hs.config.enable_registration_captcha:
@ -107,6 +108,11 @@ class RegisterRestServlet(ClientV1RestServlet):
try: try:
login_type = register_json["type"] login_type = register_json["type"]
is_application_server = login_type == LoginType.APPLICATION_SERVICE
if self.disable_registration and not is_application_server:
raise SynapseError(403, "Registration has been disabled")
stages = { stages = {
LoginType.RECAPTCHA: self._do_recaptcha, LoginType.RECAPTCHA: self._do_recaptcha,
LoginType.PASSWORD: self._do_password, LoginType.PASSWORD: self._do_password,

View File

@ -73,6 +73,7 @@ class BaseHomeServer(object):
'resource_for_client', 'resource_for_client',
'resource_for_client_v2_alpha', 'resource_for_client_v2_alpha',
'resource_for_federation', 'resource_for_federation',
'resource_for_static_content',
'resource_for_web_client', 'resource_for_web_client',
'resource_for_content_repo', 'resource_for_content_repo',
'resource_for_server_key', 'resource_for_server_key',

View File

@ -45,36 +45,19 @@ from syutil.jsonutil import encode_canonical_json
from synapse.crypto.event_signing import compute_event_reference_hash from synapse.crypto.event_signing import compute_event_reference_hash
import fnmatch
import imp
import logging import logging
import os import os
import re
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SCHEMAS = [ # Remember to update this number every time a change is made to database
"transactions", # schema files, so the users will be informed on server restarts.
"users", SCHEMA_VERSION = 14
"profiles",
"presence",
"im",
"room_aliases",
"keys",
"redactions",
"state",
"event_edges",
"event_signatures",
"pusher",
"media_repository",
"application_services",
"filtering",
"rejections",
]
# Remember to update this number every time an incompatible change is made to
# database schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 13
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))
@ -576,28 +559,15 @@ class DataStore(RoomMemberStore, RoomStore,
) )
def schema_path(schema): def read_schema(path):
""" Get a filesystem path for the named database schema
Args:
schema: Name of the database schema.
Returns:
A filesystem path pointing at a ".sql" file.
"""
schemaPath = os.path.join(dir_path, "schema", schema + ".sql")
return schemaPath
def read_schema(schema):
""" Read the named database schema. """ Read the named database schema.
Args: Args:
schema: Name of the datbase schema. path: Path of the database schema.
Returns: Returns:
A string containing the database schema. A string containing the database schema.
""" """
with open(schema_path(schema)) as schema_file: with open(path) as schema_file:
return schema_file.read() return schema_file.read()
@ -610,49 +580,275 @@ class UpgradeDatabaseException(PrepareDatabaseException):
def prepare_database(db_conn): def prepare_database(db_conn):
""" Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we """Prepares a database for usage. Will either create all necessary tables
don't have to worry about overwriting existing content. or upgrade from an older schema version.
""" """
c = db_conn.cursor() try:
c.execute("PRAGMA user_version") cur = db_conn.cursor()
row = c.fetchone() version_info = _get_or_create_schema_state(cur)
if row and row[0]: if version_info:
user_version = row[0] user_version, delta_files, upgraded = version_info
_upgrade_existing_database(cur, user_version, delta_files, upgraded)
else:
_setup_new_database(cur)
if user_version > SCHEMA_VERSION: cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,))
cur.close()
db_conn.commit()
except:
db_conn.rollback()
raise
def _setup_new_database(cur):
"""Sets up the database by finding a base set of "full schemas" and then
applying any necessary deltas.
The "full_schemas" directory has subdirectories named after versions. This
function searches for the highest version less than or equal to
`SCHEMA_VERSION` and executes all .sql files in that directory.
The function will then apply all deltas for all versions after the base
version.
Example directory structure:
schema/
delta/
...
full_schemas/
3/
test.sql
...
11/
foo.sql
bar.sql
...
In the example foo.sql and bar.sql would be run, and then any delta files
for versions strictly greater than 11.
"""
current_dir = os.path.join(dir_path, "schema", "full_schemas")
directory_entries = os.listdir(current_dir)
valid_dirs = []
pattern = re.compile(r"^\d+(\.sql)?$")
for filename in directory_entries:
match = pattern.match(filename)
abs_path = os.path.join(current_dir, filename)
if match and os.path.isdir(abs_path):
ver = int(match.group(0))
if ver <= SCHEMA_VERSION:
valid_dirs.append((ver, abs_path))
else:
logger.warn("Unexpected entry in 'full_schemas': %s", filename)
if not valid_dirs:
raise PrepareDatabaseException(
"Could not find a suitable base set of full schemas"
)
max_current_ver, sql_dir = max(valid_dirs, key=lambda x: x[0])
logger.debug("Initialising schema v%d", max_current_ver)
directory_entries = os.listdir(sql_dir)
sql_script = "BEGIN TRANSACTION;\n"
for filename in fnmatch.filter(directory_entries, "*.sql"):
sql_loc = os.path.join(sql_dir, filename)
logger.debug("Applying schema %s", sql_loc)
sql_script += read_schema(sql_loc)
sql_script += "\n"
sql_script += "COMMIT TRANSACTION;"
cur.executescript(sql_script)
cur.execute(
"INSERT OR REPLACE INTO schema_version (version, upgraded)"
" VALUES (?,?)",
(max_current_ver, False)
)
_upgrade_existing_database(
cur,
current_version=max_current_ver,
applied_delta_files=[],
upgraded=False
)
def _upgrade_existing_database(cur, current_version, applied_delta_files,
upgraded):
"""Upgrades an existing database.
Delta files can either be SQL stored in *.sql files, or python modules
in *.py.
There can be multiple delta files per version. Synapse will keep track of
which delta files have been applied, and will apply any that haven't been
even if there has been no version bump. This is useful for development
where orthogonal schema changes may happen on separate branches.
Different delta files for the same version *must* be orthogonal and give
the same result when applied in any order. No guarantees are made on the
order of execution of these scripts.
This is a no-op of current_version == SCHEMA_VERSION.
Example directory structure:
schema/
delta/
11/
foo.sql
...
12/
foo.sql
bar.py
...
full_schemas/
...
In the example, if current_version is 11, then foo.sql will be run if and
only if `upgraded` is True. Then `foo.sql` and `bar.py` would be run in
some arbitrary order.
Args:
cur (Cursor)
current_version (int): The current version of the schema.
applied_delta_files (list): A list of deltas that have already been
applied.
upgraded (bool): Whether the current version was generated by having
applied deltas or from full schema file. If `True` the function
will never apply delta files for the given `current_version`, since
the current_version wasn't generated by applying those delta files.
"""
if current_version > SCHEMA_VERSION:
raise ValueError( raise ValueError(
"Cannot use this database as it is too " + "Cannot use this database as it is too " +
"new for the server to understand" "new for the server to understand"
) )
elif user_version < SCHEMA_VERSION:
logger.info(
"Upgrading database from version %d",
user_version
)
# Run every version since after the current version. start_ver = current_version
for v in range(user_version + 1, SCHEMA_VERSION + 1): if not upgraded:
if v == 10: start_ver += 1
for v in range(start_ver, SCHEMA_VERSION + 1):
logger.debug("Upgrading schema to v%d", v)
delta_dir = os.path.join(dir_path, "schema", "delta", str(v))
try:
directory_entries = os.listdir(delta_dir)
except OSError:
logger.exception("Could not open delta dir for version %d", v)
raise UpgradeDatabaseException( raise UpgradeDatabaseException(
"No delta for version 10" "Could not open delta dir for version %d" % (v,)
) )
sql_script = read_schema("delta/v%d" % (v))
c.executescript(sql_script)
db_conn.commit() directory_entries.sort()
for file_name in directory_entries:
relative_path = os.path.join(str(v), file_name)
if relative_path in applied_delta_files:
continue
absolute_path = os.path.join(
dir_path, "schema", "delta", relative_path,
)
root_name, ext = os.path.splitext(file_name)
if ext == ".py":
# This is a python upgrade module. We need to import into some
# package and then execute its `run_upgrade` function.
module_name = "synapse.storage.v%d_%s" % (
v, root_name
)
with open(absolute_path) as python_file:
module = imp.load_source(
module_name, absolute_path, python_file
)
logger.debug("Running script %s", relative_path)
module.run_upgrade(cur)
elif ext == ".sql":
# A plain old .sql file, just read and execute it
delta_schema = read_schema(absolute_path)
logger.debug("Applying schema %s", relative_path)
cur.executescript(delta_schema)
else: else:
logger.info("Database is at version %r", user_version) # Not a valid delta file.
logger.warn(
"Found directory entry that did not end in .py or"
" .sql: %s",
relative_path,
)
continue
else: # Mark as done.
sql_script = "BEGIN TRANSACTION;\n" cur.execute(
for sql_loc in SCHEMAS: "INSERT INTO applied_schema_deltas (version, file)"
logger.debug("Applying schema %r", sql_loc) " VALUES (?,?)",
sql_script += read_schema(sql_loc) (v, relative_path)
sql_script += "\n" )
sql_script += "COMMIT TRANSACTION;"
c.executescript(sql_script)
db_conn.commit()
c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION)
cur.execute(
"INSERT OR REPLACE INTO schema_version (version, upgraded)"
" VALUES (?,?)",
(v, True)
)
def _get_or_create_schema_state(txn):
schema_path = os.path.join(
dir_path, "schema", "schema_version.sql",
)
create_schema = read_schema(schema_path)
txn.executescript(create_schema)
txn.execute("SELECT version, upgraded FROM schema_version")
row = txn.fetchone()
current_version = int(row[0]) if row else None
upgraded = bool(row[1]) if row else None
if current_version:
txn.execute(
"SELECT file FROM applied_schema_deltas WHERE version >= ?",
(current_version,)
)
return current_version, txn.fetchall(), upgraded
return None
def prepare_sqlite3_database(db_conn):
"""This function should be called before `prepare_database` on sqlite3
databases.
Since we changed the way we store the current schema version and handle
updates to schemas, we need a way to upgrade from the old method to the
new. This only affects sqlite databases since they were the only ones
supported at the time.
"""
with db_conn:
schema_path = os.path.join(
dir_path, "schema", "schema_version.sql",
)
create_schema = read_schema(schema_path)
db_conn.executescript(create_schema)
c = db_conn.execute("SELECT * FROM schema_version")
rows = c.fetchall()
c.close() c.close()
if not rows:
c = db_conn.execute("PRAGMA user_version")
row = c.fetchone()
c.close()
if row and row[0]:
db_conn.execute(
"INSERT OR REPLACE INTO schema_version (version, upgraded)"
" VALUES (?,?)",
(row[0], False)
)

View File

@ -23,7 +23,7 @@ from synapse.util.lrucache import LruCache
from twisted.internet import defer from twisted.internet import defer
import collections from collections import namedtuple, OrderedDict
import simplejson as json import simplejson as json
import sys import sys
import time import time
@ -35,6 +35,52 @@ sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn") transaction_logger = logging.getLogger("synapse.storage.txn")
# TODO(paul):
# * more generic key management
# * export monitoring stats
# * consider other eviction strategies - LRU?
def cached(max_entries=1000):
""" A method decorator that applies a memoizing cache around the function.
The function is presumed to take one additional argument, which is used as
the key for the cache. Cache hits are served directly from the cache;
misses use the function body to generate the value.
The wrapped function has an additional member, a callable called
"invalidate". This can be used to remove individual entries from the cache.
The wrapped function has another additional callable, called "prefill",
which can be used to insert values into the cache specifically, without
calling the calculation function.
"""
def wrap(orig):
cache = OrderedDict()
def prefill(key, value):
while len(cache) > max_entries:
cache.popitem(last=False)
cache[key] = value
@defer.inlineCallbacks
def wrapped(self, key):
if key in cache:
defer.returnValue(cache[key])
ret = yield orig(self, key)
prefill(key, ret)
defer.returnValue(ret)
def invalidate(key):
cache.pop(key, None)
wrapped.invalidate = invalidate
wrapped.prefill = prefill
return wrapped
return wrap
class LoggingTransaction(object): class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object """An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging to the .execute() method.""" passed to the constructor. Adds logging to the .execute() method."""
@ -404,7 +450,8 @@ class SQLBaseStore(object):
Args: Args:
table : string giving the table name table : string giving the table name
keyvalues : dict of column names and values to select the rows with keyvalues : dict of column names and values to select the rows with,
or None to not apply a WHERE clause.
retcols : list of strings giving the names of the columns to return retcols : list of strings giving the names of the columns to return
""" """
return self.runInteraction( return self.runInteraction(
@ -423,13 +470,20 @@ class SQLBaseStore(object):
keyvalues : dict of column names and values to select the rows with keyvalues : dict of column names and values to select the rows with
retcols : list of strings giving the names of the columns to return retcols : list of strings giving the names of the columns to return
""" """
if keyvalues:
sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % ( sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
", ".join(retcols), ", ".join(retcols),
table, table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues) " AND ".join("%s = ?" % (k, ) for k in keyvalues)
) )
txn.execute(sql, keyvalues.values()) txn.execute(sql, keyvalues.values())
else:
sql = "SELECT %s FROM %s ORDER BY rowid asc" % (
", ".join(retcols),
table
)
txn.execute(sql)
return self.cursor_to_dict(txn) return self.cursor_to_dict(txn)
def _simple_update_one(self, table, keyvalues, updatevalues, def _simple_update_one(self, table, keyvalues, updatevalues,
@ -586,8 +640,9 @@ class SQLBaseStore(object):
start_time = time.time() * 1000 start_time = time.time() * 1000
update_counter = self._get_event_counters.update update_counter = self._get_event_counters.update
try:
cache = self._get_event_cache.setdefault(event_id, {}) cache = self._get_event_cache.setdefault(event_id, {})
try:
# Separate cache entries for each way to invoke _get_event_txn # Separate cache entries for each way to invoke _get_event_txn
return cache[(check_redacted, get_prev_content, allow_rejected)] return cache[(check_redacted, get_prev_content, allow_rejected)]
except KeyError: except KeyError:
@ -786,7 +841,7 @@ class JoinHelper(object):
for table in self.tables: for table in self.tables:
res += [f for f in table.fields if f not in res] res += [f for f in table.fields if f not in res]
self.EntryType = collections.namedtuple("JoinHelperEntry", res) self.EntryType = namedtuple("JoinHelperEntry", res)
def get_fields(self, **prefixes): def get_fields(self, **prefixes):
"""Get a string representing a list of fields for use in SELECT """Get a string representing a list of fields for use in SELECT

View File

@ -13,34 +13,32 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import simplejson
from simplejson import JSONDecodeError
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.storage.roommember import RoomsForUser
from ._base import SQLBaseStore from ._base import SQLBaseStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ApplicationServiceCache(object): def log_failure(failure):
"""Caches ApplicationServices and provides utility functions on top. logger.error("Failed to detect application services: %s", failure.value)
logger.error(failure.getTraceback())
This class is designed to be invoked on incoming events in order to avoid
hammering the database every time to extract a list of application service
regexes.
"""
def __init__(self):
self.services = []
class ApplicationServiceStore(SQLBaseStore): class ApplicationServiceStore(SQLBaseStore):
def __init__(self, hs): def __init__(self, hs):
super(ApplicationServiceStore, self).__init__(hs) super(ApplicationServiceStore, self).__init__(hs)
self.cache = ApplicationServiceCache() self.services_cache = []
self.cache_defer = self._populate_cache() self.cache_defer = self._populate_cache()
self.cache_defer.addErrback(log_failure)
@defer.inlineCallbacks @defer.inlineCallbacks
def unregister_app_service(self, token): def unregister_app_service(self, token):
@ -56,7 +54,7 @@ class ApplicationServiceStore(SQLBaseStore):
token, token,
) )
# update cache TODO: Should this be in the txn? # update cache TODO: Should this be in the txn?
for service in self.cache.services: for service in self.services_cache:
if service.token == token: if service.token == token:
service.url = None service.url = None
service.namespaces = None service.namespaces = None
@ -110,13 +108,13 @@ class ApplicationServiceStore(SQLBaseStore):
) )
# update cache TODO: Should this be in the txn? # update cache TODO: Should this be in the txn?
for (index, cache_service) in enumerate(self.cache.services): for (index, cache_service) in enumerate(self.services_cache):
if service.token == cache_service.token: if service.token == cache_service.token:
self.cache.services[index] = service self.services_cache[index] = service
logger.info("Updated: %s", service) logger.info("Updated: %s", service)
return return
# new entry # new entry
self.cache.services.append(service) self.services_cache.append(service)
logger.info("Updated(new): %s", service) logger.info("Updated(new): %s", service)
def _update_app_service_txn(self, txn, service): def _update_app_service_txn(self, txn, service):
@ -140,11 +138,11 @@ class ApplicationServiceStore(SQLBaseStore):
) )
for (ns_int, ns_str) in enumerate(ApplicationService.NS_LIST): for (ns_int, ns_str) in enumerate(ApplicationService.NS_LIST):
if ns_str in service.namespaces: if ns_str in service.namespaces:
for regex in service.namespaces[ns_str]: for regex_obj in service.namespaces[ns_str]:
txn.execute( txn.execute(
"INSERT INTO application_services_regex(" "INSERT INTO application_services_regex("
"as_id, namespace, regex) values(?,?,?)", "as_id, namespace, regex) values(?,?,?)",
(as_id, ns_int, regex) (as_id, ns_int, simplejson.dumps(regex_obj))
) )
return True return True
@ -160,11 +158,34 @@ class ApplicationServiceStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_app_services(self): def get_app_services(self):
yield self.cache_defer # make sure the cache is ready yield self.cache_defer # make sure the cache is ready
defer.returnValue(self.cache.services) defer.returnValue(self.services_cache)
@defer.inlineCallbacks
def get_app_service_by_user_id(self, user_id):
"""Retrieve an application service from their user ID.
All application services have associated with them a particular user ID.
There is no distinguishing feature on the user ID which indicates it
represents an application service. This function allows you to map from
a user ID to an application service.
Args:
user_id(str): The user ID to see if it is an application service.
Returns:
synapse.appservice.ApplicationService or None.
"""
yield self.cache_defer # make sure the cache is ready
for service in self.services_cache:
if service.sender == user_id:
defer.returnValue(service)
return
defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_app_service_by_token(self, token, from_cache=True): def get_app_service_by_token(self, token, from_cache=True):
"""Get the application service with the given token. """Get the application service with the given appservice token.
Args: Args:
token (str): The application service token. token (str): The application service token.
@ -176,7 +197,7 @@ class ApplicationServiceStore(SQLBaseStore):
yield self.cache_defer # make sure the cache is ready yield self.cache_defer # make sure the cache is ready
if from_cache: if from_cache:
for service in self.cache.services: for service in self.services_cache:
if service.token == token: if service.token == token:
defer.returnValue(service) defer.returnValue(service)
return return
@ -185,6 +206,77 @@ class ApplicationServiceStore(SQLBaseStore):
# TODO: The from_cache=False impl # TODO: The from_cache=False impl
# TODO: This should be JOINed with the application_services_regex table. # TODO: This should be JOINed with the application_services_regex table.
def get_app_service_rooms(self, service):
"""Get a list of RoomsForUser for this application service.
Application services may be "interested" in lots of rooms depending on
the room ID, the room aliases, or the members in the room. This function
takes all of these into account and returns a list of RoomsForUser which
represent the entire list of room IDs that this application service
wants to know about.
Args:
service: The application service to get a room list for.
Returns:
A list of RoomsForUser.
"""
return self.runInteraction(
"get_app_service_rooms",
self._get_app_service_rooms_txn,
service,
)
def _get_app_service_rooms_txn(self, txn, service):
# get all rooms matching the room ID regex.
room_entries = self._simple_select_list_txn(
txn=txn, table="rooms", keyvalues=None, retcols=["room_id"]
)
matching_room_list = set([
r["room_id"] for r in room_entries if
service.is_interested_in_room(r["room_id"])
])
# resolve room IDs for matching room alias regex.
room_alias_mappings = self._simple_select_list_txn(
txn=txn, table="room_aliases", keyvalues=None,
retcols=["room_id", "room_alias"]
)
matching_room_list |= set([
r["room_id"] for r in room_alias_mappings if
service.is_interested_in_alias(r["room_alias"])
])
# get all rooms for every user for this AS. This is scoped to users on
# this HS only.
user_list = self._simple_select_list_txn(
txn=txn, table="users", keyvalues=None, retcols=["name"]
)
user_list = [
u["name"] for u in user_list if
service.is_interested_in_user(u["name"])
]
rooms_for_user_matching_user_id = set() # RoomsForUser list
for user_id in user_list:
# FIXME: This assumes this store is linked with RoomMemberStore :(
rooms_for_user = self._get_rooms_for_user_where_membership_is_txn(
txn=txn,
user_id=user_id,
membership_list=[Membership.JOIN]
)
rooms_for_user_matching_user_id |= set(rooms_for_user)
# make RoomsForUser tuples for room ids and aliases which are not in the
# main rooms_for_user_list - e.g. they are rooms which do not have AS
# registered users in it.
known_room_ids = [r.room_id for r in rooms_for_user_matching_user_id]
missing_rooms_for_user = [
RoomsForUser(r, service.sender, "join") for r in
matching_room_list if r not in known_room_ids
]
rooms_for_user_matching_user_id |= set(missing_rooms_for_user)
return rooms_for_user_matching_user_id
@defer.inlineCallbacks @defer.inlineCallbacks
def _populate_cache(self): def _populate_cache(self):
"""Populates the ApplicationServiceCache from the database.""" """Populates the ApplicationServiceCache from the database."""
@ -227,15 +319,17 @@ class ApplicationServiceStore(SQLBaseStore):
try: try:
services[as_token]["namespaces"][ services[as_token]["namespaces"][
ApplicationService.NS_LIST[ns_int]].append( ApplicationService.NS_LIST[ns_int]].append(
res["regex"] simplejson.loads(res["regex"])
) )
except IndexError: except IndexError:
logger.error("Bad namespace enum '%s'. %s", ns_int, res) logger.error("Bad namespace enum '%s'. %s", ns_int, res)
except JSONDecodeError:
logger.error("Bad regex object '%s'", res["regex"])
# TODO get last successful txn id f.e. service # TODO get last successful txn id f.e. service
for service in services.values(): for service in services.values():
logger.info("Found application service: %s", service) logger.info("Found application service: %s", service)
self.cache.services.append(ApplicationService( self.services_cache.append(ApplicationService(
token=service["token"], token=service["token"],
url=service["url"], url=service["url"],
namespaces=service["namespaces"], namespaces=service["namespaces"],

View File

@ -64,6 +64,9 @@ class EventFederationStore(SQLBaseStore):
for f in front: for f in front:
txn.execute(base_sql, (f,)) txn.execute(base_sql, (f,))
new_front.update([r[0] for r in txn.fetchall()]) new_front.update([r[0] for r in txn.fetchall()])
new_front -= results
front = new_front front = new_front
results.update(front) results.update(front)
@ -378,3 +381,51 @@ class EventFederationStore(SQLBaseStore):
event_results += new_front event_results += new_front
return self._get_events_txn(txn, event_results) return self._get_events_txn(txn, event_results)
def get_missing_events(self, room_id, earliest_events, latest_events,
limit, min_depth):
return self.runInteraction(
"get_missing_events",
self._get_missing_events,
room_id, earliest_events, latest_events, limit, min_depth
)
def _get_missing_events(self, txn, room_id, earliest_events, latest_events,
limit, min_depth):
earliest_events = set(earliest_events)
front = set(latest_events) - earliest_events
event_results = set()
query = (
"SELECT prev_event_id FROM event_edges "
"WHERE room_id = ? AND event_id = ? AND is_state = 0 "
"LIMIT ?"
)
while front and len(event_results) < limit:
new_front = set()
for event_id in front:
txn.execute(
query,
(room_id, event_id, limit - len(event_results))
)
for e_id, in txn.fetchall():
new_front.add(e_id)
new_front -= earliest_events
new_front -= event_results
front = new_front
event_results |= new_front
events = self._get_events_txn(txn, event_results)
events = sorted(
[ev for ev in events if ev.depth >= min_depth],
key=lambda e: e.depth,
)
return events[:limit]

View File

@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
class PushRuleStore(SQLBaseStore): class PushRuleStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_push_rules_for_user_name(self, user_name): def get_push_rules_for_user(self, user_name):
sql = ( sql = (
"SELECT "+",".join(PushRuleTable.fields)+" " "SELECT "+",".join(PushRuleTable.fields)+" "
"FROM "+PushRuleTable.table_name+" " "FROM "+PushRuleTable.table_name+" "
@ -45,6 +45,28 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(dicts) defer.returnValue(dicts)
@defer.inlineCallbacks
def get_push_rules_enabled_for_user(self, user_name):
results = yield self._simple_select_list(
PushRuleEnableTable.table_name,
{'user_name': user_name},
PushRuleEnableTable.fields
)
defer.returnValue(
{r['rule_id']: False if r['enabled'] == 0 else True for r in results}
)
@defer.inlineCallbacks
def get_push_rule_enabled_by_user_rule_id(self, user_name, rule_id):
results = yield self._simple_select_list(
PushRuleEnableTable.table_name,
{'user_name': user_name, 'rule_id': rule_id},
['enabled']
)
if not results:
defer.returnValue(True)
defer.returnValue(results[0])
@defer.inlineCallbacks @defer.inlineCallbacks
def add_push_rule(self, before, after, **kwargs): def add_push_rule(self, before, after, **kwargs):
vals = copy.copy(kwargs) vals = copy.copy(kwargs)
@ -193,6 +215,20 @@ class PushRuleStore(SQLBaseStore):
{'user_name': user_name, 'rule_id': rule_id} {'user_name': user_name, 'rule_id': rule_id}
) )
@defer.inlineCallbacks
def set_push_rule_enabled(self, user_name, rule_id, enabled):
if enabled:
yield self._simple_delete_one(
PushRuleEnableTable.table_name,
{'user_name': user_name, 'rule_id': rule_id}
)
else:
yield self._simple_upsert(
PushRuleEnableTable.table_name,
{'user_name': user_name, 'rule_id': rule_id},
{'enabled': False}
)
class RuleNotFoundException(Exception): class RuleNotFoundException(Exception):
pass pass
@ -216,3 +252,13 @@ class PushRuleTable(Table):
] ]
EntryType = collections.namedtuple("PushRuleEntry", fields) EntryType = collections.namedtuple("PushRuleEntry", fields)
class PushRuleEnableTable(Table):
table_name = "push_rules_enable"
fields = [
"user_name",
"rule_id",
"enabled"
]

View File

@ -17,7 +17,7 @@ from twisted.internet import defer
from collections import namedtuple from collections import namedtuple
from ._base import SQLBaseStore from ._base import SQLBaseStore, cached
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.types import UserID from synapse.types import UserID
@ -35,11 +35,6 @@ RoomsForUser = namedtuple(
class RoomMemberStore(SQLBaseStore): class RoomMemberStore(SQLBaseStore):
def __init__(self, *args, **kw):
super(RoomMemberStore, self).__init__(*args, **kw)
self._user_rooms_cache = {}
def _store_room_member_txn(self, txn, event): def _store_room_member_txn(self, txn, event):
"""Store a room member in the database. """Store a room member in the database.
""" """
@ -103,7 +98,7 @@ class RoomMemberStore(SQLBaseStore):
txn.execute(sql, (event.room_id, domain)) txn.execute(sql, (event.room_id, domain))
self.invalidate_rooms_for_user(target_user_id) self.get_rooms_for_user.invalidate(target_user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_member(self, user_id, room_id): def get_room_member(self, user_id, room_id):
@ -185,6 +180,14 @@ class RoomMemberStore(SQLBaseStore):
if not membership_list: if not membership_list:
return defer.succeed(None) return defer.succeed(None)
return self.runInteraction(
"get_rooms_for_user_where_membership_is",
self._get_rooms_for_user_where_membership_is_txn,
user_id, membership_list
)
def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id,
membership_list):
where_clause = "user_id = ? AND (%s)" % ( where_clause = "user_id = ? AND (%s)" % (
" OR ".join(["membership = ?" for _ in membership_list]), " OR ".join(["membership = ?" for _ in membership_list]),
) )
@ -192,7 +195,6 @@ class RoomMemberStore(SQLBaseStore):
args = [user_id] args = [user_id]
args.extend(membership_list) args.extend(membership_list)
def f(txn):
sql = ( sql = (
"SELECT m.room_id, m.sender, m.membership" "SELECT m.room_id, m.sender, m.membership"
" FROM room_memberships as m" " FROM room_memberships as m"
@ -206,11 +208,6 @@ class RoomMemberStore(SQLBaseStore):
RoomsForUser(**r) for r in self.cursor_to_dict(txn) RoomsForUser(**r) for r in self.cursor_to_dict(txn)
] ]
return self.runInteraction(
"get_rooms_for_user_where_membership_is",
f
)
def get_joined_hosts_for_room(self, room_id): def get_joined_hosts_for_room(self, room_id):
return self._simple_select_onecol( return self._simple_select_onecol(
"room_hosts", "room_hosts",
@ -247,33 +244,12 @@ class RoomMemberStore(SQLBaseStore):
results = self._parse_events_txn(txn, rows) results = self._parse_events_txn(txn, rows)
return results return results
# TODO(paul): Create a nice @cached decorator to do this @cached()
# @cached
# def get_foo(...)
# ...
# invalidate_foo = get_foo.invalidator
@defer.inlineCallbacks
def get_rooms_for_user(self, user_id): def get_rooms_for_user(self, user_id):
# TODO(paul): put some performance counters in here so we can easily return self.get_rooms_for_user_where_membership_is(
# track what impact this cache is having
if user_id in self._user_rooms_cache:
defer.returnValue(self._user_rooms_cache[user_id])
rooms = yield self.get_rooms_for_user_where_membership_is(
user_id, membership_list=[Membership.JOIN], user_id, membership_list=[Membership.JOIN],
) )
# TODO(paul): Consider applying a maximum size; just evict things at
# random, or consider LRU?
self._user_rooms_cache[user_id] = rooms
defer.returnValue(rooms)
def invalidate_rooms_for_user(self, user_id):
if user_id in self._user_rooms_cache:
del self._user_rooms_cache[user_id]
@defer.inlineCallbacks @defer.inlineCallbacks
def user_rooms_intersect(self, user_id_list): def user_rooms_intersect(self, user_id_list):
""" Checks whether all the users whose IDs are given in a list share a """ Checks whether all the users whose IDs are given in a list share a

View File

@ -1,34 +0,0 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS application_services(
id INTEGER PRIMARY KEY AUTOINCREMENT,
url TEXT,
token TEXT,
hs_token TEXT,
sender TEXT,
UNIQUE(token) ON CONFLICT ROLLBACK
);
CREATE TABLE IF NOT EXISTS application_services_regex(
id INTEGER PRIMARY KEY AUTOINCREMENT,
as_id INTEGER NOT NULL,
namespace INTEGER, /* enum[room_id|room_alias|user_id] */
regex TEXT,
FOREIGN KEY(as_id) REFERENCES application_services(id)
);

View File

@ -0,0 +1,23 @@
import json
import logging
logger = logging.getLogger(__name__)
def run_upgrade(cur):
cur.execute("SELECT id, regex FROM application_services_regex")
for row in cur.fetchall():
try:
logger.debug("Checking %s..." % row[0])
json.loads(row[1])
except ValueError:
# row isn't in json, make it so.
string_regex = row[1]
new_regex = json.dumps({
"regex": string_regex,
"exclusive": True
})
cur.execute(
"UPDATE application_services_regex SET regex=? WHERE id=?",
(new_regex, row[0])
)

View File

@ -0,0 +1,9 @@
CREATE TABLE IF NOT EXISTS push_rules_enable (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_name TEXT NOT NULL,
rule_id TEXT NOT NULL,
enabled TINYINT,
UNIQUE(user_name, rule_id)
);
CREATE INDEX IF NOT EXISTS push_rules_enable_user_name on push_rules_enable (user_name);

View File

@ -1,168 +0,0 @@
/* Copyright 2014, 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS events(
stream_ordering INTEGER PRIMARY KEY AUTOINCREMENT,
topological_ordering INTEGER NOT NULL,
event_id TEXT NOT NULL,
type TEXT NOT NULL,
room_id TEXT NOT NULL,
content TEXT NOT NULL,
unrecognized_keys TEXT,
processed BOOL NOT NULL,
outlier BOOL NOT NULL,
CONSTRAINT ev_uniq UNIQUE (event_id)
);
CREATE INDEX IF NOT EXISTS events_event_id ON events (event_id);
CREATE INDEX IF NOT EXISTS events_stream_ordering ON events (stream_ordering);
CREATE INDEX IF NOT EXISTS events_topological_ordering ON events (topological_ordering);
CREATE INDEX IF NOT EXISTS events_room_id ON events (room_id);
CREATE TABLE IF NOT EXISTS state_events(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
type TEXT NOT NULL,
state_key TEXT NOT NULL,
prev_state TEXT
);
CREATE UNIQUE INDEX IF NOT EXISTS state_events_event_id ON state_events (event_id);
CREATE INDEX IF NOT EXISTS state_events_room_id ON state_events (room_id);
CREATE INDEX IF NOT EXISTS state_events_type ON state_events (type);
CREATE INDEX IF NOT EXISTS state_events_state_key ON state_events (state_key);
CREATE TABLE IF NOT EXISTS current_state_events(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
type TEXT NOT NULL,
state_key TEXT NOT NULL,
CONSTRAINT curr_uniq UNIQUE (room_id, type, state_key) ON CONFLICT REPLACE
);
CREATE INDEX IF NOT EXISTS curr_events_event_id ON current_state_events (event_id);
CREATE INDEX IF NOT EXISTS current_state_events_room_id ON current_state_events (room_id);
CREATE INDEX IF NOT EXISTS current_state_events_type ON current_state_events (type);
CREATE INDEX IF NOT EXISTS current_state_events_state_key ON current_state_events (state_key);
CREATE TABLE IF NOT EXISTS room_memberships(
event_id TEXT NOT NULL,
user_id TEXT NOT NULL,
sender TEXT NOT NULL,
room_id TEXT NOT NULL,
membership TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS room_memberships_event_id ON room_memberships (event_id);
CREATE INDEX IF NOT EXISTS room_memberships_room_id ON room_memberships (room_id);
CREATE INDEX IF NOT EXISTS room_memberships_user_id ON room_memberships (user_id);
CREATE TABLE IF NOT EXISTS feedback(
event_id TEXT NOT NULL,
feedback_type TEXT,
target_event_id TEXT,
sender TEXT,
room_id TEXT
);
CREATE TABLE IF NOT EXISTS topics(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
topic TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS room_names(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
name TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS rooms(
room_id TEXT PRIMARY KEY NOT NULL,
is_public INTEGER,
creator TEXT
);
CREATE TABLE IF NOT EXISTS room_join_rules(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
join_rule TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS room_join_rules_event_id ON room_join_rules(event_id);
CREATE INDEX IF NOT EXISTS room_join_rules_room_id ON room_join_rules(room_id);
CREATE TABLE IF NOT EXISTS room_power_levels(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
user_id TEXT NOT NULL,
level INTEGER NOT NULL
);
CREATE INDEX IF NOT EXISTS room_power_levels_event_id ON room_power_levels(event_id);
CREATE INDEX IF NOT EXISTS room_power_levels_room_id ON room_power_levels(room_id);
CREATE INDEX IF NOT EXISTS room_power_levels_room_user ON room_power_levels(room_id, user_id);
CREATE TABLE IF NOT EXISTS room_default_levels(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
level INTEGER NOT NULL
);
CREATE INDEX IF NOT EXISTS room_default_levels_event_id ON room_default_levels(event_id);
CREATE INDEX IF NOT EXISTS room_default_levels_room_id ON room_default_levels(room_id);
CREATE TABLE IF NOT EXISTS room_add_state_levels(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
level INTEGER NOT NULL
);
CREATE INDEX IF NOT EXISTS room_add_state_levels_event_id ON room_add_state_levels(event_id);
CREATE INDEX IF NOT EXISTS room_add_state_levels_room_id ON room_add_state_levels(room_id);
CREATE TABLE IF NOT EXISTS room_send_event_levels(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
level INTEGER NOT NULL
);
CREATE INDEX IF NOT EXISTS room_send_event_levels_event_id ON room_send_event_levels(event_id);
CREATE INDEX IF NOT EXISTS room_send_event_levels_room_id ON room_send_event_levels(room_id);
CREATE TABLE IF NOT EXISTS room_ops_levels(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
ban_level INTEGER,
kick_level INTEGER
);
CREATE INDEX IF NOT EXISTS room_ops_levels_event_id ON room_ops_levels(event_id);
CREATE INDEX IF NOT EXISTS room_ops_levels_room_id ON room_ops_levels(room_id);
CREATE TABLE IF NOT EXISTS room_hosts(
room_id TEXT NOT NULL,
host TEXT NOT NULL,
CONSTRAINT room_hosts_uniq UNIQUE (room_id, host) ON CONFLICT IGNORE
);
CREATE INDEX IF NOT EXISTS room_hosts_room_id ON room_hosts (room_id);
PRAGMA user_version = 2;

View File

@ -1,27 +0,0 @@
/* Copyright 2014, 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE INDEX IF NOT EXISTS room_aliases_alias ON room_aliases(room_alias);
CREATE INDEX IF NOT EXISTS room_aliases_id ON room_aliases(room_id);
CREATE INDEX IF NOT EXISTS room_alias_servers_alias ON room_alias_servers(room_alias);
DELETE FROM room_aliases WHERE rowid NOT IN (SELECT max(rowid) FROM room_aliases GROUP BY room_alias, room_id);
CREATE UNIQUE INDEX IF NOT EXISTS room_aliases_uniq ON room_aliases(room_alias, room_id);
PRAGMA user_version = 3;

View File

@ -1,26 +0,0 @@
/* Copyright 2014, 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS redactions (
event_id TEXT NOT NULL,
redacts TEXT NOT NULL,
CONSTRAINT ev_uniq UNIQUE (event_id)
);
CREATE INDEX IF NOT EXISTS redactions_event_id ON redactions (event_id);
CREATE INDEX IF NOT EXISTS redactions_redacts ON redactions (redacts);
ALTER TABLE room_ops_levels ADD COLUMN redact_level INTEGER;
PRAGMA user_version = 4;

View File

@ -1,30 +0,0 @@
/* Copyright 2014, 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS user_ips (
user TEXT NOT NULL,
access_token TEXT NOT NULL,
device_id TEXT,
ip TEXT NOT NULL,
user_agent TEXT NOT NULL,
last_seen INTEGER NOT NULL,
CONSTRAINT user_ip UNIQUE (user, access_token, ip, user_agent) ON CONFLICT REPLACE
);
CREATE INDEX IF NOT EXISTS user_ips_user ON user_ips(user);
ALTER TABLE users ADD COLUMN admin BOOL DEFAULT 0 NOT NULL;
PRAGMA user_version = 5;

View File

@ -1,31 +0,0 @@
/* Copyright 2014, 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS server_tls_certificates(
server_name TEXT, -- Server name.
fingerprint TEXT, -- Certificate fingerprint.
from_server TEXT, -- Which key server the certificate was fetched from.
ts_added_ms INTEGER, -- When the certifcate was added.
tls_certificate BLOB, -- DER encoded x509 certificate.
CONSTRAINT uniqueness UNIQUE (server_name, fingerprint)
);
CREATE TABLE IF NOT EXISTS server_signature_keys(
server_name TEXT, -- Server name.
key_id TEXT, -- Key version.
from_server TEXT, -- Which key server the key was fetched form.
ts_added_ms INTEGER, -- When the key was added.
verify_key BLOB, -- NACL verification key.
CONSTRAINT uniqueness UNIQUE (server_name, key_id)
);

View File

@ -1,34 +0,0 @@
/* Copyright 2014, 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS event_signatures_2 (
event_id TEXT,
signature_name TEXT,
key_id TEXT,
signature BLOB,
CONSTRAINT uniqueness UNIQUE (event_id, signature_name, key_id)
);
INSERT INTO event_signatures_2 (event_id, signature_name, key_id, signature)
SELECT event_id, signature_name, key_id, signature FROM event_signatures;
DROP TABLE event_signatures;
ALTER TABLE event_signatures_2 RENAME TO event_signatures;
CREATE INDEX IF NOT EXISTS event_signatures_id ON event_signatures (
event_id
);
PRAGMA user_version = 8;

View File

@ -1,79 +0,0 @@
/* Copyright 2014, 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- To track destination health
CREATE TABLE IF NOT EXISTS destinations(
destination TEXT PRIMARY KEY,
retry_last_ts INTEGER,
retry_interval INTEGER
);
CREATE TABLE IF NOT EXISTS local_media_repository (
media_id TEXT, -- The id used to refer to the media.
media_type TEXT, -- The MIME-type of the media.
media_length INTEGER, -- Length of the media in bytes.
created_ts INTEGER, -- When the content was uploaded in ms.
upload_name TEXT, -- The name the media was uploaded with.
user_id TEXT, -- The user who uploaded the file.
CONSTRAINT uniqueness UNIQUE (media_id)
);
CREATE TABLE IF NOT EXISTS local_media_repository_thumbnails (
media_id TEXT, -- The id used to refer to the media.
thumbnail_width INTEGER, -- The width of the thumbnail in pixels.
thumbnail_height INTEGER, -- The height of the thumbnail in pixels.
thumbnail_type TEXT, -- The MIME-type of the thumbnail.
thumbnail_method TEXT, -- The method used to make the thumbnail.
thumbnail_length INTEGER, -- The length of the thumbnail in bytes.
CONSTRAINT uniqueness UNIQUE (
media_id, thumbnail_width, thumbnail_height, thumbnail_type
)
);
CREATE INDEX IF NOT EXISTS local_media_repository_thumbnails_media_id
ON local_media_repository_thumbnails (media_id);
CREATE TABLE IF NOT EXISTS remote_media_cache (
media_origin TEXT, -- The remote HS the media came from.
media_id TEXT, -- The id used to refer to the media on that server.
media_type TEXT, -- The MIME-type of the media.
created_ts INTEGER, -- When the content was uploaded in ms.
upload_name TEXT, -- The name the media was uploaded with.
media_length INTEGER, -- Length of the media in bytes.
filesystem_id TEXT, -- The name used to store the media on disk.
CONSTRAINT uniqueness UNIQUE (media_origin, media_id)
);
CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails (
media_origin TEXT, -- The remote HS the media came from.
media_id TEXT, -- The id used to refer to the media.
thumbnail_width INTEGER, -- The width of the thumbnail in pixels.
thumbnail_height INTEGER, -- The height of the thumbnail in pixels.
thumbnail_method TEXT, -- The method used to make the thumbnail
thumbnail_type TEXT, -- The MIME-type of the thumbnail.
thumbnail_length INTEGER, -- The length of the thumbnail in bytes.
filesystem_id TEXT, -- The name used to store the media on disk.
CONSTRAINT uniqueness UNIQUE (
media_origin, media_id, thumbnail_width, thumbnail_height,
thumbnail_type, thumbnail_type
)
);
CREATE INDEX IF NOT EXISTS remote_media_cache_thumbnails_media_id
ON local_media_repository_thumbnails (media_id);
PRAGMA user_version = 9;

View File

@ -1,24 +0,0 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS user_filters(
user_id TEXT,
filter_id INTEGER,
filter_json TEXT,
FOREIGN KEY(user_id) REFERENCES users(id)
);
CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters(
user_id, filter_id
);

View File

@ -1,46 +0,0 @@
/* Copyright 2014 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Push notification endpoints that users have configured
CREATE TABLE IF NOT EXISTS pushers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_name TEXT NOT NULL,
profile_tag varchar(32) NOT NULL,
kind varchar(8) NOT NULL,
app_id varchar(64) NOT NULL,
app_display_name varchar(64) NOT NULL,
device_display_name varchar(128) NOT NULL,
pushkey blob NOT NULL,
ts BIGINT NOT NULL,
lang varchar(8),
data blob,
last_token TEXT,
last_success BIGINT,
failing_since BIGINT,
FOREIGN KEY(user_name) REFERENCES users(name),
UNIQUE (app_id, pushkey)
);
CREATE TABLE IF NOT EXISTS push_rules (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_name TEXT NOT NULL,
rule_id TEXT NOT NULL,
priority_class TINYINT NOT NULL,
priority INTEGER NOT NULL DEFAULT 0,
conditions TEXT NOT NULL,
actions TEXT NOT NULL,
UNIQUE(user_name, rule_id)
);
CREATE INDEX IF NOT EXISTS push_rules_user_name on push_rules (user_name);

View File

@ -1,21 +0,0 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS rejections(
event_id TEXT NOT NULL,
reason TEXT NOT NULL,
last_check TEXT NOT NULL,
CONSTRAINT ev_id UNIQUE (event_id) ON CONFLICT REPLACE
);

View File

@ -0,0 +1,30 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS schema_version(
Lock char(1) NOT NULL DEFAULT 'X', -- Makes sure this table only has one row.
version INTEGER NOT NULL,
upgraded BOOL NOT NULL, -- Whether we reached this version from an upgrade or an initial schema.
CONSTRAINT schema_version_lock_x CHECK (Lock='X')
CONSTRAINT schema_version_lock_uniq UNIQUE (Lock)
);
CREATE TABLE IF NOT EXISTS applied_schema_deltas(
version INTEGER NOT NULL,
file TEXT NOT NULL,
CONSTRAINT schema_deltas_ver_file UNIQUE (version, file) ON CONFLICT IGNORE
);
CREATE INDEX IF NOT EXISTS schema_deltas_ver ON applied_schema_deltas(version);

View File

@ -36,6 +36,7 @@ what sort order was used:
from twisted.internet import defer from twisted.internet import defer
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
@ -127,6 +128,85 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")):
class StreamStore(SQLBaseStore): class StreamStore(SQLBaseStore):
@defer.inlineCallbacks
def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
# NB this lives here instead of appservice.py so we can reuse the
# 'private' StreamToken class in this file.
if limit:
limit = max(limit, MAX_STREAM_SIZE)
else:
limit = MAX_STREAM_SIZE
# From and to keys should be integers from ordering.
from_id = _StreamToken.parse_stream_token(from_key)
to_id = _StreamToken.parse_stream_token(to_key)
if from_key == to_key:
defer.returnValue(([], to_key))
return
# select all the events between from/to with a sensible limit
sql = (
"SELECT e.event_id, e.room_id, e.type, s.state_key, "
"e.stream_ordering FROM events AS e LEFT JOIN state_events as s ON "
"e.event_id = s.event_id "
"WHERE e.stream_ordering > ? AND e.stream_ordering <= ? "
"ORDER BY stream_ordering ASC LIMIT %(limit)d "
) % {
"limit": limit
}
def f(txn):
# pull out all the events between the tokens
txn.execute(sql, (from_id.stream, to_id.stream,))
rows = self.cursor_to_dict(txn)
# Logic:
# - We want ALL events which match the AS room_id regex
# - We want ALL events which match the rooms represented by the AS
# room_alias regex
# - We want ALL events for rooms that AS users have joined.
# This is currently supported via get_app_service_rooms (which is
# used for the Notifier listener rooms). We can't reasonably make a
# SQL query for these room IDs, so we'll pull all the events between
# from/to and filter in python.
rooms_for_as = self._get_app_service_rooms_txn(txn, service)
room_ids_for_as = [r.room_id for r in rooms_for_as]
def app_service_interested(row):
if row["room_id"] in room_ids_for_as:
return True
if row["type"] == EventTypes.Member:
if service.is_interested_in_user(row.get("state_key")):
return True
return False
ret = self._get_events_txn(
txn,
# apply the filter on the room id list
[
r["event_id"] for r in rows
if app_service_interested(r)
],
get_prev_content=True
)
self._set_before_and_after(ret, rows)
if rows:
key = "s%d" % max(r["stream_ordering"] for r in rows)
else:
# Assume we didn't get anything because there was nothing to
# get.
key = to_key
return ret, key
results = yield self.runInteraction("get_appservice_room_stream", f)
defer.returnValue(results)
@log_function @log_function
def get_room_events_stream(self, user_id, from_key, to_key, room_id, def get_room_events_stream(self, user_id, from_key, to_key, room_id,
limit=0, with_feedback=False): limit=0, with_feedback=False):
@ -184,8 +264,7 @@ class StreamStore(SQLBaseStore):
self._set_before_and_after(ret, rows) self._set_before_and_after(ret, rows)
if rows: if rows:
key = "s%d" % max([r["stream_ordering"] for r in rows]) key = "s%d" % max(r["stream_ordering"] for r in rows)
else: else:
# Assume we didn't get anything because there was nothing to # Assume we didn't get anything because there was nothing to
# get. # get.

View File

@ -13,12 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, Table from ._base import SQLBaseStore, Table, cached
from collections import namedtuple from collections import namedtuple
from twisted.internet import defer
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -28,10 +26,6 @@ class TransactionStore(SQLBaseStore):
"""A collection of queries for handling PDUs. """A collection of queries for handling PDUs.
""" """
# a write-through cache of DestinationsTable.EntryType indexed by
# destination string
destination_retry_cache = {}
def get_received_txn_response(self, transaction_id, origin): def get_received_txn_response(self, transaction_id, origin):
"""For an incoming transaction from a given origin, check if we have """For an incoming transaction from a given origin, check if we have
already responded to it. If so, return the response code and response already responded to it. If so, return the response code and response
@ -211,6 +205,7 @@ class TransactionStore(SQLBaseStore):
return ReceivedTransactionsTable.decode_results(txn.fetchall()) return ReceivedTransactionsTable.decode_results(txn.fetchall())
@cached()
def get_destination_retry_timings(self, destination): def get_destination_retry_timings(self, destination):
"""Gets the current retry timings (if any) for a given destination. """Gets the current retry timings (if any) for a given destination.
@ -221,9 +216,6 @@ class TransactionStore(SQLBaseStore):
None if not retrying None if not retrying
Otherwise a DestinationsTable.EntryType for the retry scheme Otherwise a DestinationsTable.EntryType for the retry scheme
""" """
if destination in self.destination_retry_cache:
return defer.succeed(self.destination_retry_cache[destination])
return self.runInteraction( return self.runInteraction(
"get_destination_retry_timings", "get_destination_retry_timings",
self._get_destination_retry_timings, destination) self._get_destination_retry_timings, destination)
@ -250,7 +242,9 @@ class TransactionStore(SQLBaseStore):
retry_interval (int) - how long until next retry in ms retry_interval (int) - how long until next retry in ms
""" """
self.destination_retry_cache[destination] = ( # As this is the new value, we might as well prefill the cache
self.get_destination_retry_timings.prefill(
destination,
DestinationsTable.EntryType( DestinationsTable.EntryType(
destination, destination,
retry_last_ts, retry_last_ts,

View File

@ -88,11 +88,15 @@ class LruCache(object):
else: else:
return default return default
def cache_len():
return len(cache)
self.sentinel = object() self.sentinel = object()
self.get = cache_get self.get = cache_get
self.set = cache_set self.set = cache_set
self.setdefault = cache_set_default self.setdefault = cache_set_default
self.pop = cache_pop self.pop = cache_pop
self.len = cache_len
def __getitem__(self, key): def __getitem__(self, key):
result = self.get(key, self.sentinel) result = self.get(key, self.sentinel)
@ -108,3 +112,6 @@ class LruCache(object):
result = self.pop(key, self.sentinel) result = self.pop(key, self.sentinel)
if result is self.sentinel: if result is self.sentinel:
raise KeyError() raise KeyError()
def __len__(self):
return self.len()

View File

@ -0,0 +1,216 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import LimitExceededError
from synapse.util.async import sleep
import collections
import contextlib
import logging
logger = logging.getLogger(__name__)
class FederationRateLimiter(object):
def __init__(self, clock, window_size, sleep_limit, sleep_msec,
reject_limit, concurrent_requests):
"""
Args:
clock (Clock)
window_size (int): The window size in milliseconds.
sleep_limit (int): The number of requests received in the last
`window_size` milliseconds before we artificially start
delaying processing of requests.
sleep_msec (int): The number of milliseconds to delay processing
of incoming requests by.
reject_limit (int): The maximum number of requests that are can be
queued for processing before we start rejecting requests with
a 429 Too Many Requests response.
concurrent_requests (int): The number of concurrent requests to
process.
"""
self.clock = clock
self.window_size = window_size
self.sleep_limit = sleep_limit
self.sleep_msec = sleep_msec
self.reject_limit = reject_limit
self.concurrent_requests = concurrent_requests
self.ratelimiters = {}
def ratelimit(self, host):
"""Used to ratelimit an incoming request from given host
Example usage:
with rate_limiter.ratelimit(origin) as wait_deferred:
yield wait_deferred
# Handle request ...
Args:
host (str): Origin of incoming request.
Returns:
_PerHostRatelimiter
"""
return self.ratelimiters.setdefault(
host,
_PerHostRatelimiter(
clock=self.clock,
window_size=self.window_size,
sleep_limit=self.sleep_limit,
sleep_msec=self.sleep_msec,
reject_limit=self.reject_limit,
concurrent_requests=self.concurrent_requests,
)
).ratelimit()
class _PerHostRatelimiter(object):
def __init__(self, clock, window_size, sleep_limit, sleep_msec,
reject_limit, concurrent_requests):
self.clock = clock
self.window_size = window_size
self.sleep_limit = sleep_limit
self.sleep_msec = sleep_msec
self.reject_limit = reject_limit
self.concurrent_requests = concurrent_requests
self.sleeping_requests = set()
self.ready_request_queue = collections.OrderedDict()
self.current_processing = set()
self.request_times = []
def is_empty(self):
time_now = self.clock.time_msec()
self.request_times[:] = [
r for r in self.request_times
if time_now - r < self.window_size
]
return not (
self.ready_request_queue
or self.sleeping_requests
or self.current_processing
or self.request_times
)
@contextlib.contextmanager
def ratelimit(self):
# `contextlib.contextmanager` takes a generator and turns it into a
# context manager. The generator should only yield once with a value
# to be returned by manager.
# Exceptions will be reraised at the yield.
request_id = object()
ret = self._on_enter(request_id)
try:
yield ret
finally:
self._on_exit(request_id)
def _on_enter(self, request_id):
time_now = self.clock.time_msec()
self.request_times[:] = [
r for r in self.request_times
if time_now - r < self.window_size
]
queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
if queue_size > self.reject_limit:
raise LimitExceededError(
retry_after_ms=int(
self.window_size / self.sleep_limit
),
)
self.request_times.append(time_now)
def queue_request():
if len(self.current_processing) > self.concurrent_requests:
logger.debug("Ratelimit [%s]: Queue req", id(request_id))
queue_defer = defer.Deferred()
self.ready_request_queue[request_id] = queue_defer
return queue_defer
else:
return defer.succeed(None)
logger.debug(
"Ratelimit [%s]: len(self.request_times)=%d",
id(request_id), len(self.request_times),
)
if len(self.request_times) > self.sleep_limit:
logger.debug(
"Ratelimit [%s]: sleeping req",
id(request_id),
)
ret_defer = sleep(self.sleep_msec/1000.0)
self.sleeping_requests.add(request_id)
def on_wait_finished(_):
logger.debug(
"Ratelimit [%s]: Finished sleeping",
id(request_id),
)
self.sleeping_requests.discard(request_id)
queue_defer = queue_request()
return queue_defer
ret_defer.addBoth(on_wait_finished)
else:
ret_defer = queue_request()
def on_start(r):
logger.debug(
"Ratelimit [%s]: Processing req",
id(request_id),
)
self.current_processing.add(request_id)
return r
def on_err(r):
self.current_processing.discard(request_id)
return r
def on_both(r):
# Ensure that we've properly cleaned up.
self.sleeping_requests.discard(request_id)
self.ready_request_queue.pop(request_id, None)
return r
ret_defer.addCallbacks(on_start, on_err)
ret_defer.addBoth(on_both)
return ret_defer
def _on_exit(self, request_id):
logger.debug(
"Ratelimit [%s]: Processed req",
id(request_id),
)
self.current_processing.discard(request_id)
try:
request_id, deferred = self.ready_request_queue.popitem()
self.current_processing.add(request_id)
deferred.callback(None)
except KeyError:
pass

View File

@ -18,6 +18,13 @@ from mock import Mock, PropertyMock
from tests import unittest from tests import unittest
def _regex(regex, exclusive=True):
return {
"regex": regex,
"exclusive": exclusive
}
class ApplicationServiceTestCase(unittest.TestCase): class ApplicationServiceTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
@ -36,21 +43,21 @@ class ApplicationServiceTestCase(unittest.TestCase):
def test_regex_user_id_prefix_match(self): def test_regex_user_id_prefix_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(
"@irc_.*" _regex("@irc_.*")
) )
self.event.sender = "@irc_foobar:matrix.org" self.event.sender = "@irc_foobar:matrix.org"
self.assertTrue(self.service.is_interested(self.event)) self.assertTrue(self.service.is_interested(self.event))
def test_regex_user_id_prefix_no_match(self): def test_regex_user_id_prefix_no_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(
"@irc_.*" _regex("@irc_.*")
) )
self.event.sender = "@someone_else:matrix.org" self.event.sender = "@someone_else:matrix.org"
self.assertFalse(self.service.is_interested(self.event)) self.assertFalse(self.service.is_interested(self.event))
def test_regex_room_member_is_checked(self): def test_regex_room_member_is_checked(self):
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(
"@irc_.*" _regex("@irc_.*")
) )
self.event.sender = "@someone_else:matrix.org" self.event.sender = "@someone_else:matrix.org"
self.event.type = "m.room.member" self.event.type = "m.room.member"
@ -59,30 +66,78 @@ class ApplicationServiceTestCase(unittest.TestCase):
def test_regex_room_id_match(self): def test_regex_room_id_match(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append( self.service.namespaces[ApplicationService.NS_ROOMS].append(
"!some_prefix.*some_suffix:matrix.org" _regex("!some_prefix.*some_suffix:matrix.org")
) )
self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org" self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org"
self.assertTrue(self.service.is_interested(self.event)) self.assertTrue(self.service.is_interested(self.event))
def test_regex_room_id_no_match(self): def test_regex_room_id_no_match(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append( self.service.namespaces[ApplicationService.NS_ROOMS].append(
"!some_prefix.*some_suffix:matrix.org" _regex("!some_prefix.*some_suffix:matrix.org")
) )
self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org" self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org"
self.assertFalse(self.service.is_interested(self.event)) self.assertFalse(self.service.is_interested(self.event))
def test_regex_alias_match(self): def test_regex_alias_match(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
"#irc_.*:matrix.org" _regex("#irc_.*:matrix.org")
) )
self.assertTrue(self.service.is_interested( self.assertTrue(self.service.is_interested(
self.event, self.event,
aliases_for_event=["#irc_foobar:matrix.org", "#athing:matrix.org"] aliases_for_event=["#irc_foobar:matrix.org", "#athing:matrix.org"]
)) ))
def test_non_exclusive_alias(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org", exclusive=False)
)
self.assertFalse(self.service.is_exclusive_alias(
"#irc_foobar:matrix.org"
))
def test_non_exclusive_room(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!irc_.*:matrix.org", exclusive=False)
)
self.assertFalse(self.service.is_exclusive_room(
"!irc_foobar:matrix.org"
))
def test_non_exclusive_user(self):
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*:matrix.org", exclusive=False)
)
self.assertFalse(self.service.is_exclusive_user(
"@irc_foobar:matrix.org"
))
def test_exclusive_alias(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org", exclusive=True)
)
self.assertTrue(self.service.is_exclusive_alias(
"#irc_foobar:matrix.org"
))
def test_exclusive_user(self):
self.service.namespaces[ApplicationService.NS_USERS].append(
_regex("@irc_.*:matrix.org", exclusive=True)
)
self.assertTrue(self.service.is_exclusive_user(
"@irc_foobar:matrix.org"
))
def test_exclusive_room(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append(
_regex("!irc_.*:matrix.org", exclusive=True)
)
self.assertTrue(self.service.is_exclusive_room(
"!irc_foobar:matrix.org"
))
def test_regex_alias_no_match(self): def test_regex_alias_no_match(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
"#irc_.*:matrix.org" _regex("#irc_.*:matrix.org")
) )
self.assertFalse(self.service.is_interested( self.assertFalse(self.service.is_interested(
self.event, self.event,
@ -91,10 +146,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
def test_regex_multiple_matches(self): def test_regex_multiple_matches(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
"#irc_.*:matrix.org" _regex("#irc_.*:matrix.org")
) )
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(
"@irc_.*" _regex("@irc_.*")
) )
self.event.sender = "@irc_foobar:matrix.org" self.event.sender = "@irc_foobar:matrix.org"
self.assertTrue(self.service.is_interested( self.assertTrue(self.service.is_interested(
@ -104,10 +159,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
def test_restrict_to_rooms(self): def test_restrict_to_rooms(self):
self.service.namespaces[ApplicationService.NS_ROOMS].append( self.service.namespaces[ApplicationService.NS_ROOMS].append(
"!flibble_.*:matrix.org" _regex("!flibble_.*:matrix.org")
) )
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(
"@irc_.*" _regex("@irc_.*")
) )
self.event.sender = "@irc_foobar:matrix.org" self.event.sender = "@irc_foobar:matrix.org"
self.event.room_id = "!wibblewoo:matrix.org" self.event.room_id = "!wibblewoo:matrix.org"
@ -118,10 +173,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
def test_restrict_to_aliases(self): def test_restrict_to_aliases(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
"#xmpp_.*:matrix.org" _regex("#xmpp_.*:matrix.org")
) )
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(
"@irc_.*" _regex("@irc_.*")
) )
self.event.sender = "@irc_foobar:matrix.org" self.event.sender = "@irc_foobar:matrix.org"
self.assertFalse(self.service.is_interested( self.assertFalse(self.service.is_interested(
@ -132,10 +187,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
def test_restrict_to_senders(self): def test_restrict_to_senders(self):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
"#xmpp_.*:matrix.org" _regex("#xmpp_.*:matrix.org")
) )
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(
"@irc_.*" _regex("@irc_.*")
) )
self.event.sender = "@xmpp_foobar:matrix.org" self.event.sender = "@xmpp_foobar:matrix.org"
self.assertFalse(self.service.is_interested( self.assertFalse(self.service.is_interested(
@ -146,7 +201,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
def test_member_list_match(self): def test_member_list_match(self):
self.service.namespaces[ApplicationService.NS_USERS].append( self.service.namespaces[ApplicationService.NS_USERS].append(
"@irc_.*" _regex("@irc_.*")
) )
join_list = [ join_list = [
Mock( Mock(

View File

@ -389,14 +389,18 @@ class PresenceInvitesTestCase(PresenceTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_invite_remote(self): def test_invite_remote(self):
# Use a different destination, otherwise retry logic might fail the
# request
u_rocket = UserID.from_string("@rocket:there")
put_json = self.mock_http_client.put_json put_json = self.mock_http_client.put_json
put_json.expect_call_and_return( put_json.expect_call_and_return(
call("elsewhere", call("there",
path="/_matrix/federation/v1/send/1000000/", path="/_matrix/federation/v1/send/1000000/",
data=_expect_edu("elsewhere", "m.presence_invite", data=_expect_edu("there", "m.presence_invite",
content={ content={
"observer_user": "@apple:test", "observer_user": "@apple:test",
"observed_user": "@cabbage:elsewhere", "observed_user": "@rocket:there",
} }
), ),
json_data_callback=ANY, json_data_callback=ANY,
@ -405,10 +409,10 @@ class PresenceInvitesTestCase(PresenceTestCase):
) )
yield self.handler.send_invite( yield self.handler.send_invite(
observer_user=self.u_apple, observed_user=self.u_cabbage) observer_user=self.u_apple, observed_user=u_rocket)
self.assertEquals( self.assertEquals(
[{"observed_user_id": "@cabbage:elsewhere", "accepted": 0}], [{"observed_user_id": "@rocket:there", "accepted": 0}],
(yield self.datastore.get_presence_list(self.u_apple.localpart)) (yield self.datastore.get_presence_list(self.u_apple.localpart))
) )
@ -418,13 +422,18 @@ class PresenceInvitesTestCase(PresenceTestCase):
def test_accept_remote(self): def test_accept_remote(self):
# TODO(paul): This test will likely break if/when real auth permissions # TODO(paul): This test will likely break if/when real auth permissions
# are added; for now the HS will always accept any invite # are added; for now the HS will always accept any invite
# Use a different destination, otherwise retry logic might fail the
# request
u_rocket = UserID.from_string("@rocket:moon")
put_json = self.mock_http_client.put_json put_json = self.mock_http_client.put_json
put_json.expect_call_and_return( put_json.expect_call_and_return(
call("elsewhere", call("moon",
path="/_matrix/federation/v1/send/1000000/", path="/_matrix/federation/v1/send/1000000/",
data=_expect_edu("elsewhere", "m.presence_accept", data=_expect_edu("moon", "m.presence_accept",
content={ content={
"observer_user": "@cabbage:elsewhere", "observer_user": "@rocket:moon",
"observed_user": "@apple:test", "observed_user": "@apple:test",
} }
), ),
@ -437,7 +446,7 @@ class PresenceInvitesTestCase(PresenceTestCase):
"/_matrix/federation/v1/send/1000000/", "/_matrix/federation/v1/send/1000000/",
_make_edu_json("elsewhere", "m.presence_invite", _make_edu_json("elsewhere", "m.presence_invite",
content={ content={
"observer_user": "@cabbage:elsewhere", "observer_user": "@rocket:moon",
"observed_user": "@apple:test", "observed_user": "@apple:test",
} }
) )
@ -446,7 +455,7 @@ class PresenceInvitesTestCase(PresenceTestCase):
self.assertTrue( self.assertTrue(
(yield self.datastore.is_presence_visible( (yield self.datastore.is_presence_visible(
observed_localpart=self.u_apple.localpart, observed_localpart=self.u_apple.localpart,
observer_userid=self.u_cabbage.to_string(), observer_userid=u_rocket.to_string(),
)) ))
) )
@ -454,13 +463,17 @@ class PresenceInvitesTestCase(PresenceTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_invited_remote_nonexistant(self): def test_invited_remote_nonexistant(self):
# Use a different destination, otherwise retry logic might fail the
# request
u_rocket = UserID.from_string("@rocket:sun")
put_json = self.mock_http_client.put_json put_json = self.mock_http_client.put_json
put_json.expect_call_and_return( put_json.expect_call_and_return(
call("elsewhere", call("sun",
path="/_matrix/federation/v1/send/1000000/", path="/_matrix/federation/v1/send/1000000/",
data=_expect_edu("elsewhere", "m.presence_deny", data=_expect_edu("sun", "m.presence_deny",
content={ content={
"observer_user": "@cabbage:elsewhere", "observer_user": "@rocket:sun",
"observed_user": "@durian:test", "observed_user": "@durian:test",
} }
), ),
@ -471,9 +484,9 @@ class PresenceInvitesTestCase(PresenceTestCase):
yield self.mock_federation_resource.trigger("PUT", yield self.mock_federation_resource.trigger("PUT",
"/_matrix/federation/v1/send/1000000/", "/_matrix/federation/v1/send/1000000/",
_make_edu_json("elsewhere", "m.presence_invite", _make_edu_json("sun", "m.presence_invite",
content={ content={
"observer_user": "@cabbage:elsewhere", "observer_user": "@rocket:sun",
"observed_user": "@durian:test", "observed_user": "@durian:test",
} }
) )

View File

@ -128,6 +128,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0) self.ratelimiter.send_message.return_value = (True, 0)
hs.config.enable_registration_captcha = False hs.config.enable_registration_captcha = False
hs.config.disable_registration = False
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()

View File

@ -295,6 +295,9 @@ class PresenceEventStreamTestCase(unittest.TestCase):
self.mock_datastore = hs.get_datastore() self.mock_datastore = hs.get_datastore()
self.mock_datastore.get_app_service_by_token = Mock(return_value=None) self.mock_datastore.get_app_service_by_token = Mock(return_value=None)
self.mock_datastore.get_app_service_by_user_id = Mock(
return_value=defer.succeed(None)
)
def get_profile_displayname(user_id): def get_profile_displayname(user_id):
return defer.succeed("Frank") return defer.succeed("Frank")

110
tests/storage/test__base.py Normal file
View File

@ -0,0 +1,110 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from twisted.internet import defer
from synapse.storage._base import cached
class CacheDecoratorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_passthrough(self):
@cached()
def func(self, key):
return key
self.assertEquals((yield func(self, "foo")), "foo")
self.assertEquals((yield func(self, "bar")), "bar")
@defer.inlineCallbacks
def test_hit(self):
callcount = [0]
@cached()
def func(self, key):
callcount[0] += 1
return key
yield func(self, "foo")
self.assertEquals(callcount[0], 1)
self.assertEquals((yield func(self, "foo")), "foo")
self.assertEquals(callcount[0], 1)
@defer.inlineCallbacks
def test_invalidate(self):
callcount = [0]
@cached()
def func(self, key):
callcount[0] += 1
return key
yield func(self, "foo")
self.assertEquals(callcount[0], 1)
func.invalidate("foo")
yield func(self, "foo")
self.assertEquals(callcount[0], 2)
def test_invalidate_missing(self):
@cached()
def func(self, key):
return key
func.invalidate("what")
@defer.inlineCallbacks
def test_max_entries(self):
callcount = [0]
@cached(max_entries=10)
def func(self, key):
callcount[0] += 1
return key
for k in range(0,12):
yield func(self, k)
self.assertEquals(callcount[0], 12)
# There must have been at least 2 evictions, meaning if we calculate
# all 12 values again, we must get called at least 2 more times
for k in range(0,12):
yield func(self, k)
self.assertTrue(callcount[0] >= 14,
msg="Expected callcount >= 14, got %d" % (callcount[0]))
@defer.inlineCallbacks
def test_prefill(self):
callcount = [0]
@cached()
def func(self, key):
callcount[0] += 1
return key
func.prefill("foo", 123)
self.assertEquals((yield func(self, "foo")), 123)
self.assertEquals(callcount[0], 0)

View File

@ -50,9 +50,15 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
def test_update_and_retrieval_of_service(self): def test_update_and_retrieval_of_service(self):
url = "https://matrix.org/appservices/foobar" url = "https://matrix.org/appservices/foobar"
hs_token = "hstok" hs_token = "hstok"
user_regex = ["@foobar_.*:matrix.org"] user_regex = [
alias_regex = ["#foobar_.*:matrix.org"] {"regex": "@foobar_.*:matrix.org", "exclusive": True}
room_regex = [] ]
alias_regex = [
{"regex": "#foobar_.*:matrix.org", "exclusive": False}
]
room_regex = [
]
service = ApplicationService( service = ApplicationService(
url=url, hs_token=hs_token, token=self.as_token, namespaces={ url=url, hs_token=hs_token, token=self.as_token, namespaces={
ApplicationService.NS_USERS: user_regex, ApplicationService.NS_USERS: user_regex,

View File

@ -42,6 +42,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config = Mock() config = Mock()
config.signing_key = [MockKey()] config.signing_key = [MockKey()]
config.event_cache_size = 1 config.event_cache_size = 1
config.disable_registration = False
if datastore is None: if datastore is None:
db_pool = SQLiteMemoryDbPool() db_pool = SQLiteMemoryDbPool()