Merge branch 'develop' of github.com:matrix-org/synapse into erikj/state_ids_api

This commit is contained in:
Erik Johnston 2016-08-04 14:04:35 +01:00
commit b4e2290d89
18 changed files with 581 additions and 314 deletions

View File

@ -4,62 +4,19 @@ set -eux
: ${WORKSPACE:="$(pwd)"} : ${WORKSPACE:="$(pwd)"}
export WORKSPACE
export PYTHONDONTWRITEBYTECODE=yep export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1 export SYNAPSE_CACHE_FACTOR=1
# Output test results as junit xml
export TRIAL_FLAGS="--reporter=subunit"
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
# Output flake8 violations to violations.flake8.log
# Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
rm .coverage* || echo "No coverage files to remove"
./jenkins/prepare_synapse.sh ./jenkins/prepare_synapse.sh
./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git ./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git
./jenkins/clone.sh dendron https://github.com/matrix-org/dendron.git ./jenkins/clone.sh dendron https://github.com/matrix-org/dendron.git
./dendron/jenkins/build_dendron.sh
./sytest/jenkins/prep_sytest_for_postgres.sh
: ${GOPATH:=${WORKSPACE}/.gopath} ./sytest/jenkins/install_and_run.sh \
if [[ "${GOPATH}" != *:* ]]; then --synapse-directory $WORKSPACE \
mkdir -p "${GOPATH}" --dendron $WORKSPACE/dendron/bin/dendron \
export PATH="${GOPATH}/bin:${PATH}" --pusher \
fi --synchrotron \
export GOPATH --federation-reader \
cd dendron
go get github.com/constabulary/gb/...
gb generate
gb build
cd ../sytest
: ${PORT_BASE:=20000}
: ${PORT_COUNT=100}
export PORT_BASE
export PORT_COUNT
./jenkins/prep_sytest_for_postgres.sh
mkdir -p var
echo >&2 "Running sytest with PostgreSQL";
TOX_BIN=$WORKSPACE/.tox/py27/bin
./jenkins/install_and_run.sh --python $TOX_BIN/python \
--synapse-directory $WORKSPACE \
--dendron $WORKSPACE/dendron/bin/dendron \
--pusher \
--synchrotron \
--federation-reader \
--port-range ${PORT_BASE}:$((PORT_BASE+PORT_COUNT-1))
cd ..

View File

@ -4,50 +4,14 @@ set -eux
: ${WORKSPACE:="$(pwd)"} : ${WORKSPACE:="$(pwd)"}
export WORKSPACE
export PYTHONDONTWRITEBYTECODE=yep export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1 export SYNAPSE_CACHE_FACTOR=1
# Output test results as junit xml
export TRIAL_FLAGS="--reporter=subunit"
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
# Output flake8 violations to violations.flake8.log
# Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
rm .coverage* || echo "No coverage files to remove"
./jenkins/prepare_synapse.sh ./jenkins/prepare_synapse.sh
./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git ./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git
: ${PORT_BASE:=20000} ./sytest/jenkins/prep_sytest_for_postgres.sh
: ${PORT_COUNT=100}
export PORT_BASE
export PORT_COUNT
cd sytest ./sytest/jenkins/install_and_run.sh \
--synapse-directory $WORKSPACE \
./jenkins/prep_sytest_for_postgres.sh
echo >&2 "Running sytest with PostgreSQL";
TOX_BIN=$WORKSPACE/.tox/py27/bin
./jenkins/install_and_run.sh --coverage \
--python $TOX_BIN/python \
--synapse-directory $WORKSPACE \
--port-range ${PORT_BASE}:$((PORT_BASE+PORT_COUNT-1))
cd ..
cp sytest/.coverage.* .
# Combine the coverage reports
echo "Combining:" .coverage.*
$TOX_BIN/python -m coverage combine
# Output coverage to coverage.xml
$TOX_BIN/coverage xml -o coverage.xml

View File

@ -8,43 +8,8 @@ export WORKSPACE
export PYTHONDONTWRITEBYTECODE=yep export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1 export SYNAPSE_CACHE_FACTOR=1
# Output test results as junit xml
export TRIAL_FLAGS="--reporter=subunit"
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
# Output flake8 violations to violations.flake8.log
# Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
rm .coverage* || echo "No coverage files to remove"
./jenkins/prepare_synapse.sh ./jenkins/prepare_synapse.sh
./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git ./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git
: ${PORT_BASE:=20000} ./sytest/jenkins/install_and_run.sh \
: ${PORT_COUNT=100} --synapse-directory $WORKSPACE \
export PORT_BASE
export PORT_COUNT
cd sytest
TOX_BIN=$WORKSPACE/.tox/py27/bin
./jenkins/install_and_run.sh --coverage \
--python $TOX_BIN/python \
--synapse-directory $WORKSPACE \
--port-range ${PORT_BASE}:$((PORT_BASE+PORT_COUNT-1)) \
cd ..
cp sytest/.coverage.* .
# Combine the coverage reports
echo "Combining:" .coverage.*
$TOX_BIN/python -m coverage combine
# Output coverage to coverage.xml
$TOX_BIN/coverage xml -o coverage.xml

