Merge branch 'develop' of github.com:matrix-org/synapse into erikj/invite_state

This commit is contained in:
Erik Johnston 2015-09-25 11:38:28 +01:00
commit a14665bde7
52 changed files with 1142 additions and 266 deletions

View File

@ -1,3 +1,17 @@
Changes in synapse v0.10.0-r2 (2015-09-16)
==========================================
* Fix bug where we always fetched remote server signing keys instead of using
ones in our cache.
* Fix adding threepids to an existing account.
* Fix bug with invinting over federation where remote server was already in
the room. (PR #281, SYN-392)
Changes in synapse v0.10.0-r1 (2015-09-08)
==========================================
* Fix bug with python packaging
Changes in synapse v0.10.0 (2015-09-03) Changes in synapse v0.10.0 (2015-09-03)
======================================= =======================================

View File

@ -25,6 +25,7 @@ for port in 8080 8081 8082; do
--generate-config \ --generate-config \
-H "localhost:$https_port" \ -H "localhost:$https_port" \
--config-path "$DIR/etc/$port.config" \ --config-path "$DIR/etc/$port.config" \
--report-stats no
# Check script parameters # Check script parameters
if [ $# -eq 1 ]; then if [ $# -eq 1 ]; then

View File

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.10.0" __version__ = "0.10.0-r2"

View File

@ -23,6 +23,7 @@ from synapse.util.logutils import log_function
from synapse.types import UserID, EventID from synapse.types import UserID, EventID
import logging import logging
import pymacaroons
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -40,6 +41,12 @@ class Auth(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
self._KNOWN_CAVEAT_PREFIXES = set([
"gen = ",
"type = ",
"time < ",
"user_id = ",
])
def check(self, event, auth_events): def check(self, event, auth_events):
""" Checks if this event is correctly authed. """ Checks if this event is correctly authed.
@ -65,6 +72,14 @@ class Auth(object):
# FIXME # FIXME
return True return True
creation_event = auth_events.get((EventTypes.Create, ""), None)
if not creation_event:
raise SynapseError(
403,
"Room %r does not exist" % (event.room_id,)
)
# FIXME: Temp hack # FIXME: Temp hack
if event.type == EventTypes.Aliases: if event.type == EventTypes.Aliases:
return True return True
@ -104,6 +119,20 @@ class Auth(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def check_joined_room(self, room_id, user_id, current_state=None): def check_joined_room(self, room_id, user_id, current_state=None):
"""Check if the user is currently joined in the room
Args:
room_id(str): The room to check.
user_id(str): The user to check.
current_state(dict): Optional map of the current state of the room.
If provided then that map is used to check whether they are a
member of the room. Otherwise the current membership is
loaded from the database.
Raises:
AuthError if the user is not in the room.
Returns:
A deferred membership event for the user if the user is in
the room.
"""
if current_state: if current_state:
member = current_state.get( member = current_state.get(
(EventTypes.Member, user_id), (EventTypes.Member, user_id),
@ -119,6 +148,43 @@ class Auth(object):
self._check_joined_room(member, user_id, room_id) self._check_joined_room(member, user_id, room_id)
defer.returnValue(member) defer.returnValue(member)
@defer.inlineCallbacks
def check_user_was_in_room(self, room_id, user_id, current_state=None):
"""Check if the user was in the room at some point.
Args:
room_id(str): The room to check.
user_id(str): The user to check.
current_state(dict): Optional map of the current state of the room.
If provided then that map is used to check whether they are a
member of the room. Otherwise the current membership is
loaded from the database.
Raises:
AuthError if the user was never in the room.
Returns:
A deferred membership event for the user if the user was in the
room. This will be the join event if they are currently joined to
the room. This will be the leave event if they have left the room.
"""
if current_state:
member = current_state.get(
(EventTypes.Member, user_id),
None
)
else:
member = yield self.state.get_current_state(
room_id=room_id,
event_type=EventTypes.Member,
state_key=user_id
)
membership = member.membership if member else None
if membership not in (Membership.JOIN, Membership.LEAVE):
raise AuthError(403, "User %s not in room %s" % (
user_id, room_id
))
defer.returnValue(member)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_host_in_room(self, room_id, host): def check_host_in_room(self, room_id, host):
curr_state = yield self.state.get_current_state(room_id) curr_state = yield self.state.get_current_state(room_id)
@ -359,7 +425,7 @@ class Auth(object):
except KeyError: except KeyError:
pass # normal users won't have the user_id query parameter set. pass # normal users won't have the user_id query parameter set.
user_info = yield self.get_user_by_access_token(access_token) user_info = yield self._get_user_by_access_token(access_token)
user = user_info["user"] user = user_info["user"]
token_id = user_info["token_id"] token_id = user_info["token_id"]
@ -386,7 +452,7 @@ class Auth(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_by_access_token(self, token): def _get_user_by_access_token(self, token):
""" Get a registered user's ID. """ Get a registered user's ID.
Args: Args:
@ -396,6 +462,86 @@ class Auth(object):
Raises: Raises:
AuthError if no user by that token exists or the token is invalid. AuthError if no user by that token exists or the token is invalid.
""" """
try:
ret = yield self._get_user_from_macaroon(token)
except AuthError:
# TODO(daniel): Remove this fallback when all existing access tokens
# have been re-issued as macaroons.
ret = yield self._look_up_user_by_access_token(token)
defer.returnValue(ret)
@defer.inlineCallbacks
def _get_user_from_macaroon(self, macaroon_str):
try:
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
self._validate_macaroon(macaroon)
user_prefix = "user_id = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
user = UserID.from_string(caveat.caveat_id[len(user_prefix):])
# This codepath exists so that we can actually return a
# token ID, because we use token IDs in place of device
# identifiers throughout the codebase.
# TODO(daniel): Remove this fallback when device IDs are
# properly implemented.
ret = yield self._look_up_user_by_access_token(macaroon_str)
if ret["user"] != user:
logger.error(
"Macaroon user (%s) != DB user (%s)",
user,
ret["user"]
)
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"User mismatch in macaroon",
errcode=Codes.UNKNOWN_TOKEN
)
defer.returnValue(ret)
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
errcode=Codes.UNKNOWN_TOKEN
)
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
errcode=Codes.UNKNOWN_TOKEN
)
def _validate_macaroon(self, macaroon):
v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1")
v.satisfy_exact("type = access")
v.satisfy_general(lambda c: c.startswith("user_id = "))
v.satisfy_general(self._verify_expiry)
v.verify(macaroon, self.hs.config.macaroon_secret_key)
v = pymacaroons.Verifier()
v.satisfy_general(self._verify_recognizes_caveats)
v.verify(macaroon, self.hs.config.macaroon_secret_key)
def _verify_expiry(self, caveat):
prefix = "time < "
if not caveat.startswith(prefix):
return False
# TODO(daniel): Enable expiry check when clients actually know how to
# refresh tokens. (And remember to enable the tests)
return True
expiry = int(caveat[len(prefix):])
now = self.hs.get_clock().time_msec()
return now < expiry
def _verify_recognizes_caveats(self, caveat):
first_space = caveat.find(" ")
if first_space < 0:
return False
second_space = caveat.find(" ", first_space + 1)
if second_space < 0:
return False
return caveat[:second_space + 1] in self._KNOWN_CAVEAT_PREFIXES
@defer.inlineCallbacks
def _look_up_user_by_access_token(self, token):
ret = yield self.store.get_user_by_access_token(token) ret = yield self.store.get_user_by_access_token(token)
if not ret: if not ret:
raise AuthError( raise AuthError(
@ -406,7 +552,6 @@ class Auth(object):
"user": UserID.from_string(ret.get("name")), "user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None), "token_id": ret.get("token_id", None),
} }
defer.returnValue(user_info) defer.returnValue(user_info)
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -27,16 +27,6 @@ class Membership(object):
LIST = (INVITE, JOIN, KNOCK, LEAVE, BAN) LIST = (INVITE, JOIN, KNOCK, LEAVE, BAN)
class Feedback(object):
"""Represents the types of feedback a user can send in response to a
message."""
DELIVERED = u"delivered"
READ = u"read"
LIST = (DELIVERED, READ)
class PresenceState(object): class PresenceState(object):
"""Represents the presence state of a user.""" """Represents the presence state of a user."""
OFFLINE = u"offline" OFFLINE = u"offline"
@ -73,7 +63,6 @@ class EventTypes(object):
PowerLevels = "m.room.power_levels" PowerLevels = "m.room.power_levels"
Aliases = "m.room.aliases" Aliases = "m.room.aliases"
Redaction = "m.room.redaction" Redaction = "m.room.redaction"
Feedback = "m.room.message.feedback"
RoomHistoryVisibility = "m.room.history_visibility" RoomHistoryVisibility = "m.room.history_visibility"
CanonicalAlias = "m.room.canonical_alias" CanonicalAlias = "m.room.canonical_alias"

View File

@ -16,10 +16,23 @@
import sys import sys
sys.dont_write_bytecode = True sys.dont_write_bytecode = True
from synapse.python_dependencies import check_requirements, DEPENDENCY_LINKS from synapse.python_dependencies import (
check_requirements, DEPENDENCY_LINKS, MissingRequirementError
)
if __name__ == '__main__': if __name__ == '__main__':
try:
check_requirements() check_requirements()
except MissingRequirementError as e:
message = "\n".join([
"Missing Requirement: %s" % (e.message,),
"To install run:",
" pip install --upgrade --force \"%s\"" % (e.dependency,),
"",
])
sys.stderr.writelines(message)
sys.exit(1)
from synapse.storage.engines import create_engine, IncorrectDatabaseSetup from synapse.storage.engines import create_engine, IncorrectDatabaseSetup
from synapse.storage import ( from synapse.storage import (
@ -29,7 +42,7 @@ from synapse.storage import (
from synapse.server import HomeServer from synapse.server import HomeServer
from twisted.internet import reactor from twisted.internet import reactor, task, defer
from twisted.application import service from twisted.application import service
from twisted.enterprise import adbapi from twisted.enterprise import adbapi
from twisted.web.resource import Resource, EncodingResourceWrapper from twisted.web.resource import Resource, EncodingResourceWrapper
@ -221,7 +234,7 @@ class SynapseHomeServer(HomeServer):
listener_config, listener_config,
root_resource, root_resource,
), ),
self.tls_context_factory, self.tls_server_context_factory,
interface=bind_address interface=bind_address
) )
else: else:
@ -365,7 +378,6 @@ def setup(config_options):
Args: Args:
config_options_options: The options passed to Synapse. Usually config_options_options: The options passed to Synapse. Usually
`sys.argv[1:]`. `sys.argv[1:]`.
should_run (bool): Whether to start the reactor.
Returns: Returns:
HomeServer HomeServer
@ -388,7 +400,7 @@ def setup(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts events.USE_FROZEN_DICTS = config.use_frozen_dicts
tls_context_factory = context_factory.ServerContextFactory(config) tls_server_context_factory = context_factory.ServerContextFactory(config)
database_engine = create_engine(config.database_config["name"]) database_engine = create_engine(config.database_config["name"])
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
@ -396,7 +408,7 @@ def setup(config_options):
hs = SynapseHomeServer( hs = SynapseHomeServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_context_factory=tls_context_factory, tls_server_context_factory=tls_server_context_factory,
config=config, config=config,
content_addr=config.content_addr, content_addr=config.content_addr,
version_string=version_string, version_string=version_string,
@ -665,6 +677,42 @@ def run(hs):
ThreadPool._worker = profile(ThreadPool._worker) ThreadPool._worker = profile(ThreadPool._worker)
reactor.run = profile(reactor.run) reactor.run = profile(reactor.run)
start_time = hs.get_clock().time()
@defer.inlineCallbacks
def phone_stats_home():
now = int(hs.get_clock().time())
uptime = int(now - start_time)
if uptime < 0:
uptime = 0
stats = {}
stats["homeserver"] = hs.config.server_name
stats["timestamp"] = now
stats["uptime_seconds"] = uptime
stats["total_users"] = yield hs.get_datastore().count_all_users()
all_rooms = yield hs.get_datastore().get_rooms(False)
stats["total_room_count"] = len(all_rooms)
stats["daily_active_users"] = yield hs.get_datastore().count_daily_users()
daily_messages = yield hs.get_datastore().count_daily_messages()
if daily_messages is not None:
stats["daily_messages"] = daily_messages
logger.info("Reporting stats to matrix.org: %s" % (stats,))
try:
yield hs.get_simple_http_client().put_json(
"https://matrix.org/report-usage-stats/push",
stats
)
except Exception as e:
logger.warn("Error reporting stats: %s", e)
if hs.config.report_stats:
phone_home_task = task.LoopingCall(phone_stats_home)
phone_home_task.start(60 * 60 * 24, now=False)
def in_thread(): def in_thread():
with LoggingContext("run"): with LoggingContext("run"):
change_resource_limit(hs.config.soft_file_limit) change_resource_limit(hs.config.soft_file_limit)

View File

@ -25,6 +25,7 @@ SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"]
CONFIGFILE = "homeserver.yaml" CONFIGFILE = "homeserver.yaml"
GREEN = "\x1b[1;32m" GREEN = "\x1b[1;32m"
RED = "\x1b[1;31m"
NORMAL = "\x1b[m" NORMAL = "\x1b[m"
if not os.path.exists(CONFIGFILE): if not os.path.exists(CONFIGFILE):
@ -45,8 +46,15 @@ def start():
print "Starting ...", print "Starting ...",
args = SYNAPSE args = SYNAPSE
args.extend(["--daemonize", "-c", CONFIGFILE]) args.extend(["--daemonize", "-c", CONFIGFILE])
try:
subprocess.check_call(args) subprocess.check_call(args)
print GREEN + "started" + NORMAL print GREEN + "started" + NORMAL
except subprocess.CalledProcessError as e:
print (
RED +
"error starting (exit code: %d); see above for logs" % e.returncode +
NORMAL
)
def stop(): def stop():

View File

@ -26,6 +26,16 @@ class ConfigError(Exception):
class Config(object): class Config(object):
stats_reporting_begging_spiel = (
"We would really appreciate it if you could help our project out by"
" reporting anonymized usage statistics from your homeserver. Only very"
" basic aggregate data (e.g. number of users) will be reported, but it"
" helps us to track the growth of the Matrix community, and helps us to"
" make Matrix a success, as well as to convince other networks that they"
" should peer with us."
"\nThank you."
)
@staticmethod @staticmethod
def parse_size(value): def parse_size(value):
if isinstance(value, int) or isinstance(value, long): if isinstance(value, int) or isinstance(value, long):
@ -111,11 +121,14 @@ class Config(object):
results.append(getattr(cls, name)(self, *args, **kargs)) results.append(getattr(cls, name)(self, *args, **kargs))
return results return results
def generate_config(self, config_dir_path, server_name): def generate_config(self, config_dir_path, server_name, report_stats=None):
default_config = "# vim:ft=yaml\n" default_config = "# vim:ft=yaml\n"
default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all( default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all(
"default_config", config_dir_path, server_name "default_config",
config_dir_path=config_dir_path,
server_name=server_name,
report_stats=report_stats,
)) ))
config = yaml.load(default_config) config = yaml.load(default_config)
@ -139,6 +152,12 @@ class Config(object):
action="store_true", action="store_true",
help="Generate a config file for the server name" help="Generate a config file for the server name"
) )
config_parser.add_argument(
"--report-stats",
action="store",
help="Stuff",
choices=["yes", "no"]
)
config_parser.add_argument( config_parser.add_argument(
"--generate-keys", "--generate-keys",
action="store_true", action="store_true",
@ -189,6 +208,11 @@ class Config(object):
config_files.append(config_path) config_files.append(config_path)
if config_args.generate_config: if config_args.generate_config:
if config_args.report_stats is None:
config_parser.error(
"Please specify either --report-stats=yes or --report-stats=no\n\n" +
cls.stats_reporting_begging_spiel
)
if not config_files: if not config_files:
config_parser.error( config_parser.error(
"Must supply a config file.\nA config file can be automatically" "Must supply a config file.\nA config file can be automatically"
@ -211,7 +235,9 @@ class Config(object):
os.makedirs(config_dir_path) os.makedirs(config_dir_path)
with open(config_path, "wb") as config_file: with open(config_path, "wb") as config_file:
config_bytes, config = obj.generate_config( config_bytes, config = obj.generate_config(
config_dir_path, server_name config_dir_path=config_dir_path,
server_name=server_name,
report_stats=(config_args.report_stats == "yes"),
) )
obj.invoke_all("generate_files", config) obj.invoke_all("generate_files", config)
config_file.write(config_bytes) config_file.write(config_bytes)
@ -261,9 +287,20 @@ class Config(object):
specified_config.update(yaml_config) specified_config.update(yaml_config)
server_name = specified_config["server_name"] server_name = specified_config["server_name"]
_, config = obj.generate_config(config_dir_path, server_name) _, config = obj.generate_config(
config_dir_path=config_dir_path,
server_name=server_name
)
config.pop("log_config") config.pop("log_config")
config.update(specified_config) config.update(specified_config)
if "report_stats" not in config:
sys.stderr.write(
"Please opt in or out of reporting anonymized homeserver usage "
"statistics, by setting the report_stats key in your config file "
" ( " + config_path + " ) " +
"to either True or False.\n\n" +
Config.stats_reporting_begging_spiel + "\n")
sys.exit(1)
if generate_keys: if generate_keys:
obj.invoke_all("generate_files", config) obj.invoke_all("generate_files", config)

View File

@ -20,7 +20,7 @@ class AppServiceConfig(Config):
def read_config(self, config): def read_config(self, config):
self.app_service_config_files = config.get("app_service_config_files", []) self.app_service_config_files = config.get("app_service_config_files", [])
def default_config(cls, config_dir_path, server_name): def default_config(cls, **kwargs):
return """\ return """\
# A list of application service config file to use # A list of application service config file to use
app_service_config_files: [] app_service_config_files: []

View File

@ -24,7 +24,7 @@ class CaptchaConfig(Config):
self.captcha_bypass_secret = config.get("captcha_bypass_secret") self.captcha_bypass_secret = config.get("captcha_bypass_secret")
self.recaptcha_siteverify_api = config["recaptcha_siteverify_api"] self.recaptcha_siteverify_api = config["recaptcha_siteverify_api"]
def default_config(self, config_dir_path, server_name): def default_config(self, **kwargs):
return """\ return """\
## Captcha ## ## Captcha ##

View File

@ -45,7 +45,7 @@ class DatabaseConfig(Config):
self.set_databasepath(config.get("database_path")) self.set_databasepath(config.get("database_path"))
def default_config(self, config, config_dir_path): def default_config(self, **kwargs):
database_path = self.abspath("homeserver.db") database_path = self.abspath("homeserver.db")
return """\ return """\
# Database configuration # Database configuration

View File

@ -40,7 +40,7 @@ class KeyConfig(Config):
config["perspectives"] config["perspectives"]
) )
def default_config(self, config_dir_path, server_name): def default_config(self, config_dir_path, server_name, **kwargs):
base_key_name = os.path.join(config_dir_path, server_name) base_key_name = os.path.join(config_dir_path, server_name)
return """\ return """\
## Signing Keys ## ## Signing Keys ##

View File

@ -21,6 +21,7 @@ import logging.config
import yaml import yaml
from string import Template from string import Template
import os import os
import signal
DEFAULT_LOG_CONFIG = Template(""" DEFAULT_LOG_CONFIG = Template("""
@ -69,7 +70,7 @@ class LoggingConfig(Config):
self.log_config = self.abspath(config.get("log_config")) self.log_config = self.abspath(config.get("log_config"))
self.log_file = self.abspath(config.get("log_file")) self.log_file = self.abspath(config.get("log_file"))
def default_config(self, config_dir_path, server_name): def default_config(self, config_dir_path, server_name, **kwargs):
log_file = self.abspath("homeserver.log") log_file = self.abspath("homeserver.log")
log_config = self.abspath( log_config = self.abspath(
os.path.join(config_dir_path, server_name + ".log.config") os.path.join(config_dir_path, server_name + ".log.config")
@ -142,6 +143,19 @@ class LoggingConfig(Config):
handler = logging.handlers.RotatingFileHandler( handler = logging.handlers.RotatingFileHandler(
self.log_file, maxBytes=(1000 * 1000 * 100), backupCount=3 self.log_file, maxBytes=(1000 * 1000 * 100), backupCount=3
) )
def sighup(signum, stack):
logger.info("Closing log file due to SIGHUP")
handler.doRollover()
logger.info("Opened new log file due to SIGHUP")
# TODO(paul): obviously this is a terrible mechanism for
# stealing SIGHUP, because it means no other part of synapse
# can use it instead. If we want to catch SIGHUP anywhere
# else as well, I'd suggest we find a nicer way to broadcast
# it around.
if getattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, sighup)
else: else:
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setFormatter(formatter) handler.setFormatter(formatter)

View File

@ -19,13 +19,15 @@ from ._base import Config
class MetricsConfig(Config): class MetricsConfig(Config):
def read_config(self, config): def read_config(self, config):
self.enable_metrics = config["enable_metrics"] self.enable_metrics = config["enable_metrics"]
self.report_stats = config.get("report_stats", None)
self.metrics_port = config.get("metrics_port") self.metrics_port = config.get("metrics_port")
self.metrics_bind_host = config.get("metrics_bind_host", "127.0.0.1") self.metrics_bind_host = config.get("metrics_bind_host", "127.0.0.1")
def default_config(self, config_dir_path, server_name): def default_config(self, report_stats=None, **kwargs):
return """\ suffix = "" if report_stats is None else "report_stats: %(report_stats)s\n"
return ("""\
## Metrics ### ## Metrics ###
# Enable collection and rendering of performance metrics # Enable collection and rendering of performance metrics
enable_metrics: False enable_metrics: False
""" """ + suffix) % locals()

View File

@ -27,7 +27,7 @@ class RatelimitConfig(Config):
self.federation_rc_reject_limit = config["federation_rc_reject_limit"] self.federation_rc_reject_limit = config["federation_rc_reject_limit"]
self.federation_rc_concurrent = config["federation_rc_concurrent"] self.federation_rc_concurrent = config["federation_rc_concurrent"]
def default_config(self, config_dir_path, server_name): def default_config(self, **kwargs):
return """\ return """\
## Ratelimiting ## ## Ratelimiting ##

View File

@ -34,7 +34,7 @@ class RegistrationConfig(Config):
self.registration_shared_secret = config.get("registration_shared_secret") self.registration_shared_secret = config.get("registration_shared_secret")
self.macaroon_secret_key = config.get("macaroon_secret_key") self.macaroon_secret_key = config.get("macaroon_secret_key")
def default_config(self, config_dir, server_name): def default_config(self, **kwargs):
registration_shared_secret = random_string_with_symbols(50) registration_shared_secret = random_string_with_symbols(50)
macaroon_secret_key = random_string_with_symbols(50) macaroon_secret_key = random_string_with_symbols(50)
return """\ return """\

View File

@ -60,7 +60,7 @@ class ContentRepositoryConfig(Config):
config["thumbnail_sizes"] config["thumbnail_sizes"]
) )
def default_config(self, config_dir_path, server_name): def default_config(self, **kwargs):
media_store = self.default_path("media_store") media_store = self.default_path("media_store")
uploads_path = self.default_path("uploads") uploads_path = self.default_path("uploads")
return """ return """

View File

@ -41,7 +41,7 @@ class SAML2Config(Config):
self.saml2_config_path = None self.saml2_config_path = None
self.saml2_idp_redirect_url = None self.saml2_idp_redirect_url = None
def default_config(self, config_dir_path, server_name): def default_config(self, config_dir_path, server_name, **kwargs):
return """ return """
# Enable SAML2 for registration and login. Uses pysaml2 # Enable SAML2 for registration and login. Uses pysaml2
# config_path: Path to the sp_conf.py configuration file # config_path: Path to the sp_conf.py configuration file

View File

@ -117,7 +117,7 @@ class ServerConfig(Config):
self.content_addr = content_addr self.content_addr = content_addr
def default_config(self, config_dir_path, server_name): def default_config(self, server_name, **kwargs):
if ":" in server_name: if ":" in server_name:
bind_port = int(server_name.split(":")[1]) bind_port = int(server_name.split(":")[1])
unsecure_port = bind_port - 400 unsecure_port = bind_port - 400

View File

@ -42,7 +42,15 @@ class TlsConfig(Config):
config.get("tls_dh_params_path"), "tls_dh_params" config.get("tls_dh_params_path"), "tls_dh_params"
) )
def default_config(self, config_dir_path, server_name): # This config option applies to non-federation HTTP clients
# (e.g. for talking to recaptcha, identity servers, and such)
# It should never be used in production, and is intended for
# use only when running tests.
self.use_insecure_ssl_client_just_for_testing_do_not_use = config.get(
"use_insecure_ssl_client_just_for_testing_do_not_use"
)
def default_config(self, config_dir_path, server_name, **kwargs):
base_key_name = os.path.join(config_dir_path, server_name) base_key_name = os.path.join(config_dir_path, server_name)
tls_certificate_path = base_key_name + ".tls.crt" tls_certificate_path = base_key_name + ".tls.crt"

