Merge branch 'develop' of github.com:matrix-org/synapse into erikj/unfederatable

This commit is contained in:
Erik Johnston 2015-10-02 10:33:49 +01:00
commit d5e081c7ae
62 changed files with 1425 additions and 949 deletions

View File

@ -1,3 +1,17 @@
Changes in synapse v0.10.0-r2 (2015-09-16)
==========================================
* Fix bug where we always fetched remote server signing keys instead of using
ones in our cache.
* Fix adding threepids to an existing account.
* Fix bug with invinting over federation where remote server was already in
the room. (PR #281, SYN-392)
Changes in synapse v0.10.0-r1 (2015-09-08)
==========================================
* Fix bug with python packaging
Changes in synapse v0.10.0 (2015-09-03) Changes in synapse v0.10.0 (2015-09-03)
======================================= =======================================

View File

@ -25,6 +25,7 @@ for port in 8080 8081 8082; do
--generate-config \ --generate-config \
-H "localhost:$https_port" \ -H "localhost:$https_port" \
--config-path "$DIR/etc/$port.config" \ --config-path "$DIR/etc/$port.config" \
--report-stats no
# Check script parameters # Check script parameters
if [ $# -eq 1 ]; then if [ $# -eq 1 ]; then

142
scripts-dev/definitions.py Executable file
View File

@ -0,0 +1,142 @@
#! /usr/bin/python
import ast
import yaml
class DefinitionVisitor(ast.NodeVisitor):
def __init__(self):
super(DefinitionVisitor, self).__init__()
self.functions = {}
self.classes = {}
self.names = {}
self.attrs = set()
self.definitions = {
'def': self.functions,
'class': self.classes,
'names': self.names,
'attrs': self.attrs,
}
def visit_Name(self, node):
self.names.setdefault(type(node.ctx).__name__, set()).add(node.id)
def visit_Attribute(self, node):
self.attrs.add(node.attr)
for child in ast.iter_child_nodes(node):
self.visit(child)
def visit_ClassDef(self, node):
visitor = DefinitionVisitor()
self.classes[node.name] = visitor.definitions
for child in ast.iter_child_nodes(node):
visitor.visit(child)
def visit_FunctionDef(self, node):
visitor = DefinitionVisitor()
self.functions[node.name] = visitor.definitions
for child in ast.iter_child_nodes(node):
visitor.visit(child)
def non_empty(defs):
functions = {name: non_empty(f) for name, f in defs['def'].items()}
classes = {name: non_empty(f) for name, f in defs['class'].items()}
result = {}
if functions: result['def'] = functions
if classes: result['class'] = classes
names = defs['names']
uses = []
for name in names.get('Load', ()):
if name not in names.get('Param', ()) and name not in names.get('Store', ()):
uses.append(name)
uses.extend(defs['attrs'])
if uses: result['uses'] = uses
result['names'] = names
result['attrs'] = defs['attrs']
return result
def definitions_in_code(input_code):
input_ast = ast.parse(input_code)
visitor = DefinitionVisitor()
visitor.visit(input_ast)
definitions = non_empty(visitor.definitions)
return definitions
def definitions_in_file(filepath):
with open(filepath) as f:
return definitions_in_code(f.read())
def defined_names(prefix, defs, names):
for name, funcs in defs.get('def', {}).items():
names.setdefault(name, {'defined': []})['defined'].append(prefix + name)
defined_names(prefix + name + ".", funcs, names)
for name, funcs in defs.get('class', {}).items():
names.setdefault(name, {'defined': []})['defined'].append(prefix + name)
defined_names(prefix + name + ".", funcs, names)
def used_names(prefix, defs, names):
for name, funcs in defs.get('def', {}).items():
used_names(prefix + name + ".", funcs, names)
for name, funcs in defs.get('class', {}).items():
used_names(prefix + name + ".", funcs, names)
for used in defs.get('uses', ()):
if used in names:
names[used].setdefault('used', []).append(prefix.rstrip('.'))
if __name__ == '__main__':
import sys, os, argparse, re
parser = argparse.ArgumentParser(description='Find definitions.')
parser.add_argument(
"--unused", action="store_true", help="Only list unused definitions"
)
parser.add_argument(
"--ignore", action="append", metavar="REGEXP", help="Ignore a pattern"
)
parser.add_argument(
"--pattern", action="append", metavar="REGEXP",
help="Search for a pattern"
)
parser.add_argument(
"directories", nargs='+', metavar="DIR",
help="Directories to search for definitions"
)
args = parser.parse_args()
definitions = {}
for directory in args.directories:
for root, dirs, files in os.walk(directory):
for filename in files:
if filename.endswith(".py"):
filepath = os.path.join(root, filename)
definitions[filepath] = definitions_in_file(filepath)
names = {}
for filepath, defs in definitions.items():
defined_names(filepath + ":", defs, names)
for filepath, defs in definitions.items():
used_names(filepath + ":", defs, names)
patterns = [re.compile(pattern) for pattern in args.pattern or ()]
ignore = [re.compile(pattern) for pattern in args.ignore or ()]
result = {}
for name, definition in names.items():
if patterns and not any(pattern.match(name) for pattern in patterns):
continue
if ignore and any(pattern.match(name) for pattern in ignore):
continue
if args.unused and definition.get('used'):
continue
result[name] = definition
yaml.dump(result, sys.stdout, default_flow_style=False)

View File

@ -95,8 +95,6 @@ class Store(object):
_simple_update_one = SQLBaseStore.__dict__["_simple_update_one"] _simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
_simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"] _simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
_execute_and_decode = SQLBaseStore.__dict__["_execute_and_decode"]
def runInteraction(self, desc, func, *args, **kwargs): def runInteraction(self, desc, func, *args, **kwargs):
def r(conn): def r(conn):
try: try:

View File

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.10.0" __version__ = "0.10.0-r2"

View File

@ -23,6 +23,7 @@ from synapse.util.logutils import log_function
from synapse.types import RoomID, UserID, EventID from synapse.types import RoomID, UserID, EventID
import logging import logging
import pymacaroons
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -40,6 +41,12 @@ class Auth(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
self._KNOWN_CAVEAT_PREFIXES = set([
"gen = ",
"type = ",
"time < ",
"user_id = ",
])
def check(self, event, auth_events): def check(self, event, auth_events):
""" Checks if this event is correctly authed. """ Checks if this event is correctly authed.
@ -121,6 +128,20 @@ class Auth(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def check_joined_room(self, room_id, user_id, current_state=None): def check_joined_room(self, room_id, user_id, current_state=None):
"""Check if the user is currently joined in the room
Args:
room_id(str): The room to check.
user_id(str): The user to check.
current_state(dict): Optional map of the current state of the room.
If provided then that map is used to check whether they are a
member of the room. Otherwise the current membership is
loaded from the database.
Raises:
AuthError if the user is not in the room.
Returns:
A deferred membership event for the user if the user is in
the room.
"""
if current_state: if current_state:
member = current_state.get( member = current_state.get(
(EventTypes.Member, user_id), (EventTypes.Member, user_id),
@ -136,6 +157,43 @@ class Auth(object):
self._check_joined_room(member, user_id, room_id) self._check_joined_room(member, user_id, room_id)
defer.returnValue(member) defer.returnValue(member)
@defer.inlineCallbacks
def check_user_was_in_room(self, room_id, user_id, current_state=None):
"""Check if the user was in the room at some point.
Args:
room_id(str): The room to check.
user_id(str): The user to check.
current_state(dict): Optional map of the current state of the room.
If provided then that map is used to check whether they are a
member of the room. Otherwise the current membership is
loaded from the database.
Raises:
AuthError if the user was never in the room.
Returns:
A deferred membership event for the user if the user was in the
room. This will be the join event if they are currently joined to
the room. This will be the leave event if they have left the room.
"""
if current_state:
member = current_state.get(
(EventTypes.Member, user_id),
None
)
else:
member = yield self.state.get_current_state(
room_id=room_id,
event_type=EventTypes.Member,
state_key=user_id
)
membership = member.membership if member else None
if membership not in (Membership.JOIN, Membership.LEAVE):
raise AuthError(403, "User %s not in room %s" % (
user_id, room_id
))
defer.returnValue(member)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_host_in_room(self, room_id, host): def check_host_in_room(self, room_id, host):
curr_state = yield self.state.get_current_state(room_id) curr_state = yield self.state.get_current_state(room_id)
@ -390,7 +448,7 @@ class Auth(object):
except KeyError: except KeyError:
pass # normal users won't have the user_id query parameter set. pass # normal users won't have the user_id query parameter set.
user_info = yield self.get_user_by_access_token(access_token) user_info = yield self._get_user_by_access_token(access_token)
user = user_info["user"] user = user_info["user"]
token_id = user_info["token_id"] token_id = user_info["token_id"]
@ -417,7 +475,7 @@ class Auth(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_by_access_token(self, token): def _get_user_by_access_token(self, token):
""" Get a registered user's ID. """ Get a registered user's ID.
Args: Args:
@ -427,6 +485,86 @@ class Auth(object):
Raises: Raises:
AuthError if no user by that token exists or the token is invalid. AuthError if no user by that token exists or the token is invalid.
""" """
try:
ret = yield self._get_user_from_macaroon(token)
except AuthError:
# TODO(daniel): Remove this fallback when all existing access tokens
# have been re-issued as macaroons.
ret = yield self._look_up_user_by_access_token(token)
defer.returnValue(ret)
@defer.inlineCallbacks
def _get_user_from_macaroon(self, macaroon_str):
try:
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
self._validate_macaroon(macaroon)
user_prefix = "user_id = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
user = UserID.from_string(caveat.caveat_id[len(user_prefix):])
# This codepath exists so that we can actually return a
# token ID, because we use token IDs in place of device
# identifiers throughout the codebase.
# TODO(daniel): Remove this fallback when device IDs are
# properly implemented.
ret = yield self._look_up_user_by_access_token(macaroon_str)
if ret["user"] != user:
logger.error(
"Macaroon user (%s) != DB user (%s)",
user,
ret["user"]
)
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"User mismatch in macaroon",
errcode=Codes.UNKNOWN_TOKEN
)
defer.returnValue(ret)
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
errcode=Codes.UNKNOWN_TOKEN
)
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
errcode=Codes.UNKNOWN_TOKEN
)
def _validate_macaroon(self, macaroon):
v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1")
v.satisfy_exact("type = access")
v.satisfy_general(lambda c: c.startswith("user_id = "))
v.satisfy_general(self._verify_expiry)
v.verify(macaroon, self.hs.config.macaroon_secret_key)
v = pymacaroons.Verifier()
v.satisfy_general(self._verify_recognizes_caveats)
v.verify(macaroon, self.hs.config.macaroon_secret_key)
def _verify_expiry(self, caveat):
prefix = "time < "
if not caveat.startswith(prefix):
return False
# TODO(daniel): Enable expiry check when clients actually know how to
# refresh tokens. (And remember to enable the tests)
return True
expiry = int(caveat[len(prefix):])
now = self.hs.get_clock().time_msec()
return now < expiry
def _verify_recognizes_caveats(self, caveat):
first_space = caveat.find(" ")
if first_space < 0:
return False
second_space = caveat.find(" ", first_space + 1)
if second_space < 0:
return False
return caveat[:second_space + 1] in self._KNOWN_CAVEAT_PREFIXES
@defer.inlineCallbacks
def _look_up_user_by_access_token(self, token):
ret = yield self.store.get_user_by_access_token(token) ret = yield self.store.get_user_by_access_token(token)
if not ret: if not ret:
raise AuthError( raise AuthError(
@ -437,7 +575,6 @@ class Auth(object):
"user": UserID.from_string(ret.get("name")), "user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None), "token_id": ret.get("token_id", None),
} }
defer.returnValue(user_info) defer.returnValue(user_info)
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -27,16 +27,6 @@ class Membership(object):
LIST = (INVITE, JOIN, KNOCK, LEAVE, BAN) LIST = (INVITE, JOIN, KNOCK, LEAVE, BAN)
class Feedback(object):
"""Represents the types of feedback a user can send in response to a
message."""
DELIVERED = u"delivered"
READ = u"read"
LIST = (DELIVERED, READ)
class PresenceState(object): class PresenceState(object):
"""Represents the presence state of a user.""" """Represents the presence state of a user."""
OFFLINE = u"offline" OFFLINE = u"offline"
@ -73,7 +63,6 @@ class EventTypes(object):
PowerLevels = "m.room.power_levels" PowerLevels = "m.room.power_levels"
Aliases = "m.room.aliases" Aliases = "m.room.aliases"
Redaction = "m.room.redaction" Redaction = "m.room.redaction"
Feedback = "m.room.message.feedback"
RoomHistoryVisibility = "m.room.history_visibility" RoomHistoryVisibility = "m.room.history_visibility"
CanonicalAlias = "m.room.canonical_alias" CanonicalAlias = "m.room.canonical_alias"

View File

@ -77,11 +77,6 @@ class SynapseError(CodeMessageException):
) )
class RoomError(SynapseError):
"""An error raised when a room event fails."""
pass
class RegistrationError(SynapseError): class RegistrationError(SynapseError):
"""An error raised when a registration event fails.""" """An error raised when a registration event fails."""
pass pass

View File

@ -16,10 +16,23 @@
import sys import sys
sys.dont_write_bytecode = True sys.dont_write_bytecode = True
from synapse.python_dependencies import check_requirements, DEPENDENCY_LINKS from synapse.python_dependencies import (
check_requirements, DEPENDENCY_LINKS, MissingRequirementError
)
if __name__ == '__main__': if __name__ == '__main__':
check_requirements() try:
check_requirements()
except MissingRequirementError as e:
message = "\n".join([
"Missing Requirement: %s" % (e.message,),
"To install run:",
" pip install --upgrade --force \"%s\"" % (e.dependency,),
"",
])
sys.stderr.writelines(message)
sys.exit(1)
from synapse.storage.engines import create_engine, IncorrectDatabaseSetup from synapse.storage.engines import create_engine, IncorrectDatabaseSetup
from synapse.storage import ( from synapse.storage import (
@ -29,7 +42,7 @@ from synapse.storage import (
from synapse.server import HomeServer from synapse.server import HomeServer
from twisted.internet import reactor from twisted.internet import reactor, task, defer
from twisted.application import service from twisted.application import service
from twisted.enterprise import adbapi from twisted.enterprise import adbapi
from twisted.web.resource import Resource, EncodingResourceWrapper from twisted.web.resource import Resource, EncodingResourceWrapper
@ -72,12 +85,6 @@ import time
logger = logging.getLogger("synapse.app.homeserver") logger = logging.getLogger("synapse.app.homeserver")
class GzipFile(File):
def getChild(self, path, request):
child = File.getChild(self, path, request)
return EncodingResourceWrapper(child, [GzipEncoderFactory()])
def gz_wrap(r): def gz_wrap(r):
return EncodingResourceWrapper(r, [GzipEncoderFactory()]) return EncodingResourceWrapper(r, [GzipEncoderFactory()])
@ -121,6 +128,7 @@ class SynapseHomeServer(HomeServer):
# (It can stay enabled for the API resources: they call # (It can stay enabled for the API resources: they call
# write() with the whole body and then finish() straight # write() with the whole body and then finish() straight
# after and so do not trigger the bug. # after and so do not trigger the bug.
# GzipFile was removed in commit 184ba09
# return GzipFile(webclient_path) # TODO configurable? # return GzipFile(webclient_path) # TODO configurable?
return File(webclient_path) # TODO configurable? return File(webclient_path) # TODO configurable?
@ -221,7 +229,7 @@ class SynapseHomeServer(HomeServer):
listener_config, listener_config,
root_resource, root_resource,
), ),
self.tls_context_factory, self.tls_server_context_factory,
interface=bind_address interface=bind_address
) )
else: else:
@ -365,7 +373,6 @@ def setup(config_options):
Args: Args:
config_options_options: The options passed to Synapse. Usually config_options_options: The options passed to Synapse. Usually
`sys.argv[1:]`. `sys.argv[1:]`.
should_run (bool): Whether to start the reactor.
Returns: Returns:
HomeServer HomeServer
@ -388,7 +395,7 @@ def setup(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts events.USE_FROZEN_DICTS = config.use_frozen_dicts
tls_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
database_engine = create_engine(config.database_config["name"]) database_engine = create_engine(config.database_config["name"])
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
@ -396,7 +403,7 @@ def setup(config_options):
hs = SynapseHomeServer( hs = SynapseHomeServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_context_factory=tls_context_factory, tls_server_context_factory=tls_server_context_factory,
config=config, config=config,
content_addr=config.content_addr, content_addr=config.content_addr,
version_string=version_string, version_string=version_string,
@ -665,6 +672,42 @@ def run(hs):
ThreadPool._worker = profile(ThreadPool._worker) ThreadPool._worker = profile(ThreadPool._worker)
reactor.run = profile(reactor.run) reactor.run = profile(reactor.run)
start_time = hs.get_clock().time()
@defer.inlineCallbacks
def phone_stats_home():
now = int(hs.get_clock().time())
uptime = int(now - start_time)
if uptime < 0:
uptime = 0
stats = {}
stats["homeserver"] = hs.config.server_name
stats["timestamp"] = now
stats["uptime_seconds"] = uptime
stats["total_users"] = yield hs.get_datastore().count_all_users()
all_rooms = yield hs.get_datastore().get_rooms(False)
stats["total_room_count"] = len(all_rooms)
stats["daily_active_users"] = yield hs.get_datastore().count_daily_users()
daily_messages = yield hs.get_datastore().count_daily_messages()
if daily_messages is not None:
stats["daily_messages"] = daily_messages
logger.info("Reporting stats to matrix.org: %s" % (stats,))
try:
yield hs.get_simple_http_client().put_json(
"https://matrix.org/report-usage-stats/push",
stats
)
except Exception as e:
logger.warn("Error reporting stats: %s", e)
if hs.config.report_stats:
phone_home_task = task.LoopingCall(phone_stats_home)
phone_home_task.start(60 * 60 * 24, now=False)
def in_thread(): def in_thread():
with LoggingContext("run"): with LoggingContext("run"):
change_resource_limit(hs.config.soft_file_limit) change_resource_limit(hs.config.soft_file_limit)

View File

@ -16,57 +16,67 @@
import sys import sys
import os import os
import os.path
import subprocess import subprocess
import signal import signal
import yaml import yaml
SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"] SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"]
CONFIGFILE = "homeserver.yaml"
GREEN = "\x1b[1;32m" GREEN = "\x1b[1;32m"
RED = "\x1b[1;31m"
NORMAL = "\x1b[m" NORMAL = "\x1b[m"
if not os.path.exists(CONFIGFILE):
sys.stderr.write(
"No config file found\n"
"To generate a config file, run '%s -c %s --generate-config"
" --server-name=<server name>'\n" % (
" ".join(SYNAPSE), CONFIGFILE
)
)
sys.exit(1)
CONFIG = yaml.load(open(CONFIGFILE)) def start(configfile):
PIDFILE = CONFIG["pid_file"]
def start():
print "Starting ...", print "Starting ...",
args = SYNAPSE args = SYNAPSE
args.extend(["--daemonize", "-c", CONFIGFILE]) args.extend(["--daemonize", "-c", configfile])
subprocess.check_call(args)
print GREEN + "started" + NORMAL try:
subprocess.check_call(args)
print GREEN + "started" + NORMAL
except subprocess.CalledProcessError as e:
print (
RED +
"error starting (exit code: %d); see above for logs" % e.returncode +
NORMAL
)
def stop(): def stop(pidfile):
if os.path.exists(PIDFILE): if os.path.exists(pidfile):
pid = int(open(PIDFILE).read()) pid = int(open(pidfile).read())
os.kill(pid, signal.SIGTERM) os.kill(pid, signal.SIGTERM)
print GREEN + "stopped" + NORMAL print GREEN + "stopped" + NORMAL
def main(): def main():
configfile = sys.argv[2] if len(sys.argv) == 3 else "homeserver.yaml"
if not os.path.exists(configfile):
sys.stderr.write(
"No config file found\n"
"To generate a config file, run '%s -c %s --generate-config"
" --server-name=<server name>'\n" % (
" ".join(SYNAPSE), configfile
)
)
sys.exit(1)
config = yaml.load(open(configfile))
pidfile = config["pid_file"]
action = sys.argv[1] if sys.argv[1:] else "usage" action = sys.argv[1] if sys.argv[1:] else "usage"
if action == "start": if action == "start":
start() start(configfile)
elif action == "stop": elif action == "stop":
stop() stop(pidfile)
elif action == "restart": elif action == "restart":
stop() stop(pidfile)
start() start(configfile)
else: else:
sys.stderr.write("Usage: %s [start|stop|restart]\n" % (sys.argv[0],)) sys.stderr.write("Usage: %s [start|stop|restart] [configfile]\n" % (sys.argv[0],))
sys.exit(1) sys.exit(1)

View File

@ -26,6 +26,16 @@ class ConfigError(Exception):
class Config(object): class Config(object):
stats_reporting_begging_spiel = (
"We would really appreciate it if you could help our project out by"
" reporting anonymized usage statistics from your homeserver. Only very"
" basic aggregate data (e.g. number of users) will be reported, but it"
" helps us to track the growth of the Matrix community, and helps us to"
" make Matrix a success, as well as to convince other networks that they"
" should peer with us."
"\nThank you."
)
@staticmethod @staticmethod
def parse_size(value): def parse_size(value):
if isinstance(value, int) or isinstance(value, long): if isinstance(value, int) or isinstance(value, long):
@ -111,11 +121,14 @@ class Config(object):
results.append(getattr(cls, name)(self, *args, **kargs)) results.append(getattr(cls, name)(self, *args, **kargs))
return results return results
def generate_config(self, config_dir_path, server_name): def generate_config(self, config_dir_path, server_name, report_stats=None):
default_config = "# vim:ft=yaml\n" default_config = "# vim:ft=yaml\n"
default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all( default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all(
"default_config", config_dir_path, server_name "default_config",
config_dir_path=config_dir_path,
server_name=server_name,
report_stats=report_stats,
)) ))
config = yaml.load(default_config) config = yaml.load(default_config)
@ -139,6 +152,12 @@ class Config(object):
action="store_true", action="store_true",
help="Generate a config file for the server name" help="Generate a config file for the server name"
) )
config_parser.add_argument(
"--report-stats",
action="store",
help="Stuff",
choices=["yes", "no"]
)
config_parser.add_argument( config_parser.add_argument(
"--generate-keys", "--generate-keys",
action="store_true", action="store_true",
@ -189,6 +208,11 @@ class Config(object):
config_files.append(config_path) config_files.append(config_path)
if config_args.generate_config: if config_args.generate_config:
if config_args.report_stats is None:
config_parser.error(
"Please specify either --report-stats=yes or --report-stats=no\n\n" +
cls.stats_reporting_begging_spiel
)
if not config_files: if not config_files:
config_parser.error( config_parser.error(
"Must supply a config file.\nA config file can be automatically" "Must supply a config file.\nA config file can be automatically"
@ -211,7 +235,9 @@ class Config(object):
os.makedirs(config_dir_path) os.makedirs(config_dir_path)
with open(config_path, "wb") as config_file: with open(config_path, "wb") as config_file:
config_bytes, config = obj.generate_config( config_bytes, config = obj.generate_config(
config_dir_path, server_name config_dir_path=config_dir_path,
server_name=server_name,
report_stats=(config_args.report_stats == "yes"),
) )
obj.invoke_all("generate_files", config) obj.invoke_all("generate_files", config)
config_file.write(config_bytes) config_file.write(config_bytes)
@ -261,9 +287,20 @@ class Config(object):
specified_config.update(yaml_config) specified_config.update(yaml_config)
server_name = specified_config["server_name"] server_name = specified_config["server_name"]
_, config = obj.generate_config(config_dir_path, server_name) _, config = obj.generate_config(
config_dir_path=config_dir_path,
server_name=server_name
)
config.pop("log_config") config.pop("log_config")
config.update(specified_config) config.update(specified_config)
if "report_stats" not in config:
sys.stderr.write(
"Please opt in or out of reporting anonymized homeserver usage "
"statistics, by setting the report_stats key in your config file "
" ( " + config_path + " ) " +
"to either True or False.\n\n" +
Config.stats_reporting_begging_spiel + "\n")
sys.exit(1)
if generate_keys: if generate_keys:
obj.invoke_all("generate_files", config) obj.invoke_all("generate_files", config)

View File

@ -20,7 +20,7 @@ class AppServiceConfig(Config):
def read_config(self, config): def read_config(self, config):
self.app_service_config_files = config.get("app_service_config_files", []) self.app_service_config_files = config.get("app_service_config_files", [])
def default_config(cls, config_dir_path, server_name): def default_config(cls, **kwargs):
return """\ return """\
# A list of application service config file to use # A list of application service config file to use
app_service_config_files: [] app_service_config_files: []

View File

@ -24,7 +24,7 @@ class CaptchaConfig(Config):
self.captcha_bypass_secret = config.get("captcha_bypass_secret") self.captcha_bypass_secret = config.get("captcha_bypass_secret")
self.recaptcha_siteverify_api = config["recaptcha_siteverify_api"] self.recaptcha_siteverify_api = config["recaptcha_siteverify_api"]
def default_config(self, config_dir_path, server_name): def default_config(self, **kwargs):
return """\ return """\
## Captcha ## ## Captcha ##

View File

@ -45,7 +45,7 @@ class DatabaseConfig(Config):
self.set_databasepath(config.get("database_path")) self.set_databasepath(config.get("database_path"))
def default_config(self, config, config_dir_path): def default_config(self, **kwargs):
database_path = self.abspath("homeserver.db") database_path = self.abspath("homeserver.db")
return """\ return """\
# Database configuration # Database configuration

View File

@ -40,7 +40,7 @@ class KeyConfig(Config):
config["perspectives"] config["perspectives"]
) )
def default_config(self, config_dir_path, server_name): def default_config(self, config_dir_path, server_name, **kwargs):
base_key_name = os.path.join(config_dir_path, server_name) base_key_name = os.path.join(config_dir_path, server_name)
return """\ return """\
## Signing Keys ## ## Signing Keys ##

View File

@ -21,6 +21,7 @@ import logging.config
import yaml import yaml
from string import Template from string import Template
import os import os
import signal
DEFAULT_LOG_CONFIG = Template(""" DEFAULT_LOG_CONFIG = Template("""
@ -69,7 +70,7 @@ class LoggingConfig(Config):
self.log_config = self.abspath(config.get("log_config")) self.log_config = self.abspath(config.get("log_config"))
self.log_file = self.abspath(config.get("log_file")) self.log_file = self.abspath(config.get("log_file"))
def default_config(self, config_dir_path, server_name): def default_config(self, config_dir_path, server_name, **kwargs):
log_file = self.abspath("homeserver.log") log_file = self.abspath("homeserver.log")
log_config = self.abspath( log_config = self.abspath(
os.path.join(config_dir_path, server_name + ".log.config") os.path.join(config_dir_path, server_name + ".log.config")
@ -142,6 +143,19 @@ class LoggingConfig(Config):
handler = logging.handlers.RotatingFileHandler( handler = logging.handlers.RotatingFileHandler(
self.log_file, maxBytes=(1000 * 1000 * 100), backupCount=3 self.log_file, maxBytes=(1000 * 1000 * 100), backupCount=3
) )
def sighup(signum, stack):
logger.info("Closing log file due to SIGHUP")
handler.doRollover()
logger.info("Opened new log file due to SIGHUP")
# TODO(paul): obviously this is a terrible mechanism for
# stealing SIGHUP, because it means no other part of synapse
# can use it instead. If we want to catch SIGHUP anywhere
# else as well, I'd suggest we find a nicer way to broadcast
# it around.
if getattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, sighup)
else: else:
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setFormatter(formatter) handler.setFormatter(formatter)

View File

@ -19,13 +19,15 @@ from ._base import Config
class MetricsConfig(Config): class MetricsConfig(Config):
def read_config(self, config): def read_config(self, config):
self.enable_metrics = config["enable_metrics"] self.enable_metrics = config["enable_metrics"]
self.report_stats = config.get("report_stats", None)
self.metrics_port = config.get("metrics_port") self.metrics_port = config.get("metrics_port")
self.metrics_bind_host = config.get("metrics_bind_host", "127.0.0.1") self.metrics_bind_host = config.get("metrics_bind_host", "127.0.0.1")
def default_config(self, config_dir_path, server_name): def default_config(self, report_stats=None, **kwargs):
return """\ suffix = "" if report_stats is None else "report_stats: %(report_stats)s\n"
return ("""\
## Metrics ### ## Metrics ###
# Enable collection and rendering of performance metrics # Enable collection and rendering of performance metrics
enable_metrics: False enable_metrics: False
""" """ + suffix) % locals()

View File

@ -27,7 +27,7 @@ class RatelimitConfig(Config):
self.federation_rc_reject_limit = config["federation_rc_reject_limit"] self.federation_rc_reject_limit = config["federation_rc_reject_limit"]
self.federation_rc_concurrent = config["federation_rc_concurrent"] self.federation_rc_concurrent = config["federation_rc_concurrent"]
def default_config(self, config_dir_path, server_name): def default_config(self, **kwargs):
return """\ return """\
## Ratelimiting ## ## Ratelimiting ##

View File

@ -34,7 +34,7 @@ class RegistrationConfig(Config):
self.registration_shared_secret = config.get("registration_shared_secret") self.registration_shared_secret = config.get("registration_shared_secret")
self.macaroon_secret_key = config.get("macaroon_secret_key") self.macaroon_secret_key = config.get("macaroon_secret_key")
def default_config(self, config_dir, server_name): def default_config(self, **kwargs):
registration_shared_secret = random_string_with_symbols(50) registration_shared_secret = random_string_with_symbols(50)
macaroon_secret_key = random_string_with_symbols(50) macaroon_secret_key = random_string_with_symbols(50)
return """\ return """\

View File

@ -60,7 +60,7 @@ class ContentRepositoryConfig(Config):
config["thumbnail_sizes"] config["thumbnail_sizes"]
) )
def default_config(self, config_dir_path, server_name): def default_config(self, **kwargs):
media_store = self.default_path("media_store") media_store = self.default_path("media_store")
uploads_path = self.default_path("uploads") uploads_path = self.default_path("uploads")
return """ return """

View File

@ -41,7 +41,7 @@ class SAML2Config(Config):
self.saml2_config_path = None self.saml2_config_path = None
self.saml2_idp_redirect_url = None self.saml2_idp_redirect_url = None
def default_config(self, config_dir_path, server_name): def default_config(self, config_dir_path, server_name, **kwargs):
return """ return """
# Enable SAML2 for registration and login. Uses pysaml2 # Enable SAML2 for registration and login. Uses pysaml2
# config_path: Path to the sp_conf.py configuration file # config_path: Path to the sp_conf.py configuration file

View File

@ -117,7 +117,7 @@ class ServerConfig(Config):
self.content_addr = content_addr self.content_addr = content_addr
def default_config(self, config_dir_path, server_name): def default_config(self, server_name, **kwargs):
if ":" in server_name: if ":" in server_name:
bind_port = int(server_name.split(":")[1]) bind_port = int(server_name.split(":")[1])
unsecure_port = bind_port - 400 unsecure_port = bind_port - 400

View File

@ -42,7 +42,15 @@ class TlsConfig(Config):
config.get("tls_dh_params_path"), "tls_dh_params" config.get("tls_dh_params_path"), "tls_dh_params"
) )
def default_config(self, config_dir_path, server_name): # This config option applies to non-federation HTTP clients
# (e.g. for talking to recaptcha, identity servers, and such)
# It should never be used in production, and is intended for
# use only when running tests.
self.use_insecure_ssl_client_just_for_testing_do_not_use = config.get(
"use_insecure_ssl_client_just_for_testing_do_not_use"
)
def default_config(self, config_dir_path, server_name, **kwargs):
base_key_name = os.path.join(config_dir_path, server_name) base_key_name = os.path.join(config_dir_path, server_name)
tls_certificate_path = base_key_name + ".tls.crt" tls_certificate_path = base_key_name + ".tls.crt"