View File

@ -1,24 +1,44 @@
#! /bin/bash #! /bin/bash
# This clones a project from github into a named subdirectory
# If the project has a branch with the same name as this branch
# then it will checkout that branch after cloning.
# Otherwise it will checkout "origin/develop."
# The first argument is the name of the directory to checkout
# the branch into.
# The second argument is the URL of the remote repository to checkout.
# Usually something like https://github.com/matrix-org/sytest.git
set -eux
NAME=$1 NAME=$1
PROJECT=$2 PROJECT=$2
BASE=".$NAME-base" BASE=".$NAME-base"
# update our clone # Update our mirror.
if [ ! -d .$NAME-base ]; then if [ ! -d ".$NAME-base" ]; then
git clone $PROJECT $BASE --mirror # Create a local mirror of the source repository.
# This saves us from having to download the entire repository
# when this script is next run.
git clone "$PROJECT" "$BASE" --mirror
else else
(cd $BASE; git fetch -p) # Fetch any updates from the source repository.
(cd "$BASE"; git fetch -p)
fi fi
rm -rf $NAME # Remove the existing repository so that we have a clean copy
git clone $BASE $NAME --shared rm -rf "$NAME"
# Cloning with --shared means that we will share portions of the
# .git directory with our local mirror.
git clone "$BASE" "$NAME" --shared
# Jenkins may have supplied us with the name of the branch in the
# environment. Otherwise we will have to guess based on the current
# commit.
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"} : ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
cd $NAME cd "$NAME"
# check out the relevant branch # check out the relevant branch
git checkout "${GIT_BRANCH}" || ( git checkout "${GIT_BRANCH}" || (
echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop"
git checkout "origin/develop" git checkout "origin/develop"
) )
git clean -df .

View File

