Implement configurable stats reporting

SYN-287

This requires that HS owners either opt in or out of stats reporting.

When --generate-config is passed, --report-stats must be specified
If an already-generated config is used, and doesn't have the
report_stats key, it is requested to be set.
This commit is contained in:
Daniel Wagner-Hall 2015-09-22 12:57:40 +01:00
parent ee2d722f0f
commit 7213588083
24 changed files with 425 additions and 78 deletions

View File

@ -42,7 +42,7 @@ from synapse.storage import (
from synapse.server import HomeServer from synapse.server import HomeServer
from twisted.internet import reactor from twisted.internet import reactor, task, defer
from twisted.application import service from twisted.application import service
from twisted.enterprise import adbapi from twisted.enterprise import adbapi
from twisted.web.resource import Resource, EncodingResourceWrapper from twisted.web.resource import Resource, EncodingResourceWrapper
@ -677,6 +677,39 @@ def run(hs):
ThreadPool._worker = profile(ThreadPool._worker) ThreadPool._worker = profile(ThreadPool._worker)
reactor.run = profile(reactor.run) reactor.run = profile(reactor.run)
start_time = hs.get_clock().time()
@defer.inlineCallbacks
def phone_stats_home():
now = int(hs.get_clock().time())
uptime = int(now - start_time)
if uptime < 0:
uptime = 0
stats = {}
stats["homeserver"] = hs.config.server_name
stats["timestamp"] = now
stats["uptime_seconds"] = uptime
stats["total_users"] = yield hs.get_datastore().count_all_users()
all_rooms = yield hs.get_datastore().get_rooms(False)
stats["total_room_count"] = len(all_rooms)
stats["daily_active_users"] = yield hs.get_datastore().count_daily_users()
daily_messages = yield hs.get_datastore().count_daily_messages()
if daily_messages is not None:
stats["daily_messages"] = daily_messages
logger.info("Reporting stats to matrix.org: %s" % (stats,))
hs.get_simple_http_client().put_json(
"https://matrix.org/report-usage-stats/push",
stats
)
if hs.config.report_stats:
phone_home_task = task.LoopingCall(phone_stats_home)
phone_home_task.start(60 * 60 * 24, now=False)
def in_thread(): def in_thread():
with LoggingContext("run"): with LoggingContext("run"):
change_resource_limit(hs.config.soft_file_limit) change_resource_limit(hs.config.soft_file_limit)

View File

@ -25,6 +25,7 @@ SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"]
CONFIGFILE = "homeserver.yaml" CONFIGFILE = "homeserver.yaml"
GREEN = "\x1b[1;32m" GREEN = "\x1b[1;32m"
RED = "\x1b[1;31m"
NORMAL = "\x1b[m" NORMAL = "\x1b[m"
if not os.path.exists(CONFIGFILE): if not os.path.exists(CONFIGFILE):
@ -45,8 +46,15 @@ def start():
print "Starting ...", print "Starting ...",
args = SYNAPSE args = SYNAPSE
args.extend(["--daemonize", "-c", CONFIGFILE]) args.extend(["--daemonize", "-c", CONFIGFILE])
subprocess.check_call(args) try:
print GREEN + "started" + NORMAL subprocess.check_call(args)
print GREEN + "started" + NORMAL
except subprocess.CalledProcessError as e:
print (
RED +
"error starting (exit code: %d); see above for logs" % e.returncode +
NORMAL
)
def stop(): def stop():

View File

@ -26,6 +26,16 @@ class ConfigError(Exception):
class Config(object): class Config(object):
stats_reporting_begging_spiel = (
"We would really appreciate it if you could help our project out by "
"reporting anonymized usage statistics from your homeserver. Only very "
"basic aggregate data (e.g. number of users) will be reported, but it "
"helps us to track the growth of the Matrix community, and helps us to "
"make Matrix a success, as well as to convince other networks that they "
"should peer with us.\n"
"Thank you."
)
@staticmethod @staticmethod
def parse_size(value): def parse_size(value):
if isinstance(value, int) or isinstance(value, long): if isinstance(value, int) or isinstance(value, long):
@ -111,11 +121,14 @@ class Config(object):
results.append(getattr(cls, name)(self, *args, **kargs)) results.append(getattr(cls, name)(self, *args, **kargs))
return results return results
def generate_config(self, config_dir_path, server_name): def generate_config(self, config_dir_path, server_name, report_stats=None):
default_config = "# vim:ft=yaml\n" default_config = "# vim:ft=yaml\n"
default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all( default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all(
"default_config", config_dir_path, server_name "default_config",
config_dir_path=config_dir_path,
server_name=server_name,
report_stats=report_stats,
)) ))
config = yaml.load(default_config) config = yaml.load(default_config)
@ -139,6 +152,12 @@ class Config(object):
action="store_true", action="store_true",
help="Generate a config file for the server name" help="Generate a config file for the server name"
) )
config_parser.add_argument(
"--report-stats",
action="store",
help="Stuff",
choices=["yes", "no"]
)
config_parser.add_argument( config_parser.add_argument(
"--generate-keys", "--generate-keys",
action="store_true", action="store_true",
@ -189,6 +208,11 @@ class Config(object):
config_files.append(config_path) config_files.append(config_path)
if config_args.generate_config: if config_args.generate_config:
if config_args.report_stats is None:
config_parser.error(
"Please specify either --report-stats=yes or --report-stats=no\n\n" +
cls.stats_reporting_begging_spiel
)
if not config_files: if not config_files:
config_parser.error( config_parser.error(
"Must supply a config file.\nA config file can be automatically" "Must supply a config file.\nA config file can be automatically"
@ -211,7 +235,9 @@ class Config(object):
os.makedirs(config_dir_path) os.makedirs(config_dir_path)
with open(config_path, "wb") as config_file: with open(config_path, "wb") as config_file:
config_bytes, config = obj.generate_config( config_bytes, config = obj.generate_config(
config_dir_path, server_name config_dir_path=config_dir_path,
server_name=server_name,
report_stats=(config_args.report_stats == "yes"),
) )
obj.invoke_all("generate_files", config) obj.invoke_all("generate_files", config)
config_file.write(config_bytes) config_file.write(config_bytes)
@ -261,9 +287,20 @@ class Config(object):
specified_config.update(yaml_config) specified_config.update(yaml_config)
server_name = specified_config["server_name"] server_name = specified_config["server_name"]
_, config = obj.generate_config(config_dir_path, server_name) _, config = obj.generate_config(
config_dir_path=config_dir_path,
server_name=server_name
)
config.pop("log_config") config.pop("log_config")
config.update(specified_config) config.update(specified_config)
if "report_stats" not in config:
sys.stderr.write(
"Please opt in or out of reporting anonymized homeserver usage "
"statistics, by setting the report_stats key in your config file "
" ( " + config_path + " ) " +
"to either True or False.\n\n" +
Config.stats_reporting_begging_spiel + "\n")
sys.exit(1)
if generate_keys: if generate_keys:
obj.invoke_all("generate_files", config) obj.invoke_all("generate_files", config)

View File

@ -20,7 +20,7 @@ class AppServiceConfig(Config):
def read_config(self, config): def read_config(self, config):
self.app_service_config_files = config.get("app_service_config_files", []) self.app_service_config_files = config.get("app_service_config_files", [])
def default_config(cls, config_dir_path, server_name): def default_config(cls, **kwargs):
return """\ return """\
# A list of application service config file to use # A list of application service config file to use
app_service_config_files: [] app_service_config_files: []

View File

@ -24,7 +24,7 @@ class CaptchaConfig(Config):
self.captcha_bypass_secret = config.get("captcha_bypass_secret") self.captcha_bypass_secret = config.get("captcha_bypass_secret")
self.recaptcha_siteverify_api = config["recaptcha_siteverify_api"] self.recaptcha_siteverify_api = config["recaptcha_siteverify_api"]
def default_config(self, config_dir_path, server_name): def default_config(self, **kwargs):
return """\ return """\
## Captcha ## ## Captcha ##

View File

@ -45,7 +45,7 @@ class DatabaseConfig(Config):
self.set_databasepath(config.get("database_path")) self.set_databasepath(config.get("database_path"))
def default_config(self, config, config_dir_path): def default_config(self, **kwargs):
database_path = self.abspath("homeserver.db") database_path = self.abspath("homeserver.db")
return """\ return """\
# Database configuration # Database configuration

View File

@ -40,7 +40,7 @@ class KeyConfig(Config):
config["perspectives"] config["perspectives"]
) )
def default_config(self, config_dir_path, server_name): def default_config(self, config_dir_path, server_name, **kwargs):
base_key_name = os.path.join(config_dir_path, server_name) base_key_name = os.path.join(config_dir_path, server_name)
return """\ return """\
## Signing Keys ## ## Signing Keys ##

View File

@ -70,7 +70,7 @@ class LoggingConfig(Config):
self.log_config = self.abspath(config.get("log_config")) self.log_config = self.abspath(config.get("log_config"))
self.log_file = self.abspath(config.get("log_file")) self.log_file = self.abspath(config.get("log_file"))
def default_config(self, config_dir_path, server_name): def default_config(self, config_dir_path, server_name, **kwargs):
log_file = self.abspath("homeserver.log") log_file = self.abspath("homeserver.log")
log_config = self.abspath( log_config = self.abspath(
os.path.join(config_dir_path, server_name + ".log.config") os.path.join(config_dir_path, server_name + ".log.config")

View File

@ -19,13 +19,15 @@ from ._base import Config
class MetricsConfig(Config): class MetricsConfig(Config):
def read_config(self, config): def read_config(self, config):
self.enable_metrics = config["enable_metrics"] self.enable_metrics = config["enable_metrics"]
self.report_stats = config.get("report_stats", None)
self.metrics_port = config.get("metrics_port") self.metrics_port = config.get("metrics_port")
self.metrics_bind_host = config.get("metrics_bind_host", "127.0.0.1") self.metrics_bind_host = config.get("metrics_bind_host", "127.0.0.1")
def default_config(self, config_dir_path, server_name): def default_config(self, report_stats=None, **kwargs):
return """\ suffix = "" if report_stats is None else "report_stats: %(report_stats)s\n"
return ("""\
## Metrics ### ## Metrics ###
# Enable collection and rendering of performance metrics # Enable collection and rendering of performance metrics
enable_metrics: False enable_metrics: False
""" """ + suffix) % locals()

View File

@ -27,7 +27,7 @@ class RatelimitConfig(Config):
self.federation_rc_reject_limit = config["federation_rc_reject_limit"] self.federation_rc_reject_limit = config["federation_rc_reject_limit"]
self.federation_rc_concurrent = config["federation_rc_concurrent"] self.federation_rc_concurrent = config["federation_rc_concurrent"]
def default_config(self, config_dir_path, server_name): def default_config(self, **kwargs):
return """\ return """\
## Ratelimiting ## ## Ratelimiting ##

View File

@ -34,7 +34,7 @@ class RegistrationConfig(Config):
self.registration_shared_secret = config.get("registration_shared_secret") self.registration_shared_secret = config.get("registration_shared_secret")
self.macaroon_secret_key = config.get("macaroon_secret_key") self.macaroon_secret_key = config.get("macaroon_secret_key")
def default_config(self, config_dir, server_name): def default_config(self, **kwargs):
registration_shared_secret = random_string_with_symbols(50) registration_shared_secret = random_string_with_symbols(50)
macaroon_secret_key = random_string_with_symbols(50) macaroon_secret_key = random_string_with_symbols(50)
return """\ return """\

View File

@ -60,7 +60,7 @@ class ContentRepositoryConfig(Config):
config["thumbnail_sizes"] config["thumbnail_sizes"]
) )
def default_config(self, config_dir_path, server_name): def default_config(self, **kwargs):
media_store = self.default_path("media_store") media_store = self.default_path("media_store")
uploads_path = self.default_path("uploads") uploads_path = self.default_path("uploads")
return """ return """

View File

@ -41,7 +41,7 @@ class SAML2Config(Config):
self.saml2_config_path = None self.saml2_config_path = None
self.saml2_idp_redirect_url = None self.saml2_idp_redirect_url = None
def default_config(self, config_dir_path, server_name): def default_config(self, config_dir_path, server_name, **kwargs):
return """ return """
# Enable SAML2 for registration and login. Uses pysaml2 # Enable SAML2 for registration and login. Uses pysaml2
# config_path: Path to the sp_conf.py configuration file # config_path: Path to the sp_conf.py configuration file

View File

@ -117,7 +117,7 @@ class ServerConfig(Config):
self.content_addr = content_addr self.content_addr = content_addr
def default_config(self, config_dir_path, server_name): def default_config(self, server_name, **kwargs):
if ":" in server_name: if ":" in server_name:
bind_port = int(server_name.split(":")[1]) bind_port = int(server_name.split(":")[1])
unsecure_port = bind_port - 400 unsecure_port = bind_port - 400

View File

@ -50,7 +50,7 @@ class TlsConfig(Config):
"use_insecure_ssl_client_just_for_testing_do_not_use" "use_insecure_ssl_client_just_for_testing_do_not_use"
) )
def default_config(self, config_dir_path, server_name): def default_config(self, config_dir_path, server_name, **kwargs):
base_key_name = os.path.join(config_dir_path, server_name) base_key_name = os.path.join(config_dir_path, server_name)
tls_certificate_path = base_key_name + ".tls.crt" tls_certificate_path = base_key_name + ".tls.crt"