View File

@ -22,7 +22,7 @@ class VoipConfig(Config):
self.turn_shared_secret = config["turn_shared_secret"] self.turn_shared_secret = config["turn_shared_secret"]
self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"]) self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"])
def default_config(self, config_dir_path, server_name): def default_config(self, **kwargs):
return """\ return """\
## Turn ## ## Turn ##

View File

@ -228,10 +228,9 @@ class Keyring(object):
def do_iterations(): def do_iterations():
merged_results = {} merged_results = {}
missing_keys = { missing_keys = {}
group.server_name: set(group.key_ids) for group in group_id_to_group.values():
for group in group_id_to_group.values() missing_keys.setdefault(group.server_name, set()).union(group.key_ids)
}
for fn in key_fetch_fns: for fn in key_fetch_fns:
results = yield fn(missing_keys.items()) results = yield fn(missing_keys.items())
@ -470,7 +469,7 @@ class Keyring(object):
continue continue
(response, tls_certificate) = yield fetch_server_key( (response, tls_certificate) = yield fetch_server_key(
server_name, self.hs.tls_context_factory, server_name, self.hs.tls_server_context_factory,
path=(b"/_matrix/key/v2/server/%s" % ( path=(b"/_matrix/key/v2/server/%s" % (
urllib.quote(requested_key_id), urllib.quote(requested_key_id),
)).encode("ascii"), )).encode("ascii"),
@ -604,7 +603,7 @@ class Keyring(object):
# Try to fetch the key from the remote server. # Try to fetch the key from the remote server.
(response, tls_certificate) = yield fetch_server_key( (response, tls_certificate) = yield fetch_server_key(
server_name, self.hs.tls_context_factory server_name, self.hs.tls_server_context_factory
) )
# Check the response. # Check the response.

View File

@ -19,7 +19,6 @@ from ._base import BaseHandler
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.types import UserID from synapse.types import UserID
from synapse.api.errors import LoginError, Codes from synapse.api.errors import LoginError, Codes
from synapse.http.client import SimpleHttpClient
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from twisted.web.client import PartialDownloadError from twisted.web.client import PartialDownloadError
@ -187,7 +186,7 @@ class AuthHandler(BaseHandler):
# TODO: get this from the homeserver rather than creating a new one for # TODO: get this from the homeserver rather than creating a new one for
# each request # each request
try: try:
client = SimpleHttpClient(self.hs) client = self.hs.get_simple_http_client()
resp_body = yield client.post_urlencoded_get_json( resp_body = yield client.post_urlencoded_get_json(
self.hs.config.recaptcha_siteverify_api, self.hs.config.recaptcha_siteverify_api,
args={ args={

View File

@ -125,60 +125,72 @@ class FederationHandler(BaseHandler):
) )
if not is_in_room and not event.internal_metadata.is_outlier(): if not is_in_room and not event.internal_metadata.is_outlier():
logger.debug("Got event for room we're not in.") logger.debug("Got event for room we're not in.")
current_state = state
event_ids = set() try:
if state: event_stream_id, max_stream_id = yield self._persist_auth_tree(
event_ids |= {e.event_id for e in state} auth_chain, state, event
if auth_chain: )
event_ids |= {e.event_id for e in auth_chain} except AuthError as e:
raise FederationError(
"ERROR",
e.code,
e.msg,
affected=event.event_id,
)
seen_ids = set( else:
(yield self.store.have_events(event_ids)).keys() event_ids = set()
) if state:
event_ids |= {e.event_id for e in state}
if auth_chain:
event_ids |= {e.event_id for e in auth_chain}
if state and auth_chain is not None: seen_ids = set(
# If we have any state or auth_chain given to us by the replication (yield self.store.have_events(event_ids)).keys()
# layer, then we should handle them (if we haven't before.)
event_infos = []
for e in itertools.chain(auth_chain, state):
if e.event_id in seen_ids:
continue
e.internal_metadata.outlier = True
auth_ids = [e_id for e_id, _ in e.auth_events]
auth = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
event_infos.append({
"event": e,
"auth_events": auth,
})
seen_ids.add(e.event_id)
yield self._handle_new_events(
origin,
event_infos,
outliers=True
) )
try: if state and auth_chain is not None:
_, event_stream_id, max_stream_id = yield self._handle_new_event( # If we have any state or auth_chain given to us by the replication
origin, # layer, then we should handle them (if we haven't before.)
event,
state=state, event_infos = []
backfilled=backfilled,
current_state=current_state, for e in itertools.chain(auth_chain, state):
) if e.event_id in seen_ids:
except AuthError as e: continue
raise FederationError( e.internal_metadata.outlier = True
"ERROR", auth_ids = [e_id for e_id, _ in e.auth_events]
e.code, auth = {
e.msg, (e.type, e.state_key): e for e in auth_chain
affected=event.event_id, if e.event_id in auth_ids
) }
event_infos.append({
"event": e,
"auth_events": auth,
})
seen_ids.add(e.event_id)
yield self._handle_new_events(
origin,
event_infos,
outliers=True
)
try:
_, event_stream_id, max_stream_id = yield self._handle_new_event(
origin,
event,
state=state,
backfilled=backfilled,
current_state=current_state,
)
except AuthError as e:
raise FederationError(
"ERROR",
e.code,
e.msg,
affected=event.event_id,
)
# if we're receiving valid events from an origin, # if we're receiving valid events from an origin,
# it's probably a good idea to mark it as not in retry-state # it's probably a good idea to mark it as not in retry-state
@ -649,35 +661,8 @@ class FederationHandler(BaseHandler):
# FIXME # FIXME
pass pass
ev_infos = [] event_stream_id, max_stream_id = yield self._persist_auth_tree(
for e in itertools.chain(state, auth_chain): auth_chain, state, event
if e.event_id == event.event_id:
continue
e.internal_metadata.outlier = True
auth_ids = [e_id for e_id, _ in e.auth_events]
ev_infos.append({
"event": e,
"auth_events": {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
})
yield self._handle_new_events(origin, ev_infos, outliers=True)
auth_ids = [e_id for e_id, _ in event.auth_events]
auth_events = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
_, event_stream_id, max_stream_id = yield self._handle_new_event(
origin,
new_event,
state=state,
current_state=state,
auth_events=auth_events,
) )
with PreserveLoggingContext(): with PreserveLoggingContext():
@ -1026,6 +1011,76 @@ class FederationHandler(BaseHandler):
is_new_state=(not outliers and not backfilled), is_new_state=(not outliers and not backfilled),
) )
@defer.inlineCallbacks
def _persist_auth_tree(self, auth_events, state, event):
"""Checks the auth chain is valid (and passes auth checks) for the
state and event. Then persists the auth chain and state atomically.
Persists the event seperately.
Returns:
2-tuple of (event_stream_id, max_stream_id) from the persist_event
call for `event`
"""
events_to_context = {}
for e in itertools.chain(auth_events, state):
ctx = yield self.state_handler.compute_event_context(
e, outlier=True,
)
events_to_context[e.event_id] = ctx
e.internal_metadata.outlier = True
event_map = {
e.event_id: e
for e in auth_events
}
create_event = None
for e in auth_events:
if (e.type, e.state_key) == (EventTypes.Create, ""):
create_event = e
break
for e in itertools.chain(auth_events, state, [event]):
auth_for_e = {
(event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
for e_id, _ in e.auth_events
}
if create_event:
auth_for_e[(EventTypes.Create, "")] = create_event
try:
self.auth.check(e, auth_events=auth_for_e)
except AuthError as err:
logger.warn(
"Rejecting %s because %s",
e.event_id, err.msg
)
if e == event:
raise
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
yield self.store.persist_events(
[
(e, events_to_context[e.event_id])
for e in itertools.chain(auth_events, state)
],
is_new_state=False,
)
new_event_context = yield self.state_handler.compute_event_context(
event, old_state=state, outlier=False,
)
event_stream_id, max_stream_id = yield self.store.persist_event(
event, new_event_context,
backfilled=False,
is_new_state=True,
current_state=state,
)
defer.returnValue((event_stream_id, max_stream_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def _prep_event(self, origin, event, state=None, backfilled=False, def _prep_event(self, origin, event, state=None, backfilled=False,
current_state=None, auth_events=None): current_state=None, auth_events=None):
@ -1456,52 +1511,3 @@ class FederationHandler(BaseHandler):
}, },
"missing": [e.event_id for e in missing_locals], "missing": [e.event_id for e in missing_locals],
}) })
@defer.inlineCallbacks
def _handle_auth_events(self, origin, auth_events):
auth_ids_to_deferred = {}
def process_auth_ev(ev):
auth_ids = [e_id for e_id, _ in ev.auth_events]
prev_ds = [
auth_ids_to_deferred[i]
for i in auth_ids
if i in auth_ids_to_deferred
]
d = defer.Deferred()
auth_ids_to_deferred[ev.event_id] = d
@defer.inlineCallbacks
def f(*_):
ev.internal_metadata.outlier = True
try:
auth = {
(e.type, e.state_key): e for e in auth_events
if e.event_id in auth_ids
}
yield self._handle_new_event(
origin, ev, auth_events=auth
)
except:
logger.exception(
"Failed to handle auth event %s",
ev.event_id,
)
d.callback(None)
if prev_ds:
dx = defer.DeferredList(prev_ds)
dx.addBoth(f)
else:
f()
for e in auth_events:
process_auth_ev(e)
yield defer.DeferredList(auth_ids_to_deferred.values())