@ -34,7 +34,7 @@ logger = logging.getLogger("synapse_port_db")
BOOLEAN_COLUMNS = { BOOLEAN_COLUMNS = {
"events": ["processed", "outlier"], "events": ["processed", "outlier", "contains_url"],
"rooms": ["is_public"], "rooms": ["is_public"],
"event_edges": ["is_state"], "event_edges": ["is_state"],
"presence_list": ["accepted"], "presence_list": ["accepted"],
@ -92,8 +92,12 @@ class Store(object):
_simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"] _simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
_simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"] _simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"]
_simple_select_one = SQLBaseStore.__dict__["_simple_select_one"]
_simple_select_one_txn = SQLBaseStore.__dict__["_simple_select_one_txn"]
_simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"] _simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"]
_simple_select_one_onecol_txn = SQLBaseStore.__dict__["_simple_select_one_onecol_txn"] _simple_select_one_onecol_txn = SQLBaseStore.__dict__[
"_simple_select_one_onecol_txn"
]
_simple_update_one = SQLBaseStore.__dict__["_simple_update_one"] _simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
_simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"] _simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
@ -158,31 +162,40 @@ class Porter(object):
def setup_table(self, table): def setup_table(self, table):
if table in APPEND_ONLY_TABLES: if table in APPEND_ONLY_TABLES:
# It's safe to just carry on inserting. # It's safe to just carry on inserting.
next_chunk = yield self.postgres_store._simple_select_one_onecol( row = yield self.postgres_store._simple_select_one(
table="port_from_sqlite3", table="port_from_sqlite3",
keyvalues={"table_name": table}, keyvalues={"table_name": table},
retcol="rowid", retcols=("forward_rowid", "backward_rowid"),
allow_none=True, allow_none=True,
) )
total_to_port = None total_to_port = None
if next_chunk is None: if row is None:
if table == "sent_transactions": if table == "sent_transactions":
next_chunk, already_ported, total_to_port = ( forward_chunk, already_ported, total_to_port = (
yield self._setup_sent_transactions() yield self._setup_sent_transactions()
) )
backward_chunk = 0
else: else:
yield self.postgres_store._simple_insert( yield self.postgres_store._simple_insert(
table="port_from_sqlite3", table="port_from_sqlite3",
values={"table_name": table, "rowid": 1} values={
"table_name": table,
"forward_rowid": 1,
"backward_rowid": 0,
}
) )
next_chunk = 1 forward_chunk = 1
backward_chunk = 0
already_ported = 0 already_ported = 0
else:
forward_chunk = row["forward_rowid"]
backward_chunk = row["backward_rowid"]
if total_to_port is None: if total_to_port is None:
already_ported, total_to_port = yield self._get_total_count_to_port( already_ported, total_to_port = yield self._get_total_count_to_port(
table, next_chunk table, forward_chunk, backward_chunk
) )
else: else:
def delete_all(txn): def delete_all(txn):
@ -196,46 +209,85 @@ class Porter(object):
yield self.postgres_store._simple_insert( yield self.postgres_store._simple_insert(
table="port_from_sqlite3", table="port_from_sqlite3",
values={"table_name": table, "rowid": 0} values={
"table_name": table,
"forward_rowid": 1,
"backward_rowid": 0,
}
) )
next_chunk = 1 forward_chunk = 1
backward_chunk = 0
already_ported, total_to_port = yield self._get_total_count_to_port( already_ported, total_to_port = yield self._get_total_count_to_port(
table, next_chunk table, forward_chunk, backward_chunk
) )
defer.returnValue((table, already_ported, total_to_port, next_chunk)) defer.returnValue(
(table, already_ported, total_to_port, forward_chunk, backward_chunk)
)
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_table(self, table, postgres_size, table_size, next_chunk): def handle_table(self, table, postgres_size, table_size, forward_chunk,
backward_chunk):
if not table_size: if not table_size:
return return
self.progress.add_table(table, postgres_size, table_size) self.progress.add_table(table, postgres_size, table_size)
if table == "event_search": if table == "event_search":
yield self.handle_search_table(postgres_size, table_size, next_chunk) yield self.handle_search_table(
postgres_size, table_size, forward_chunk, backward_chunk
)
return return
select = ( forward_select = (
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?" "SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
% (table,) % (table,)
) )
backward_select = (
"SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?"
% (table,)
)
do_forward = [True]
do_backward = [True]
while True: while True:
def r(txn): def r(txn):
txn.execute(select, (next_chunk, self.batch_size,)) forward_rows = []
rows = txn.fetchall() backward_rows = []
headers = [column[0] for column in txn.description] if do_forward[0]:
txn.execute(forward_select, (forward_chunk, self.batch_size,))
forward_rows = txn.fetchall()
if not forward_rows:
do_forward[0] = False
return headers, rows if do_backward[0]:
txn.execute(backward_select, (backward_chunk, self.batch_size,))
backward_rows = txn.fetchall()
if not backward_rows:
do_backward[0] = False
headers, rows = yield self.sqlite_store.runInteraction("select", r) if forward_rows or backward_rows:
headers = [column[0] for column in txn.description]
else:
headers = None
if rows: return headers, forward_rows, backward_rows
next_chunk = rows[-1][0] + 1
headers, frows, brows = yield self.sqlite_store.runInteraction(
"select", r
)
if frows or brows:
if frows:
forward_chunk = max(row[0] for row in frows) + 1
if brows:
backward_chunk = min(row[0] for row in brows) - 1
rows = frows + brows
self._convert_rows(table, headers, rows) self._convert_rows(table, headers, rows)
def insert(txn): def insert(txn):
@ -247,7 +299,10 @@ class Porter(object):
txn, txn,
table="port_from_sqlite3", table="port_from_sqlite3",
keyvalues={"table_name": table}, keyvalues={"table_name": table},
updatevalues={"rowid": next_chunk}, updatevalues={
"forward_rowid": forward_chunk,
"backward_rowid": backward_chunk,
},
) )
yield self.postgres_store.execute(insert) yield self.postgres_store.execute(insert)
@ -259,7 +314,8 @@ class Porter(object):
return return
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_search_table(self, postgres_size, table_size, next_chunk): def handle_search_table(self, postgres_size, table_size, forward_chunk,
backward_chunk):
select = ( select = (
"SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering" "SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
" FROM event_search as es" " FROM event_search as es"
@ -270,7 +326,7 @@ class Porter(object):
while True: while True:
def r(txn): def r(txn):
txn.execute(select, (next_chunk, self.batch_size,)) txn.execute(select, (forward_chunk, self.batch_size,))
rows = txn.fetchall() rows = txn.fetchall()
headers = [column[0] for column in txn.description] headers = [column[0] for column in txn.description]
@ -279,7 +335,7 @@ class Porter(object):
headers, rows = yield self.sqlite_store.runInteraction("select", r) headers, rows = yield self.sqlite_store.runInteraction("select", r)
if rows: if rows:
next_chunk = rows[-1][0] + 1 forward_chunk = rows[-1][0] + 1
# We have to treat event_search differently since it has a # We have to treat event_search differently since it has a
# different structure in the two different databases. # different structure in the two different databases.
@ -312,7 +368,10 @@ class Porter(object):
txn, txn,
table="port_from_sqlite3", table="port_from_sqlite3",
keyvalues={"table_name": "event_search"}, keyvalues={"table_name": "event_search"},
updatevalues={"rowid": next_chunk}, updatevalues={
"forward_rowid": forward_chunk,
"backward_rowid": backward_chunk,
},
) )
yield self.postgres_store.execute(insert) yield self.postgres_store.execute(insert)
@ -324,7 +383,6 @@ class Porter(object):
else: else:
return return
def setup_db(self, db_config, database_engine): def setup_db(self, db_config, database_engine):
db_conn = database_engine.module.connect( db_conn = database_engine.module.connect(
**{ **{
@ -395,10 +453,32 @@ class Porter(object):
txn.execute( txn.execute(
"CREATE TABLE port_from_sqlite3 (" "CREATE TABLE port_from_sqlite3 ("
" table_name varchar(100) NOT NULL UNIQUE," " table_name varchar(100) NOT NULL UNIQUE,"
" rowid bigint NOT NULL" " forward_rowid bigint NOT NULL,"
" backward_rowid bigint NOT NULL"
")" ")"
) )
# The old port script created a table with just a "rowid" column.
# We want people to be able to rerun this script from an old port
# so that they can pick up any missing events that were not
# ported across.
def alter_table(txn):
txn.execute(
"ALTER TABLE IF EXISTS port_from_sqlite3"
" RENAME rowid TO forward_rowid"
)
txn.execute(
"ALTER TABLE IF EXISTS port_from_sqlite3"
" ADD backward_rowid bigint NOT NULL DEFAULT 0"
)
try:
yield self.postgres_store.runInteraction(
"alter_table", alter_table
)
except Exception as e:
logger.info("Failed to create port table: %s", e)
try: try:
yield self.postgres_store.runInteraction( yield self.postgres_store.runInteraction(
"create_port_table", create_port_table "create_port_table", create_port_table
@ -458,7 +538,7 @@ class Porter(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _setup_sent_transactions(self): def _setup_sent_transactions(self):
# Only save things from the last day # Only save things from the last day
yesterday = int(time.time()*1000) - 86400000 yesterday = int(time.time() * 1000) - 86400000
# And save the max transaction id from each destination # And save the max transaction id from each destination
select = ( select = (
@ -514,7 +594,11 @@ class Porter(object):
yield self.postgres_store._simple_insert( yield self.postgres_store._simple_insert(
table="port_from_sqlite3", table="port_from_sqlite3",
values={"table_name": "sent_transactions", "rowid": next_chunk} values={
"table_name": "sent_transactions",
"forward_rowid": next_chunk,
"backward_rowid": 0,
}
) )
def get_sent_table_size(txn): def get_sent_table_size(txn):
@ -535,13 +619,18 @@ class Porter(object):
defer.returnValue((next_chunk, inserted_rows, total_count)) defer.returnValue((next_chunk, inserted_rows, total_count))
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_remaining_count_to_port(self, table, next_chunk): def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk):
rows = yield self.sqlite_store.execute_sql( frows = yield self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,),
next_chunk, forward_chunk,
) )
defer.returnValue(rows[0][0]) brows = yield self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,),
backward_chunk,
)
defer.returnValue(frows[0][0] + brows[0][0])
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_already_ported_count(self, table): def _get_already_ported_count(self, table):
@ -552,10 +641,10 @@ class Porter(object):
defer.returnValue(rows[0][0]) defer.returnValue(rows[0][0])
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_total_count_to_port(self, table, next_chunk): def _get_total_count_to_port(self, table, forward_chunk, backward_chunk):
remaining, done = yield defer.gatherResults( remaining, done = yield defer.gatherResults(
[ [
self._get_remaining_count_to_port(table, next_chunk), self._get_remaining_count_to_port(table, forward_chunk, backward_chunk),
self._get_already_ported_count(table), self._get_already_ported_count(table),
], ],
consumeErrors=True, consumeErrors=True,
@ -686,7 +775,7 @@ class CursesProgress(Progress):
color = curses.color_pair(2) if perc == 100 else curses.color_pair(1) color = curses.color_pair(2) if perc == 100 else curses.color_pair(1)
self.stdscr.addstr( self.stdscr.addstr(
i+2, left_margin + max_len - len(table), i + 2, left_margin + max_len - len(table),
table, table,
curses.A_BOLD | color, curses.A_BOLD | color,
) )
@ -694,18 +783,18 @@ class CursesProgress(Progress):
size = 20 size = 20
progress = "[%s%s]" % ( progress = "[%s%s]" % (
"#" * int(perc*size/100), "#" * int(perc * size / 100),
" " * (size - int(perc*size/100)), " " * (size - int(perc * size / 100)),
) )
self.stdscr.addstr( self.stdscr.addstr(
i+2, left_margin + max_len + middle_space, i + 2, left_margin + max_len + middle_space,
"%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]), "%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]),
) )
if self.finished: if self.finished:
self.stdscr.addstr( self.stdscr.addstr(
rows-1, 0, rows - 1, 0,
"Press any key to exit...", "Press any key to exit...",
) )

View File

@ -393,27 +393,9 @@ class FederationServer(FederationBase):
(200, send_content) (200, send_content)
) )
@defer.inlineCallbacks
@log_function @log_function
def on_query_client_keys(self, origin, content): def on_query_client_keys(self, origin, content):
query = [] return self.on_query_request("client_keys", content)
for user_id, device_ids in content.get("device_keys", {}).items():
if not device_ids:
query.append((user_id, None))
else:
for device_id in device_ids:
query.append((user_id, device_id))
results = yield self.store.get_e2e_device_keys(query)
json_result = {}
for user_id, device_keys in results.items():
for device_id, json_bytes in device_keys.items():
json_result.setdefault(user_id, {})[device_id] = json.loads(
json_bytes
)
defer.returnValue({"device_keys": json_result})
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function

View File

@ -378,10 +378,8 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
class FederationClientKeysQueryServlet(BaseFederationServlet): class FederationClientKeysQueryServlet(BaseFederationServlet):
PATH = "/user/keys/query" PATH = "/user/keys/query"
@defer.inlineCallbacks
def on_POST(self, origin, content, query): def on_POST(self, origin, content, query):
response = yield self.handler.on_query_client_keys(origin, content) return self.handler.on_query_client_keys(origin, content)
defer.returnValue((200, response))
class FederationClientKeysClaimServlet(BaseFederationServlet): class FederationClientKeysClaimServlet(BaseFederationServlet):

View File

@ -29,7 +29,7 @@ class DeviceHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def check_device_registered(self, user_id, device_id, def check_device_registered(self, user_id, device_id,
initial_device_display_name): initial_device_display_name=None):
""" """
If the given device has not been registered, register it with the If the given device has not been registered, register it with the
supplied display name. supplied display name.

View File

@ -0,0 +1,139 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import json
import logging
from twisted.internet import defer
from synapse.api import errors
import synapse.types
logger = logging.getLogger(__name__)
class E2eKeysHandler(object):
def __init__(self, hs):
self.store = hs.get_datastore()
self.federation = hs.get_replication_layer()
self.is_mine_id = hs.is_mine_id
self.server_name = hs.hostname
# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
# "query handler" interface.
self.federation.register_query_handler(
"client_keys", self.on_federation_query_client_keys
)
@defer.inlineCallbacks
def query_devices(self, query_body):
""" Handle a device key query from a client
{
"device_keys": {
"<user_id>": ["<device_id>"]
}
}
->
{
"device_keys": {
"<user_id>": {
"<device_id>": {
...
}
}
}
}
"""
device_keys_query = query_body.get("device_keys", {})
# separate users by domain.
# make a map from domain to user_id to device_ids
queries_by_domain = collections.defaultdict(dict)
for user_id, device_ids in device_keys_query.items():
user = synapse.types.UserID.from_string(user_id)
queries_by_domain[user.domain][user_id] = device_ids
# do the queries
# TODO: do these in parallel
results = {}
for destination, destination_query in queries_by_domain.items():
if destination == self.server_name:
res = yield self.query_local_devices(destination_query)
else:
res = yield self.federation.query_client_keys(
destination, {"device_keys": destination_query}
)
res = res["device_keys"]
for user_id, keys in res.items():
if user_id in destination_query:
results[user_id] = keys
defer.returnValue((200, {"device_keys": results}))
@defer.inlineCallbacks
def query_local_devices(self, query):
"""Get E2E device keys for local users
Args:
query (dict[string, list[string]|None): map from user_id to a list
of devices to query (None for all devices)
Returns:
defer.Deferred: (resolves to dict[string, dict[string, dict]]):
map from user_id -> device_id -> device details
"""
local_query = []
result_dict = {}
for user_id, device_ids in query.items():
if not self.is_mine_id(user_id):
logger.warning("Request for keys for non-local user %s",
user_id)
raise errors.SynapseError(400, "Not a user here")
if not device_ids:
local_query.append((user_id, None))
else:
for device_id in device_ids:
local_query.append((user_id, device_id))
# make sure that each queried user appears in the result dict
result_dict[user_id] = {}
results = yield self.store.get_e2e_device_keys(local_query)
# Build the result structure, un-jsonify the results, and add the
# "unsigned" section
for user_id, device_keys in results.items():
for device_id, device_info in device_keys.items():
r = json.loads(device_info["key_json"])
r["unsigned"] = {}
display_name = device_info["device_display_name"]
if display_name is not None:
r["unsigned"]["device_display_name"] = display_name
result_dict[user_id][device_id] = r
defer.returnValue(result_dict)
@defer.inlineCallbacks
def on_federation_query_client_keys(self, query_body):
""" Handle a device key query from a federated server
"""
device_keys_query = query_body.get("device_keys", {})
res = yield self.query_local_devices(device_keys_query)
defer.returnValue({"device_keys": res})

View File

@ -130,9 +130,7 @@ class KeyUploadServlet(RestServlet):
# old access_token without an associated device_id. Either way, we # old access_token without an associated device_id. Either way, we
# need to double-check the device is registered to avoid ending up with # need to double-check the device is registered to avoid ending up with
# keys without a corresponding device. # keys without a corresponding device.
self.device_handler.check_device_registered( self.device_handler.check_device_registered(user_id, device_id)
user_id, device_id, "unknown device"
)
result = yield self.store.count_e2e_one_time_keys(user_id, device_id) result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue((200, {"one_time_key_counts": result})) defer.returnValue((200, {"one_time_key_counts": result}))
@ -186,17 +184,19 @@ class KeyQueryServlet(RestServlet):
) )
def __init__(self, hs): def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
super(KeyQueryServlet, self).__init__() super(KeyQueryServlet, self).__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.federation = hs.get_replication_layer() self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.is_mine = hs.is_mine
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id, device_id): def on_POST(self, request, user_id, device_id):
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
result = yield self.handle_request(body) result = yield self.e2e_keys_handler.query_devices(body)
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -205,45 +205,11 @@ class KeyQueryServlet(RestServlet):
auth_user_id = requester.user.to_string() auth_user_id = requester.user.to_string()
user_id = user_id if user_id else auth_user_id user_id = user_id if user_id else auth_user_id
device_ids = [device_id] if device_id else [] device_ids = [device_id] if device_id else []
result = yield self.handle_request( result = yield self.e2e_keys_handler.query_devices(
{"device_keys": {user_id: device_ids}} {"device_keys": {user_id: device_ids}}
) )
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks
def handle_request(self, body):
local_query = []
remote_queries = {}
for user_id, device_ids in body.get("device_keys", {}).items():
user = UserID.from_string(user_id)
if self.is_mine(user):
if not device_ids:
local_query.append((user_id, None))
else:
for device_id in device_ids:
local_query.append((user_id, device_id))
else:
remote_queries.setdefault(user.domain, {})[user_id] = list(
device_ids
)
results = yield self.store.get_e2e_device_keys(local_query)
json_result = {}
for user_id, device_keys in results.items():
for device_id, json_bytes in device_keys.items():
json_result.setdefault(user_id, {})[device_id] = json.loads(
json_bytes
)
for destination, device_keys in remote_queries.items():
remote_result = yield self.federation.query_client_keys(
destination, {"device_keys": device_keys}
)
for user_id, keys in remote_result["device_keys"].items():
if user_id in device_keys:
json_result[user_id] = keys
defer.returnValue((200, {"device_keys": json_result}))
class OneTimeKeyServlet(RestServlet): class OneTimeKeyServlet(RestServlet):
""" """