View File

@ -22,7 +22,7 @@ class VoipConfig(Config):
self.turn_shared_secret = config["turn_shared_secret"] self.turn_shared_secret = config["turn_shared_secret"]
self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"]) self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"])
def default_config(self, config_dir_path, server_name): def default_config(self, **kwargs):
return """\ return """\
## Turn ## ## Turn ##

View File

@ -228,10 +228,9 @@ class Keyring(object):
def do_iterations(): def do_iterations():
merged_results = {} merged_results = {}
missing_keys = { missing_keys = {}
group.server_name: set(group.key_ids) for group in group_id_to_group.values():
for group in group_id_to_group.values() missing_keys.setdefault(group.server_name, set()).union(group.key_ids)
}
for fn in key_fetch_fns: for fn in key_fetch_fns:
results = yield fn(missing_keys.items()) results = yield fn(missing_keys.items())
@ -470,7 +469,7 @@ class Keyring(object):
continue continue
(response, tls_certificate) = yield fetch_server_key( (response, tls_certificate) = yield fetch_server_key(
server_name, self.hs.tls_context_factory, server_name, self.hs.tls_server_context_factory,
path=(b"/_matrix/key/v2/server/%s" % ( path=(b"/_matrix/key/v2/server/%s" % (
urllib.quote(requested_key_id), urllib.quote(requested_key_id),
)).encode("ascii"), )).encode("ascii"),
@ -604,7 +603,7 @@ class Keyring(object):
# Try to fetch the key from the remote server. # Try to fetch the key from the remote server.
(response, tls_certificate) = yield fetch_server_key( (response, tls_certificate) = yield fetch_server_key(
server_name, self.hs.tls_context_factory server_name, self.hs.tls_server_context_factory
) )
# Check the response. # Check the response.

View File

@ -19,7 +19,6 @@ from ._base import BaseHandler
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.types import UserID from synapse.types import UserID
from synapse.api.errors import LoginError, Codes from synapse.api.errors import LoginError, Codes
from synapse.http.client import SimpleHttpClient
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from twisted.web.client import PartialDownloadError from twisted.web.client import PartialDownloadError
@ -187,7 +186,7 @@ class AuthHandler(BaseHandler):
# TODO: get this from the homeserver rather than creating a new one for # TODO: get this from the homeserver rather than creating a new one for
# each request # each request
try: try:
client = SimpleHttpClient(self.hs) client = self.hs.get_simple_http_client()
resp_body = yield client.post_urlencoded_get_json( resp_body = yield client.post_urlencoded_get_json(
self.hs.config.recaptcha_siteverify_api, self.hs.config.recaptcha_siteverify_api,
args={ args={

View File

@ -16,13 +16,13 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import RoomError, SynapseError from synapse.api.errors import SynapseError
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID, RoomStreamToken from synapse.types import UserID, RoomStreamToken, StreamToken
from ._base import BaseHandler from ._base import BaseHandler
@ -71,7 +71,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_messages(self, user_id=None, room_id=None, pagin_config=None, def get_messages(self, user_id=None, room_id=None, pagin_config=None,
feedback=False, as_client_event=True): as_client_event=True):
"""Get messages in a room. """Get messages in a room.
Args: Args:
@ -79,26 +79,52 @@ class MessageHandler(BaseHandler):
room_id (str): The room they want messages from. room_id (str): The room they want messages from.
pagin_config (synapse.api.streams.PaginationConfig): The pagination pagin_config (synapse.api.streams.PaginationConfig): The pagination
config rules to apply, if any. config rules to apply, if any.
feedback (bool): True to get compressed feedback with the messages
as_client_event (bool): True to get events in client-server format. as_client_event (bool): True to get events in client-server format.
Returns: Returns:
dict: Pagination API results dict: Pagination API results
""" """
yield self.auth.check_joined_room(room_id, user_id) member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
data_source = self.hs.get_event_sources().sources["room"] data_source = self.hs.get_event_sources().sources["room"]
if not pagin_config.from_token: if pagin_config.from_token:
room_token = pagin_config.from_token.room_key
else:
pagin_config.from_token = ( pagin_config.from_token = (
yield self.hs.get_event_sources().get_current_token( yield self.hs.get_event_sources().get_current_token(
direction='b' direction='b'
) )
) )
room_token = pagin_config.from_token.room_key
room_token = RoomStreamToken.parse(pagin_config.from_token.room_key) room_token = RoomStreamToken.parse(room_token)
if room_token.topological is None: if room_token.topological is None:
raise SynapseError(400, "Invalid token") raise SynapseError(400, "Invalid token")
pagin_config.from_token = pagin_config.from_token.copy_and_replace(
"room_key", str(room_token)
)
source_config = pagin_config.get_source_config("room")
if member_event.membership == Membership.LEAVE:
# If they have left the room then clamp the token to be before
# they left the room
leave_token = yield self.store.get_topological_token_for_event(
member_event.event_id
)
leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < room_token.topological:
source_config.from_key = str(leave_token)
if source_config.direction == "f":
if source_config.to_key is None:
source_config.to_key = str(leave_token)
else:
to_token = RoomStreamToken.parse(source_config.to_key)
if leave_token.topological < to_token.topological:
source_config.to_key = str(leave_token)
yield self.hs.get_handlers().federation_handler.maybe_backfill( yield self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, room_token.topological room_id, room_token.topological
) )
@ -106,7 +132,7 @@ class MessageHandler(BaseHandler):
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
events, next_key = yield data_source.get_pagination_rows( events, next_key = yield data_source.get_pagination_rows(
user, pagin_config.get_source_config("room"), room_id user, source_config, room_id
) )
next_token = pagin_config.from_token.copy_and_replace( next_token = pagin_config.from_token.copy_and_replace(
@ -255,29 +281,26 @@ class MessageHandler(BaseHandler):
Raises: Raises:
SynapseError if something went wrong. SynapseError if something went wrong.
""" """
have_joined = yield self.auth.check_joined_room(room_id, user_id) member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
if not have_joined:
raise RoomError(403, "User not in room.")
if member_event.membership == Membership.JOIN:
data = yield self.state_handler.get_current_state( data = yield self.state_handler.get_current_state(
room_id, event_type, state_key room_id, event_type, state_key
) )
elif member_event.membership == Membership.LEAVE:
key = (event_type, state_key)
room_state = yield self.store.get_state_for_events(
room_id, [member_event.event_id], [key]
)
data = room_state[member_event.event_id].get(key)
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks
def get_feedback(self, event_id):
# yield self.auth.check_joined_room(room_id, user_id)
# Pull out the feedback from the db
fb = yield self.store.get_feedback(event_id)
if fb:
defer.returnValue(fb)
defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_events(self, user_id, room_id): def get_state_events(self, user_id, room_id):
"""Retrieve all state events for a given room. """Retrieve all state events for a given room. If the user is
joined to the room then return the current state. If the user has
left the room return the state events from when they left.
Args: Args:
user_id(str): The user requesting state events. user_id(str): The user requesting state events.
@ -285,18 +308,23 @@ class MessageHandler(BaseHandler):
Returns: Returns:
A list of dicts representing state events. [{}, {}, {}] A list of dicts representing state events. [{}, {}, {}]
""" """
yield self.auth.check_joined_room(room_id, user_id) member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
if member_event.membership == Membership.JOIN:
room_state = yield self.state_handler.get_current_state(room_id)
elif member_event.membership == Membership.LEAVE:
room_state = yield self.store.get_state_for_events(
room_id, [member_event.event_id], None
)
room_state = room_state[member_event.event_id]
# TODO: This is duplicating logic from snapshot_all_rooms
current_state = yield self.state_handler.get_current_state(room_id)
now = self.clock.time_msec() now = self.clock.time_msec()
defer.returnValue( defer.returnValue(
[serialize_event(c, now) for c in current_state.values()] [serialize_event(c, now) for c in room_state.values()]
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def snapshot_all_rooms(self, user_id=None, pagin_config=None, def snapshot_all_rooms(self, user_id=None, pagin_config=None, as_client_event=True):
feedback=False, as_client_event=True):
"""Retrieve a snapshot of all rooms the user is invited or has joined. """Retrieve a snapshot of all rooms the user is invited or has joined.
This snapshot may include messages for all rooms where the user is This snapshot may include messages for all rooms where the user is
@ -306,7 +334,6 @@ class MessageHandler(BaseHandler):
user_id (str): The ID of the user making the request. user_id (str): The ID of the user making the request.
pagin_config (synapse.api.streams.PaginationConfig): The pagination pagin_config (synapse.api.streams.PaginationConfig): The pagination
config used to determine how many messages *PER ROOM* to return. config used to determine how many messages *PER ROOM* to return.
feedback (bool): True to get feedback along with these messages.
as_client_event (bool): True to get events in client-server format. as_client_event (bool): True to get events in client-server format.
Returns: Returns:
A list of dicts with "room_id" and "membership" keys for all rooms A list of dicts with "room_id" and "membership" keys for all rooms
@ -316,7 +343,9 @@ class MessageHandler(BaseHandler):
""" """
room_list = yield self.store.get_rooms_for_user_where_membership_is( room_list = yield self.store.get_rooms_for_user_where_membership_is(
user_id=user_id, user_id=user_id,
membership_list=[Membership.INVITE, Membership.JOIN] membership_list=[
Membership.INVITE, Membership.JOIN, Membership.LEAVE
]
) )
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
@ -362,19 +391,32 @@ class MessageHandler(BaseHandler):
rooms_ret.append(d) rooms_ret.append(d)
if event.membership != Membership.JOIN: if event.membership not in (Membership.JOIN, Membership.LEAVE):
return return
try: try:
if event.membership == Membership.JOIN:
room_end_token = now_token.room_key
deferred_room_state = self.state_handler.get_current_state(
event.room_id
)
elif event.membership == Membership.LEAVE:
room_end_token = "s%d" % (event.stream_ordering,)
deferred_room_state = self.store.get_state_for_events(
event.room_id, [event.event_id], None
)
deferred_room_state.addCallback(
lambda states: states[event.event_id]
)
(messages, token), current_state = yield defer.gatherResults( (messages, token), current_state = yield defer.gatherResults(
[ [
self.store.get_recent_events_for_room( self.store.get_recent_events_for_room(
event.room_id, event.room_id,
limit=limit, limit=limit,
end_token=now_token.room_key, end_token=room_end_token,
),
self.state_handler.get_current_state(
event.room_id
), ),
deferred_room_state,
] ]
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
@ -421,15 +463,85 @@ class MessageHandler(BaseHandler):
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def room_initial_sync(self, user_id, room_id, pagin_config=None, def room_initial_sync(self, user_id, room_id, pagin_config=None):
feedback=False): """Capture the a snapshot of a room. If user is currently a member of
current_state = yield self.state.get_current_state( the room this will be what is currently in the room. If the user left
room_id=room_id, the room this will be what was in the room when they left.
Args:
user_id(str): The user to get a snapshot for.
room_id(str): The room to get a snapshot of.
pagin_config(synapse.streams.config.PaginationConfig):
The pagination config used to determine how many messages to
return.
Raises:
AuthError if the user wasn't in the room.
Returns:
A JSON serialisable dict with the snapshot of the room.
"""
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
if member_event.membership == Membership.JOIN:
result = yield self._room_initial_sync_joined(
user_id, room_id, pagin_config, member_event
)
elif member_event.membership == Membership.LEAVE:
result = yield self._room_initial_sync_parted(
user_id, room_id, pagin_config, member_event
)
defer.returnValue(result)
@defer.inlineCallbacks
def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
member_event):
room_state = yield self.store.get_state_for_events(
member_event.room_id, [member_event.event_id], None
) )
yield self.auth.check_joined_room( room_state = room_state[member_event.event_id]
room_id, user_id,
current_state=current_state limit = pagin_config.limit if pagin_config else None
if limit is None:
limit = 10
stream_token = yield self.store.get_stream_token_for_event(
member_event.event_id
)
messages, token = yield self.store.get_recent_events_for_room(
room_id,
limit=limit,
end_token=stream_token
)
messages = yield self._filter_events_for_client(
user_id, room_id, messages
)
start_token = StreamToken(token[0], 0, 0, 0)
end_token = StreamToken(token[1], 0, 0, 0)
time_now = self.clock.time_msec()
defer.returnValue({
"membership": member_event.membership,
"room_id": room_id,
"messages": {
"chunk": [serialize_event(m, time_now) for m in messages],
"start": start_token.to_string(),
"end": end_token.to_string(),
},
"state": [serialize_event(s, time_now) for s in room_state.values()],
"presence": [],
"receipts": [],
})
@defer.inlineCallbacks
def _room_initial_sync_joined(self, user_id, room_id, pagin_config,
member_event):
current_state = yield self.state.get_current_state(
room_id=room_id,
) )
# TODO(paul): I wish I was called with user objects not user_id # TODO(paul): I wish I was called with user objects not user_id
@ -443,8 +555,6 @@ class MessageHandler(BaseHandler):
for x in current_state.values() for x in current_state.values()
] ]
member_event = current_state.get((EventTypes.Member, user_id,))
now_token = yield self.hs.get_event_sources().get_current_token() now_token = yield self.hs.get_event_sources().get_current_token()
limit = pagin_config.limit if pagin_config else None limit = pagin_config.limit if pagin_config else None

View File

@ -25,7 +25,6 @@ from synapse.api.constants import (
from synapse.api.errors import StoreError, SynapseError from synapse.api.errors import StoreError, SynapseError
from synapse.util import stringutils, unwrapFirstError from synapse.util import stringutils, unwrapFirstError
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.events.utils import serialize_event
from collections import OrderedDict from collections import OrderedDict
import logging import logging
@ -39,7 +38,7 @@ class RoomCreationHandler(BaseHandler):
PRESETS_DICT = { PRESETS_DICT = {
RoomCreationPreset.PRIVATE_CHAT: { RoomCreationPreset.PRIVATE_CHAT: {
"join_rules": JoinRules.INVITE, "join_rules": JoinRules.INVITE,
"history_visibility": "invited", "history_visibility": "shared",
"original_invitees_have_ops": False, "original_invitees_have_ops": False,
}, },
RoomCreationPreset.PUBLIC_CHAT: { RoomCreationPreset.PUBLIC_CHAT: {
@ -342,41 +341,6 @@ class RoomMemberHandler(BaseHandler):
if remotedomains is not None: if remotedomains is not None:
remotedomains.add(member.domain) remotedomains.add(member.domain)
@defer.inlineCallbacks
def get_room_members_as_pagination_chunk(self, room_id=None, user_id=None,
limit=0, start_tok=None,
end_tok=None):
"""Retrieve a list of room members in the room.
Args:
room_id (str): The room to get the member list for.
user_id (str): The ID of the user making the request.
limit (int): The max number of members to return.
start_tok (str): Optional. The start token if known.
end_tok (str): Optional. The end token if known.
Returns:
dict: A Pagination streamable dict.
Raises:
SynapseError if something goes wrong.
"""
yield self.auth.check_joined_room(room_id, user_id)
member_list = yield self.store.get_room_members(room_id=room_id)
time_now = self.clock.time_msec()
event_list = [
serialize_event(entry, time_now)
for entry in member_list
]
chunk_data = {
"start": "START", # FIXME (erikj): START is no longer valid
"end": "END",
"chunk": event_list
}
# TODO honor Pagination stream params
# TODO snapshot this list to return on subsequent requests when
# paginating
defer.returnValue(chunk_data)
@defer.inlineCallbacks @defer.inlineCallbacks
def change_membership(self, event, context, do_auth=True): def change_membership(self, event, context, do_auth=True):
""" Change the membership status of a user in a room. """ Change the membership status of a user in a room.
@ -646,7 +610,6 @@ class RoomEventSource(object):
to_key=config.to_key, to_key=config.to_key,
direction=config.direction, direction=config.direction,
limit=config.limit, limit=config.limit,
with_feedback=True
) )
defer.returnValue((events, next_key)) defer.returnValue((events, next_key))

View File

@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
from synapse.api.errors import CodeMessageException from synapse.api.errors import CodeMessageException
from synapse.util.logcontext import preserve_context_over_fn from synapse.util.logcontext import preserve_context_over_fn
@ -19,7 +21,7 @@ import synapse.metrics
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from twisted.internet import defer, reactor from twisted.internet import defer, reactor, ssl
from twisted.web.client import ( from twisted.web.client import (
Agent, readBody, FileBodyProducer, PartialDownloadError, Agent, readBody, FileBodyProducer, PartialDownloadError,
HTTPConnectionPool, HTTPConnectionPool,
@ -59,7 +61,12 @@ class SimpleHttpClient(object):
# 'like a browser' # 'like a browser'
pool = HTTPConnectionPool(reactor) pool = HTTPConnectionPool(reactor)
pool.maxPersistentPerHost = 10 pool.maxPersistentPerHost = 10
self.agent = Agent(reactor, pool=pool) self.agent = Agent(
reactor,
pool=pool,
connectTimeout=15,
contextFactory=hs.get_http_client_context_factory()
)
self.version_string = hs.version_string self.version_string = hs.version_string
def request(self, method, uri, *args, **kwargs): def request(self, method, uri, *args, **kwargs):
@ -252,3 +259,18 @@ def _print_ex(e):
_print_ex(ex) _print_ex(ex)
else: else:
logger.exception(e) logger.exception(e)
class InsecureInterceptableContextFactory(ssl.ContextFactory):
"""
Factory for PyOpenSSL SSL contexts which accepts any certificate for any domain.
Do not use this since it allows an attacker to intercept your communications.
"""
def __init__(self):
self._context = SSL.Context(SSL.SSLv23_METHOD)
self._context.set_verify(VERIFY_NONE, lambda *_: None)
def getContext(self, hostname, port):
return self._context

View File

@ -57,14 +57,14 @@ incoming_responses_counter = metrics.register_counter(
class MatrixFederationEndpointFactory(object): class MatrixFederationEndpointFactory(object):
def __init__(self, hs): def __init__(self, hs):
self.tls_context_factory = hs.tls_context_factory self.tls_server_context_factory = hs.tls_server_context_factory
def endpointForURI(self, uri): def endpointForURI(self, uri):
destination = uri.netloc destination = uri.netloc
return matrix_federation_endpoint( return matrix_federation_endpoint(
reactor, destination, timeout=10, reactor, destination, timeout=10,
ssl_context_factory=self.tls_context_factory ssl_context_factory=self.tls_server_context_factory
) )

View File

@ -18,18 +18,18 @@ from distutils.version import LooseVersion
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
REQUIREMENTS = { REQUIREMENTS = {
"frozendict>=0.4": ["frozendict"],
"unpaddedbase64>=1.0.1": ["unpaddedbase64>=1.0.1"], "unpaddedbase64>=1.0.1": ["unpaddedbase64>=1.0.1"],
"canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"], "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],
"signedjson>=1.0.0": ["signedjson>=1.0.0"], "signedjson>=1.0.0": ["signedjson>=1.0.0"],
"Twisted>=15.1.0": ["twisted>=15.1.0"], "pynacl>=0.3.0": ["nacl>=0.3.0", "nacl.bindings"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"], "service_identity>=1.0.0": ["service_identity>=1.0.0"],
"Twisted>=15.1.0": ["twisted>=15.1.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"], "pyopenssl>=0.14": ["OpenSSL>=0.14"],
"pyyaml": ["yaml"], "pyyaml": ["yaml"],
"pyasn1": ["pyasn1"], "pyasn1": ["pyasn1"],
"pynacl>=0.3.0": ["nacl>=0.3.0"],
"daemonize": ["daemonize"], "daemonize": ["daemonize"],
"py-bcrypt": ["bcrypt"], "py-bcrypt": ["bcrypt"],
"frozendict>=0.4": ["frozendict"],
"pillow": ["PIL"], "pillow": ["PIL"],
"pydenticon": ["pydenticon"], "pydenticon": ["pydenticon"],
"ujson": ["ujson"], "ujson": ["ujson"],
@ -60,7 +60,10 @@ DEPENDENCY_LINKS = {
class MissingRequirementError(Exception): class MissingRequirementError(Exception):
pass def __init__(self, message, module_name, dependency):
super(MissingRequirementError, self).__init__(message)
self.module_name = module_name
self.dependency = dependency
def check_requirements(config=None): def check_requirements(config=None):
@ -88,7 +91,7 @@ def check_requirements(config=None):
) )
raise MissingRequirementError( raise MissingRequirementError(
"Can't import %r which is part of %r" "Can't import %r which is part of %r"
% (module_name, dependency) % (module_name, dependency), module_name, dependency
) )
version = getattr(module, "__version__", None) version = getattr(module, "__version__", None)
file_path = getattr(module, "__file__", None) file_path = getattr(module, "__file__", None)
@ -101,23 +104,25 @@ def check_requirements(config=None):
if version is None: if version is None:
raise MissingRequirementError( raise MissingRequirementError(
"Version of %r isn't set as __version__ of module %r" "Version of %r isn't set as __version__ of module %r"
% (dependency, module_name) % (dependency, module_name), module_name, dependency
) )
if LooseVersion(version) < LooseVersion(required_version): if LooseVersion(version) < LooseVersion(required_version):
raise MissingRequirementError( raise MissingRequirementError(
"Version of %r in %r is too old. %r < %r" "Version of %r in %r is too old. %r < %r"
% (dependency, file_path, version, required_version) % (dependency, file_path, version, required_version),
module_name, dependency
) )
elif version_test == "==": elif version_test == "==":
if version is None: if version is None:
raise MissingRequirementError( raise MissingRequirementError(
"Version of %r isn't set as __version__ of module %r" "Version of %r isn't set as __version__ of module %r"
% (dependency, module_name) % (dependency, module_name), module_name, dependency
) )
if LooseVersion(version) != LooseVersion(required_version): if LooseVersion(version) != LooseVersion(required_version):
raise MissingRequirementError( raise MissingRequirementError(
"Unexpected version of %r in %r. %r != %r" "Unexpected version of %r in %r. %r != %r"
% (dependency, file_path, version, required_version) % (dependency, file_path, version, required_version),
module_name, dependency
) )

View File

@ -26,14 +26,12 @@ class InitialSyncRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
user, _ = yield self.auth.get_user_by_req(request) user, _ = yield self.auth.get_user_by_req(request)
with_feedback = "feedback" in request.args
as_client_event = "raw" not in request.args as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request) pagination_config = PaginationConfig.from_request(request)
handler = self.handlers.message_handler handler = self.handlers.message_handler
content = yield handler.snapshot_all_rooms( content = yield handler.snapshot_all_rooms(
user_id=user.to_string(), user_id=user.to_string(),
pagin_config=pagination_config, pagin_config=pagination_config,
feedback=with_feedback,
as_client_event=as_client_event as_client_event=as_client_event
) )

View File

@ -290,12 +290,18 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens) # TODO support Pagination stream API (limit/tokens)
user, _ = yield self.auth.get_user_by_req(request) user, _ = yield self.auth.get_user_by_req(request)
handler = self.handlers.room_member_handler handler = self.handlers.message_handler
members = yield handler.get_room_members_as_pagination_chunk( events = yield handler.get_state_events(
room_id=room_id, room_id=room_id,
user_id=user.to_string()) user_id=user.to_string(),
)
for event in members["chunk"]: chunk = []
for event in events:
if event["type"] != EventTypes.Member:
continue
chunk.append(event)
# FIXME: should probably be state_key here, not user_id # FIXME: should probably be state_key here, not user_id
target_user = UserID.from_string(event["user_id"]) target_user = UserID.from_string(event["user_id"])
# Presence is an optional cache; don't fail if we can't fetch it # Presence is an optional cache; don't fail if we can't fetch it
@ -308,7 +314,9 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
except: except:
pass pass
defer.returnValue((200, members)) defer.returnValue((200, {
"chunk": chunk
}))
# TODO: Needs unit testing # TODO: Needs unit testing
@ -321,14 +329,12 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
pagination_config = PaginationConfig.from_request( pagination_config = PaginationConfig.from_request(
request, default_limit=10, request, default_limit=10,
) )
with_feedback = "feedback" in request.args
as_client_event = "raw" not in request.args as_client_event = "raw" not in request.args
handler = self.handlers.message_handler handler = self.handlers.message_handler
msgs = yield handler.get_messages( msgs = yield handler.get_messages(
room_id=room_id, room_id=room_id,
user_id=user.to_string(), user_id=user.to_string(),
pagin_config=pagination_config, pagin_config=pagination_config,
feedback=with_feedback,
as_client_event=as_client_event as_client_event=as_client_event
) )

View File

@ -19,7 +19,9 @@
# partial one for unit test mocking. # partial one for unit test mocking.
# Imports required for the default HomeServer() implementation # Imports required for the default HomeServer() implementation
from twisted.web.client import BrowserLikePolicyForHTTPS
from synapse.federation import initialize_http_replication from synapse.federation import initialize_http_replication
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.notifier import Notifier from synapse.notifier import Notifier
from synapse.api.auth import Auth from synapse.api.auth import Auth
from synapse.handlers import Handlers from synapse.handlers import Handlers
@ -87,6 +89,8 @@ class BaseHomeServer(object):
'pusherpool', 'pusherpool',
'event_builder_factory', 'event_builder_factory',
'filtering', 'filtering',
'http_client_context_factory',
'simple_http_client',
] ]
def __init__(self, hostname, **kwargs): def __init__(self, hostname, **kwargs):
@ -174,6 +178,17 @@ class HomeServer(BaseHomeServer):
def build_auth(self): def build_auth(self):
return Auth(self) return Auth(self)
def build_http_client_context_factory(self):
config = self.get_config()
return (
InsecureInterceptableContextFactory()
if config.use_insecure_ssl_client_just_for_testing_do_not_use
else BrowserLikePolicyForHTTPS()
)
def build_simple_http_client(self):
return SimpleHttpClient(self)
def build_v1auth(self): def build_v1auth(self):
orf = Auth(self) orf = Auth(self)
# Matrix spec makes no reference to what HTTP status code is returned, # Matrix spec makes no reference to what HTTP status code is returned,

View File

@ -17,7 +17,6 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
@ -119,8 +118,6 @@ class StateHandler(object):
Returns: Returns:
an EventContext an EventContext
""" """
yield run_on_reactor()
context = EventContext() context = EventContext()
if outlier: if outlier:

View File

@ -54,7 +54,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 23 SCHEMA_VERSION = 24
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))
@ -126,6 +126,27 @@ class DataStore(RoomMemberStore, RoomStore,
lock=False, lock=False,
) )
@defer.inlineCallbacks
def count_daily_users(self):
"""
Counts the number of users who used this homeserver in the last 24 hours.
"""
def _count_users(txn):
txn.execute(
"SELECT COUNT(DISTINCT user_id) AS users"
" FROM user_ips"
" WHERE last_seen > ?",
# This is close enough to a day for our purposes.
(int(self._clock.time_msec()) - (1000 * 60 * 60 * 24),)
)
rows = self.cursor_to_dict(txn)
if rows:
return rows[0]["users"]
return 0
ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret)
def get_user_ip_and_agents(self, user): def get_user_ip_and_agents(self, user):
return self._simple_select_list( return self._simple_select_list(
table="user_ips", table="user_ips",

View File

@ -303,6 +303,15 @@ class EventFederationStore(SQLBaseStore):
], ],
) )
self._update_extremeties(txn, events)
def _update_extremeties(self, txn, events):
"""Updates the event_*_extremities tables based on the new/updated
events being persisted.
This is called for new events *and* for events that were outliers, but
are are now being persisted as non-outliers.
"""
events_by_room = {} events_by_room = {}
for ev in events: for ev in events:
events_by_room.setdefault(ev.room_id, []).append(ev) events_by_room.setdefault(ev.room_id, []).append(ev)

View File

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from _base import SQLBaseStore, _RollbackButIsFineException from _base import SQLBaseStore, _RollbackButIsFineException
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
@ -28,6 +27,7 @@ from canonicaljson import encode_canonical_json
from contextlib import contextmanager from contextlib import contextmanager
import logging import logging
import math
import ujson as json import ujson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -281,6 +281,8 @@ class EventsStore(SQLBaseStore):
(False, event.event_id,) (False, event.event_id,)
) )
self._update_extremeties(txn, [event])
events_and_contexts = filter( events_and_contexts = filter(
lambda ec: ec[0] not in to_remove, lambda ec: ec[0] not in to_remove,
events_and_contexts events_and_contexts
@ -903,3 +905,65 @@ class EventsStore(SQLBaseStore):
txn.execute(sql, (event.event_id,)) txn.execute(sql, (event.event_id,))
result = txn.fetchone() result = txn.fetchone()
return result[0] if result else None return result[0] if result else None
@defer.inlineCallbacks
def count_daily_messages(self):
"""
Returns an estimate of the number of messages sent in the last day.
If it has been significantly less or more than one day since the last
call to this function, it will return None.
"""
def _count_messages(txn):
now = self.hs.get_clock().time()
txn.execute(
"SELECT reported_stream_token, reported_time FROM stats_reporting"
)
last_reported = self.cursor_to_dict(txn)
txn.execute(
"SELECT stream_ordering"
" FROM events"
" ORDER BY stream_ordering DESC"
" LIMIT 1"
)
now_reporting = self.cursor_to_dict(txn)
if not now_reporting:
return None
now_reporting = now_reporting[0]["stream_ordering"]
txn.execute("DELETE FROM stats_reporting")
txn.execute(
"INSERT INTO stats_reporting"
" (reported_stream_token, reported_time)"
" VALUES (?, ?)",
(now_reporting, now,)
)
if not last_reported:
return None
# Close enough to correct for our purposes.
yesterday = (now - 24 * 60 * 60)
if math.fabs(yesterday - last_reported[0]["reported_time"]) > 60 * 60:
return None
txn.execute(
"SELECT COUNT(*) as messages"
" FROM events NATURAL JOIN event_json"
" WHERE json like '%m.room.message%'"
" AND stream_ordering > ?"
" AND stream_ordering <= ?",
(
last_reported[0]["reported_stream_token"],
now_reporting,
)
)
rows = self.cursor_to_dict(txn)
if not rows:
return None
return rows[0]["messages"]
ret = yield self.runInteraction("count_messages", _count_messages)
defer.returnValue(ret)

View File

@ -289,3 +289,16 @@ class RegistrationStore(SQLBaseStore):
if ret: if ret:
defer.returnValue(ret['user_id']) defer.returnValue(ret['user_id'])
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks
def count_all_users(self):
"""Counts all users registered on the homeserver."""
def _count_users(txn):
txn.execute("SELECT COUNT(*) AS users FROM users")
rows = self.cursor_to_dict(txn)
if rows:
return rows[0]["users"]
return 0
ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret)

View File

@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
RoomsForUser = namedtuple( RoomsForUser = namedtuple(
"RoomsForUser", "RoomsForUser",
("room_id", "sender", "membership", "event_id") ("room_id", "sender", "membership", "event_id", "stream_ordering")
) )
@ -141,11 +141,13 @@ class RoomMemberStore(SQLBaseStore):
args.extend(membership_list) args.extend(membership_list)
sql = ( sql = (
"SELECT m.event_id, m.room_id, m.sender, m.membership" "SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering"
" FROM room_memberships as m" " FROM current_state_events as c"
" INNER JOIN current_state_events as c" " INNER JOIN room_memberships as m"
" ON m.event_id = c.event_id " " ON m.event_id = c.event_id"
" AND m.room_id = c.room_id " " INNER JOIN events as e"
" ON e.event_id = c.event_id"
" AND m.room_id = c.room_id"
" AND m.user_id = c.state_key" " AND m.user_id = c.state_key"
" WHERE %s" " WHERE %s"
) % (where_clause,) ) % (where_clause,)

View File

@ -0,0 +1,16 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
DROP INDEX IF EXISTS state_groups_state_tuple;

View File

@ -0,0 +1,22 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Should only ever contain one row
CREATE TABLE IF NOT EXISTS stats_reporting(
-- The stream ordering token which was most recently reported as stats
reported_stream_token INTEGER,
-- The time (seconds since epoch) stats were most recently reported
reported_time BIGINT
);

View File

@ -159,9 +159,7 @@ class StreamStore(SQLBaseStore):
@log_function @log_function
def get_room_events_stream(self, user_id, from_key, to_key, room_id, def get_room_events_stream(self, user_id, from_key, to_key, room_id,
limit=0, with_feedback=False): limit=0):
# TODO (erikj): Handle compressed feedback
current_room_membership_sql = ( current_room_membership_sql = (
"SELECT m.room_id FROM room_memberships as m " "SELECT m.room_id FROM room_memberships as m "
" INNER JOIN current_state_events as c" " INNER JOIN current_state_events as c"
@ -227,10 +225,7 @@ class StreamStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def paginate_room_events(self, room_id, from_key, to_key=None, def paginate_room_events(self, room_id, from_key, to_key=None,
direction='b', limit=-1, direction='b', limit=-1):
with_feedback=False):
# TODO (erikj): Handle compressed feedback
# Tokens really represent positions between elements, but we use # Tokens really represent positions between elements, but we use
# the convention of pointing to the event before the gap. Hence # the convention of pointing to the event before the gap. Hence
# we have a bit of asymmetry when it comes to equalities. # we have a bit of asymmetry when it comes to equalities.
@ -302,7 +297,6 @@ class StreamStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=4) @cachedInlineCallbacks(num_args=4)
def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None): def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None):
# TODO (erikj): Handle compressed feedback
end_token = RoomStreamToken.parse_stream_token(end_token) end_token = RoomStreamToken.parse_stream_token(end_token)
@ -379,6 +373,38 @@ class StreamStore(SQLBaseStore):
) )
defer.returnValue("t%d-%d" % (topo, token)) defer.returnValue("t%d-%d" % (topo, token))
def get_stream_token_for_event(self, event_id):
"""The stream token for an event
Args:
event_id(str): The id of the event to look up a stream token for.
Raises:
StoreError if the event wasn't in the database.
Returns:
A deferred "s%d" stream token.
"""
return self._simple_select_one_onecol(
table="events",
keyvalues={"event_id": event_id},
retcol="stream_ordering",
).addCallback(lambda row: "s%d" % (row,))
def get_topological_token_for_event(self, event_id):
"""The stream token for an event
Args:
event_id(str): The id of the event to look up a stream token for.
Raises:
StoreError if the event wasn't in the database.
Returns:
A deferred "t%d-%d" topological token.
"""
return self._simple_select_one(
table="events",
keyvalues={"event_id": event_id},
retcols=("stream_ordering", "topological_ordering"),
).addCallback(lambda row: "t%d-%d" % (
row["topological_ordering"], row["stream_ordering"],)
)
def _get_max_topological_txn(self, txn): def _get_max_topological_txn(self, txn):
txn.execute( txn.execute(
"SELECT MAX(topological_ordering) FROM events" "SELECT MAX(topological_ordering) FROM events"

View File

@ -34,6 +34,11 @@ class SourcePaginationConfig(object):
self.direction = 'f' if direction == 'f' else 'b' self.direction = 'f' if direction == 'f' else 'b'
self.limit = int(limit) if limit is not None else None self.limit = int(limit) if limit is not None else None
def __repr__(self):
return (
"StreamConfig(from_key=%r, to_key=%r, direction=%r, limit=%r)"
) % (self.from_key, self.to_key, self.direction, self.limit)
class PaginationConfig(object): class PaginationConfig(object):
@ -94,10 +99,10 @@ class PaginationConfig(object):
logger.exception("Failed to create pagination config") logger.exception("Failed to create pagination config")
raise SynapseError(400, "Invalid request.") raise SynapseError(400, "Invalid request.")
def __str__(self): def __repr__(self):
return ( return (
"<PaginationConfig from_tok=%s, to_tok=%s, " "PaginationConfig(from_tok=%r, to_tok=%r,"
"direction=%s, limit=%s>" " direction=%r, limit=%r)"
) % (self.from_token, self.to_token, self.direction, self.limit) ) % (self.from_token, self.to_token, self.direction, self.limit)
def get_source_config(self, source_name): def get_source_config(self, source_name):

View File

@ -19,17 +19,21 @@ from mock import Mock
from synapse.api.auth import Auth from synapse.api.auth import Auth
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.types import UserID
from tests.utils import setup_test_homeserver
import pymacaroons
class AuthTestCase(unittest.TestCase): class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self): def setUp(self):
self.state_handler = Mock() self.state_handler = Mock()
self.store = Mock() self.store = Mock()
self.hs = Mock() self.hs = yield setup_test_homeserver(handlers=None)
self.hs.get_datastore = Mock(return_value=self.store) self.hs.get_datastore = Mock(return_value=self.store)
self.hs.get_state_handler = Mock(return_value=self.state_handler)
self.auth = Auth(self.hs) self.auth = Auth(self.hs)
self.test_user = "@foo:bar" self.test_user = "@foo:bar"
@ -133,3 +137,140 @@ class AuthTestCase(unittest.TestCase):
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = Mock(return_value=[""])
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) self.failureResultOf(d, AuthError)
@defer.inlineCallbacks
def test_get_user_from_macaroon(self):
# TODO(danielwh): Remove this mock when we remove the
# get_user_by_access_token fallback.
self.store.get_user_by_access_token = Mock(
return_value={"name": "@baldrick:matrix.org"}
)
user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
user_info = yield self.auth._get_user_from_macaroon(macaroon.serialize())
user = user_info["user"]
self.assertEqual(UserID.from_string(user_id), user)
@defer.inlineCallbacks
def test_get_user_from_macaroon_user_db_mismatch(self):
self.store.get_user_by_access_token = Mock(
return_value={"name": "@percy:matrix.org"}
)
user = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user,))
with self.assertRaises(AuthError) as cm:
yield self.auth._get_user_from_macaroon(macaroon.serialize())
self.assertEqual(401, cm.exception.code)
self.assertIn("User mismatch", cm.exception.msg)
@defer.inlineCallbacks
def test_get_user_from_macaroon_missing_caveat(self):
# TODO(danielwh): Remove this mock when we remove the
# get_user_by_access_token fallback.
self.store.get_user_by_access_token = Mock(
return_value={"name": "@baldrick:matrix.org"}
)
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
with self.assertRaises(AuthError) as cm:
yield self.auth._get_user_from_macaroon(macaroon.serialize())
self.assertEqual(401, cm.exception.code)
self.assertIn("No user caveat", cm.exception.msg)
@defer.inlineCallbacks
def test_get_user_from_macaroon_wrong_key(self):
# TODO(danielwh): Remove this mock when we remove the
# get_user_by_access_token fallback.
self.store.get_user_by_access_token = Mock(
return_value={"name": "@baldrick:matrix.org"}
)
user = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key + "wrong")
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user,))
with self.assertRaises(AuthError) as cm:
yield self.auth._get_user_from_macaroon(macaroon.serialize())
self.assertEqual(401, cm.exception.code)
self.assertIn("Invalid macaroon", cm.exception.msg)
@defer.inlineCallbacks
def test_get_user_from_macaroon_unknown_caveat(self):
# TODO(danielwh): Remove this mock when we remove the
# get_user_by_access_token fallback.
self.store.get_user_by_access_token = Mock(
return_value={"name": "@baldrick:matrix.org"}
)
user = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user,))
macaroon.add_first_party_caveat("cunning > fox")
with self.assertRaises(AuthError) as cm:
yield self.auth._get_user_from_macaroon(macaroon.serialize())
self.assertEqual(401, cm.exception.code)
self.assertIn("Invalid macaroon", cm.exception.msg)
@defer.inlineCallbacks
def test_get_user_from_macaroon_expired(self):
# TODO(danielwh): Remove this mock when we remove the
# get_user_by_access_token fallback.
self.store.get_user_by_access_token = Mock(
return_value={"name": "@baldrick:matrix.org"}
)
self.store.get_user_by_access_token = Mock(
return_value={"name": "@baldrick:matrix.org"}
)
user = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user,))
macaroon.add_first_party_caveat("time < 1") # ms
self.hs.clock.now = 5000 # seconds
yield self.auth._get_user_from_macaroon(macaroon.serialize())
# TODO(daniel): Turn on the check that we validate expiration, when we
# validate expiration (and remove the above line, which will start
# throwing).
# with self.assertRaises(AuthError) as cm:
# yield self.auth._get_user_from_macaroon(macaroon.serialize())
# self.assertEqual(401, cm.exception.code)
# self.assertIn("Invalid macaroon", cm.exception.msg)