View File

@ -16,13 +16,13 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import RoomError, SynapseError from synapse.api.errors import SynapseError
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID, RoomStreamToken from synapse.types import UserID, RoomStreamToken, StreamToken
from ._base import BaseHandler from ._base import BaseHandler
@ -71,7 +71,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_messages(self, user_id=None, room_id=None, pagin_config=None, def get_messages(self, user_id=None, room_id=None, pagin_config=None,
feedback=False, as_client_event=True): as_client_event=True):
"""Get messages in a room. """Get messages in a room.
Args: Args:
@ -79,26 +79,52 @@ class MessageHandler(BaseHandler):
room_id (str): The room they want messages from. room_id (str): The room they want messages from.
pagin_config (synapse.api.streams.PaginationConfig): The pagination pagin_config (synapse.api.streams.PaginationConfig): The pagination
config rules to apply, if any. config rules to apply, if any.
feedback (bool): True to get compressed feedback with the messages
as_client_event (bool): True to get events in client-server format. as_client_event (bool): True to get events in client-server format.
Returns: Returns:
dict: Pagination API results dict: Pagination API results
""" """
yield self.auth.check_joined_room(room_id, user_id) member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
data_source = self.hs.get_event_sources().sources["room"] data_source = self.hs.get_event_sources().sources["room"]
if not pagin_config.from_token: if pagin_config.from_token:
room_token = pagin_config.from_token.room_key
else:
pagin_config.from_token = ( pagin_config.from_token = (
yield self.hs.get_event_sources().get_current_token( yield self.hs.get_event_sources().get_current_token(
direction='b' direction='b'
) )
) )
room_token = pagin_config.from_token.room_key
room_token = RoomStreamToken.parse(pagin_config.from_token.room_key) room_token = RoomStreamToken.parse(room_token)
if room_token.topological is None: if room_token.topological is None:
raise SynapseError(400, "Invalid token") raise SynapseError(400, "Invalid token")
pagin_config.from_token = pagin_config.from_token.copy_and_replace(
"room_key", str(room_token)
)
source_config = pagin_config.get_source_config("room")
if member_event.membership == Membership.LEAVE:
# If they have left the room then clamp the token to be before
# they left the room
leave_token = yield self.store.get_topological_token_for_event(
member_event.event_id
)
leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < room_token.topological:
source_config.from_key = str(leave_token)
if source_config.direction == "f":
if source_config.to_key is None:
source_config.to_key = str(leave_token)
else:
to_token = RoomStreamToken.parse(source_config.to_key)
if leave_token.topological < to_token.topological:
source_config.to_key = str(leave_token)
yield self.hs.get_handlers().federation_handler.maybe_backfill( yield self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, room_token.topological room_id, room_token.topological
) )
@ -106,7 +132,7 @@ class MessageHandler(BaseHandler):
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
events, next_key = yield data_source.get_pagination_rows( events, next_key = yield data_source.get_pagination_rows(
user, pagin_config.get_source_config("room"), room_id user, source_config, room_id
) )
next_token = pagin_config.from_token.copy_and_replace( next_token = pagin_config.from_token.copy_and_replace(
@ -255,29 +281,26 @@ class MessageHandler(BaseHandler):
Raises: Raises:
SynapseError if something went wrong. SynapseError if something went wrong.
""" """
have_joined = yield self.auth.check_joined_room(room_id, user_id) member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
if not have_joined:
raise RoomError(403, "User not in room.") if member_event.membership == Membership.JOIN:
data = yield self.state_handler.get_current_state(
room_id, event_type, state_key
)
elif member_event.membership == Membership.LEAVE:
key = (event_type, state_key)
room_state = yield self.store.get_state_for_events(
room_id, [member_event.event_id], [key]
)
data = room_state[member_event.event_id].get(key)
data = yield self.state_handler.get_current_state(
room_id, event_type, state_key
)
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks
def get_feedback(self, event_id):
# yield self.auth.check_joined_room(room_id, user_id)
# Pull out the feedback from the db
fb = yield self.store.get_feedback(event_id)
if fb:
defer.returnValue(fb)
defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_events(self, user_id, room_id): def get_state_events(self, user_id, room_id):
"""Retrieve all state events for a given room. """Retrieve all state events for a given room. If the user is
joined to the room then return the current state. If the user has
left the room return the state events from when they left.
Args: Args:
user_id(str): The user requesting state events. user_id(str): The user requesting state events.
@ -285,18 +308,23 @@ class MessageHandler(BaseHandler):
Returns: Returns:
A list of dicts representing state events. [{}, {}, {}] A list of dicts representing state events. [{}, {}, {}]
""" """
yield self.auth.check_joined_room(room_id, user_id) member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
if member_event.membership == Membership.JOIN:
room_state = yield self.state_handler.get_current_state(room_id)
elif member_event.membership == Membership.LEAVE:
room_state = yield self.store.get_state_for_events(
room_id, [member_event.event_id], None
)
room_state = room_state[member_event.event_id]
# TODO: This is duplicating logic from snapshot_all_rooms
current_state = yield self.state_handler.get_current_state(room_id)
now = self.clock.time_msec() now = self.clock.time_msec()
defer.returnValue( defer.returnValue(
[serialize_event(c, now) for c in current_state.values()] [serialize_event(c, now) for c in room_state.values()]
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def snapshot_all_rooms(self, user_id=None, pagin_config=None, def snapshot_all_rooms(self, user_id=None, pagin_config=None, as_client_event=True):
feedback=False, as_client_event=True):
"""Retrieve a snapshot of all rooms the user is invited or has joined. """Retrieve a snapshot of all rooms the user is invited or has joined.
This snapshot may include messages for all rooms where the user is This snapshot may include messages for all rooms where the user is
@ -306,7 +334,6 @@ class MessageHandler(BaseHandler):
user_id (str): The ID of the user making the request. user_id (str): The ID of the user making the request.
pagin_config (synapse.api.streams.PaginationConfig): The pagination pagin_config (synapse.api.streams.PaginationConfig): The pagination
config used to determine how many messages *PER ROOM* to return. config used to determine how many messages *PER ROOM* to return.
feedback (bool): True to get feedback along with these messages.
as_client_event (bool): True to get events in client-server format. as_client_event (bool): True to get events in client-server format.
Returns: Returns:
A list of dicts with "room_id" and "membership" keys for all rooms A list of dicts with "room_id" and "membership" keys for all rooms
@ -316,7 +343,9 @@ class MessageHandler(BaseHandler):
""" """
room_list = yield self.store.get_rooms_for_user_where_membership_is( room_list = yield self.store.get_rooms_for_user_where_membership_is(
user_id=user_id, user_id=user_id,
membership_list=[Membership.INVITE, Membership.JOIN] membership_list=[
Membership.INVITE, Membership.JOIN, Membership.LEAVE
]
) )
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
@ -358,19 +387,32 @@ class MessageHandler(BaseHandler):
rooms_ret.append(d) rooms_ret.append(d)
if event.membership != Membership.JOIN: if event.membership not in (Membership.JOIN, Membership.LEAVE):
return return
try: try:
if event.membership == Membership.JOIN:
room_end_token = now_token.room_key
deferred_room_state = self.state_handler.get_current_state(
event.room_id
)
elif event.membership == Membership.LEAVE:
room_end_token = "s%d" % (event.stream_ordering,)
deferred_room_state = self.store.get_state_for_events(
event.room_id, [event.event_id], None
)
deferred_room_state.addCallback(
lambda states: states[event.event_id]
)
(messages, token), current_state = yield defer.gatherResults( (messages, token), current_state = yield defer.gatherResults(
[ [
self.store.get_recent_events_for_room( self.store.get_recent_events_for_room(
event.room_id, event.room_id,
limit=limit, limit=limit,
end_token=now_token.room_key, end_token=room_end_token,
),
self.state_handler.get_current_state(
event.room_id
), ),
deferred_room_state,
] ]
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
@ -417,15 +459,85 @@ class MessageHandler(BaseHandler):
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def room_initial_sync(self, user_id, room_id, pagin_config=None, def room_initial_sync(self, user_id, room_id, pagin_config=None):
feedback=False): """Capture the a snapshot of a room. If user is currently a member of
current_state = yield self.state.get_current_state( the room this will be what is currently in the room. If the user left
room_id=room_id, the room this will be what was in the room when they left.
Args:
user_id(str): The user to get a snapshot for.
room_id(str): The room to get a snapshot of.
pagin_config(synapse.streams.config.PaginationConfig):
The pagination config used to determine how many messages to
return.
Raises:
AuthError if the user wasn't in the room.
Returns:
A JSON serialisable dict with the snapshot of the room.
"""
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
if member_event.membership == Membership.JOIN:
result = yield self._room_initial_sync_joined(
user_id, room_id, pagin_config, member_event
)
elif member_event.membership == Membership.LEAVE:
result = yield self._room_initial_sync_parted(
user_id, room_id, pagin_config, member_event
)
defer.returnValue(result)
@defer.inlineCallbacks
def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
member_event):
room_state = yield self.store.get_state_for_events(
member_event.room_id, [member_event.event_id], None
) )
yield self.auth.check_joined_room( room_state = room_state[member_event.event_id]
room_id, user_id,
current_state=current_state limit = pagin_config.limit if pagin_config else None
if limit is None:
limit = 10
stream_token = yield self.store.get_stream_token_for_event(
member_event.event_id
)
messages, token = yield self.store.get_recent_events_for_room(
room_id,
limit=limit,
end_token=stream_token
)
messages = yield self._filter_events_for_client(
user_id, room_id, messages
)
start_token = StreamToken(token[0], 0, 0, 0)
end_token = StreamToken(token[1], 0, 0, 0)
time_now = self.clock.time_msec()
defer.returnValue({
"membership": member_event.membership,
"room_id": room_id,
"messages": {
"chunk": [serialize_event(m, time_now) for m in messages],
"start": start_token.to_string(),
"end": end_token.to_string(),
},
"state": [serialize_event(s, time_now) for s in room_state.values()],
"presence": [],
"receipts": [],
})
@defer.inlineCallbacks
def _room_initial_sync_joined(self, user_id, room_id, pagin_config,
member_event):
current_state = yield self.state.get_current_state(
room_id=room_id,
) )
# TODO(paul): I wish I was called with user objects not user_id # TODO(paul): I wish I was called with user objects not user_id
@ -439,8 +551,6 @@ class MessageHandler(BaseHandler):
for x in current_state.values() for x in current_state.values()
] ]
member_event = current_state.get((EventTypes.Member, user_id,))
now_token = yield self.hs.get_event_sources().get_current_token() now_token = yield self.hs.get_event_sources().get_current_token()
limit = pagin_config.limit if pagin_config else None limit = pagin_config.limit if pagin_config else None

View File

@ -25,7 +25,6 @@ from synapse.api.constants import (
from synapse.api.errors import StoreError, SynapseError from synapse.api.errors import StoreError, SynapseError
from synapse.util import stringutils, unwrapFirstError from synapse.util import stringutils, unwrapFirstError
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.events.utils import serialize_event
from collections import OrderedDict from collections import OrderedDict
import logging import logging
@ -39,7 +38,7 @@ class RoomCreationHandler(BaseHandler):
PRESETS_DICT = { PRESETS_DICT = {
RoomCreationPreset.PRIVATE_CHAT: { RoomCreationPreset.PRIVATE_CHAT: {
"join_rules": JoinRules.INVITE, "join_rules": JoinRules.INVITE,
"history_visibility": "invited", "history_visibility": "shared",
"original_invitees_have_ops": False, "original_invitees_have_ops": False,
}, },
RoomCreationPreset.PUBLIC_CHAT: { RoomCreationPreset.PUBLIC_CHAT: {
@ -159,6 +158,7 @@ class RoomCreationHandler(BaseHandler):
invite_list=invite_list, invite_list=invite_list,
initial_state=initial_state, initial_state=initial_state,
creation_content=creation_content, creation_content=creation_content,
room_alias=room_alias,
) )
msg_handler = self.hs.get_handlers().message_handler msg_handler = self.hs.get_handlers().message_handler
@ -206,7 +206,8 @@ class RoomCreationHandler(BaseHandler):
defer.returnValue(result) defer.returnValue(result)
def _create_events_for_new_room(self, creator, room_id, preset_config, def _create_events_for_new_room(self, creator, room_id, preset_config,
invite_list, initial_state, creation_content): invite_list, initial_state, creation_content,
room_alias):
config = RoomCreationHandler.PRESETS_DICT[preset_config] config = RoomCreationHandler.PRESETS_DICT[preset_config]
creator_id = creator.to_string() creator_id = creator.to_string()
@ -276,6 +277,14 @@ class RoomCreationHandler(BaseHandler):
returned_events.append(power_levels_event) returned_events.append(power_levels_event)
if room_alias and (EventTypes.CanonicalAlias, '') not in initial_state:
room_alias_event = create(
etype=EventTypes.CanonicalAlias,
content={"alias": room_alias.to_string()},
)
returned_events.append(room_alias_event)
if (EventTypes.JoinRules, '') not in initial_state: if (EventTypes.JoinRules, '') not in initial_state:
join_rules_event = create( join_rules_event = create(
etype=EventTypes.JoinRules, etype=EventTypes.JoinRules,
@ -346,41 +355,6 @@ class RoomMemberHandler(BaseHandler):
if remotedomains is not None: if remotedomains is not None:
remotedomains.add(member.domain) remotedomains.add(member.domain)
@defer.inlineCallbacks
def get_room_members_as_pagination_chunk(self, room_id=None, user_id=None,
limit=0, start_tok=None,
end_tok=None):
"""Retrieve a list of room members in the room.
Args:
room_id (str): The room to get the member list for.
user_id (str): The ID of the user making the request.
limit (int): The max number of members to return.
start_tok (str): Optional. The start token if known.
end_tok (str): Optional. The end token if known.
Returns:
dict: A Pagination streamable dict.
Raises:
SynapseError if something goes wrong.
"""
yield self.auth.check_joined_room(room_id, user_id)
member_list = yield self.store.get_room_members(room_id=room_id)
time_now = self.clock.time_msec()
event_list = [
serialize_event(entry, time_now)
for entry in member_list
]
chunk_data = {
"start": "START", # FIXME (erikj): START is no longer valid
"end": "END",
"chunk": event_list
}
# TODO honor Pagination stream params
# TODO snapshot this list to return on subsequent requests when
# paginating
defer.returnValue(chunk_data)
@defer.inlineCallbacks @defer.inlineCallbacks
def change_membership(self, event, context, do_auth=True): def change_membership(self, event, context, do_auth=True):
""" Change the membership status of a user in a room. """ Change the membership status of a user in a room.
@ -532,32 +506,6 @@ class RoomMemberHandler(BaseHandler):
"user_joined_room", user=user, room_id=room_id "user_joined_room", user=user, room_id=room_id
) )
@defer.inlineCallbacks
def _should_invite_join(self, room_id, prev_state, do_auth):
logger.debug("_should_invite_join: room_id: %s", room_id)
# XXX: We don't do an auth check if we are doing an invite
# join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we
# need to do the invite/join dance.
# Only do an invite join dance if a) we were invited,
# b) the person inviting was from a differnt HS and c) we are
# not currently in the room
room_host = None
if prev_state and prev_state.membership == Membership.INVITE:
room = yield self.store.get_room(room_id)
inviter = UserID.from_string(
prev_state.sender
)
is_remote_invite_join = not self.hs.is_mine(inviter) and not room
room_host = inviter.domain
else:
is_remote_invite_join = False
defer.returnValue((is_remote_invite_join, room_host))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_joined_rooms_for_user(self, user): def get_joined_rooms_for_user(self, user):
"""Returns a list of roomids that the user has any of the given """Returns a list of roomids that the user has any of the given
@ -650,7 +598,6 @@ class RoomEventSource(object):
to_key=config.to_key, to_key=config.to_key,
direction=config.direction, direction=config.direction,
limit=config.limit, limit=config.limit,
with_feedback=True
) )
defer.returnValue((events, next_key)) defer.returnValue((events, next_key))

View File

@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
from synapse.api.errors import CodeMessageException from synapse.api.errors import CodeMessageException
from synapse.util.logcontext import preserve_context_over_fn from synapse.util.logcontext import preserve_context_over_fn
@ -19,7 +21,7 @@ import synapse.metrics
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from twisted.internet import defer, reactor from twisted.internet import defer, reactor, ssl
from twisted.web.client import ( from twisted.web.client import (
Agent, readBody, FileBodyProducer, PartialDownloadError, Agent, readBody, FileBodyProducer, PartialDownloadError,
HTTPConnectionPool, HTTPConnectionPool,
@ -59,7 +61,12 @@ class SimpleHttpClient(object):
# 'like a browser' # 'like a browser'
pool = HTTPConnectionPool(reactor) pool = HTTPConnectionPool(reactor)
pool.maxPersistentPerHost = 10 pool.maxPersistentPerHost = 10
self.agent = Agent(reactor, pool=pool) self.agent = Agent(
reactor,
pool=pool,
connectTimeout=15,
contextFactory=hs.get_http_client_context_factory()
)
self.version_string = hs.version_string self.version_string = hs.version_string
def request(self, method, uri, *args, **kwargs): def request(self, method, uri, *args, **kwargs):
@ -252,3 +259,18 @@ def _print_ex(e):
_print_ex(ex) _print_ex(ex)
else: else:
logger.exception(e) logger.exception(e)
class InsecureInterceptableContextFactory(ssl.ContextFactory):
"""
Factory for PyOpenSSL SSL contexts which accepts any certificate for any domain.
Do not use this since it allows an attacker to intercept your communications.
"""
def __init__(self):
self._context = SSL.Context(SSL.SSLv23_METHOD)
self._context.set_verify(VERIFY_NONE, lambda *_: None)
def getContext(self, hostname, port):
return self._context

View File

@ -57,14 +57,14 @@ incoming_responses_counter = metrics.register_counter(
class MatrixFederationEndpointFactory(object): class MatrixFederationEndpointFactory(object):
def __init__(self, hs): def __init__(self, hs):
self.tls_context_factory = hs.tls_context_factory self.tls_server_context_factory = hs.tls_server_context_factory
def endpointForURI(self, uri): def endpointForURI(self, uri):
destination = uri.netloc destination = uri.netloc
return matrix_federation_endpoint( return matrix_federation_endpoint(
reactor, destination, timeout=10, reactor, destination, timeout=10,
ssl_context_factory=self.tls_context_factory ssl_context_factory=self.tls_server_context_factory
) )

View File

@ -18,18 +18,18 @@ from distutils.version import LooseVersion
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
REQUIREMENTS = { REQUIREMENTS = {
"frozendict>=0.4": ["frozendict"],
"unpaddedbase64>=1.0.1": ["unpaddedbase64>=1.0.1"], "unpaddedbase64>=1.0.1": ["unpaddedbase64>=1.0.1"],
"canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"], "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],
"signedjson>=1.0.0": ["signedjson>=1.0.0"], "signedjson>=1.0.0": ["signedjson>=1.0.0"],
"Twisted>=15.1.0": ["twisted>=15.1.0"], "pynacl>=0.3.0": ["nacl>=0.3.0", "nacl.bindings"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"], "service_identity>=1.0.0": ["service_identity>=1.0.0"],
"Twisted>=15.1.0": ["twisted>=15.1.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"], "pyopenssl>=0.14": ["OpenSSL>=0.14"],
"pyyaml": ["yaml"], "pyyaml": ["yaml"],
"pyasn1": ["pyasn1"], "pyasn1": ["pyasn1"],
"pynacl>=0.3.0": ["nacl>=0.3.0"],
"daemonize": ["daemonize"], "daemonize": ["daemonize"],
"py-bcrypt": ["bcrypt"], "py-bcrypt": ["bcrypt"],
"frozendict>=0.4": ["frozendict"],
"pillow": ["PIL"], "pillow": ["PIL"],
"pydenticon": ["pydenticon"], "pydenticon": ["pydenticon"],
"ujson": ["ujson"], "ujson": ["ujson"],
@ -60,7 +60,10 @@ DEPENDENCY_LINKS = {
class MissingRequirementError(Exception): class MissingRequirementError(Exception):
pass def __init__(self, message, module_name, dependency):
super(MissingRequirementError, self).__init__(message)
self.module_name = module_name
self.dependency = dependency
def check_requirements(config=None): def check_requirements(config=None):
@ -88,7 +91,7 @@ def check_requirements(config=None):
) )
raise MissingRequirementError( raise MissingRequirementError(
"Can't import %r which is part of %r" "Can't import %r which is part of %r"
% (module_name, dependency) % (module_name, dependency), module_name, dependency
) )
version = getattr(module, "__version__", None) version = getattr(module, "__version__", None)
file_path = getattr(module, "__file__", None) file_path = getattr(module, "__file__", None)
@ -101,23 +104,25 @@ def check_requirements(config=None):
if version is None: if version is None:
raise MissingRequirementError( raise MissingRequirementError(
"Version of %r isn't set as __version__ of module %r" "Version of %r isn't set as __version__ of module %r"
% (dependency, module_name) % (dependency, module_name), module_name, dependency
) )
if LooseVersion(version) < LooseVersion(required_version): if LooseVersion(version) < LooseVersion(required_version):
raise MissingRequirementError( raise MissingRequirementError(
"Version of %r in %r is too old. %r < %r" "Version of %r in %r is too old. %r < %r"
% (dependency, file_path, version, required_version) % (dependency, file_path, version, required_version),
module_name, dependency
) )
elif version_test == "==": elif version_test == "==":
if version is None: if version is None:
raise MissingRequirementError( raise MissingRequirementError(
"Version of %r isn't set as __version__ of module %r" "Version of %r isn't set as __version__ of module %r"
% (dependency, module_name) % (dependency, module_name), module_name, dependency
) )
if LooseVersion(version) != LooseVersion(required_version): if LooseVersion(version) != LooseVersion(required_version):
raise MissingRequirementError( raise MissingRequirementError(
"Unexpected version of %r in %r. %r != %r" "Unexpected version of %r in %r. %r != %r"
% (dependency, file_path, version, required_version) % (dependency, file_path, version, required_version),
module_name, dependency
) )

View File

@ -26,14 +26,12 @@ class InitialSyncRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
user, _ = yield self.auth.get_user_by_req(request) user, _ = yield self.auth.get_user_by_req(request)
with_feedback = "feedback" in request.args
as_client_event = "raw" not in request.args as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request) pagination_config = PaginationConfig.from_request(request)
handler = self.handlers.message_handler handler = self.handlers.message_handler
content = yield handler.snapshot_all_rooms( content = yield handler.snapshot_all_rooms(
user_id=user.to_string(), user_id=user.to_string(),
pagin_config=pagination_config, pagin_config=pagination_config,
feedback=with_feedback,
as_client_event=as_client_event as_client_event=as_client_event
) )

View File

@ -290,12 +290,18 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens) # TODO support Pagination stream API (limit/tokens)
user, _ = yield self.auth.get_user_by_req(request) user, _ = yield self.auth.get_user_by_req(request)
handler = self.handlers.room_member_handler handler = self.handlers.message_handler
members = yield handler.get_room_members_as_pagination_chunk( events = yield handler.get_state_events(
room_id=room_id, room_id=room_id,
user_id=user.to_string()) user_id=user.to_string(),
)
for event in members["chunk"]: chunk = []
for event in events:
if event["type"] != EventTypes.Member:
continue
chunk.append(event)
# FIXME: should probably be state_key here, not user_id # FIXME: should probably be state_key here, not user_id
target_user = UserID.from_string(event["user_id"]) target_user = UserID.from_string(event["user_id"])
# Presence is an optional cache; don't fail if we can't fetch it # Presence is an optional cache; don't fail if we can't fetch it
@ -308,7 +314,9 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
except: except:
pass pass
defer.returnValue((200, members)) defer.returnValue((200, {
"chunk": chunk
}))
# TODO: Needs unit testing # TODO: Needs unit testing
@ -321,14 +329,12 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
pagination_config = PaginationConfig.from_request( pagination_config = PaginationConfig.from_request(
request, default_limit=10, request, default_limit=10,
) )
with_feedback = "feedback" in request.args
as_client_event = "raw" not in request.args as_client_event = "raw" not in request.args
handler = self.handlers.message_handler handler = self.handlers.message_handler
msgs = yield handler.get_messages( msgs = yield handler.get_messages(
room_id=room_id, room_id=room_id,
user_id=user.to_string(), user_id=user.to_string(),
pagin_config=pagination_config, pagin_config=pagination_config,
feedback=with_feedback,
as_client_event=as_client_event as_client_event=as_client_event
) )

View File

@ -15,6 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern from ._base import client_v2_pattern
@ -41,6 +42,9 @@ class ReceiptRestServlet(RestServlet):
def on_POST(self, request, room_id, receipt_type, event_id): def on_POST(self, request, room_id, receipt_type, event_id):
user, _ = yield self.auth.get_user_by_req(request) user, _ = yield self.auth.get_user_by_req(request)
if receipt_type != "m.read":
raise SynapseError(400, "Receipt type must be 'm.read'")
yield self.receipts_handler.received_client_receipt( yield self.receipts_handler.received_client_receipt(
room_id, room_id,
receipt_type, receipt_type,

View File

@ -19,7 +19,9 @@
# partial one for unit test mocking. # partial one for unit test mocking.
# Imports required for the default HomeServer() implementation # Imports required for the default HomeServer() implementation
from twisted.web.client import BrowserLikePolicyForHTTPS
from synapse.federation import initialize_http_replication from synapse.federation import initialize_http_replication
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.notifier import Notifier from synapse.notifier import Notifier
from synapse.api.auth import Auth from synapse.api.auth import Auth
from synapse.handlers import Handlers from synapse.handlers import Handlers
@ -87,6 +89,8 @@ class BaseHomeServer(object):
'pusherpool', 'pusherpool',
'event_builder_factory', 'event_builder_factory',
'filtering', 'filtering',
'http_client_context_factory',
'simple_http_client',
] ]
def __init__(self, hostname, **kwargs): def __init__(self, hostname, **kwargs):
@ -174,6 +178,17 @@ class HomeServer(BaseHomeServer):
def build_auth(self): def build_auth(self):
return Auth(self) return Auth(self)
def build_http_client_context_factory(self):
config = self.get_config()
return (
InsecureInterceptableContextFactory()
if config.use_insecure_ssl_client_just_for_testing_do_not_use
else BrowserLikePolicyForHTTPS()
)
def build_simple_http_client(self):
return SimpleHttpClient(self)
def build_v1auth(self): def build_v1auth(self):
orf = Auth(self) orf = Auth(self)
# Matrix spec makes no reference to what HTTP status code is returned, # Matrix spec makes no reference to what HTTP status code is returned,

View File

@ -17,7 +17,6 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
@ -32,10 +31,6 @@ import hashlib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _get_state_key_from_event(event):
return event.state_key
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
@ -119,8 +114,6 @@ class StateHandler(object):
Returns: Returns:
an EventContext an EventContext
""" """
yield run_on_reactor()
context = EventContext() context = EventContext()
if outlier: if outlier:

View File

@ -54,7 +54,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 23 SCHEMA_VERSION = 24
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))
@ -126,6 +126,27 @@ class DataStore(RoomMemberStore, RoomStore,
lock=False, lock=False,
) )
@defer.inlineCallbacks
def count_daily_users(self):
"""
Counts the number of users who used this homeserver in the last 24 hours.
"""
def _count_users(txn):
txn.execute(
"SELECT COUNT(DISTINCT user_id) AS users"
" FROM user_ips"
" WHERE last_seen > ?",
# This is close enough to a day for our purposes.
(int(self._clock.time_msec()) - (1000 * 60 * 60 * 24),)
)
rows = self.cursor_to_dict(txn)
if rows:
return rows[0]["users"]
return 0
ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret)
def get_user_ip_and_agents(self, user): def get_user_ip_and_agents(self, user):
return self._simple_select_list( return self._simple_select_list(
table="user_ips", table="user_ips",

View File

@ -25,8 +25,6 @@ from util.id_generators import IdGenerator, StreamIdGenerator
from twisted.internet import defer from twisted.internet import defer
from collections import namedtuple
import sys import sys
import time import time
import threading import threading
@ -376,9 +374,6 @@ class SQLBaseStore(object):
return self.runInteraction(desc, interaction) return self.runInteraction(desc, interaction)
def _execute_and_decode(self, desc, query, *args):
return self._execute(desc, self.cursor_to_dict, query, *args)
# "Simple" SQL API methods that operate on a single table with no JOINs, # "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns. # no complex WHERE clauses, just a dict of values for columns.
@ -691,37 +686,6 @@ class SQLBaseStore(object):
return dict(zip(retcols, row)) return dict(zip(retcols, row))
def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None,
retcols=None, allow_none=False,
desc="_simple_selectupdate_one"):
""" Combined SELECT then UPDATE."""
def func(txn):
ret = None
if retcols:
ret = self._simple_select_one_txn(
txn,
table=table,
keyvalues=keyvalues,
retcols=retcols,
allow_none=allow_none,
)
if updatevalues:
self._simple_update_one_txn(
txn,
table=table,
keyvalues=keyvalues,
updatevalues=updatevalues,
)
# if txn.rowcount == 0:
# raise StoreError(404, "No row found")
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched")
return ret
return self.runInteraction(desc, func)
def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"): def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"):
"""Executes a DELETE query on the named table, expecting to delete a """Executes a DELETE query on the named table, expecting to delete a
single row. single row.
@ -743,16 +707,6 @@ class SQLBaseStore(object):
raise StoreError(500, "more than one row matched") raise StoreError(500, "more than one row matched")
return self.runInteraction(desc, func) return self.runInteraction(desc, func)
def _simple_delete(self, table, keyvalues, desc="_simple_delete"):
"""Executes a DELETE query on the named table.
Args:
table : string giving the table name
keyvalues : dict of column names and values to select the row with
"""
return self.runInteraction(desc, self._simple_delete_txn)
def _simple_delete_txn(self, txn, table, keyvalues): def _simple_delete_txn(self, txn, table, keyvalues):
sql = "DELETE FROM %s WHERE %s" % ( sql = "DELETE FROM %s WHERE %s" % (
table, table,
@ -761,24 +715,6 @@ class SQLBaseStore(object):
return txn.execute(sql, keyvalues.values()) return txn.execute(sql, keyvalues.values())
def _simple_max_id(self, table):
"""Executes a SELECT query on the named table, expecting to return the
max value for the column "id".
Args:
table : string giving the table name
"""
sql = "SELECT MAX(id) AS id FROM %s" % table
def func(txn):
txn.execute(sql)
max_id = self.cursor_to_dict(txn)[0]["id"]
if max_id is None:
return 0
return max_id
return self.runInteraction("_simple_max_id", func)
def get_next_stream_id(self): def get_next_stream_id(self):
with self._next_stream_id_lock: with self._next_stream_id_lock:
i = self._next_stream_id i = self._next_stream_id
@ -791,129 +727,3 @@ class _RollbackButIsFineException(Exception):
something went wrong. something went wrong.
""" """
pass pass
class Table(object):
""" A base class used to store information about a particular table.
"""
table_name = None
""" str: The name of the table """
fields = None
""" list: The field names """
EntryType = None
""" Type: A tuple type used to decode the results """
_select_where_clause = "SELECT %s FROM %s WHERE %s"
_select_clause = "SELECT %s FROM %s"
_insert_clause = "REPLACE INTO %s (%s) VALUES (%s)"
@classmethod
def select_statement(cls, where_clause=None):
"""
Args:
where_clause (str): The WHERE clause to use.
Returns:
str: An SQL statement to select rows from the table with the given
WHERE clause.
"""
if where_clause:
return cls._select_where_clause % (
", ".join(cls.fields),
cls.table_name,
where_clause
)
else:
return cls._select_clause % (
", ".join(cls.fields),
cls.table_name,
)
@classmethod
def insert_statement(cls):
return cls._insert_clause % (
cls.table_name,
", ".join(cls.fields),
", ".join(["?"] * len(cls.fields)),
)
@classmethod
def decode_single_result(cls, results):
""" Given an iterable of tuples, return a single instance of
`EntryType` or None if the iterable is empty
Args:
results (list): The results list to convert to `EntryType`
Returns:
EntryType: An instance of `EntryType`
"""
results = list(results)
if results:
return cls.EntryType(*results[0])
else:
return None
@classmethod
def decode_results(cls, results):
""" Given an iterable of tuples, return a list of `EntryType`
Args:
results (list): The results list to convert to `EntryType`
Returns:
list: A list of `EntryType`
"""
return [cls.EntryType(*row) for row in results]
@classmethod
def get_fields_string(cls, prefix=None):
if prefix:
to_join = ("%s.%s" % (prefix, f) for f in cls.fields)
else:
to_join = cls.fields
return ", ".join(to_join)
class JoinHelper(object):
""" Used to help do joins on tables by looking at the tables' fields and
creating a list of unique fields to use with SELECTs and a namedtuple
to dump the results into.
Attributes:
tables (list): List of `Table` classes
EntryType (type)
"""
def __init__(self, *tables):
self.tables = tables
res = []
for table in self.tables:
res += [f for f in table.fields if f not in res]
self.EntryType = namedtuple("JoinHelperEntry", res)
def get_fields(self, **prefixes):
"""Get a string representing a list of fields for use in SELECT
statements with the given prefixes applied to each.
For example::
JoinHelper(PdusTable, StateTable).get_fields(
PdusTable="pdus",
StateTable="state"
)
"""
res = []
for field in self.EntryType._fields:
for table in self.tables:
if field in table.fields:
res.append("%s.%s" % (prefixes[table.__name__], field))
break
return ", ".join(res)
def decode_results(self, rows):
return [self.EntryType(*row) for row in rows]

View File

@ -154,98 +154,6 @@ class EventFederationStore(SQLBaseStore):
return results return results
def _get_latest_state_in_room(self, txn, room_id, type, state_key):
event_ids = self._simple_select_onecol_txn(
txn,
table="state_forward_extremities",
keyvalues={
"room_id": room_id,
"type": type,
"state_key": state_key,
},
retcol="event_id",
)
results = []
for event_id in event_ids:
hashes = self._get_event_reference_hashes_txn(txn, event_id)
prev_hashes = {
k: encode_base64(v) for k, v in hashes.items()
if k == "sha256"
}
results.append((event_id, prev_hashes))
return results
def _get_prev_events(self, txn, event_id):
results = self._get_prev_events_and_state(
txn,
event_id,
is_state=0,
)
return [(e_id, h, ) for e_id, h, _ in results]
def _get_prev_state(self, txn, event_id):
results = self._get_prev_events_and_state(
txn,
event_id,
is_state=True,
)
return [(e_id, h, ) for e_id, h, _ in results]
def _get_prev_events_and_state(self, txn, event_id, is_state=None):
keyvalues = {
"event_id": event_id,
}
if is_state is not None:
keyvalues["is_state"] = bool(is_state)
res = self._simple_select_list_txn(
txn,
table="event_edges",
keyvalues=keyvalues,
retcols=["prev_event_id", "is_state"],
)
hashes = self._get_prev_event_hashes_txn(txn, event_id)
results = []
for d in res:
edge_hash = self._get_event_reference_hashes_txn(txn, d["prev_event_id"])
edge_hash.update(hashes.get(d["prev_event_id"], {}))
prev_hashes = {
k: encode_base64(v)
for k, v in edge_hash.items()
if k == "sha256"
}
results.append((d["prev_event_id"], prev_hashes, d["is_state"]))
return results
def _get_auth_events(self, txn, event_id):
auth_ids = self._simple_select_onecol_txn(
txn,
table="event_auth",
keyvalues={
"event_id": event_id,
},
retcol="auth_id",
)
results = []
for auth_id in auth_ids:
hashes = self._get_event_reference_hashes_txn(txn, auth_id)
prev_hashes = {
k: encode_base64(v) for k, v in hashes.items()
if k == "sha256"
}
results.append((auth_id, prev_hashes))
return results
def get_min_depth(self, room_id): def get_min_depth(self, room_id):
""" For hte given room, get the minimum depth we have seen for it. """ For hte given room, get the minimum depth we have seen for it.
""" """
@ -303,6 +211,15 @@ class EventFederationStore(SQLBaseStore):
], ],
) )
self._update_extremeties(txn, events)
def _update_extremeties(self, txn, events):
"""Updates the event_*_extremities tables based on the new/updated
events being persisted.
This is called for new events *and* for events that were outliers, but
are are now being persisted as non-outliers.
"""
events_by_room = {} events_by_room = {}
for ev in events: for ev in events:
events_by_room.setdefault(ev.room_id, []).append(ev) events_by_room.setdefault(ev.room_id, []).append(ev)