View File

@ -19,39 +19,38 @@
# partial one for unit test mocking. # partial one for unit test mocking.
# Imports required for the default HomeServer() implementation # Imports required for the default HomeServer() implementation
from twisted.web.client import BrowserLikePolicyForHTTPS
from twisted.enterprise import adbapi
from synapse.appservice.scheduler import ApplicationServiceScheduler
from synapse.appservice.api import ApplicationServiceApi
from synapse.federation import initialize_http_replication
from synapse.handlers.device import DeviceHandler
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.notifier import Notifier
from synapse.api.auth import Auth
from synapse.handlers import Handlers
from synapse.handlers.presence import PresenceHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler
from synapse.handlers.room import RoomListHandler
from synapse.handlers.auth import AuthHandler
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.state import StateHandler
from synapse.storage import DataStore
from synapse.util import Clock
from synapse.util.distributor import Distributor
from synapse.streams.events import EventSources
from synapse.api.ratelimiting import Ratelimiter
from synapse.crypto.keyring import Keyring
from synapse.push.pusherpool import PusherPool
from synapse.events.builder import EventBuilderFactory
from synapse.api.filtering import Filtering
from synapse.rest.media.v1.media_repository import MediaRepository
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
import logging import logging
from twisted.enterprise import adbapi
from twisted.web.client import BrowserLikePolicyForHTTPS
from synapse.api.auth import Auth
from synapse.api.filtering import Filtering
from synapse.api.ratelimiting import Ratelimiter
from synapse.appservice.api import ApplicationServiceApi
from synapse.appservice.scheduler import ApplicationServiceScheduler
from synapse.crypto.keyring import Keyring
from synapse.events.builder import EventBuilderFactory
from synapse.federation import initialize_http_replication
from synapse.handlers import Handlers
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler
from synapse.handlers.device import DeviceHandler
from synapse.handlers.e2e_keys import E2eKeysHandler
from synapse.handlers.presence import PresenceHandler
from synapse.handlers.room import RoomListHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.notifier import Notifier
from synapse.push.pusherpool import PusherPool
from synapse.rest.media.v1.media_repository import MediaRepository
from synapse.state import StateHandler
from synapse.storage import DataStore
from synapse.streams.events import EventSources
from synapse.util import Clock
from synapse.util.distributor import Distributor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -94,6 +93,7 @@ class HomeServer(object):
'room_list_handler', 'room_list_handler',
'auth_handler', 'auth_handler',
'device_handler', 'device_handler',
'e2e_keys_handler',
'application_service_api', 'application_service_api',
'application_service_scheduler', 'application_service_scheduler',
'application_service_handler', 'application_service_handler',
@ -202,6 +202,9 @@ class HomeServer(object):
def build_device_handler(self): def build_device_handler(self):
return DeviceHandler(self) return DeviceHandler(self)
def build_e2e_keys_handler(self):
return E2eKeysHandler(self)
def build_application_service_api(self): def build_application_service_api(self):
return ApplicationServiceApi(self) return ApplicationServiceApi(self)