View File

@ -76,7 +76,7 @@ class PresenceStateTestCase(unittest.TestCase):
"token_id": 1, "token_id": 1,
} }
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
room_member_handler = hs.handlers.room_member_handler = Mock( room_member_handler = hs.handlers.room_member_handler = Mock(
spec=[ spec=[
@ -169,7 +169,7 @@ class PresenceListTestCase(unittest.TestCase):
] ]
) )
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
presence.register_servlets(hs, self.mock_resource) presence.register_servlets(hs, self.mock_resource)

View File

@ -59,7 +59,7 @@ class RoomPermissionsTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
} }
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -239,7 +239,7 @@ class RoomPermissionsTestCase(RestTestCase):
"PUT", topic_path, topic_content) "PUT", topic_path, topic_content)
self.assertEquals(403, code, msg=str(response)) self.assertEquals(403, code, msg=str(response))
(code, response) = yield self.mock_resource.trigger_get(topic_path) (code, response) = yield self.mock_resource.trigger_get(topic_path)
self.assertEquals(403, code, msg=str(response)) self.assertEquals(200, code, msg=str(response))
# get topic in PUBLIC room, not joined, expect 403 # get topic in PUBLIC room, not joined, expect 403
(code, response) = yield self.mock_resource.trigger_get( (code, response) = yield self.mock_resource.trigger_get(
@ -301,11 +301,11 @@ class RoomPermissionsTestCase(RestTestCase):
room=room, expect_code=200) room=room, expect_code=200)
# get membership of self, get membership of other, private room + left # get membership of self, get membership of other, private room + left
# expect all 403s # expect all 200s
yield self.leave(room=room, user=self.user_id) yield self.leave(room=room, user=self.user_id)
yield self._test_get_membership( yield self._test_get_membership(
members=[self.user_id, self.rmcreator_id], members=[self.user_id, self.rmcreator_id],
room=room, expect_code=403) room=room, expect_code=200)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_membership_public_room_perms(self): def test_membership_public_room_perms(self):
@ -326,11 +326,11 @@ class RoomPermissionsTestCase(RestTestCase):
room=room, expect_code=200) room=room, expect_code=200)
# get membership of self, get membership of other, public room + left # get membership of self, get membership of other, public room + left
# expect 403. # expect 200.
yield self.leave(room=room, user=self.user_id) yield self.leave(room=room, user=self.user_id)
yield self._test_get_membership( yield self._test_get_membership(
members=[self.user_id, self.rmcreator_id], members=[self.user_id, self.rmcreator_id],
room=room, expect_code=403) room=room, expect_code=200)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_invited_permissions(self): def test_invited_permissions(self):
@ -444,7 +444,7 @@ class RoomsMemberListTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
} }
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -492,9 +492,9 @@ class RoomsMemberListTestCase(RestTestCase):
self.assertEquals(200, code, msg=str(response)) self.assertEquals(200, code, msg=str(response))
yield self.leave(room=room_id, user=self.user_id) yield self.leave(room=room_id, user=self.user_id)
# can no longer see list, you've left. # can see old list once left
(code, response) = yield self.mock_resource.trigger_get(room_path) (code, response) = yield self.mock_resource.trigger_get(room_path)
self.assertEquals(403, code, msg=str(response)) self.assertEquals(200, code, msg=str(response))
class RoomsCreateTestCase(RestTestCase): class RoomsCreateTestCase(RestTestCase):
@ -522,7 +522,7 @@ class RoomsCreateTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
} }
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -614,7 +614,7 @@ class RoomTopicTestCase(RestTestCase):
"token_id": 1, "token_id": 1,
} }
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -718,7 +718,7 @@ class RoomMemberStateTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
} }
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -843,7 +843,7 @@ class RoomMessagesTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
} }
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -938,7 +938,7 @@ class RoomInitialSyncTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
} }
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)