View File

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from _base import SQLBaseStore, _RollbackButIsFineException from _base import SQLBaseStore, _RollbackButIsFineException
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
@ -28,6 +27,7 @@ from canonicaljson import encode_canonical_json
from contextlib import contextmanager from contextlib import contextmanager
import logging import logging
import math
import ujson as json import ujson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -281,6 +281,8 @@ class EventsStore(SQLBaseStore):
(False, event.event_id,) (False, event.event_id,)
) )
self._update_extremeties(txn, [event])
events_and_contexts = filter( events_and_contexts = filter(
lambda ec: ec[0] not in to_remove, lambda ec: ec[0] not in to_remove,
events_and_contexts events_and_contexts
@ -888,18 +890,69 @@ class EventsStore(SQLBaseStore):
return ev return ev
def _parse_events(self, rows):
return self.runInteraction(
"_parse_events", self._parse_events_txn, rows
)
def _parse_events_txn(self, txn, rows): def _parse_events_txn(self, txn, rows):
event_ids = [r["event_id"] for r in rows] event_ids = [r["event_id"] for r in rows]
return self._get_events_txn(txn, event_ids) return self._get_events_txn(txn, event_ids)
def _has_been_redacted_txn(self, txn, event): @defer.inlineCallbacks
sql = "SELECT event_id FROM redactions WHERE redacts = ?" def count_daily_messages(self):
txn.execute(sql, (event.event_id,)) """
result = txn.fetchone() Returns an estimate of the number of messages sent in the last day.
return result[0] if result else None
If it has been significantly less or more than one day since the last
call to this function, it will return None.
"""
def _count_messages(txn):
now = self.hs.get_clock().time()
txn.execute(
"SELECT reported_stream_token, reported_time FROM stats_reporting"
)
last_reported = self.cursor_to_dict(txn)
txn.execute(
"SELECT stream_ordering"
" FROM events"
" ORDER BY stream_ordering DESC"
" LIMIT 1"
)
now_reporting = self.cursor_to_dict(txn)
if not now_reporting:
return None
now_reporting = now_reporting[0]["stream_ordering"]
txn.execute("DELETE FROM stats_reporting")
txn.execute(
"INSERT INTO stats_reporting"
" (reported_stream_token, reported_time)"
" VALUES (?, ?)",
(now_reporting, now,)
)
if not last_reported:
return None
# Close enough to correct for our purposes.
yesterday = (now - 24 * 60 * 60)
if math.fabs(yesterday - last_reported[0]["reported_time"]) > 60 * 60:
return None
txn.execute(
"SELECT COUNT(*) as messages"
" FROM events NATURAL JOIN event_json"
" WHERE json like '%m.room.message%'"
" AND stream_ordering > ?"
" AND stream_ordering <= ?",
(
last_reported[0]["reported_stream_token"],
now_reporting,
)
)
rows = self.cursor_to_dict(txn)
if not rows:
return None
return rows[0]["messages"]
ret = yield self.runInteraction("count_messages", _count_messages)
defer.returnValue(ret)

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, Table from ._base import SQLBaseStore
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
@ -149,5 +149,5 @@ class PusherStore(SQLBaseStore):
) )
class PushersTable(Table): class PushersTable(object):
table_name = "pushers" table_name = "pushers"