View File

@ -22,7 +22,7 @@ class VoipConfig(Config):
self.turn_shared_secret = config["turn_shared_secret"] self.turn_shared_secret = config["turn_shared_secret"]
self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"]) self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"])
def default_config(self, config_dir_path, server_name): def default_config(self, **kwargs):
return """\ return """\
## Turn ## ## Turn ##

View File

@ -54,7 +54,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 23 SCHEMA_VERSION = 24
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))
@ -126,6 +126,24 @@ class DataStore(RoomMemberStore, RoomStore,
lock=False, lock=False,
) )
@defer.inlineCallbacks
def count_daily_users(self):
def _count_users(txn):
txn.execute(
"SELECT COUNT(DISTINCT user_id) AS users"
" FROM user_ips"
" WHERE last_seen > ?",
# This is close enough to a day for our purposes.
(int(self._clock.time_msec()) - (1000 * 60 * 60 * 24),)
)
rows = self.cursor_to_dict(txn)
if rows:
return rows[0]["users"]
return 0
ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret)
def get_user_ip_and_agents(self, user): def get_user_ip_and_agents(self, user):
return self._simple_select_list( return self._simple_select_list(
table="user_ips", table="user_ips",

View File

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from _base import SQLBaseStore, _RollbackButIsFineException from _base import SQLBaseStore, _RollbackButIsFineException
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
@ -28,6 +27,7 @@ from canonicaljson import encode_canonical_json
from contextlib import contextmanager from contextlib import contextmanager
import logging import logging
import math
import ujson as json import ujson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -905,3 +905,59 @@ class EventsStore(SQLBaseStore):
txn.execute(sql, (event.event_id,)) txn.execute(sql, (event.event_id,))
result = txn.fetchone() result = txn.fetchone()
return result[0] if result else None return result[0] if result else None
@defer.inlineCallbacks
def count_daily_messages(self):
def _count_messages(txn):
now = self.hs.get_clock().time()
txn.execute(
"SELECT reported_stream_token, reported_time FROM stats_reporting"
)
last_reported = self.cursor_to_dict(txn)
txn.execute(
"SELECT stream_ordering"
" FROM events"
" ORDER BY stream_ordering DESC"
" LIMIT 1"
)
now_reporting = self.cursor_to_dict(txn)
if not now_reporting:
return None
now_reporting = now_reporting[0]["stream_ordering"]
txn.execute("DELETE FROM stats_reporting")
txn.execute(
"INSERT INTO stats_reporting"
" (reported_stream_token, reported_time)"
" VALUES (?, ?)",
(now_reporting, now,)
)
if not last_reported:
return None
# Close enough to correct for our purposes.
yesterday = (now - 24 * 60 * 60)
if math.fabs(yesterday - last_reported[0]["reported_time"]) > 60 * 60:
return None
txn.execute(
"SELECT COUNT(*) as messages"
" FROM events NATURAL JOIN event_json"
" WHERE json like '%m.room.message%'"
" AND stream_ordering > ?"
" AND stream_ordering <= ?",
(
last_reported[0]["reported_stream_token"],
now_reporting,
)
)
rows = self.cursor_to_dict(txn)
if not rows:
return None
return rows[0]["messages"]
ret = yield self.runInteraction("count_messages", _count_messages)
defer.returnValue(ret)

View File

@ -289,3 +289,15 @@ class RegistrationStore(SQLBaseStore):
if ret: if ret:
defer.returnValue(ret['user_id']) defer.returnValue(ret['user_id'])
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks
def count_all_users(self):
def _count_users(txn):
txn.execute("SELECT COUNT(*) AS users FROM users")
rows = self.cursor_to_dict(txn)
if rows:
return rows[0]["users"]
return 0
ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret)

View File

@ -0,0 +1,22 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Should only ever contain one row
CREATE TABLE IF NOT EXISTS stats_reporting(
-- The stream ordering token which was most recently reported as stats
reported_stream_token INTEGER,
-- The time (seconds since epoch) stats were most recently reported
reported_time BIGINT
);

View File

@ -0,0 +1,81 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.types import UserID, RoomID
from tests.utils import setup_test_homeserver
from mock import Mock
class EventInjector:
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
self.message_handler = hs.get_handlers().message_handler
self.event_builder_factory = hs.get_event_builder_factory()
@defer.inlineCallbacks
def create_room(self, room):
builder = self.event_builder_factory.new({
"type": EventTypes.Create,
"room_id": room.to_string(),
"content": {},
})
event, context = yield self.message_handler._create_new_client_event(
builder
)
yield self.store.persist_event(event, context)
@defer.inlineCallbacks
def inject_room_member(self, room, user, membership):
builder = self.event_builder_factory.new({
"type": EventTypes.Member,
"sender": user.to_string(),
"state_key": user.to_string(),
"room_id": room.to_string(),
"content": {"membership": membership},
})
event, context = yield self.message_handler._create_new_client_event(
builder
)
yield self.store.persist_event(event, context)
defer.returnValue(event)
@defer.inlineCallbacks
def inject_message(self, room, user, body):
builder = self.event_builder_factory.new({
"type": EventTypes.Message,
"sender": user.to_string(),
"state_key": user.to_string(),
"room_id": room.to_string(),
"content": {"body": body, "msgtype": u"message"},
})
event, context = yield self.message_handler._create_new_client_event(
builder
)
yield self.store.persist_event(event, context)

View File

@ -0,0 +1,116 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import uuid
from mock.mock import Mock
from synapse.types import RoomID, UserID
from tests import unittest
from twisted.internet import defer
from tests.storage.event_injector import EventInjector
from tests.utils import setup_test_homeserver
class EventsStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hs = yield setup_test_homeserver(
resource_for_federation=Mock(),
http_client=None,
)
self.store = self.hs.get_datastore()
self.db_pool = self.hs.get_db_pool()
self.message_handler = self.hs.get_handlers().message_handler
self.event_injector = EventInjector(self.hs)
@defer.inlineCallbacks
def test_count_daily_messages(self):
self.db_pool.runQuery("DELETE FROM stats_reporting")
self.hs.clock.now = 100
# Never reported before, and nothing which could be reported
count = yield self.store.count_daily_messages()
self.assertIsNone(count)
count = yield self.db_pool.runQuery("SELECT COUNT(*) FROM stats_reporting")
self.assertEqual([(0,)], count)
# Create something to report
room = RoomID.from_string("!abc123:test")
user = UserID.from_string("@raccoonlover:test")
yield self.event_injector.create_room(room)
self.base_event = yield self._get_last_stream_token()
yield self.event_injector.inject_message(room, user, "Raccoons are really cute")
# Never reported before, something could be reported, but isn't because
# it isn't old enough.
count = yield self.store.count_daily_messages()
self.assertIsNone(count)
self._assert_stats_reporting(1, self.hs.clock.now)
# Already reported yesterday, two new events from today.
yield self.event_injector.inject_message(room, user, "Yeah they are!")
yield self.event_injector.inject_message(room, user, "Incredibly!")
self.hs.clock.now += 60 * 60 * 24
count = yield self.store.count_daily_messages()
self.assertEqual(2, count) # 2 since yesterday
self._assert_stats_reporting(3, self.hs.clock.now) # 3 ever
# Last reported too recently.
yield self.event_injector.inject_message(room, user, "Who could disagree?")
self.hs.clock.now += 60 * 60 * 22
count = yield self.store.count_daily_messages()
self.assertIsNone(count)
self._assert_stats_reporting(4, self.hs.clock.now)
# Last reported too long ago
yield self.event_injector.inject_message(room, user, "No one.")
self.hs.clock.now += 60 * 60 * 26
count = yield self.store.count_daily_messages()
self.assertIsNone(count)
self._assert_stats_reporting(5, self.hs.clock.now)
# And now let's actually report something
yield self.event_injector.inject_message(room, user, "Indeed.")
yield self.event_injector.inject_message(room, user, "Indeed.")
yield self.event_injector.inject_message(room, user, "Indeed.")
# A little over 24 hours is fine :)
self.hs.clock.now += (60 * 60 * 24) + 50
count = yield self.store.count_daily_messages()
self.assertEqual(3, count)
self._assert_stats_reporting(8, self.hs.clock.now)
@defer.inlineCallbacks
def _get_last_stream_token(self):
rows = yield self.db_pool.runQuery(
"SELECT stream_ordering"
" FROM events"
" ORDER BY stream_ordering DESC"
" LIMIT 1"
)
if not rows:
defer.returnValue(0)
else:
defer.returnValue(rows[0][0])
@defer.inlineCallbacks
def _assert_stats_reporting(self, messages, time):
rows = yield self.db_pool.runQuery(
"SELECT reported_stream_token, reported_time FROM stats_reporting"
)
self.assertEqual([(self.base_event + messages, time,)], rows)

View File

@ -85,7 +85,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
# Room events need the full datastore, for persist_event() and # Room events need the full datastore, for persist_event() and
# get_room_state() # get_room_state()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.event_factory = hs.get_event_factory(); self.event_factory = hs.get_event_factory()
self.room = RoomID.from_string("!abcde:test") self.room = RoomID.from_string("!abcde:test")

View File

@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.types import UserID, RoomID from synapse.types import UserID, RoomID
from tests.storage.event_injector import EventInjector
from tests.utils import setup_test_homeserver from tests.utils import setup_test_homeserver
@ -36,6 +37,7 @@ class StreamStoreTestCase(unittest.TestCase):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
self.event_injector = EventInjector(hs)
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.message_handler = self.handlers.message_handler self.message_handler = self.handlers.message_handler
@ -45,60 +47,20 @@ class StreamStoreTestCase(unittest.TestCase):
self.room1 = RoomID.from_string("!abc123:test") self.room1 = RoomID.from_string("!abc123:test")
self.room2 = RoomID.from_string("!xyx987:test") self.room2 = RoomID.from_string("!xyx987:test")
self.depth = 1
@defer.inlineCallbacks
def inject_room_member(self, room, user, membership):
self.depth += 1
builder = self.event_builder_factory.new({
"type": EventTypes.Member,
"sender": user.to_string(),
"state_key": user.to_string(),
"room_id": room.to_string(),
"content": {"membership": membership},
})
event, context = yield self.message_handler._create_new_client_event(
builder
)
yield self.store.persist_event(event, context)
defer.returnValue(event)
@defer.inlineCallbacks
def inject_message(self, room, user, body):
self.depth += 1
builder = self.event_builder_factory.new({
"type": EventTypes.Message,
"sender": user.to_string(),
"state_key": user.to_string(),
"room_id": room.to_string(),
"content": {"body": body, "msgtype": u"message"},
})
event, context = yield self.message_handler._create_new_client_event(
builder
)
yield self.store.persist_event(event, context)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_event_stream_get_other(self): def test_event_stream_get_other(self):
# Both bob and alice joins the room # Both bob and alice joins the room
yield self.inject_room_member( yield self.event_injector.inject_room_member(
self.room1, self.u_alice, Membership.JOIN self.room1, self.u_alice, Membership.JOIN
) )
yield self.inject_room_member( yield self.event_injector.inject_room_member(
self.room1, self.u_bob, Membership.JOIN self.room1, self.u_bob, Membership.JOIN
) )
# Initial stream key: # Initial stream key:
start = yield self.store.get_room_events_max_id() start = yield self.store.get_room_events_max_id()
yield self.inject_message(self.room1, self.u_alice, u"test") yield self.event_injector.inject_message(self.room1, self.u_alice, u"test")
end = yield self.store.get_room_events_max_id() end = yield self.store.get_room_events_max_id()
@ -125,17 +87,17 @@ class StreamStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_event_stream_get_own(self): def test_event_stream_get_own(self):
# Both bob and alice joins the room # Both bob and alice joins the room
yield self.inject_room_member( yield self.event_injector.inject_room_member(
self.room1, self.u_alice, Membership.JOIN self.room1, self.u_alice, Membership.JOIN
) )
yield self.inject_room_member( yield self.event_injector.inject_room_member(
self.room1, self.u_bob, Membership.JOIN self.room1, self.u_bob, Membership.JOIN
) )
# Initial stream key: # Initial stream key:
start = yield self.store.get_room_events_max_id() start = yield self.store.get_room_events_max_id()
yield self.inject_message(self.room1, self.u_alice, u"test") yield self.event_injector.inject_message(self.room1, self.u_alice, u"test")
end = yield self.store.get_room_events_max_id() end = yield self.store.get_room_events_max_id()
@ -162,22 +124,22 @@ class StreamStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_event_stream_join_leave(self): def test_event_stream_join_leave(self):
# Both bob and alice joins the room # Both bob and alice joins the room
yield self.inject_room_member( yield self.event_injector.inject_room_member(
self.room1, self.u_alice, Membership.JOIN self.room1, self.u_alice, Membership.JOIN
) )
yield self.inject_room_member( yield self.event_injector.inject_room_member(
self.room1, self.u_bob, Membership.JOIN self.room1, self.u_bob, Membership.JOIN
) )
# Then bob leaves again. # Then bob leaves again.
yield self.inject_room_member( yield self.event_injector.inject_room_member(
self.room1, self.u_bob, Membership.LEAVE self.room1, self.u_bob, Membership.LEAVE
) )
# Initial stream key: # Initial stream key:
start = yield self.store.get_room_events_max_id() start = yield self.store.get_room_events_max_id()
yield self.inject_message(self.room1, self.u_alice, u"test") yield self.event_injector.inject_message(self.room1, self.u_alice, u"test")
end = yield self.store.get_room_events_max_id() end = yield self.store.get_room_events_max_id()
@ -193,17 +155,17 @@ class StreamStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_event_stream_prev_content(self): def test_event_stream_prev_content(self):
yield self.inject_room_member( yield self.event_injector.inject_room_member(
self.room1, self.u_bob, Membership.JOIN self.room1, self.u_bob, Membership.JOIN
) )
event1 = yield self.inject_room_member( event1 = yield self.event_injector.inject_room_member(
self.room1, self.u_alice, Membership.JOIN self.room1, self.u_alice, Membership.JOIN
) )
start = yield self.store.get_room_events_max_id() start = yield self.store.get_room_events_max_id()
event2 = yield self.inject_room_member( event2 = yield self.event_injector.inject_room_member(
self.room1, self.u_alice, Membership.JOIN, self.room1, self.u_alice, Membership.JOIN,
) )