View File

@ -1,6 +1,7 @@
import synapse.handlers import synapse.handlers
import synapse.handlers.auth import synapse.handlers.auth
import synapse.handlers.device import synapse.handlers.device
import synapse.handlers.e2e_keys
import synapse.storage import synapse.storage
import synapse.state import synapse.state
@ -14,6 +15,9 @@ class HomeServer(object):
def get_device_handler(self) -> synapse.handlers.device.DeviceHandler: def get_device_handler(self) -> synapse.handlers.device.DeviceHandler:
pass pass
def get_e2e_keys_handler(self) -> synapse.handlers.e2e_keys.E2eKeysHandler:
pass
def get_handlers(self) -> synapse.handlers.Handlers: def get_handlers(self) -> synapse.handlers.Handlers:
pass pass

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections
import twisted.internet.defer import twisted.internet.defer
@ -38,24 +39,49 @@ class EndToEndKeyStore(SQLBaseStore):
query_list(list): List of pairs of user_ids and device_ids. query_list(list): List of pairs of user_ids and device_ids.
Returns: Returns:
Dict mapping from user-id to dict mapping from device_id to Dict mapping from user-id to dict mapping from device_id to
key json byte strings. dict containing "key_json", "device_display_name".
""" """
def _get_e2e_device_keys(txn): if not query_list:
result = {} return {}
for user_id, device_id in query_list:
user_result = result.setdefault(user_id, {}) return self.runInteraction(
keyvalues = {"user_id": user_id} "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list
if device_id: )
keyvalues["device_id"] = device_id
rows = self._simple_select_list_txn( def _get_e2e_device_keys_txn(self, txn, query_list):
txn, table="e2e_device_keys_json", query_clauses = []
keyvalues=keyvalues, query_params = []
retcols=["device_id", "key_json"]
) for (user_id, device_id) in query_list:
for row in rows: query_clause = "k.user_id = ?"
user_result[row["device_id"]] = row["key_json"] query_params.append(user_id)
return result
return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys) if device_id:
query_clause += " AND k.device_id = ?"
query_params.append(device_id)
query_clauses.append(query_clause)
sql = (
"SELECT k.user_id, k.device_id, "
" d.display_name AS device_display_name, "
" k.key_json"
" FROM e2e_device_keys_json k"
" LEFT JOIN devices d ON d.user_id = k.user_id"
" AND d.device_id = k.device_id"
" WHERE %s"
) % (
" OR ".join("(" + q + ")" for q in query_clauses)
)
txn.execute(sql, query_params)
rows = self.cursor_to_dict(txn)
result = collections.defaultdict(dict)
for row in rows:
result[row["user_id"]][row["device_id"]] = row
return result
def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list): def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
def _add_e2e_one_time_keys(txn): def _add_e2e_one_time_keys(txn):