View File

@ -289,3 +289,16 @@ class RegistrationStore(SQLBaseStore):
if ret: if ret:
defer.returnValue(ret['user_id']) defer.returnValue(ret['user_id'])
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks
def count_all_users(self):
"""Counts all users registered on the homeserver."""
def _count_users(txn):
txn.execute("SELECT COUNT(*) AS users FROM users")
rows = self.cursor_to_dict(txn)
if rows:
return rows[0]["users"]
return 0
ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret)

View File

@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
RoomsForUser = namedtuple( RoomsForUser = namedtuple(
"RoomsForUser", "RoomsForUser",
("room_id", "sender", "membership") ("room_id", "sender", "membership", "event_id", "stream_ordering")
) )
@ -141,11 +141,13 @@ class RoomMemberStore(SQLBaseStore):
args.extend(membership_list) args.extend(membership_list)
sql = ( sql = (
"SELECT m.room_id, m.sender, m.membership" "SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering"
" FROM room_memberships as m" " FROM current_state_events as c"
" INNER JOIN current_state_events as c" " INNER JOIN room_memberships as m"
" ON m.event_id = c.event_id " " ON m.event_id = c.event_id"
" AND m.room_id = c.room_id " " INNER JOIN events as e"
" ON e.event_id = c.event_id"
" AND m.room_id = c.room_id"
" AND m.user_id = c.state_key" " AND m.user_id = c.state_key"
" WHERE %s" " WHERE %s"
) % (where_clause,) ) % (where_clause,)
@ -176,12 +178,6 @@ class RoomMemberStore(SQLBaseStore):
return joined_domains return joined_domains
def _get_members_query(self, where_clause, where_values):
return self.runInteraction(
"get_members_query", self._get_members_events_txn,
where_clause, where_values
).addCallbacks(self._get_events)
def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None): def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None):
rows = self._get_members_rows_txn( rows = self._get_members_rows_txn(
txn, txn,

View File

@ -0,0 +1,22 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Should only ever contain one row
CREATE TABLE IF NOT EXISTS stats_reporting(
-- The stream ordering token which was most recently reported as stats
reported_stream_token INTEGER,
-- The time (seconds since epoch) stats were most recently reported
reported_time BIGINT
);

View File

@ -24,41 +24,6 @@ from synapse.crypto.event_signing import compute_event_reference_hash
class SignatureStore(SQLBaseStore): class SignatureStore(SQLBaseStore):
"""Persistence for event signatures and hashes""" """Persistence for event signatures and hashes"""
def _get_event_content_hashes_txn(self, txn, event_id):
"""Get all the hashes for a given Event.
Args:
txn (cursor):
event_id (str): Id for the Event.
Returns:
A dict of algorithm -> hash.
"""
query = (
"SELECT algorithm, hash"
" FROM event_content_hashes"
" WHERE event_id = ?"
)
txn.execute(query, (event_id, ))
return dict(txn.fetchall())
def _store_event_content_hash_txn(self, txn, event_id, algorithm,
hash_bytes):
"""Store a hash for a Event
Args:
txn (cursor):
event_id (str): Id for the Event.
algorithm (str): Hashing algorithm.
hash_bytes (bytes): Hash function output bytes.
"""
self._simple_insert_txn(
txn,
"event_content_hashes",
{
"event_id": event_id,
"algorithm": algorithm,
"hash": buffer(hash_bytes),
},
)
def get_event_reference_hashes(self, event_ids): def get_event_reference_hashes(self, event_ids):
def f(txn): def f(txn):
return [ return [
@ -123,80 +88,3 @@ class SignatureStore(SQLBaseStore):
table="event_reference_hashes", table="event_reference_hashes",
values=vals, values=vals,
) )
def _get_event_signatures_txn(self, txn, event_id):
"""Get all the signatures for a given PDU.
Args:
txn (cursor):
event_id (str): Id for the Event.
Returns:
A dict of sig name -> dict(key_id -> signature_bytes)
"""
query = (
"SELECT signature_name, key_id, signature"
" FROM event_signatures"
" WHERE event_id = ? "
)
txn.execute(query, (event_id, ))
rows = txn.fetchall()
res = {}
for name, key, sig in rows:
res.setdefault(name, {})[key] = sig
return res
def _store_event_signature_txn(self, txn, event_id, signature_name, key_id,
signature_bytes):
"""Store a signature from the origin server for a PDU.
Args:
txn (cursor):
event_id (str): Id for the Event.
origin (str): origin of the Event.
key_id (str): Id for the signing key.
signature (bytes): The signature.
"""
self._simple_insert_txn(
txn,
"event_signatures",
{
"event_id": event_id,
"signature_name": signature_name,
"key_id": key_id,
"signature": buffer(signature_bytes),
},
)
def _get_prev_event_hashes_txn(self, txn, event_id):
"""Get all the hashes for previous PDUs of a PDU
Args:
txn (cursor):
event_id (str): Id for the Event.
Returns:
dict of (pdu_id, origin) -> dict of algorithm -> hash_bytes.
"""
query = (
"SELECT prev_event_id, algorithm, hash"
" FROM event_edge_hashes"
" WHERE event_id = ?"
)
txn.execute(query, (event_id, ))
results = {}
for prev_event_id, algorithm, hash_bytes in txn.fetchall():
hashes = results.setdefault(prev_event_id, {})
hashes[algorithm] = hash_bytes
return results
def _store_prev_event_hash_txn(self, txn, event_id, prev_event_id,
algorithm, hash_bytes):
self._simple_insert_txn(
txn,
"event_edge_hashes",
{
"event_id": event_id,
"prev_event_id": prev_event_id,
"algorithm": algorithm,
"hash": buffer(hash_bytes),
},
)

View File

@ -20,8 +20,6 @@ from synapse.util.caches.descriptors import (
from twisted.internet import defer from twisted.internet import defer
from synapse.util.stringutils import random_string
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -428,7 +426,3 @@ class StateStore(SQLBaseStore):
} }
defer.returnValue(results) defer.returnValue(results)
def _make_group_id(clock):
return str(int(clock.time_msec())) + random_string(5)

View File

@ -159,9 +159,7 @@ class StreamStore(SQLBaseStore):
@log_function @log_function
def get_room_events_stream(self, user_id, from_key, to_key, room_id, def get_room_events_stream(self, user_id, from_key, to_key, room_id,
limit=0, with_feedback=False): limit=0):
# TODO (erikj): Handle compressed feedback
current_room_membership_sql = ( current_room_membership_sql = (
"SELECT m.room_id FROM room_memberships as m " "SELECT m.room_id FROM room_memberships as m "
" INNER JOIN current_state_events as c" " INNER JOIN current_state_events as c"
@ -227,10 +225,7 @@ class StreamStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def paginate_room_events(self, room_id, from_key, to_key=None, def paginate_room_events(self, room_id, from_key, to_key=None,
direction='b', limit=-1, direction='b', limit=-1):
with_feedback=False):
# TODO (erikj): Handle compressed feedback
# Tokens really represent positions between elements, but we use # Tokens really represent positions between elements, but we use
# the convention of pointing to the event before the gap. Hence # the convention of pointing to the event before the gap. Hence
# we have a bit of asymmetry when it comes to equalities. # we have a bit of asymmetry when it comes to equalities.
@ -302,7 +297,6 @@ class StreamStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=4) @cachedInlineCallbacks(num_args=4)
def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None): def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None):
# TODO (erikj): Handle compressed feedback
end_token = RoomStreamToken.parse_stream_token(end_token) end_token = RoomStreamToken.parse_stream_token(end_token)
@ -379,6 +373,38 @@ class StreamStore(SQLBaseStore):
) )
defer.returnValue("t%d-%d" % (topo, token)) defer.returnValue("t%d-%d" % (topo, token))
def get_stream_token_for_event(self, event_id):
"""The stream token for an event
Args:
event_id(str): The id of the event to look up a stream token for.
Raises:
StoreError if the event wasn't in the database.
Returns:
A deferred "s%d" stream token.
"""
return self._simple_select_one_onecol(
table="events",
keyvalues={"event_id": event_id},
retcol="stream_ordering",
).addCallback(lambda row: "s%d" % (row,))
def get_topological_token_for_event(self, event_id):
"""The stream token for an event
Args:
event_id(str): The id of the event to look up a stream token for.
Raises:
StoreError if the event wasn't in the database.
Returns:
A deferred "t%d-%d" topological token.
"""
return self._simple_select_one(
table="events",
keyvalues={"event_id": event_id},
retcols=("stream_ordering", "topological_ordering"),
).addCallback(lambda row: "t%d-%d" % (
row["topological_ordering"], row["stream_ordering"],)
)
def _get_max_topological_txn(self, txn): def _get_max_topological_txn(self, txn):
txn.execute( txn.execute(
"SELECT MAX(topological_ordering) FROM events" "SELECT MAX(topological_ordering) FROM events"

View File

@ -34,6 +34,11 @@ class SourcePaginationConfig(object):
self.direction = 'f' if direction == 'f' else 'b' self.direction = 'f' if direction == 'f' else 'b'
self.limit = int(limit) if limit is not None else None self.limit = int(limit) if limit is not None else None
def __repr__(self):
return (
"StreamConfig(from_key=%r, to_key=%r, direction=%r, limit=%r)"
) % (self.from_key, self.to_key, self.direction, self.limit)
class PaginationConfig(object): class PaginationConfig(object):
@ -94,10 +99,10 @@ class PaginationConfig(object):
logger.exception("Failed to create pagination config") logger.exception("Failed to create pagination config")
raise SynapseError(400, "Invalid request.") raise SynapseError(400, "Invalid request.")
def __str__(self): def __repr__(self):
return ( return (
"<PaginationConfig from_tok=%s, to_tok=%s, " "PaginationConfig(from_tok=%r, to_tok=%r,"
"direction=%s, limit=%s>" " direction=%r, limit=%r)"
) % (self.from_token, self.to_token, self.direction, self.limit) ) % (self.from_token, self.to_token, self.direction, self.limit)
def get_source_config(self, source_name): def get_source_config(self, source_name):

View File

@ -23,22 +23,6 @@ from synapse.handlers.typing import TypingNotificationEventSource
from synapse.handlers.receipts import ReceiptEventSource from synapse.handlers.receipts import ReceiptEventSource
class NullSource(object):
"""This event source never yields any events and its token remains at
zero. It may be useful for unit-testing."""
def __init__(self, hs):
pass
def get_new_events_for_user(self, user, from_key, limit):
return defer.succeed(([], from_key))
def get_current_key(self, direction='f'):
return defer.succeed(0)
def get_pagination_rows(self, user, pagination_config, key):
return defer.succeed(([], pagination_config.from_key))
class EventSources(object): class EventSources(object):
SOURCE_TYPES = { SOURCE_TYPES = {
"room": RoomEventSource, "room": RoomEventSource,
@ -70,15 +54,3 @@ class EventSources(object):
), ),
) )
defer.returnValue(token) defer.returnValue(token)
class StreamSource(object):
def get_new_events_for_user(self, user, from_key, limit):
"""from_key is the key within this event source."""
raise NotImplementedError("get_new_events_for_user")
def get_current_key(self):
raise NotImplementedError("get_current_key")
def get_pagination_rows(self, user, pagination_config, key):
raise NotImplementedError("get_rows")