View File

@ -67,7 +67,7 @@ class RoomTypingTestCase(RestTestCase):
"token_id": 1, "token_id": 1,
} }
hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)

View File

@ -37,9 +37,6 @@ class RestTestCase(unittest.TestCase):
self.mock_resource = None self.mock_resource = None
self.auth_user_id = None self.auth_user_id = None
def mock_get_user_by_access_token(self, token=None):
return self.auth_user_id
@defer.inlineCallbacks @defer.inlineCallbacks
def create_room_as(self, room_creator, is_public=True, tok=None): def create_room_as(self, room_creator, is_public=True, tok=None):
temp_id = self.auth_user_id temp_id = self.auth_user_id

View File

@ -48,7 +48,7 @@ class V2AlphaRestTestCase(unittest.TestCase):
"user": UserID.from_string(self.USER_ID), "user": UserID.from_string(self.USER_ID),
"token_id": 1, "token_id": 1,
} }
hs.get_auth().get_user_by_access_token = _get_user_by_access_token hs.get_auth()._get_user_by_access_token = _get_user_by_access_token
for r in self.TO_REGISTER: for r in self.TO_REGISTER:
r.register_servlets(hs, self.mock_resource) r.register_servlets(hs, self.mock_resource)

View File

@ -0,0 +1,81 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.types import UserID, RoomID
from tests.utils import setup_test_homeserver
from mock import Mock
class EventInjector:
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
self.message_handler = hs.get_handlers().message_handler
self.event_builder_factory = hs.get_event_builder_factory()
@defer.inlineCallbacks
def create_room(self, room):
builder = self.event_builder_factory.new({
"type": EventTypes.Create,
"room_id": room.to_string(),
"content": {},
})
event, context = yield self.message_handler._create_new_client_event(
builder
)
yield self.store.persist_event(event, context)
@defer.inlineCallbacks
def inject_room_member(self, room, user, membership):
builder = self.event_builder_factory.new({
"type": EventTypes.Member,
"sender": user.to_string(),
"state_key": user.to_string(),
"room_id": room.to_string(),
"content": {"membership": membership},
})
event, context = yield self.message_handler._create_new_client_event(
builder
)
yield self.store.persist_event(event, context)
defer.returnValue(event)
@defer.inlineCallbacks
def inject_message(self, room, user, body):
builder = self.event_builder_factory.new({
"type": EventTypes.Message,
"sender": user.to_string(),
"state_key": user.to_string(),
"room_id": room.to_string(),
"content": {"body": body, "msgtype": u"message"},
})
event, context = yield self.message_handler._create_new_client_event(
builder
)
yield self.store.persist_event(event, context)