View File

@ -26,7 +26,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from collections import deque, namedtuple from collections import deque, namedtuple, OrderedDict
import synapse import synapse
import synapse.metrics import synapse.metrics
@ -403,6 +403,23 @@ class EventsStore(SQLBaseStore):
and the rejections table. Things reading from those table will need to check and the rejections table. Things reading from those table will need to check
whether the event was rejected. whether the event was rejected.
""" """
# Ensure that we don't have the same event twice.
# Pick the earliest non-outlier if there is one, else the earliest one.
new_events_and_contexts = OrderedDict()
for event, context in events_and_contexts:
prev_event_context = new_events_and_contexts.get(event.event_id)
if prev_event_context:
if not event.internal_metadata.is_outlier():
if prev_event_context[0].internal_metadata.is_outlier():
# To ensure correct ordering we pop, as OrderedDict is
# ordered by first insertion.
new_events_and_contexts.pop(event.event_id, None)
new_events_and_contexts[event.event_id] = (event, context)
else:
new_events_and_contexts[event.event_id] = (event, context)
events_and_contexts = new_events_and_contexts.values()
depth_updates = {} depth_updates = {}
for event, context in events_and_contexts: for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids # Remove the any existing cache entries for the event_ids
@ -433,8 +450,6 @@ class EventsStore(SQLBaseStore):
for event_id, outlier in txn.fetchall() for event_id, outlier in txn.fetchall()
} }
# Remove the events that we've seen before.
event_map = {}
to_remove = set() to_remove = set()
for event, context in events_and_contexts: for event, context in events_and_contexts:
if context.rejected: if context.rejected:
@ -445,23 +460,6 @@ class EventsStore(SQLBaseStore):
to_remove.add(event) to_remove.add(event)
continue continue
# Handle the case of the list including the same event multiple
# times. The tricky thing here is when they differ by whether
# they are an outlier.
if event.event_id in event_map:
other = event_map[event.event_id]
if not other.internal_metadata.is_outlier():
to_remove.add(event)
continue
elif not event.internal_metadata.is_outlier():
to_remove.add(event)
continue
else:
to_remove.add(other)
event_map[event.event_id] = event
if event.event_id not in have_persisted: if event.event_id not in have_persisted:
continue continue