View File

@ -29,34 +29,6 @@ def unwrapFirstError(failure):
return failure.value.subFailure return failure.value.subFailure
def unwrap_deferred(d):
"""Given a deferred that we know has completed, return its value or raise
the failure as an exception
"""
if not d.called:
raise RuntimeError("deferred has not finished")
res = []
def f(r):
res.append(r)
return r
d.addCallback(f)
if res:
return res[0]
def f(r):
res.append(r)
return r
d.addErrback(f)
if res:
res[0].raiseException()
else:
raise RuntimeError("deferred did not call callbacks")
class Clock(object): class Clock(object):
"""A small utility that obtains current time-of-day so that time may be """A small utility that obtains current time-of-day so that time may be
mocked during unit-tests. mocked during unit-tests.

View File

@ -19,17 +19,21 @@ from mock import Mock
from synapse.api.auth import Auth from synapse.api.auth import Auth
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.types import UserID
from tests.utils import setup_test_homeserver
import pymacaroons
class AuthTestCase(unittest.TestCase): class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self): def setUp(self):
self.state_handler = Mock() self.state_handler = Mock()
self.store = Mock() self.store = Mock()
self.hs = Mock() self.hs = yield setup_test_homeserver(handlers=None)
self.hs.get_datastore = Mock(return_value=self.store) self.hs.get_datastore = Mock(return_value=self.store)
self.hs.get_state_handler = Mock(return_value=self.state_handler)
self.auth = Auth(self.hs) self.auth = Auth(self.hs)
self.test_user = "@foo:bar" self.test_user = "@foo:bar"
@ -133,3 +137,140 @@ class AuthTestCase(unittest.TestCase):
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = Mock(return_value=[""])
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)
@defer.inlineCallbacks
def test_get_user_from_macaroon(self):
# TODO(danielwh): Remove this mock when we remove the
# get_user_by_access_token fallback.
self.store.get_user_by_access_token = Mock(
return_value={"name": "@baldrick:matrix.org"}
)
user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
user_info = yield self.auth._get_user_from_macaroon(macaroon.serialize())
user = user_info["user"]
self.assertEqual(UserID.from_string(user_id), user)
@defer.inlineCallbacks
def test_get_user_from_macaroon_user_db_mismatch(self):
self.store.get_user_by_access_token = Mock(
return_value={"name": "@percy:matrix.org"}
)
user = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user,))
with self.assertRaises(AuthError) as cm:
yield self.auth._get_user_from_macaroon(macaroon.serialize())
self.assertEqual(401, cm.exception.code)
self.assertIn("User mismatch", cm.exception.msg)
@defer.inlineCallbacks
def test_get_user_from_macaroon_missing_caveat(self):
# TODO(danielwh): Remove this mock when we remove the
# get_user_by_access_token fallback.
self.store.get_user_by_access_token = Mock(
return_value={"name": "@baldrick:matrix.org"}
)
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
with self.assertRaises(AuthError) as cm:
yield self.auth._get_user_from_macaroon(macaroon.serialize())
self.assertEqual(401, cm.exception.code)
self.assertIn("No user caveat", cm.exception.msg)
@defer.inlineCallbacks
def test_get_user_from_macaroon_wrong_key(self):
# TODO(danielwh): Remove this mock when we remove the
# get_user_by_access_token fallback.
self.store.get_user_by_access_token = Mock(
return_value={"name": "@baldrick:matrix.org"}
)
user = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key + "wrong")
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user,))
with self.assertRaises(AuthError) as cm:
yield self.auth._get_user_from_macaroon(macaroon.serialize())
self.assertEqual(401, cm.exception.code)
self.assertIn("Invalid macaroon", cm.exception.msg)
@defer.inlineCallbacks
def test_get_user_from_macaroon_unknown_caveat(self):
# TODO(danielwh): Remove this mock when we remove the
# get_user_by_access_token fallback.
self.store.get_user_by_access_token = Mock(
return_value={"name": "@baldrick:matrix.org"}
)
user = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user,))
macaroon.add_first_party_caveat("cunning > fox")
with self.assertRaises(AuthError) as cm:
yield self.auth._get_user_from_macaroon(macaroon.serialize())
self.assertEqual(401, cm.exception.code)
self.assertIn("Invalid macaroon", cm.exception.msg)
@defer.inlineCallbacks
def test_get_user_from_macaroon_expired(self):
# TODO(danielwh): Remove this mock when we remove the
# get_user_by_access_token fallback.
self.store.get_user_by_access_token = Mock(
return_value={"name": "@baldrick:matrix.org"}
)
self.store.get_user_by_access_token = Mock(
return_value={"name": "@baldrick:matrix.org"}
)
user = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user,))
macaroon.add_first_party_caveat("time < 1") # ms
self.hs.clock.now = 5000 # seconds
yield self.auth._get_user_from_macaroon(macaroon.serialize())
# TODO(daniel): Turn on the check that we validate expiration, when we
# validate expiration (and remove the above line, which will start
# throwing).
# with self.assertRaises(AuthError) as cm:
# yield self.auth._get_user_from_macaroon(macaroon.serialize())
# self.assertEqual(401, cm.exception.code)
# self.assertIn("Invalid macaroon", cm.exception.msg)