View File

@ -0,0 +1,116 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import uuid
from mock.mock import Mock
from synapse.types import RoomID, UserID
from tests import unittest
from twisted.internet import defer
from tests.storage.event_injector import EventInjector
from tests.utils import setup_test_homeserver
class EventsStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hs = yield setup_test_homeserver(
resource_for_federation=Mock(),
http_client=None,
)
self.store = self.hs.get_datastore()
self.db_pool = self.hs.get_db_pool()
self.message_handler = self.hs.get_handlers().message_handler
self.event_injector = EventInjector(self.hs)
@defer.inlineCallbacks
def test_count_daily_messages(self):
self.db_pool.runQuery("DELETE FROM stats_reporting")
self.hs.clock.now = 100
# Never reported before, and nothing which could be reported
count = yield self.store.count_daily_messages()
self.assertIsNone(count)
count = yield self.db_pool.runQuery("SELECT COUNT(*) FROM stats_reporting")
self.assertEqual([(0,)], count)
# Create something to report
room = RoomID.from_string("!abc123:test")
user = UserID.from_string("@raccoonlover:test")
yield self.event_injector.create_room(room)
self.base_event = yield self._get_last_stream_token()
yield self.event_injector.inject_message(room, user, "Raccoons are really cute")
# Never reported before, something could be reported, but isn't because
# it isn't old enough.
count = yield self.store.count_daily_messages()
self.assertIsNone(count)
self._assert_stats_reporting(1, self.hs.clock.now)
# Already reported yesterday, two new events from today.
yield self.event_injector.inject_message(room, user, "Yeah they are!")
yield self.event_injector.inject_message(room, user, "Incredibly!")
self.hs.clock.now += 60 * 60 * 24
count = yield self.store.count_daily_messages()
self.assertEqual(2, count) # 2 since yesterday
self._assert_stats_reporting(3, self.hs.clock.now) # 3 ever
# Last reported too recently.
yield self.event_injector.inject_message(room, user, "Who could disagree?")
self.hs.clock.now += 60 * 60 * 22
count = yield self.store.count_daily_messages()
self.assertIsNone(count)
self._assert_stats_reporting(4, self.hs.clock.now)
# Last reported too long ago
yield self.event_injector.inject_message(room, user, "No one.")
self.hs.clock.now += 60 * 60 * 26
count = yield self.store.count_daily_messages()
self.assertIsNone(count)
self._assert_stats_reporting(5, self.hs.clock.now)
# And now let's actually report something
yield self.event_injector.inject_message(room, user, "Indeed.")
yield self.event_injector.inject_message(room, user, "Indeed.")
yield self.event_injector.inject_message(room, user, "Indeed.")
# A little over 24 hours is fine :)
self.hs.clock.now += (60 * 60 * 24) + 50
count = yield self.store.count_daily_messages()
self.assertEqual(3, count)
self._assert_stats_reporting(8, self.hs.clock.now)
@defer.inlineCallbacks
def _get_last_stream_token(self):
rows = yield self.db_pool.runQuery(
"SELECT stream_ordering"
" FROM events"
" ORDER BY stream_ordering DESC"
" LIMIT 1"
)
if not rows:
defer.returnValue(0)
else:
defer.returnValue(rows[0][0])
@defer.inlineCallbacks
def _assert_stats_reporting(self, messages, time):
rows = yield self.db_pool.runQuery(
"SELECT reported_stream_token, reported_time FROM stats_reporting"
)
self.assertEqual([(self.base_event + messages, time,)], rows)