View File

@ -16,4 +16,4 @@
-- make sure that we have a device record for each set of E2E keys, so that the -- make sure that we have a device record for each set of E2E keys, so that the
-- user can delete them if they like. -- user can delete them if they like.
INSERT INTO devices INSERT INTO devices
SELECT user_id, device_id, 'unknown device' FROM e2e_device_keys_json; SELECT user_id, device_id, NULL FROM e2e_device_keys_json;

View File

@ -0,0 +1,20 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- a previous version of the "devices_for_e2e_keys" delta set all the device
-- names to "unknown device". This wasn't terribly helpful
UPDATE devices
SET display_name = NULL
WHERE display_name = 'unknown device';

View File

@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import mock
from twisted.internet import defer
import synapse.api.errors
import synapse.handlers.e2e_keys
import synapse.storage
from tests import unittest, utils
class E2eKeysHandlerTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(E2eKeysHandlerTestCase, self).__init__(*args, **kwargs)
self.hs = None # type: synapse.server.HomeServer
self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler
@defer.inlineCallbacks
def setUp(self):
self.hs = yield utils.setup_test_homeserver(
handlers=None,
replication_layer=mock.Mock(),
)
self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs)
@defer.inlineCallbacks
def test_query_local_devices_no_devices(self):
"""If the user has no devices, we expect an empty list.
"""
local_user = "@boris:" + self.hs.hostname
res = yield self.handler.query_local_devices({local_user: None})
self.assertDictEqual(res, {local_user: {}})