View File

@ -41,6 +41,22 @@ myid = "@apple:test"
PATH_PREFIX = "/_matrix/client/api/v1" PATH_PREFIX = "/_matrix/client/api/v1"
class NullSource(object):
"""This event source never yields any events and its token remains at
zero. It may be useful for unit-testing."""
def __init__(self, hs):
pass
def get_new_events_for_user(self, user, from_key, limit):
return defer.succeed(([], from_key))
def get_current_key(self, direction='f'):
return defer.succeed(0)
def get_pagination_rows(self, user, pagination_config, key):
return defer.succeed(([], pagination_config.from_key))
class JustPresenceHandlers(object): class JustPresenceHandlers(object):
def __init__(self, hs): def __init__(self, hs):
self.presence_handler = PresenceHandler(hs) self.presence_handler = PresenceHandler(hs)
@ -76,7 +92,7 @@ class PresenceStateTestCase(unittest.TestCase):
"token_id": 1, "token_id": 1,
} }
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
room_member_handler = hs.handlers.room_member_handler = Mock( room_member_handler = hs.handlers.room_member_handler = Mock(
spec=[ spec=[
@ -169,7 +185,7 @@ class PresenceListTestCase(unittest.TestCase):
] ]
) )
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
presence.register_servlets(hs, self.mock_resource) presence.register_servlets(hs, self.mock_resource)
@ -243,7 +259,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
# HIDEOUS HACKERY # HIDEOUS HACKERY
# TODO(paul): This should be injected in via the HomeServer DI system # TODO(paul): This should be injected in via the HomeServer DI system
from synapse.streams.events import ( from synapse.streams.events import (
PresenceEventSource, NullSource, EventSources PresenceEventSource, EventSources
) )
old_SOURCE_TYPES = EventSources.SOURCE_TYPES old_SOURCE_TYPES = EventSources.SOURCE_TYPES

View File