View File

@ -85,7 +85,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
# Room events need the full datastore, for persist_event() and # Room events need the full datastore, for persist_event() and
# get_room_state() # get_room_state()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.event_factory = hs.get_event_factory(); self.event_factory = hs.get_event_factory()
self.room = RoomID.from_string("!abcde:test") self.room = RoomID.from_string("!abcde:test")

View File

@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.types import UserID, RoomID from synapse.types import UserID, RoomID
from tests.storage.event_injector import EventInjector
from tests.utils import setup_test_homeserver from tests.utils import setup_test_homeserver
@ -36,6 +37,7 @@ class StreamStoreTestCase(unittest.TestCase):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
self.event_injector = EventInjector(hs)
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.message_handler = self.handlers.message_handler self.message_handler = self.handlers.message_handler
@ -45,60 +47,20 @@ class StreamStoreTestCase(unittest.TestCase):
self.room1 = RoomID.from_string("!abc123:test") self.room1 = RoomID.from_string("!abc123:test")
self.room2 = RoomID.from_string("!xyx987:test") self.room2 = RoomID.from_string("!xyx987:test")
self.depth = 1
@defer.inlineCallbacks
def inject_room_member(self, room, user, membership):
self.depth += 1
builder = self.event_builder_factory.new({
"type": EventTypes.Member,
"sender": user.to_string(),
"state_key": user.to_string(),
"room_id": room.to_string(),
"content": {"membership": membership},
})
event, context = yield self.message_handler._create_new_client_event(
builder
)
yield self.store.persist_event(event, context)
defer.returnValue(event)
@defer.inlineCallbacks
def inject_message(self, room, user, body):
self.depth += 1
builder = self.event_builder_factory.new({
"type": EventTypes.Message,
"sender": user.to_string(),
"state_key": user.to_string(),
"room_id": room.to_string(),
"content": {"body": body, "msgtype": u"message"},
})
event, context = yield self.message_handler._create_new_client_event(
builder
)
yield self.store.persist_event(event, context)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_event_stream_get_other(self): def test_event_stream_get_other(self):
# Both bob and alice joins the room # Both bob and alice joins the room
yield self.inject_room_member( yield self.event_injector.inject_room_member(
self.room1, self.u_alice, Membership.JOIN self.room1, self.u_alice, Membership.JOIN
) )
yield self.inject_room_member( yield self.event_injector.inject_room_member(
self.room1, self.u_bob, Membership.JOIN self.room1, self.u_bob, Membership.JOIN
) )
# Initial stream key: # Initial stream key:
start = yield self.store.get_room_events_max_id() start = yield self.store.get_room_events_max_id()
yield self.inject_message(self.room1, self.u_alice, u"test") yield self.event_injector.inject_message(self.room1, self.u_alice, u"test")
end = yield self.store.get_room_events_max_id() end = yield self.store.get_room_events_max_id()
@ -125,17 +87,17 @@ class StreamStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_event_stream_get_own(self): def test_event_stream_get_own(self):
# Both bob and alice joins the room # Both bob and alice joins the room
yield self.inject_room_member( yield self.event_injector.inject_room_member(
self.room1, self.u_alice, Membership.JOIN self.room1, self.u_alice, Membership.JOIN
) )
yield self.inject_room_member( yield self.event_injector.inject_room_member(
self.room1, self.u_bob, Membership.JOIN self.room1, self.u_bob, Membership.JOIN
) )
# Initial stream key: # Initial stream key:
start = yield self.store.get_room_events_max_id() start = yield self.store.get_room_events_max_id()
yield self.inject_message(self.room1, self.u_alice, u"test") yield self.event_injector.inject_message(self.room1, self.u_alice, u"test")
end = yield self.store.get_room_events_max_id() end = yield self.store.get_room_events_max_id()
@ -162,22 +124,22 @@ class StreamStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_event_stream_join_leave(self): def test_event_stream_join_leave(self):
# Both bob and alice joins the room # Both bob and alice joins the room
yield self.inject_room_member( yield self.event_injector.inject_room_member(
self.room1, self.u_alice, Membership.JOIN self.room1, self.u_alice, Membership.JOIN
) )
yield self.inject_room_member( yield self.event_injector.inject_room_member(
self.room1, self.u_bob, Membership.JOIN self.room1, self.u_bob, Membership.JOIN
) )
# Then bob leaves again. # Then bob leaves again.
yield self.inject_room_member( yield self.event_injector.inject_room_member(
self.room1, self.u_bob, Membership.LEAVE self.room1, self.u_bob, Membership.LEAVE
) )
# Initial stream key: # Initial stream key:
start = yield self.store.get_room_events_max_id() start = yield self.store.get_room_events_max_id()
yield self.inject_message(self.room1, self.u_alice, u"test") yield self.event_injector.inject_message(self.room1, self.u_alice, u"test")
end = yield self.store.get_room_events_max_id() end = yield self.store.get_room_events_max_id()
@ -193,17 +155,17 @@ class StreamStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_event_stream_prev_content(self): def test_event_stream_prev_content(self):
yield self.inject_room_member( yield self.event_injector.inject_room_member(
self.room1, self.u_bob, Membership.JOIN self.room1, self.u_bob, Membership.JOIN
) )
event1 = yield self.inject_room_member( event1 = yield self.event_injector.inject_room_member(
self.room1, self.u_alice, Membership.JOIN self.room1, self.u_alice, Membership.JOIN
) )
start = yield self.store.get_room_events_max_id() start = yield self.store.get_room_events_max_id()
event2 = yield self.inject_room_member( event2 = yield self.event_injector.inject_room_member(
self.room1, self.u_alice, Membership.JOIN, self.room1, self.u_alice, Membership.JOIN,
) )