View File

@ -0,0 +1,90 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
import tests.unittest
import tests.utils
class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
def __init__(self, *args, **kwargs):
super(EndToEndKeyStoreTestCase, self).__init__(*args, **kwargs)
self.store = None # type: synapse.storage.DataStore
@defer.inlineCallbacks
def setUp(self):
hs = yield tests.utils.setup_test_homeserver()
self.store = hs.get_datastore()
@defer.inlineCallbacks
def test_key_without_device_name(self):
now = 1470174257070
json = '{ "key": "value" }'
yield self.store.set_e2e_device_keys(
"user", "device", now, json)
res = yield self.store.get_e2e_device_keys((("user", "device"),))
self.assertIn("user", res)
self.assertIn("device", res["user"])
dev = res["user"]["device"]
self.assertDictContainsSubset({
"key_json": json,
"device_display_name": None,
}, dev)
@defer.inlineCallbacks
def test_get_key_with_device_name(self):
now = 1470174257070
json = '{ "key": "value" }'
yield self.store.set_e2e_device_keys(
"user", "device", now, json)
yield self.store.store_device(
"user", "device", "display_name"
)
res = yield self.store.get_e2e_device_keys((("user", "device"),))
self.assertIn("user", res)
self.assertIn("device", res["user"])
dev = res["user"]["device"]
self.assertDictContainsSubset({
"key_json": json,
"device_display_name": "display_name",
}, dev)
@defer.inlineCallbacks
def test_multiple_devices(self):
now = 1470174257070
yield self.store.set_e2e_device_keys(
"user1", "device1", now, 'json11')
yield self.store.set_e2e_device_keys(
"user1", "device2", now, 'json12')
yield self.store.set_e2e_device_keys(
"user2", "device1", now, 'json21')
yield self.store.set_e2e_device_keys(
"user2", "device2", now, 'json22')
res = yield self.store.get_e2e_device_keys((("user1", "device1"),
("user2", "device2")))
self.assertIn("user1", res)
self.assertIn("device1", res["user1"])
self.assertNotIn("device2", res["user1"])
self.assertIn("user2", res)
self.assertNotIn("device1", res["user2"])
self.assertIn("device2", res["user2"])