@ -59,7 +59,7 @@ class RoomPermissionsTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
} }
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -239,7 +239,7 @@ class RoomPermissionsTestCase(RestTestCase):
"PUT", topic_path, topic_content) "PUT", topic_path, topic_content)
self.assertEquals(403, code, msg=str(response)) self.assertEquals(403, code, msg=str(response))
(code, response) = yield self.mock_resource.trigger_get(topic_path) (code, response) = yield self.mock_resource.trigger_get(topic_path)
self.assertEquals(403, code, msg=str(response)) self.assertEquals(200, code, msg=str(response))
# get topic in PUBLIC room, not joined, expect 403 # get topic in PUBLIC room, not joined, expect 403
(code, response) = yield self.mock_resource.trigger_get( (code, response) = yield self.mock_resource.trigger_get(
@ -301,11 +301,11 @@ class RoomPermissionsTestCase(RestTestCase):
room=room, expect_code=200) room=room, expect_code=200)
# get membership of self, get membership of other, private room + left # get membership of self, get membership of other, private room + left
# expect all 403s # expect all 200s
yield self.leave(room=room, user=self.user_id) yield self.leave(room=room, user=self.user_id)
yield self._test_get_membership( yield self._test_get_membership(
members=[self.user_id, self.rmcreator_id], members=[self.user_id, self.rmcreator_id],
room=room, expect_code=403) room=room, expect_code=200)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_membership_public_room_perms(self): def test_membership_public_room_perms(self):
@ -326,11 +326,11 @@ class RoomPermissionsTestCase(RestTestCase):
room=room, expect_code=200) room=room, expect_code=200)
# get membership of self, get membership of other, public room + left # get membership of self, get membership of other, public room + left
# expect 403. # expect 200.
yield self.leave(room=room, user=self.user_id) yield self.leave(room=room, user=self.user_id)
yield self._test_get_membership( yield self._test_get_membership(
members=[self.user_id, self.rmcreator_id], members=[self.user_id, self.rmcreator_id],
room=room, expect_code=403) room=room, expect_code=200)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_invited_permissions(self): def test_invited_permissions(self):
@ -444,7 +444,7 @@ class RoomsMemberListTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
} }
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -492,9 +492,9 @@ class RoomsMemberListTestCase(RestTestCase):
self.assertEquals(200, code, msg=str(response)) self.assertEquals(200, code, msg=str(response))
yield self.leave(room=room_id, user=self.user_id) yield self.leave(room=room_id, user=self.user_id)
# can no longer see list, you've left. # can see old list once left
(code, response) = yield self.mock_resource.trigger_get(room_path) (code, response) = yield self.mock_resource.trigger_get(room_path)
self.assertEquals(403, code, msg=str(response)) self.assertEquals(200, code, msg=str(response))
class RoomsCreateTestCase(RestTestCase): class RoomsCreateTestCase(RestTestCase):
@ -522,7 +522,7 @@ class RoomsCreateTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
} }
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -614,7 +614,7 @@ class RoomTopicTestCase(RestTestCase):
"token_id": 1, "token_id": 1,
} }
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -718,7 +718,7 @@ class RoomMemberStateTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
} }
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -843,7 +843,7 @@ class RoomMessagesTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
} }
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -938,7 +938,7 @@ class RoomInitialSyncTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
} }
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)

View File

@ -67,7 +67,7 @@ class RoomTypingTestCase(RestTestCase):
"token_id": 1, "token_id": 1,
} }
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)

View File

@ -37,9 +37,6 @@ class RestTestCase(unittest.TestCase):
self.mock_resource = None self.mock_resource = None
self.auth_user_id = None self.auth_user_id = None
def mock_get_user_by_access_token(self, token=None):
return self.auth_user_id
@defer.inlineCallbacks @defer.inlineCallbacks
def create_room_as(self, room_creator, is_public=True, tok=None): def create_room_as(self, room_creator, is_public=True, tok=None):
temp_id = self.auth_user_id temp_id = self.auth_user_id

View File

@ -48,7 +48,7 @@ class V2AlphaRestTestCase(unittest.TestCase):
"user": UserID.from_string(self.USER_ID), "user": UserID.from_string(self.USER_ID),
"token_id": 1, "token_id": 1,
} }
hs.get_auth().get_user_by_access_token = _get_user_by_access_token hs.get_auth()._get_user_by_access_token = _get_user_by_access_token
for r in self.TO_REGISTER: for r in self.TO_REGISTER:
r.register_servlets(hs, self.mock_resource) r.register_servlets(hs, self.mock_resource)

View File

@ -0,0 +1,81 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.types import UserID, RoomID
from tests.utils import setup_test_homeserver
from mock import Mock
class EventInjector:
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
self.message_handler = hs.get_handlers().message_handler
self.event_builder_factory = hs.get_event_builder_factory()
@defer.inlineCallbacks
def create_room(self, room):
builder = self.event_builder_factory.new({
"type": EventTypes.Create,
"room_id": room.to_string(),
"content": {},
})
event, context = yield self.message_handler._create_new_client_event(
builder
)
yield self.store.persist_event(event, context)
@defer.inlineCallbacks
def inject_room_member(self, room, user, membership):
builder = self.event_builder_factory.new({
"type": EventTypes.Member,
"sender": user.to_string(),
"state_key": user.to_string(),
"room_id": room.to_string(),
"content": {"membership": membership},
})
event, context = yield self.message_handler._create_new_client_event(
builder
)
yield self.store.persist_event(event, context)
defer.returnValue(event)
@defer.inlineCallbacks
def inject_message(self, room, user, body):
builder = self.event_builder_factory.new({
"type": EventTypes.Message,
"sender": user.to_string(),
"state_key": user.to_string(),
"room_id": room.to_string(),
"content": {"body": body, "msgtype": u"message"},
})
event, context = yield self.message_handler._create_new_client_event(
builder
)
yield self.store.persist_event(event, context)

View File

@ -185,26 +185,6 @@ class SQLBaseStoreTestCase(unittest.TestCase):
[3, 4, 1, 2] [3, 4, 1, 2]
) )
@defer.inlineCallbacks
def test_update_one_with_return(self):
self.mock_txn.rowcount = 1
self.mock_txn.fetchone.return_value = ("Old Value",)
ret = yield self.datastore._simple_selectupdate_one(
table="tablename",
keyvalues={"keycol": "TheKey"},
updatevalues={"columname": "New Value"},
retcols=["columname"]
)
self.assertEquals({"columname": "Old Value"}, ret)
self.mock_txn.execute.assert_has_calls([
call('SELECT columname FROM tablename WHERE keycol = ?',
['TheKey']),
call("UPDATE tablename SET columname = ? WHERE keycol = ?",
["New Value", "TheKey"])
])
@defer.inlineCallbacks @defer.inlineCallbacks
def test_delete_one(self): def test_delete_one(self):
self.mock_txn.rowcount = 1 self.mock_txn.rowcount = 1

View File

@ -0,0 +1,116 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import uuid
from mock.mock import Mock
from synapse.types import RoomID, UserID
from tests import unittest
from twisted.internet import defer
from tests.storage.event_injector import EventInjector
from tests.utils import setup_test_homeserver
class EventsStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hs = yield setup_test_homeserver(
resource_for_federation=Mock(),
http_client=None,
)
self.store = self.hs.get_datastore()
self.db_pool = self.hs.get_db_pool()
self.message_handler = self.hs.get_handlers().message_handler
self.event_injector = EventInjector(self.hs)
@defer.inlineCallbacks
def test_count_daily_messages(self):
self.db_pool.runQuery("DELETE FROM stats_reporting")
self.hs.clock.now = 100
# Never reported before, and nothing which could be reported
count = yield self.store.count_daily_messages()
self.assertIsNone(count)
count = yield self.db_pool.runQuery("SELECT COUNT(*) FROM stats_reporting")
self.assertEqual([(0,)], count)
# Create something to report
room = RoomID.from_string("!abc123:test")
user = UserID.from_string("@raccoonlover:test")
yield self.event_injector.create_room(room)
self.base_event = yield self._get_last_stream_token()
yield self.event_injector.inject_message(room, user, "Raccoons are really cute")
# Never reported before, something could be reported, but isn't because
# it isn't old enough.
count = yield self.store.count_daily_messages()
self.assertIsNone(count)
self._assert_stats_reporting(1, self.hs.clock.now)
# Already reported yesterday, two new events from today.
yield self.event_injector.inject_message(room, user, "Yeah they are!")
yield self.event_injector.inject_message(room, user, "Incredibly!")
self.hs.clock.now += 60 * 60 * 24
count = yield self.store.count_daily_messages()
self.assertEqual(2, count) # 2 since yesterday
self._assert_stats_reporting(3, self.hs.clock.now) # 3 ever
# Last reported too recently.
yield self.event_injector.inject_message(room, user, "Who could disagree?")
self.hs.clock.now += 60 * 60 * 22
count = yield self.store.count_daily_messages()
self.assertIsNone(count)
self._assert_stats_reporting(4, self.hs.clock.now)
# Last reported too long ago
yield self.event_injector.inject_message(room, user, "No one.")
self.hs.clock.now += 60 * 60 * 26
count = yield self.store.count_daily_messages()
self.assertIsNone(count)
self._assert_stats_reporting(5, self.hs.clock.now)
# And now let's actually report something
yield self.event_injector.inject_message(room, user, "Indeed.")
yield self.event_injector.inject_message(room, user, "Indeed.")
yield self.event_injector.inject_message(room, user, "Indeed.")
# A little over 24 hours is fine :)
self.hs.clock.now += (60 * 60 * 24) + 50
count = yield self.store.count_daily_messages()
self.assertEqual(3, count)
self._assert_stats_reporting(8, self.hs.clock.now)
@defer.inlineCallbacks
def _get_last_stream_token(self):
rows = yield self.db_pool.runQuery(
"SELECT stream_ordering"
" FROM events"
" ORDER BY stream_ordering DESC"
" LIMIT 1"
)
if not rows:
defer.returnValue(0)
else:
defer.returnValue(rows[0][0])
@defer.inlineCallbacks
def _assert_stats_reporting(self, messages, time):
rows = yield self.db_pool.runQuery(
"SELECT reported_stream_token, reported_time FROM stats_reporting"
)
self.assertEqual([(self.base_event + messages, time,)], rows)

View File

@ -85,7 +85,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
# Room events need the full datastore, for persist_event() and # Room events need the full datastore, for persist_event() and
# get_room_state() # get_room_state()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.event_factory = hs.get_event_factory(); self.event_factory = hs.get_event_factory()
self.room = RoomID.from_string("!abcde:test") self.room = RoomID.from_string("!abcde:test")

View File

@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.types import UserID, RoomID from synapse.types import UserID, RoomID
from tests.storage.event_injector import EventInjector
from tests.utils import setup_test_homeserver from tests.utils import setup_test_homeserver
@ -36,6 +37,7 @@ class StreamStoreTestCase(unittest.TestCase):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
self.event_injector = EventInjector(hs)
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.message_handler = self.handlers.message_handler self.message_handler = self.handlers.message_handler
@ -45,60 +47,20 @@ class StreamStoreTestCase(unittest.TestCase):
self.room1 = RoomID.from_string("!abc123:test") self.room1 = RoomID.from_string("!abc123:test")
self.room2 = RoomID.from_string("!xyx987:test") self.room2 = RoomID.from_string("!xyx987:test")
self.depth = 1
@defer.inlineCallbacks
def inject_room_member(self, room, user, membership):
self.depth += 1
builder = self.event_builder_factory.new({
"type": EventTypes.Member,
"sender": user.to_string(),
"state_key": user.to_string(),
"room_id": room.to_string(),
"content": {"membership": membership},
})
event, context = yield self.message_handler._create_new_client_event(
builder
)
yield self.store.persist_event(event, context)
defer.returnValue(event)
@defer.inlineCallbacks
def inject_message(self, room, user, body):
self.depth += 1
builder = self.event_builder_factory.new({
"type": EventTypes.Message,
"sender": user.to_string(),
"state_key": user.to_string(),
"room_id": room.to_string(),
"content": {"body": body, "msgtype": u"message"},
})
event, context = yield self.message_handler._create_new_client_event(
builder
)
yield self.store.persist_event(event, context)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_event_stream_get_other(self): def test_event_stream_get_other(self):
# Both bob and alice joins the room # Both bob and alice joins the room
yield self.inject_room_member( yield self.event_injector.inject_room_member(
self.room1, self.u_alice, Membership.JOIN self.room1, self.u_alice, Membership.JOIN
) )
yield self.inject_room_member( yield self.event_injector.inject_room_member(
self.room1, self.u_bob, Membership.JOIN self.room1, self.u_bob, Membership.JOIN
) )
# Initial stream key: # Initial stream key:
start = yield self.store.get_room_events_max_id() start = yield self.store.get_room_events_max_id()
yield self.inject_message(self.room1, self.u_alice, u"test") yield self.event_injector.inject_message(self.room1, self.u_alice, u"test")
end = yield self.store.get_room_events_max_id() end = yield self.store.get_room_events_max_id()
@ -125,17 +87,17 @@ class StreamStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_event_stream_get_own(self): def test_event_stream_get_own(self):
# Both bob and alice joins the room # Both bob and alice joins the room
yield self.inject_room_member( yield self.event_injector.inject_room_member(
self.room1, self.u_alice, Membership.JOIN self.room1, self.u_alice, Membership.JOIN
) )
yield self.inject_room_member( yield self.event_injector.inject_room_member(
self.room1, self.u_bob, Membership.JOIN self.room1, self.u_bob, Membership.JOIN
) )
# Initial stream key: # Initial stream key:
start = yield self.store.get_room_events_max_id() start = yield self.store.get_room_events_max_id()
yield self.inject_message(self.room1, self.u_alice, u"test") yield self.event_injector.inject_message(self.room1, self.u_alice, u"test")
end = yield self.store.get_room_events_max_id() end = yield self.store.get_room_events_max_id()
@ -162,22 +124,22 @@ class StreamStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_event_stream_join_leave(self): def test_event_stream_join_leave(self):
# Both bob and alice joins the room # Both bob and alice joins the room
yield self.inject_room_member( yield self.event_injector.inject_room_member(
self.room1, self.u_alice, Membership.JOIN self.room1, self.u_alice, Membership.JOIN
) )
yield self.inject_room_member( yield self.event_injector.inject_room_member(
self.room1, self.u_bob, Membership.JOIN self.room1, self.u_bob, Membership.JOIN
) )
# Then bob leaves again. # Then bob leaves again.
yield self.inject_room_member( yield self.event_injector.inject_room_member(
self.room1, self.u_bob, Membership.LEAVE self.room1, self.u_bob, Membership.LEAVE
) )
# Initial stream key: # Initial stream key:
start = yield self.store.get_room_events_max_id() start = yield self.store.get_room_events_max_id()
yield self.inject_message(self.room1, self.u_alice, u"test") yield self.event_injector.inject_message(self.room1, self.u_alice, u"test")
end = yield self.store.get_room_events_max_id() end = yield self.store.get_room_events_max_id()
@ -193,17 +155,17 @@ class StreamStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_event_stream_prev_content(self): def test_event_stream_prev_content(self):
yield self.inject_room_member( yield self.event_injector.inject_room_member(
self.room1, self.u_bob, Membership.JOIN self.room1, self.u_bob, Membership.JOIN
) )
event1 = yield self.inject_room_member( event1 = yield self.event_injector.inject_room_member(
self.room1, self.u_alice, Membership.JOIN self.room1, self.u_alice, Membership.JOIN
) )
start = yield self.store.get_room_events_max_id() start = yield self.store.get_room_events_max_id()
event2 = yield self.inject_room_member( event2 = yield self.event_injector.inject_room_member(
self.room1, self.u_alice, Membership.JOIN, self.room1, self.u_alice, Membership.JOIN,
) )