View File

@ -204,8 +204,8 @@ class StateTestCase(unittest.TestCase):
nodes={ nodes={
"START": DictObj( "START": DictObj(
type=EventTypes.Create, type=EventTypes.Create,
state_key="creator", state_key="",
content={"membership": "@user_id:example.com"}, content={"creator": "@user_id:example.com"},
depth=1, depth=1,
), ),
"A": DictObj( "A": DictObj(
@ -259,8 +259,8 @@ class StateTestCase(unittest.TestCase):
nodes={ nodes={
"START": DictObj( "START": DictObj(
type=EventTypes.Create, type=EventTypes.Create,
state_key="creator", state_key="",
content={"membership": "@user_id:example.com"}, content={"creator": "@user_id:example.com"},
depth=1, depth=1,
), ),
"A": DictObj( "A": DictObj(
@ -432,13 +432,19 @@ class StateTestCase(unittest.TestCase):
def test_resolve_message_conflict(self): def test_resolve_message_conflict(self):
event = create_event(type="test_message", name="event") event = create_event(type="test_message", name="event")
creation = create_event(
type=EventTypes.Create, state_key=""
)
old_state_1 = [ old_state_1 = [
creation,
create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
create_event(type="test1", state_key="2"), create_event(type="test1", state_key="2"),
create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
old_state_2 = [ old_state_2 = [
creation,
create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
create_event(type="test3", state_key="2"), create_event(type="test3", state_key="2"),
create_event(type="test4", state_key=""), create_event(type="test4", state_key=""),
@ -446,7 +452,7 @@ class StateTestCase(unittest.TestCase):
context = yield self._get_context(event, old_state_1, old_state_2) context = yield self._get_context(event, old_state_1, old_state_2)
self.assertEqual(len(context.current_state), 5) self.assertEqual(len(context.current_state), 6)
self.assertIsNone(context.state_group) self.assertIsNone(context.state_group)
@ -454,13 +460,19 @@ class StateTestCase(unittest.TestCase):
def test_resolve_state_conflict(self): def test_resolve_state_conflict(self):
event = create_event(type="test4", state_key="", name="event") event = create_event(type="test4", state_key="", name="event")
creation = create_event(
type=EventTypes.Create, state_key=""
)
old_state_1 = [ old_state_1 = [
creation,
create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
create_event(type="test1", state_key="2"), create_event(type="test1", state_key="2"),
create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
old_state_2 = [ old_state_2 = [
creation,
create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
create_event(type="test3", state_key="2"), create_event(type="test3", state_key="2"),
create_event(type="test4", state_key=""), create_event(type="test4", state_key=""),
@ -468,7 +480,7 @@ class StateTestCase(unittest.TestCase):
context = yield self._get_context(event, old_state_1, old_state_2) context = yield self._get_context(event, old_state_1, old_state_2)
self.assertEqual(len(context.current_state), 5) self.assertEqual(len(context.current_state), 6)
self.assertIsNone(context.state_group) self.assertIsNone(context.state_group)
@ -484,36 +496,45 @@ class StateTestCase(unittest.TestCase):
} }
) )
creation = create_event(
type=EventTypes.Create, state_key="",
content={"creator": "@foo:bar"}
)
old_state_1 = [ old_state_1 = [
creation,
member_event, member_event,
create_event(type="test1", state_key="1", depth=1), create_event(type="test1", state_key="1", depth=1),
] ]
old_state_2 = [ old_state_2 = [
creation,
member_event, member_event,
create_event(type="test1", state_key="1", depth=2), create_event(type="test1", state_key="1", depth=2),
] ]
context = yield self._get_context(event, old_state_1, old_state_2) context = yield self._get_context(event, old_state_1, old_state_2)
self.assertEqual(old_state_2[1], context.current_state[("test1", "1")]) self.assertEqual(old_state_2[2], context.current_state[("test1", "1")])
# Reverse the depth to make sure we are actually using the depths # Reverse the depth to make sure we are actually using the depths
# during state resolution. # during state resolution.
old_state_1 = [ old_state_1 = [
creation,
member_event, member_event,
create_event(type="test1", state_key="1", depth=2), create_event(type="test1", state_key="1", depth=2),
] ]
old_state_2 = [ old_state_2 = [
creation,
member_event, member_event,
create_event(type="test1", state_key="1", depth=1), create_event(type="test1", state_key="1", depth=1),
] ]
context = yield self._get_context(event, old_state_1, old_state_2) context = yield self._get_context(event, old_state_1, old_state_2)
self.assertEqual(old_state_1[1], context.current_state[("test1", "1")]) self.assertEqual(old_state_1[2], context.current_state[("test1", "1")])
def _get_context(self, event, old_state_1, old_state_2): def _get_context(self, event, old_state_1, old_state_2):
group_name_1 = "group_name_1" group_name_1 = "group_name_1"