Merge branch 'release-v0.6.0'

This commit is contained in:
Mark Haines 2014-12-19 13:40:02 +00:00
commit 45a6869cb4
101 changed files with 4768 additions and 2241 deletions

2
.gitignore vendored
View File

@ -38,3 +38,5 @@ graph/*.dot
**/webclient/test/environment-protractor.js **/webclient/test/environment-protractor.js
uploads uploads
.idea/

View File

@ -1,3 +1,12 @@
Changes in synapse 0.6.0 (2014-12-16)
=====================================
* Add new API for media upload and download that supports thumbnailing.
* Implement typing notifications.
* Fix bugs where we sent events with invalid signatures due to bugs where
we incorrectly persisted events.
* Improve performance of database queries involving retrieving events.
Changes in synapse 0.5.4a (2014-12-13) Changes in synapse 0.5.4a (2014-12-13)
====================================== ======================================

View File

@ -133,6 +133,37 @@ failing, e.g.::
On OSX, if you encounter clang: error: unknown argument: '-mno-fused-madd' you On OSX, if you encounter clang: error: unknown argument: '-mno-fused-madd' you
will need to export CFLAGS=-Qunused-arguments. will need to export CFLAGS=-Qunused-arguments.
Windows Install
---------------
Synapse can be installed on Cygwin. It requires the following Cygwin packages:
- gcc
- git
- libffi-devel
- openssl (and openssl-devel, python-openssl)
- python
- python-setuptools
The content repository requires additional packages and will be unable to process
uploads without them:
- libjpeg8
- libjpeg8-devel
- zlib
If you choose to install Synapse without these packages, you will need to reinstall
``pillow`` for changes to be applied, e.g. ``pip uninstall pillow`` ``pip install
pillow --user``
Troubleshooting:
- You may need to upgrade ``setuptools`` to get this to work correctly:
``pip install setuptools --upgrade``.
- You may encounter errors indicating that ``ffi.h`` is missing, even with
``libffi-devel`` installed. If you do, copy the ``.h`` files:
``cp /usr/lib/libffi-3.0.13/include/*.h /usr/include``
- You may need to install libsodium from source in order to install PyNacl. If
you do, you may need to create a symlink to ``libsodium.a`` so ``ld`` can find
it: ``ln -s /usr/local/lib/libsodium.a /usr/lib/libsodium.a``
Running Your Homeserver Running Your Homeserver
======================= =======================

View File

@ -1,3 +1,23 @@
Upgrading to v0.6.0
===================
To pull in new dependencies, run::
python setup.py develop --user
This update includes a change to the database schema. To upgrade you first need
to upgrade the database by running::
python scripts/upgrade_db_to_v0.6.0.py <db> <server_name> <signing_key>
Where `<db>` is the location of the database, `<server_name>` is the
server name as specified in the synapse configuration, and `<signing_key>` is
the location of the signing key as specified in the synapse configuration.
This may take some time to complete. Failures of signatures and content hashes
can safely be ignored.
Upgrading to v0.5.1 Upgrading to v0.5.1
=================== ===================

View File

@ -1 +1 @@
0.5.4a 0.6.0

View File

@ -1,10 +1,14 @@
Basically, PEP8 Basically, PEP8
- Max line width: 80 chars. - NEVER tabs. 4 spaces to indent.
- Max line width: 79 chars (with flexibility to overflow by a "few chars" if
the overflowing content is not semantically significant and avoids an
explosion of vertical whitespace).
- Use camel case for class and type names - Use camel case for class and type names
- Use underscores for functions and variables. - Use underscores for functions and variables.
- Use double quotes. - Use double quotes.
- Use parentheses instead of '\' for line continuation where ever possible (which is pretty much everywhere) - Use parentheses instead of '\\' for line continuation where ever possible
(which is pretty much everywhere)
- There should be max a single new line between: - There should be max a single new line between:
- statements - statements
- functions in a class - functions in a class
@ -14,5 +18,32 @@ Basically, PEP8
- a single space after a comma - a single space after a comma
- a single space before and after for '=' when used as assignment - a single space before and after for '=' when used as assignment
- no spaces before and after for '=' for default values and keyword arguments. - no spaces before and after for '=' for default values and keyword arguments.
- Indenting must follow PEP8; either hanging indent or multiline-visual indent
depending on the size and shape of the arguments and what makes more sense to
the author. In other words, both this::
Comments should follow the google code style. This is so that we can generate documentation with sphinx (http://sphinxcontrib-napoleon.readthedocs.org/en/latest/) print("I am a fish %s" % "moo")
and this::
print("I am a fish %s" %
"moo")
and this::
print(
"I am a fish %s" %
"moo"
)
...are valid, although given each one takes up 2x more vertical space than
the previous, it's up to the author's discretion as to which layout makes most
sense for their function invocation. (e.g. if they want to add comments
per-argument, or put expressions in the arguments, or group related arguments
together, or want to deliberately extend or preserve vertical/horizontal
space)
Comments should follow the google code style. This is so that we can generate
documentation with sphinx (http://sphinxcontrib-napoleon.readthedocs.org/en/latest/)
Code should pass pep8 --max-line-length=100 without any warnings.

25
docs/media_repository.rst Normal file
View File

@ -0,0 +1,25 @@
Media Repository
================
The media repository is where attachments and avatar photos are stored.
It stores attachment content and thumbnails for media uploaded by local users.
It caches attachment content and thumbnails for media uploaded by remote users.
Storage
-------
Each item of media is assigned a ``media_id`` when it is uploaded.
The ``media_id`` is a randomly chosen, URL safe 24 character string.
Metadata such as the MIME type, upload time and length are stored in the
sqlite3 database indexed by ``media_id``.
Content is stored on the filesystem under a ``"local_content"`` directory.
Thumbnails are stored under a ``"local_thumbnails"`` directory.
The item with ``media_id`` ``"aabbccccccccdddddddddddd"`` is stored under
``"local_content/aa/bb/ccccccccdddddddddddd"``. Its thumbnail with width
``128`` and height ``96`` and type ``"image/jpeg"`` is stored under
``"local_thumbnails/aa/bb/ccccccccdddddddddddd/128-96-image-jpeg"``
Remote content is cached under ``"remote_content"`` directory. Each item of
remote content is assigned a local "``filesystem_id``" to ensure that the
directory structure ``"remote_content/server_name/aa/bb/ccccccccdddddddddddd"``
is appropriate. Thumbnails for remote content are stored under
``"remote_thumbnails/server_name/..."``

138
graph/graph2.py Normal file
View File

@ -0,0 +1,138 @@
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sqlite3
import pydot
import cgi
import json
import datetime
import argparse
from synapse.events import FrozenEvent
def make_graph(db_name, room_id, file_prefix):
conn = sqlite3.connect(db_name)
c = conn.execute(
"SELECT json FROM event_json where room_id = ?",
(room_id,)
)
events = [FrozenEvent(json.loads(e[0])) for e in c.fetchall()]
events.sort(key=lambda e: e.depth)
node_map = {}
state_groups = {}
graph = pydot.Dot(graph_name="Test")
for event in events:
c = conn.execute(
"SELECT state_group FROM event_to_state_groups "
"WHERE event_id = ?",
(event.event_id,)
)
res = c.fetchone()
state_group = res[0] if res else None
if state_group is not None:
state_groups.setdefault(state_group, []).append(event.event_id)
t = datetime.datetime.fromtimestamp(
float(event.origin_server_ts) / 1000
).strftime('%Y-%m-%d %H:%M:%S,%f')
content = json.dumps(event.get_dict()["content"])
label = (
"<"
"<b>%(name)s </b><br/>"
"Type: <b>%(type)s </b><br/>"
"State key: <b>%(state_key)s </b><br/>"
"Content: <b>%(content)s </b><br/>"
"Time: <b>%(time)s </b><br/>"
"Depth: <b>%(depth)s </b><br/>"
"State group: %(state_group)s<br/>"
">"
) % {
"name": event.event_id,
"type": event.type,
"state_key": event.get("state_key", None),
"content": cgi.escape(content, quote=True),
"time": t,
"depth": event.depth,
"state_group": state_group,
}
node = pydot.Node(
name=event.event_id,
label=label,
)
node_map[event.event_id] = node
graph.add_node(node)
for event in events:
for prev_id, _ in event.prev_events:
try:
end_node = node_map[prev_id]
except:
end_node = pydot.Node(
name=prev_id,
label="<<b>%s</b>>" % (prev_id,),
)
node_map[prev_id] = end_node
graph.add_node(end_node)
edge = pydot.Edge(node_map[event.event_id], end_node)
graph.add_edge(edge)
for group, event_ids in state_groups.items():
if len(event_ids) <= 1:
continue
cluster = pydot.Cluster(
str(group),
label="<State Group: %s>" % (str(group),)
)
for event_id in event_ids:
cluster.add_node(node_map[event_id])
graph.add_subgraph(cluster)
graph.write('%s.dot' % file_prefix, format='raw', prog='dot')
graph.write_svg("%s.svg" % file_prefix, prog='dot')
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate a PDU graph for a given room by talking "
"to the given homeserver to get the list of PDUs. \n"
"Requires pydot."
)
parser.add_argument(
"-p", "--prefix", dest="prefix",
help="String to prefix output files with"
)
parser.add_argument('db')
parser.add_argument('room')
args = parser.parse_args()
make_graph(args.db, args.room, args.prefix)

View File

@ -18,6 +18,9 @@ class dictobj(dict):
def get_full_dict(self): def get_full_dict(self):
return dict(self) return dict(self)
def get_pdu_json(self):
return dict(self)
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View File

@ -0,0 +1,143 @@
import nacl.signing
import json
import base64
import requests
import sys
import srvlookup
def encode_base64(input_bytes):
"""Encode bytes as a base64 string without any padding."""
input_len = len(input_bytes)
output_len = 4 * ((input_len + 2) // 3) + (input_len + 2) % 3 - 2
output_bytes = base64.b64encode(input_bytes)
output_string = output_bytes[:output_len].decode("ascii")
return output_string
def decode_base64(input_string):
"""Decode a base64 string to bytes inferring padding from the length of the
string."""
input_bytes = input_string.encode("ascii")
input_len = len(input_bytes)
padding = b"=" * (3 - ((input_len + 3) % 4))
output_len = 3 * ((input_len + 2) // 4) + (input_len + 2) % 4 - 2
output_bytes = base64.b64decode(input_bytes + padding)
return output_bytes[:output_len]
def encode_canonical_json(value):
return json.dumps(
value,
# Encode code-points outside of ASCII as UTF-8 rather than \u escapes
ensure_ascii=False,
# Remove unecessary white space.
separators=(',',':'),
# Sort the keys of dictionaries.
sort_keys=True,
# Encode the resulting unicode as UTF-8 bytes.
).encode("UTF-8")
def sign_json(json_object, signing_key, signing_name):
signatures = json_object.pop("signatures", {})
unsigned = json_object.pop("unsigned", None)
signed = signing_key.sign(encode_canonical_json(json_object))
signature_base64 = encode_base64(signed.signature)
key_id = "%s:%s" % (signing_key.alg, signing_key.version)
signatures.setdefault(signing_name, {})[key_id] = signature_base64
json_object["signatures"] = signatures
if unsigned is not None:
json_object["unsigned"] = unsigned
return json_object
NACL_ED25519 = "ed25519"
def decode_signing_key_base64(algorithm, version, key_base64):
"""Decode a base64 encoded signing key
Args:
algorithm (str): The algorithm the key is for (currently "ed25519").
version (str): Identifies this key out of the keys for this entity.
key_base64 (str): Base64 encoded bytes of the key.
Returns:
A SigningKey object.
"""
if algorithm == NACL_ED25519:
key_bytes = decode_base64(key_base64)
key = nacl.signing.SigningKey(key_bytes)
key.version = version
key.alg = NACL_ED25519
return key
else:
raise ValueError("Unsupported algorithm %s" % (algorithm,))
def read_signing_keys(stream):
"""Reads a list of keys from a stream
Args:
stream : A stream to iterate for keys.
Returns:
list of SigningKey objects.
"""
keys = []
for line in stream:
algorithm, version, key_base64 = line.split()
keys.append(decode_signing_key_base64(algorithm, version, key_base64))
return keys
def lookup(destination, path):
if ":" in destination:
return "https://%s%s" % (destination, path)
else:
srv = srvlookup.lookup("matrix", "tcp", destination)[0]
return "https://%s:%d%s" % (srv.host, srv.port, path)
def get_json(origin_name, origin_key, destination, path):
request_json = {
"method": "GET",
"uri": path,
"origin": origin_name,
"destination": destination,
}
signed_json = sign_json(request_json, origin_key, origin_name)
authorization_headers = []
for key, sig in signed_json["signatures"][origin_name].items():
authorization_headers.append(bytes(
"X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
origin_name, key, sig,
)
))
result = requests.get(
lookup(destination, path),
headers={"Authorization": authorization_headers[0]},
verify=False,
)
return result.json()
def main():
origin_name, keyfile, destination, path = sys.argv[1:]
with open(keyfile) as f:
key = read_signing_keys(f)[0]
result = get_json(
origin_name, key, destination, "/_matrix/federation/v1/" + path
)
json.dump(result, sys.stdout)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,331 @@
from synapse.storage import SCHEMA_VERSION, read_schema
from synapse.storage._base import SQLBaseStore
from synapse.storage.signatures import SignatureStore
from synapse.storage.event_federation import EventFederationStore
from syutil.base64util import encode_base64, decode_base64
from synapse.crypto.event_signing import compute_event_signature
from synapse.events.builder import EventBuilder
from synapse.events.utils import prune_event
from synapse.crypto.event_signing import check_event_content_hash
from syutil.crypto.jsonsign import (
verify_signed_json, SignatureVerifyException,
)
from syutil.crypto.signing_key import decode_verify_key_bytes
from syutil.jsonutil import encode_canonical_json
import argparse
# import dns.resolver
import hashlib
import httplib
import json
import sqlite3
import syutil
import urllib2
delta_sql = """
CREATE TABLE IF NOT EXISTS event_json(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
internal_metadata NOT NULL,
json BLOB NOT NULL,
CONSTRAINT ev_j_uniq UNIQUE (event_id)
);
CREATE INDEX IF NOT EXISTS event_json_id ON event_json(event_id);
CREATE INDEX IF NOT EXISTS event_json_room_id ON event_json(room_id);
PRAGMA user_version = 10;
"""
class Store(object):
_get_event_signatures_txn = SignatureStore.__dict__["_get_event_signatures_txn"]
_get_event_content_hashes_txn = SignatureStore.__dict__["_get_event_content_hashes_txn"]
_get_event_reference_hashes_txn = SignatureStore.__dict__["_get_event_reference_hashes_txn"]
_get_prev_event_hashes_txn = SignatureStore.__dict__["_get_prev_event_hashes_txn"]
_get_prev_events_and_state = EventFederationStore.__dict__["_get_prev_events_and_state"]
_get_auth_events = EventFederationStore.__dict__["_get_auth_events"]
cursor_to_dict = SQLBaseStore.__dict__["cursor_to_dict"]
_simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
_simple_select_list_txn = SQLBaseStore.__dict__["_simple_select_list_txn"]
_simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"]
def _generate_event_json(self, txn, rows):
events = []
for row in rows:
d = dict(row)
d.pop("stream_ordering", None)
d.pop("topological_ordering", None)
d.pop("processed", None)
if "origin_server_ts" not in d:
d["origin_server_ts"] = d.pop("ts", 0)
else:
d.pop("ts", 0)
d.pop("prev_state", None)
d.update(json.loads(d.pop("unrecognized_keys")))
d["sender"] = d.pop("user_id")
d["content"] = json.loads(d["content"])
if "age_ts" not in d:
# For compatibility
d["age_ts"] = d.get("origin_server_ts", 0)
d.setdefault("unsigned", {})["age_ts"] = d.pop("age_ts")
outlier = d.pop("outlier", False)
# d.pop("membership", None)
d.pop("state_hash", None)
d.pop("replaces_state", None)
b = EventBuilder(d)
b.internal_metadata.outlier = outlier
events.append(b)
for i, ev in enumerate(events):
signatures = self._get_event_signatures_txn(
txn, ev.event_id,
)
ev.signatures = {
n: {
k: encode_base64(v) for k, v in s.items()
}
for n, s in signatures.items()
}
hashes = self._get_event_content_hashes_txn(
txn, ev.event_id,
)
ev.hashes = {
k: encode_base64(v) for k, v in hashes.items()
}
prevs = self._get_prev_events_and_state(txn, ev.event_id)
ev.prev_events = [
(e_id, h)
for e_id, h, is_state in prevs
if is_state == 0
]
# ev.auth_events = self._get_auth_events(txn, ev.event_id)
hashes = dict(ev.auth_events)
for e_id, hash in ev.prev_events:
if e_id in hashes and not hash:
hash.update(hashes[e_id])
#
# if hasattr(ev, "state_key"):
# ev.prev_state = [
# (e_id, h)
# for e_id, h, is_state in prevs
# if is_state == 1
# ]
return [e.build() for e in events]
store = Store()
# def get_key(server_name):
# print "Getting keys for: %s" % (server_name,)
# targets = []
# if ":" in server_name:
# target, port = server_name.split(":")
# targets.append((target, int(port)))
# try:
# answers = dns.resolver.query("_matrix._tcp." + server_name, "SRV")
# for srv in answers:
# targets.append((srv.target, srv.port))
# except dns.resolver.NXDOMAIN:
# targets.append((server_name, 8448))
# except:
# print "Failed to lookup keys for %s" % (server_name,)
# return {}
#
# for target, port in targets:
# url = "https://%s:%i/_matrix/key/v1" % (target, port)
# try:
# keys = json.load(urllib2.urlopen(url, timeout=2))
# verify_keys = {}
# for key_id, key_base64 in keys["verify_keys"].items():
# verify_key = decode_verify_key_bytes(
# key_id, decode_base64(key_base64)
# )
# verify_signed_json(keys, server_name, verify_key)
# verify_keys[key_id] = verify_key
# print "Got keys for: %s" % (server_name,)
# return verify_keys
# except urllib2.URLError:
# pass
# except urllib2.HTTPError:
# pass
# except httplib.HTTPException:
# pass
#
# print "Failed to get keys for %s" % (server_name,)
# return {}
def reinsert_events(cursor, server_name, signing_key):
print "Running delta: v10"
cursor.executescript(delta_sql)
cursor.execute(
"SELECT * FROM events ORDER BY rowid ASC"
)
print "Getting events..."
rows = store.cursor_to_dict(cursor)
events = store._generate_event_json(cursor, rows)
print "Got events from DB."
algorithms = {
"sha256": hashlib.sha256,
}
key_id = "%s:%s" % (signing_key.alg, signing_key.version)
verify_key = signing_key.verify_key
verify_key.alg = signing_key.alg
verify_key.version = signing_key.version
server_keys = {
server_name: {
key_id: verify_key
}
}
i = 0
N = len(events)
for event in events:
if i % 100 == 0:
print "Processed: %d/%d events" % (i,N,)
i += 1
# for alg_name in event.hashes:
# if check_event_content_hash(event, algorithms[alg_name]):
# pass
# else:
# pass
# print "FAIL content hash %s %s" % (alg_name, event.event_id, )
have_own_correctly_signed = False
for host, sigs in event.signatures.items():
pruned = prune_event(event)
for key_id in sigs:
if host not in server_keys:
server_keys[host] = {} # get_key(host)
if key_id in server_keys[host]:
try:
verify_signed_json(
pruned.get_pdu_json(),
host,
server_keys[host][key_id]
)
if host == server_name:
have_own_correctly_signed = True
except SignatureVerifyException:
print "FAIL signature check %s %s" % (
key_id, event.event_id
)
# TODO: Re sign with our own server key
if not have_own_correctly_signed:
sigs = compute_event_signature(event, server_name, signing_key)
event.signatures.update(sigs)
pruned = prune_event(event)
for key_id in event.signatures[server_name]:
verify_signed_json(
pruned.get_pdu_json(),
server_name,
server_keys[server_name][key_id]
)
event_json = encode_canonical_json(
event.get_dict()
).decode("UTF-8")
metadata_json = encode_canonical_json(
event.internal_metadata.get_dict()
).decode("UTF-8")
store._simple_insert_txn(
cursor,
table="event_json",
values={
"event_id": event.event_id,
"room_id": event.room_id,
"internal_metadata": metadata_json,
"json": event_json,
},
or_replace=True,
)
def main(database, server_name, signing_key):
conn = sqlite3.connect(database)
cursor = conn.cursor()
# Do other deltas:
cursor.execute("PRAGMA user_version")
row = cursor.fetchone()
if row and row[0]:
user_version = row[0]
# Run every version since after the current version.
for v in range(user_version + 1, 10):
print "Running delta: %d" % (v,)
sql_script = read_schema("delta/v%d" % (v,))
cursor.executescript(sql_script)
reinsert_events(cursor, server_name, signing_key)
conn.commit()
print "Success!"
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("database")
parser.add_argument("server_name")
parser.add_argument(
"signing_key", type=argparse.FileType('r'),
)
args = parser.parse_args()
signing_key = syutil.crypto.signing_key.read_signing_keys(
args.signing_key
)
main(args.database, args.server_name, signing_key[0])

View File

@ -32,7 +32,7 @@ setup(
description="Reference Synapse Home Server", description="Reference Synapse Home Server",
install_requires=[ install_requires=[
"syutil==0.0.2", "syutil==0.0.2",
"matrix_angular_sdk==0.5.1", "matrix_angular_sdk==0.6.0",
"Twisted>=14.0.0", "Twisted>=14.0.0",
"service_identity>=1.0.0", "service_identity>=1.0.0",
"pyopenssl>=0.14", "pyopenssl>=0.14",
@ -41,11 +41,13 @@ setup(
"pynacl", "pynacl",
"daemonize", "daemonize",
"py-bcrypt", "py-bcrypt",
"frozendict>=0.4",
"pillow",
], ],
dependency_links=[ dependency_links=[
"https://github.com/matrix-org/syutil/tarball/v0.0.2#egg=syutil-0.0.2", "https://github.com/matrix-org/syutil/tarball/v0.0.2#egg=syutil-0.0.2",
"https://github.com/pyca/pynacl/tarball/d4d3175589b892f6ea7c22f466e0e223853516fa#egg=pynacl-0.3.0", "https://github.com/pyca/pynacl/tarball/d4d3175589b892f6ea7c22f466e0e223853516fa#egg=pynacl-0.3.0",
"https://github.com/matrix-org/matrix-angular-sdk/tarball/v0.5.1/#egg=matrix_angular_sdk-0.5.1", "https://github.com/matrix-org/matrix-angular-sdk/tarball/v0.6.0/#egg=matrix_angular_sdk-0.6.0",
], ],
setup_requires=[ setup_requires=[
"setuptools_trial", "setuptools_trial",
@ -59,6 +61,6 @@ setup(
entry_points=""" entry_points="""
[console_scripts] [console_scripts]
synctl=synapse.app.synctl:main synctl=synapse.app.synctl:main
synapse-homeserver=synapse.app.homeserver:run synapse-homeserver=synapse.app.homeserver:main
""" """
) )

View File

@ -16,4 +16,4 @@
""" This is a reference implementation of a synapse home server. """ This is a reference implementation of a synapse home server.
""" """
__version__ = "0.5.4a" __version__ = "0.6.0"

View File

@ -17,14 +17,10 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import Membership, JoinRules from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, StoreError, Codes, SynapseError from synapse.api.errors import AuthError, StoreError, Codes, SynapseError
from synapse.api.events.room import (
RoomMemberEvent, RoomPowerLevelsEvent, RoomRedactionEvent,
RoomJoinRulesEvent, RoomCreateEvent, RoomAliasesEvent,
)
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from syutil.base64util import encode_base64 from synapse.util.async import run_on_reactor
import logging import logging
@ -53,15 +49,17 @@ class Auth(object):
logger.warn("Trusting event: %s", event.event_id) logger.warn("Trusting event: %s", event.event_id)
return True return True
if event.type == RoomCreateEvent.TYPE: if event.type == EventTypes.Create:
# FIXME # FIXME
return True return True
# FIXME: Temp hack # FIXME: Temp hack
if event.type == RoomAliasesEvent.TYPE: if event.type == EventTypes.Aliases:
return True return True
if event.type == RoomMemberEvent.TYPE: logger.debug("Auth events: %s", auth_events)
if event.type == EventTypes.Member:
allowed = self.is_membership_change_allowed( allowed = self.is_membership_change_allowed(
event, auth_events event, auth_events
) )
@ -74,10 +72,10 @@ class Auth(object):
self.check_event_sender_in_room(event, auth_events) self.check_event_sender_in_room(event, auth_events)
self._can_send_event(event, auth_events) self._can_send_event(event, auth_events)
if event.type == RoomPowerLevelsEvent.TYPE: if event.type == EventTypes.PowerLevels:
self._check_power_levels(event, auth_events) self._check_power_levels(event, auth_events)
if event.type == RoomRedactionEvent.TYPE: if event.type == EventTypes.Redaction:
self._check_redaction(event, auth_events) self._check_redaction(event, auth_events)
logger.debug("Allowing! %s", event) logger.debug("Allowing! %s", event)
@ -93,7 +91,7 @@ class Auth(object):
def check_joined_room(self, room_id, user_id): def check_joined_room(self, room_id, user_id):
member = yield self.state.get_current_state( member = yield self.state.get_current_state(
room_id=room_id, room_id=room_id,
event_type=RoomMemberEvent.TYPE, event_type=EventTypes.Member,
state_key=user_id state_key=user_id
) )
self._check_joined_room(member, user_id, room_id) self._check_joined_room(member, user_id, room_id)
@ -104,7 +102,7 @@ class Auth(object):
curr_state = yield self.state.get_current_state(room_id) curr_state = yield self.state.get_current_state(room_id)
for event in curr_state: for event in curr_state:
if event.type == RoomMemberEvent.TYPE: if event.type == EventTypes.Member:
try: try:
if self.hs.parse_userid(event.state_key).domain != host: if self.hs.parse_userid(event.state_key).domain != host:
continue continue
@ -118,7 +116,7 @@ class Auth(object):
defer.returnValue(False) defer.returnValue(False)
def check_event_sender_in_room(self, event, auth_events): def check_event_sender_in_room(self, event, auth_events):
key = (RoomMemberEvent.TYPE, event.user_id, ) key = (EventTypes.Member, event.user_id, )
member_event = auth_events.get(key) member_event = auth_events.get(key)
return self._check_joined_room( return self._check_joined_room(
@ -140,7 +138,7 @@ class Auth(object):
# Check if this is the room creator joining: # Check if this is the room creator joining:
if len(event.prev_events) == 1 and Membership.JOIN == membership: if len(event.prev_events) == 1 and Membership.JOIN == membership:
# Get room creation event: # Get room creation event:
key = (RoomCreateEvent.TYPE, "", ) key = (EventTypes.Create, "", )
create = auth_events.get(key) create = auth_events.get(key)
if create and event.prev_events[0][0] == create.event_id: if create and event.prev_events[0][0] == create.event_id:
if create.content["creator"] == event.state_key: if create.content["creator"] == event.state_key:
@ -149,19 +147,19 @@ class Auth(object):
target_user_id = event.state_key target_user_id = event.state_key
# get info about the caller # get info about the caller
key = (RoomMemberEvent.TYPE, event.user_id, ) key = (EventTypes.Member, event.user_id, )
caller = auth_events.get(key) caller = auth_events.get(key)
caller_in_room = caller and caller.membership == Membership.JOIN caller_in_room = caller and caller.membership == Membership.JOIN
caller_invited = caller and caller.membership == Membership.INVITE caller_invited = caller and caller.membership == Membership.INVITE
# get info about the target # get info about the target
key = (RoomMemberEvent.TYPE, target_user_id, ) key = (EventTypes.Member, target_user_id, )
target = auth_events.get(key) target = auth_events.get(key)
target_in_room = target and target.membership == Membership.JOIN target_in_room = target and target.membership == Membership.JOIN
key = (RoomJoinRulesEvent.TYPE, "", ) key = (EventTypes.JoinRules, "", )
join_rule_event = auth_events.get(key) join_rule_event = auth_events.get(key)
if join_rule_event: if join_rule_event:
join_rule = join_rule_event.content.get( join_rule = join_rule_event.content.get(
@ -256,7 +254,7 @@ class Auth(object):
return True return True
def _get_power_level_from_event_state(self, event, user_id, auth_events): def _get_power_level_from_event_state(self, event, user_id, auth_events):
key = (RoomPowerLevelsEvent.TYPE, "", ) key = (EventTypes.PowerLevels, "", )
power_level_event = auth_events.get(key) power_level_event = auth_events.get(key)
level = None level = None
if power_level_event: if power_level_event:
@ -264,7 +262,7 @@ class Auth(object):
if not level: if not level:
level = power_level_event.content.get("users_default", 0) level = power_level_event.content.get("users_default", 0)
else: else:
key = (RoomCreateEvent.TYPE, "", ) key = (EventTypes.Create, "", )
create_event = auth_events.get(key) create_event = auth_events.get(key)
if (create_event is not None and if (create_event is not None and
create_event.content["creator"] == user_id): create_event.content["creator"] == user_id):
@ -273,7 +271,7 @@ class Auth(object):
return level return level
def _get_ops_level_from_event_state(self, event, auth_events): def _get_ops_level_from_event_state(self, event, auth_events):
key = (RoomPowerLevelsEvent.TYPE, "", ) key = (EventTypes.PowerLevels, "", )
power_level_event = auth_events.get(key) power_level_event = auth_events.get(key)
if power_level_event: if power_level_event:
@ -351,29 +349,31 @@ class Auth(object):
return self.store.is_server_admin(user) return self.store.is_server_admin(user)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_auth_events(self, event): def add_auth_events(self, builder, context):
if event.type == RoomCreateEvent.TYPE: yield run_on_reactor()
event.auth_events = []
if builder.type == EventTypes.Create:
builder.auth_events = []
return return
auth_events = [] auth_ids = []
key = (RoomPowerLevelsEvent.TYPE, "", ) key = (EventTypes.PowerLevels, "", )
power_level_event = event.old_state_events.get(key) power_level_event = context.current_state.get(key)
if power_level_event: if power_level_event:
auth_events.append(power_level_event.event_id) auth_ids.append(power_level_event.event_id)
key = (RoomJoinRulesEvent.TYPE, "", ) key = (EventTypes.JoinRules, "", )
join_rule_event = event.old_state_events.get(key) join_rule_event = context.current_state.get(key)
key = (RoomMemberEvent.TYPE, event.user_id, ) key = (EventTypes.Member, builder.user_id, )
member_event = event.old_state_events.get(key) member_event = context.current_state.get(key)
key = (RoomCreateEvent.TYPE, "", ) key = (EventTypes.Create, "", )
create_event = event.old_state_events.get(key) create_event = context.current_state.get(key)
if create_event: if create_event:
auth_events.append(create_event.event_id) auth_ids.append(create_event.event_id)
if join_rule_event: if join_rule_event:
join_rule = join_rule_event.content.get("join_rule") join_rule = join_rule_event.content.get("join_rule")
@ -381,33 +381,37 @@ class Auth(object):
else: else:
is_public = False is_public = False
if event.type == RoomMemberEvent.TYPE: if builder.type == EventTypes.Member:
e_type = event.content["membership"] e_type = builder.content["membership"]
if e_type in [Membership.JOIN, Membership.INVITE]: if e_type in [Membership.JOIN, Membership.INVITE]:
if join_rule_event: if join_rule_event:
auth_events.append(join_rule_event.event_id) auth_ids.append(join_rule_event.event_id)
if e_type == Membership.JOIN:
if member_event and not is_public: if member_event and not is_public:
auth_events.append(member_event.event_id) auth_ids.append(member_event.event_id)
else:
if member_event:
auth_ids.append(member_event.event_id)
elif member_event: elif member_event:
if member_event.content["membership"] == Membership.JOIN: if member_event.content["membership"] == Membership.JOIN:
auth_events.append(member_event.event_id) auth_ids.append(member_event.event_id)
hashes = yield self.store.get_event_reference_hashes( auth_events_entries = yield self.store.add_event_hashes(
auth_events auth_ids
) )
hashes = [
{ builder.auth_events = auth_events_entries
k: encode_base64(v) for k, v in h.items()
if k == "sha256" context.auth_events = {
} k: v
for h in hashes for k, v in context.current_state.items()
] if v.event_id in auth_ids
event.auth_events = zip(auth_events, hashes) }
@log_function @log_function
def _can_send_event(self, event, auth_events): def _can_send_event(self, event, auth_events):
key = (RoomPowerLevelsEvent.TYPE, "", ) key = (EventTypes.PowerLevels, "", )
send_level_event = auth_events.get(key) send_level_event = auth_events.get(key)
send_level = None send_level = None
if send_level_event: if send_level_event:

View File

@ -59,3 +59,18 @@ class LoginType(object):
EMAIL_URL = u"m.login.email.url" EMAIL_URL = u"m.login.email.url"
EMAIL_IDENTITY = u"m.login.email.identity" EMAIL_IDENTITY = u"m.login.email.identity"
RECAPTCHA = u"m.login.recaptcha" RECAPTCHA = u"m.login.recaptcha"
class EventTypes(object):
Member = "m.room.member"
Create = "m.room.create"
JoinRules = "m.room.join_rules"
PowerLevels = "m.room.power_levels"
Aliases = "m.room.aliases"
Redaction = "m.room.redaction"
Feedback = "m.room.message.feedback"
# These are used for validation
Message = "m.room.message"
Topic = "m.room.topic"
Name = "m.room.name"

View File

@ -34,6 +34,7 @@ class Codes(object):
LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED" LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED"
CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED" CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED"
CAPTCHA_INVALID = "M_CAPTCHA_INVALID" CAPTCHA_INVALID = "M_CAPTCHA_INVALID"
TOO_LARGE = "M_TOO_LARGE"
class CodeMessageException(Exception): class CodeMessageException(Exception):

View File

@ -1,148 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util.jsonobject import JsonEncodedObject
def serialize_event(hs, e):
# FIXME(erikj): To handle the case of presence events and the like
if not isinstance(e, SynapseEvent):
return e
# Should this strip out None's?
d = {k: v for k, v in e.get_dict().items()}
if "age_ts" in d:
d["age"] = int(hs.get_clock().time_msec()) - d["age_ts"]
del d["age_ts"]
return d
class SynapseEvent(JsonEncodedObject):
"""Base class for Synapse events. These are JSON objects which must abide
by a certain well-defined structure.
"""
# Attributes that are currently assumed by the federation side:
# Mandatory:
# - event_id
# - room_id
# - type
# - is_state
#
# Optional:
# - state_key (mandatory when is_state is True)
# - prev_events (these can be filled out by the federation layer itself.)
# - prev_state
valid_keys = [
"event_id",
"type",
"room_id",
"user_id", # sender/initiator
"content", # HTTP body, JSON
"state_key",
"age_ts",
"prev_content",
"replaces_state",
"redacted_because",
"origin_server_ts",
]
internal_keys = [
"is_state",
"depth",
"destinations",
"origin",
"outlier",
"redacted",
"prev_events",
"hashes",
"signatures",
"prev_state",
"auth_events",
"state_hash",
]
required_keys = [
"event_id",
"room_id",
"content",
]
outlier = False
def __init__(self, raises=True, **kwargs):
super(SynapseEvent, self).__init__(**kwargs)
# if "content" in kwargs:
# self.check_json(self.content, raises=raises)
def get_content_template(self):
""" Retrieve the JSON template for this event as a dict.
The template must be a dict representing the JSON to match. Only
required keys should be present. The values of the keys in the template
are checked via type() to the values of the same keys in the actual
event JSON.
NB: If loading content via json.loads, you MUST define strings as
unicode.
For example:
Content:
{
"name": u"bob",
"age": 18,
"friends": [u"mike", u"jill"]
}
Template:
{
"name": u"string",
"age": 0,
"friends": [u"string"]
}
The values "string" and 0 could be anything, so long as the types
are the same as the content.
"""
raise NotImplementedError("get_content_template not implemented.")
def get_pdu_json(self, time_now=None):
pdu_json = self.get_full_dict()
pdu_json.pop("destinations", None)
pdu_json.pop("outlier", None)
pdu_json.pop("replaces_state", None)
pdu_json.pop("redacted", None)
pdu_json.pop("prev_content", None)
state_hash = pdu_json.pop("state_hash", None)
if state_hash is not None:
pdu_json.setdefault("unsigned", {})["state_hash"] = state_hash
content = pdu_json.get("content", {})
content.pop("prev", None)
if time_now is not None and "age_ts" in pdu_json:
age = time_now - pdu_json["age_ts"]
pdu_json.setdefault("unsigned", {})["age"] = int(age)
del pdu_json["age_ts"]
user_id = pdu_json.pop("user_id")
pdu_json["sender"] = user_id
return pdu_json
class SynapseStateEvent(SynapseEvent):
def __init__(self, **kwargs):
if "state_key" not in kwargs:
kwargs["state_key"] = ""
super(SynapseStateEvent, self).__init__(**kwargs)

View File

@ -1,90 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.events.room import (
RoomTopicEvent, MessageEvent, RoomMemberEvent, FeedbackEvent,
InviteJoinEvent, RoomConfigEvent, RoomNameEvent, GenericEvent,
RoomPowerLevelsEvent, RoomJoinRulesEvent,
RoomCreateEvent,
RoomRedactionEvent,
)
from synapse.types import EventID
from synapse.util.stringutils import random_string
class EventFactory(object):
_event_classes = [
RoomTopicEvent,
RoomNameEvent,
MessageEvent,
RoomMemberEvent,
FeedbackEvent,
InviteJoinEvent,
RoomConfigEvent,
RoomPowerLevelsEvent,
RoomJoinRulesEvent,
RoomCreateEvent,
RoomRedactionEvent,
]
def __init__(self, hs):
self._event_list = {} # dict of TYPE to event class
for event_class in EventFactory._event_classes:
self._event_list[event_class.TYPE] = event_class
self.clock = hs.get_clock()
self.hs = hs
self.event_id_count = 0
def create_event_id(self):
i = str(self.event_id_count)
self.event_id_count += 1
local_part = str(int(self.clock.time())) + i + random_string(5)
e_id = EventID.create_local(local_part, self.hs)
return e_id.to_string()
def create_event(self, etype=None, **kwargs):
kwargs["type"] = etype
if "event_id" not in kwargs:
kwargs["event_id"] = self.create_event_id()
kwargs["origin"] = self.hs.hostname
else:
ev_id = self.hs.parse_eventid(kwargs["event_id"])
kwargs["origin"] = ev_id.domain
if "origin_server_ts" not in kwargs:
kwargs["origin_server_ts"] = int(self.clock.time_msec())
# The "age" key is a delta timestamp that should be converted into an
# absolute timestamp the minute we see it.
if "age" in kwargs:
kwargs["age_ts"] = int(self.clock.time_msec()) - int(kwargs["age"])
del kwargs["age"]
elif "age_ts" not in kwargs:
kwargs["age_ts"] = int(self.clock.time_msec())
if etype in self._event_list:
handler = self._event_list[etype]
else:
handler = GenericEvent
return handler(**kwargs)

View File

@ -1,170 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.constants import Feedback, Membership
from synapse.api.errors import SynapseError
from . import SynapseEvent, SynapseStateEvent
class GenericEvent(SynapseEvent):
def get_content_template(self):
return {}
class RoomTopicEvent(SynapseEvent):
TYPE = "m.room.topic"
internal_keys = SynapseEvent.internal_keys + [
"topic",
]
def __init__(self, **kwargs):
kwargs["state_key"] = ""
if "topic" in kwargs["content"]:
kwargs["topic"] = kwargs["content"]["topic"]
super(RoomTopicEvent, self).__init__(**kwargs)
def get_content_template(self):
return {"topic": u"string"}
class RoomNameEvent(SynapseEvent):
TYPE = "m.room.name"
internal_keys = SynapseEvent.internal_keys + [
"name",
]
def __init__(self, **kwargs):
kwargs["state_key"] = ""
if "name" in kwargs["content"]:
kwargs["name"] = kwargs["content"]["name"]
super(RoomNameEvent, self).__init__(**kwargs)
def get_content_template(self):
return {"name": u"string"}
class RoomMemberEvent(SynapseEvent):
TYPE = "m.room.member"
valid_keys = SynapseEvent.valid_keys + [
# target is the state_key
"membership", # action
]
def __init__(self, **kwargs):
if "membership" not in kwargs:
kwargs["membership"] = kwargs.get("content", {}).get("membership")
if not kwargs["membership"] in Membership.LIST:
raise SynapseError(400, "Bad membership value.")
super(RoomMemberEvent, self).__init__(**kwargs)
def get_content_template(self):
return {"membership": u"string"}
class MessageEvent(SynapseEvent):
TYPE = "m.room.message"
valid_keys = SynapseEvent.valid_keys + [
"msg_id", # unique per room + user combo
]
def __init__(self, **kwargs):
super(MessageEvent, self).__init__(**kwargs)
def get_content_template(self):
return {"msgtype": u"string"}
class FeedbackEvent(SynapseEvent):
TYPE = "m.room.message.feedback"
valid_keys = SynapseEvent.valid_keys
def __init__(self, **kwargs):
super(FeedbackEvent, self).__init__(**kwargs)
if not kwargs["content"]["type"] in Feedback.LIST:
raise SynapseError(400, "Bad feedback value.")
def get_content_template(self):
return {
"type": u"string",
"target_event_id": u"string"
}
class InviteJoinEvent(SynapseEvent):
TYPE = "m.room.invite_join"
valid_keys = SynapseEvent.valid_keys + [
# target_user_id is the state_key
"target_host",
]
def __init__(self, **kwargs):
super(InviteJoinEvent, self).__init__(**kwargs)
def get_content_template(self):
return {}
class RoomConfigEvent(SynapseEvent):
TYPE = "m.room.config"
def __init__(self, **kwargs):
kwargs["state_key"] = ""
super(RoomConfigEvent, self).__init__(**kwargs)
def get_content_template(self):
return {}
class RoomCreateEvent(SynapseStateEvent):
TYPE = "m.room.create"
def get_content_template(self):
return {}
class RoomJoinRulesEvent(SynapseStateEvent):
TYPE = "m.room.join_rules"
def get_content_template(self):
return {}
class RoomPowerLevelsEvent(SynapseStateEvent):
TYPE = "m.room.power_levels"
def get_content_template(self):
return {}
class RoomAliasesEvent(SynapseStateEvent):
TYPE = "m.room.aliases"
def get_content_template(self):
return {}
class RoomRedactionEvent(SynapseEvent):
TYPE = "m.room.redaction"
valid_keys = SynapseEvent.valid_keys + ["redacts"]
def get_content_template(self):
return {}

View File

@ -1,87 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.errors import SynapseError, Codes
class EventValidator(object):
def __init__(self, hs):
pass
def validate(self, event):
"""Checks the given JSON content abides by the rules of the template.
Args:
content : A JSON object to check.
raises: True to raise a SynapseError if the check fails.
Returns:
True if the content passes the template. Returns False if the check
fails and raises=False.
Raises:
SynapseError if the check fails and raises=True.
"""
# recursively call to inspect each layer
err_msg = self._check_json_template(
event.content,
event.get_content_template()
)
if err_msg:
raise SynapseError(400, err_msg, Codes.BAD_JSON)
else:
return True
def _check_json_template(self, content, template):
"""Check content and template matches.
If the template is a dict, each key in the dict will be validated with
the content, else it will just compare the types of content and
template. This basic type check is required because this function will
be recursively called and could be called with just strs or ints.
Args:
content: The content to validate.
template: The validation template.
Returns:
str: An error message if the validation fails, else None.
"""
if type(content) != type(template):
return "Mismatched types: %s" % template
if type(template) == dict:
for key in template:
if key not in content:
return "Missing %s key" % key
if type(content[key]) != type(template[key]):
return "Key %s is of the wrong type (got %s, want %s)" % (
key, type(content[key]), type(template[key]))
if type(content[key]) == dict:
# we must go deeper
msg = self._check_json_template(
content[key],
template[key]
)
if msg:
return msg
elif type(content[key]) == list:
# make sure each item type in content matches the template
for entry in content[key]:
msg = self._check_json_template(
entry,
template[key][0]
)
if msg:
return msg

View File

@ -20,3 +20,4 @@ FEDERATION_PREFIX = "/_matrix/federation/v1"
WEB_CLIENT_PREFIX = "/_matrix/client" WEB_CLIENT_PREFIX = "/_matrix/client"
CONTENT_REPO_PREFIX = "/_matrix/content" CONTENT_REPO_PREFIX = "/_matrix/content"
SERVER_KEY_PREFIX = "/_matrix/key/v1" SERVER_KEY_PREFIX = "/_matrix/key/v1"
MEDIA_PREFIX = "/_matrix/media/v1"

View File

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage import prepare_database from synapse.storage import prepare_database, UpgradeDatabaseException
from synapse.server import HomeServer from synapse.server import HomeServer
@ -24,12 +24,13 @@ from twisted.web.resource import Resource
from twisted.web.static import File from twisted.web.static import File
from twisted.web.server import Site from twisted.web.server import Site
from synapse.http.server import JsonResource, RootRedirect from synapse.http.server import JsonResource, RootRedirect
from synapse.http.content_repository import ContentRepoResource from synapse.media.v0.content_repository import ContentRepoResource
from synapse.media.v1.media_repository import MediaRepositoryResource
from synapse.http.server_key_resource import LocalKey from synapse.http.server_key_resource import LocalKey
from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.api.urls import ( from synapse.api.urls import (
CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX, CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
SERVER_KEY_PREFIX, SERVER_KEY_PREFIX, MEDIA_PREFIX
) )
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory from synapse.crypto import context_factory
@ -69,6 +70,9 @@ class SynapseHomeServer(HomeServer):
self, self.upload_dir, self.auth, self.content_addr self, self.upload_dir, self.auth, self.content_addr
) )
def build_resource_for_media_repository(self):
return MediaRepositoryResource(self)
def build_resource_for_server_key(self): def build_resource_for_server_key(self):
return LocalKey(self) return LocalKey(self)
@ -99,6 +103,7 @@ class SynapseHomeServer(HomeServer):
(FEDERATION_PREFIX, self.get_resource_for_federation()), (FEDERATION_PREFIX, self.get_resource_for_federation()),
(CONTENT_REPO_PREFIX, self.get_resource_for_content_repo()), (CONTENT_REPO_PREFIX, self.get_resource_for_content_repo()),
(SERVER_KEY_PREFIX, self.get_resource_for_server_key()), (SERVER_KEY_PREFIX, self.get_resource_for_server_key()),
(MEDIA_PREFIX, self.get_resource_for_media_repository()),
] ]
if web_client: if web_client:
logger.info("Adding the web client.") logger.info("Adding the web client.")
@ -223,8 +228,15 @@ def setup():
logger.info("Preparing database: %s...", db_name) logger.info("Preparing database: %s...", db_name)
with sqlite3.connect(db_name) as db_conn: try:
prepare_database(db_conn) with sqlite3.connect(db_name) as db_conn:
prepare_database(db_conn)
except UpgradeDatabaseException:
sys.stderr.write(
"\nFailed to upgrade database.\n"
"Have you followed any instructions in UPGRADES.rst?\n"
)
sys.exit(1)
logger.info("Database prepared in %s.", db_name) logger.info("Database prepared in %s.", db_name)

View File

@ -50,12 +50,26 @@ class Config(object):
) )
return cls.abspath(file_path) return cls.abspath(file_path)
@staticmethod
def ensure_directory(dir_path):
if not os.path.exists(dir_path):
os.makedirs(dir_path)
if not os.path.isdir(dir_path):
raise ConfigError(
"%s is not a directory" % (dir_path,)
)
return dir_path
@classmethod @classmethod
def read_file(cls, file_path, config_name): def read_file(cls, file_path, config_name):
cls.check_file(file_path, config_name) cls.check_file(file_path, config_name)
with open(file_path) as file_stream: with open(file_path) as file_stream:
return file_stream.read() return file_stream.read()
@staticmethod
def default_path(name):
return os.path.abspath(os.path.join(os.path.curdir, name))
@staticmethod @staticmethod
def read_config_file(file_path): def read_config_file(file_path):
with open(file_path) as file_stream: with open(file_path) as file_stream:

View File

@ -36,7 +36,7 @@ class LoggingConfig(Config):
help="The verbosity level." help="The verbosity level."
) )
logging_group.add_argument( logging_group.add_argument(
'-f', '--log-file', dest="log_file", default=None, '-f', '--log-file', dest="log_file", default="homeserver.log",
help="File to log to." help="File to log to."
) )
logging_group.add_argument( logging_group.add_argument(

View File

@ -20,6 +20,8 @@ class ContentRepositoryConfig(Config):
def __init__(self, args): def __init__(self, args):
super(ContentRepositoryConfig, self).__init__(args) super(ContentRepositoryConfig, self).__init__(args)
self.max_upload_size = self.parse_size(args.max_upload_size) self.max_upload_size = self.parse_size(args.max_upload_size)
self.max_image_pixels = self.parse_size(args.max_image_pixels)
self.media_store_path = self.ensure_directory(args.media_store_path)
def parse_size(self, string): def parse_size(self, string):
sizes = {"K": 1024, "M": 1024 * 1024} sizes = {"K": 1024, "M": 1024 * 1024}
@ -37,3 +39,10 @@ class ContentRepositoryConfig(Config):
db_group.add_argument( db_group.add_argument(
"--max-upload-size", default="1M" "--max-upload-size", default="1M"
) )
db_group.add_argument(
"--media-store-path", default=cls.default_path("media_store")
)
db_group.add_argument(
"--max-image-pixels", default="32M",
help="Maximum number of pixels that will be thumbnailed"
)

View File

@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
from synapse.api.events.utils import prune_event from synapse.events.utils import prune_event
from syutil.jsonutil import encode_canonical_json from syutil.jsonutil import encode_canonical_json
from syutil.base64util import encode_base64, decode_base64 from syutil.base64util import encode_base64, decode_base64
from syutil.crypto.jsonsign import sign_json from syutil.crypto.jsonsign import sign_json
@ -29,17 +29,17 @@ logger = logging.getLogger(__name__)
def check_event_content_hash(event, hash_algorithm=hashlib.sha256): def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
"""Check whether the hash for this PDU matches the contents""" """Check whether the hash for this PDU matches the contents"""
computed_hash = _compute_content_hash(event, hash_algorithm) name, expected_hash = compute_content_hash(event, hash_algorithm)
logger.debug("Expecting hash: %s", encode_base64(computed_hash.digest())) logger.debug("Expecting hash: %s", encode_base64(expected_hash))
if computed_hash.name not in event.hashes: if name not in event.hashes:
raise SynapseError( raise SynapseError(
400, 400,
"Algorithm %s not in hashes %s" % ( "Algorithm %s not in hashes %s" % (
computed_hash.name, list(event.hashes), name, list(event.hashes),
), ),
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
) )
message_hash_base64 = event.hashes[computed_hash.name] message_hash_base64 = event.hashes[name]
try: try:
message_hash_bytes = decode_base64(message_hash_base64) message_hash_bytes = decode_base64(message_hash_base64)
except: except:
@ -48,10 +48,10 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
"Invalid base64: %s" % (message_hash_base64,), "Invalid base64: %s" % (message_hash_base64,),
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
) )
return message_hash_bytes == computed_hash.digest() return message_hash_bytes == expected_hash
def _compute_content_hash(event, hash_algorithm): def compute_content_hash(event, hash_algorithm):
event_json = event.get_pdu_json() event_json = event.get_pdu_json()
event_json.pop("age_ts", None) event_json.pop("age_ts", None)
event_json.pop("unsigned", None) event_json.pop("unsigned", None)
@ -59,8 +59,11 @@ def _compute_content_hash(event, hash_algorithm):
event_json.pop("hashes", None) event_json.pop("hashes", None)
event_json.pop("outlier", None) event_json.pop("outlier", None)
event_json.pop("destinations", None) event_json.pop("destinations", None)
event_json_bytes = encode_canonical_json(event_json) event_json_bytes = encode_canonical_json(event_json)
return hash_algorithm(event_json_bytes)
hashed = hash_algorithm(event_json_bytes)
return (hashed.name, hashed.digest())
def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256): def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256):
@ -79,27 +82,28 @@ def compute_event_signature(event, signature_name, signing_key):
redact_json = tmp_event.get_pdu_json() redact_json = tmp_event.get_pdu_json()
redact_json.pop("age_ts", None) redact_json.pop("age_ts", None)
redact_json.pop("unsigned", None) redact_json.pop("unsigned", None)
logger.debug("Signing event: %s", redact_json) logger.debug("Signing event: %s", encode_canonical_json(redact_json))
redact_json = sign_json(redact_json, signature_name, signing_key) redact_json = sign_json(redact_json, signature_name, signing_key)
logger.debug("Signed event: %s", encode_canonical_json(redact_json))
return redact_json["signatures"] return redact_json["signatures"]
def add_hashes_and_signatures(event, signature_name, signing_key, def add_hashes_and_signatures(event, signature_name, signing_key,
hash_algorithm=hashlib.sha256): hash_algorithm=hashlib.sha256):
if hasattr(event, "old_state_events"): # if hasattr(event, "old_state_events"):
state_json_bytes = encode_canonical_json( # state_json_bytes = encode_canonical_json(
[e.event_id for e in event.old_state_events.values()] # [e.event_id for e in event.old_state_events.values()]
) # )
hashed = hash_algorithm(state_json_bytes) # hashed = hash_algorithm(state_json_bytes)
event.state_hash = { # event.state_hash = {
hashed.name: encode_base64(hashed.digest()) # hashed.name: encode_base64(hashed.digest())
} # }
hashed = _compute_content_hash(event, hash_algorithm=hash_algorithm) name, digest = compute_content_hash(event, hash_algorithm=hash_algorithm)
if not hasattr(event, "hashes"): if not hasattr(event, "hashes"):
event.hashes = {} event.hashes = {}
event.hashes[hashed.name] = encode_base64(hashed.digest()) event.hashes[name] = encode_base64(digest)
event.signatures = compute_event_signature( event.signatures = compute_event_signature(
event, event,

149
synapse/events/__init__.py Normal file
View File

@ -0,0 +1,149 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util.frozenutils import freeze, unfreeze
import copy
class _EventInternalMetadata(object):
def __init__(self, internal_metadata_dict):
self.__dict__ = copy.deepcopy(internal_metadata_dict)
def get_dict(self):
return dict(self.__dict__)
def is_outlier(self):
return hasattr(self, "outlier") and self.outlier
def _event_dict_property(key):
def getter(self):
return self._event_dict[key]
def setter(self, v):
self._event_dict[key] = v
def delete(self):
del self._event_dict[key]
return property(
getter,
setter,
delete,
)
class EventBase(object):
def __init__(self, event_dict, signatures={}, unsigned={},
internal_metadata_dict={}):
self.signatures = copy.deepcopy(signatures)
self.unsigned = copy.deepcopy(unsigned)
self._event_dict = copy.deepcopy(event_dict)
self.internal_metadata = _EventInternalMetadata(
internal_metadata_dict
)
auth_events = _event_dict_property("auth_events")
depth = _event_dict_property("depth")
content = _event_dict_property("content")
event_id = _event_dict_property("event_id")
hashes = _event_dict_property("hashes")
origin = _event_dict_property("origin")
origin_server_ts = _event_dict_property("origin_server_ts")
prev_events = _event_dict_property("prev_events")
prev_state = _event_dict_property("prev_state")
redacts = _event_dict_property("redacts")
room_id = _event_dict_property("room_id")
sender = _event_dict_property("sender")
state_key = _event_dict_property("state_key")
type = _event_dict_property("type")
user_id = _event_dict_property("sender")
@property
def membership(self):
return self.content["membership"]
def is_state(self):
return hasattr(self, "state_key")
def get_dict(self):
d = dict(self._event_dict)
d.update({
"signatures": self.signatures,
"unsigned": self.unsigned,
})
return d
def get(self, key, default):
return self._event_dict.get(key, default)
def get_internal_metadata_dict(self):
return self.internal_metadata.get_dict()
def get_pdu_json(self, time_now=None):
pdu_json = self.get_dict()
if time_now is not None and "age_ts" in pdu_json["unsigned"]:
age = time_now - pdu_json["unsigned"]["age_ts"]
pdu_json.setdefault("unsigned", {})["age"] = int(age)
del pdu_json["unsigned"]["age_ts"]
return pdu_json
def __set__(self, instance, value):
raise AttributeError("Unrecognized attribute %s" % (instance,))
class FrozenEvent(EventBase):
def __init__(self, event_dict, internal_metadata_dict={}):
event_dict = copy.deepcopy(event_dict)
signatures = copy.deepcopy(event_dict.pop("signatures", {}))
unsigned = copy.deepcopy(event_dict.pop("unsigned", {}))
frozen_dict = freeze(event_dict)
super(FrozenEvent, self).__init__(
frozen_dict,
signatures=signatures,
unsigned=unsigned,
internal_metadata_dict=internal_metadata_dict,
)
@staticmethod
def from_event(event):
e = FrozenEvent(
event.get_pdu_json()
)
e.internal_metadata = event.internal_metadata
return e
def get_dict(self):
# We need to unfreeze what we return
return unfreeze(super(FrozenEvent, self).get_dict())
def __str__(self):
return self.__repr__()
def __repr__(self):
return "<FrozenEvent event_id='%s', type='%s', state_key='%s'>" % (
self.event_id, self.type, self.get("state_key", None),
)

77
synapse/events/builder.py Normal file
View File

@ -0,0 +1,77 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import EventBase, FrozenEvent
from synapse.types import EventID
from synapse.util.stringutils import random_string
import copy
class EventBuilder(EventBase):
def __init__(self, key_values={}):
signatures = copy.deepcopy(key_values.pop("signatures", {}))
unsigned = copy.deepcopy(key_values.pop("unsigned", {}))
super(EventBuilder, self).__init__(
key_values,
signatures=signatures,
unsigned=unsigned
)
def update_event_key(self, key, value):
self._event_dict[key] = value
def update_event_keys(self, other_dict):
self._event_dict.update(other_dict)
def build(self):
return FrozenEvent.from_event(self)
class EventBuilderFactory(object):
def __init__(self, clock, hostname):
self.clock = clock
self.hostname = hostname
self.event_id_count = 0
def create_event_id(self):
i = str(self.event_id_count)
self.event_id_count += 1
local_part = str(int(self.clock.time())) + i + random_string(5)
e_id = EventID.create(local_part, self.hostname)
return e_id.to_string()
def new(self, key_values={}):
key_values["event_id"] = self.create_event_id()
time_now = int(self.clock.time_msec())
key_values.setdefault("origin", self.hostname)
key_values.setdefault("origin_server_ts", time_now)
key_values.setdefault("unsigned", {})
age = key_values["unsigned"].pop("age", 0)
key_values["unsigned"].setdefault("age_ts", time_now - age)
key_values["signatures"] = {}
return EventBuilder(key_values=key_values,)

View File

@ -13,3 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
class EventContext(object):
def __init__(self, current_state=None, auth_events=None):
self.current_state = current_state
self.auth_events = auth_events
self.state_group = None

View File

@ -13,10 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .room import ( from synapse.api.constants import EventTypes
RoomMemberEvent, RoomJoinRulesEvent, RoomPowerLevelsEvent, from . import EventBase
RoomAliasesEvent, RoomCreateEvent,
)
def prune_event(event): def prune_event(event):
@ -31,7 +29,7 @@ def prune_event(event):
allowed_keys = [ allowed_keys = [
"event_id", "event_id",
"user_id", "sender",
"room_id", "room_id",
"hashes", "hashes",
"signatures", "signatures",
@ -44,6 +42,7 @@ def prune_event(event):
"auth_events", "auth_events",
"origin", "origin",
"origin_server_ts", "origin_server_ts",
"membership",
] ]
new_content = {} new_content = {}
@ -53,13 +52,13 @@ def prune_event(event):
if field in event.content: if field in event.content:
new_content[field] = event.content[field] new_content[field] = event.content[field]
if event_type == RoomMemberEvent.TYPE: if event_type == EventTypes.Member:
add_fields("membership") add_fields("membership")
elif event_type == RoomCreateEvent.TYPE: elif event_type == EventTypes.Create:
add_fields("creator") add_fields("creator")
elif event_type == RoomJoinRulesEvent.TYPE: elif event_type == EventTypes.JoinRules:
add_fields("join_rule") add_fields("join_rule")
elif event_type == RoomPowerLevelsEvent.TYPE: elif event_type == EventTypes.PowerLevels:
add_fields( add_fields(
"users", "users",
"users_default", "users_default",
@ -71,15 +70,64 @@ def prune_event(event):
"kick", "kick",
"redact", "redact",
) )
elif event_type == RoomAliasesEvent.TYPE: elif event_type == EventTypes.Aliases:
add_fields("aliases") add_fields("aliases")
allowed_fields = { allowed_fields = {
k: v k: v
for k, v in event.get_full_dict().items() for k, v in event.get_dict().items()
if k in allowed_keys if k in allowed_keys
} }
allowed_fields["content"] = new_content allowed_fields["content"] = new_content
return type(event)(**allowed_fields) allowed_fields["unsigned"] = {}
if "age_ts" in event.unsigned:
allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"]
return type(event)(allowed_fields)
def serialize_event(hs, e):
# FIXME(erikj): To handle the case of presence events and the like
if not isinstance(e, EventBase):
return e
# Should this strip out None's?
d = {k: v for k, v in e.get_dict().items()}
if "age_ts" in d["unsigned"]:
now = int(hs.get_clock().time_msec())
d["unsigned"]["age"] = now - d["unsigned"]["age_ts"]
del d["unsigned"]["age_ts"]
d["user_id"] = d.pop("sender", None)
if "redacted_because" in e.unsigned:
d["redacted_because"] = serialize_event(
hs, e.unsigned["redacted_because"]
)
del d["unsigned"]["redacted_because"]
if "redacted_by" in e.unsigned:
d["redacted_by"] = e.unsigned["redacted_by"]
del d["unsigned"]["redacted_by"]
if "replaces_state" in e.unsigned:
d["replaces_state"] = e.unsigned["replaces_state"]
del d["unsigned"]["replaces_state"]
if "prev_content" in e.unsigned:
d["prev_content"] = e.unsigned["prev_content"]
del d["unsigned"]["prev_content"]
del d["auth_events"]
del d["prev_events"]
del d["hashes"]
del d["signatures"]
d.pop("depth", None)
d.pop("unsigned", None)
d.pop("origin", None)
return d

View File

@ -0,0 +1,92 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.types import EventID, RoomID, UserID
from synapse.api.errors import SynapseError
from synapse.api.constants import EventTypes, Membership
class EventValidator(object):
def validate(self, event):
EventID.from_string(event.event_id)
RoomID.from_string(event.room_id)
required = [
# "auth_events",
"content",
# "hashes",
"origin",
# "prev_events",
"sender",
"type",
]
for k in required:
if not hasattr(event, k):
raise SynapseError(400, "Event does not have key %s" % (k,))
# Check that the following keys have string values
strings = [
"origin",
"sender",
"type",
]
if hasattr(event, "state_key"):
strings.append("state_key")
for s in strings:
if not isinstance(getattr(event, s), basestring):
raise SynapseError(400, "Not '%s' a string type" % (s,))
if event.type == EventTypes.Member:
if "membership" not in event.content:
raise SynapseError(400, "Content has not membership key")
if event.content["membership"] not in Membership.LIST:
raise SynapseError(400, "Invalid membership key")
# Check that the following keys have dictionary values
# TODO
# Check that the following keys have the correct format for DAGs
# TODO
def validate_new(self, event):
self.validate(event)
UserID.from_string(event.sender)
if event.type == EventTypes.Message:
strings = [
"body",
"msgtype",
]
self._ensure_strings(event.content, strings)
elif event.type == EventTypes.Topic:
self._ensure_strings(event.content, ["topic"])
elif event.type == EventTypes.Name:
self._ensure_strings(event.content, ["name"])
def _ensure_strings(self, d, keys):
for s in keys:
if s not in d:
raise SynapseError(400, "'%s' not in content" % (s,))
if not isinstance(d[s], basestring):
raise SynapseError(400, "Not '%s' a string type" % (s,))

View File

@ -25,6 +25,7 @@ from .persistence import TransactionActions
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.events import FrozenEvent
import logging import logging
@ -73,7 +74,7 @@ class ReplicationLayer(object):
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.event_factory = hs.get_event_factory() self.event_builder_factory = hs.get_event_builder_factory()
def set_handler(self, handler): def set_handler(self, handler):
"""Sets the handler that the replication layer will use to communicate """Sets the handler that the replication layer will use to communicate
@ -112,7 +113,7 @@ class ReplicationLayer(object):
self.query_handlers[query_type] = handler self.query_handlers[query_type] = handler
@log_function @log_function
def send_pdu(self, pdu): def send_pdu(self, pdu, destinations):
"""Informs the replication layer about a new PDU generated within the """Informs the replication layer about a new PDU generated within the
home server that should be transmitted to others. home server that should be transmitted to others.
@ -131,7 +132,7 @@ class ReplicationLayer(object):
logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id) logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
# TODO, add errback, etc. # TODO, add errback, etc.
self._transaction_queue.enqueue_pdu(pdu, order) self._transaction_queue.enqueue_pdu(pdu, destinations, order)
logger.debug( logger.debug(
"[%s] transaction_layer.enqueue_pdu... done", "[%s] transaction_layer.enqueue_pdu... done",
@ -255,31 +256,35 @@ class ReplicationLayer(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_state_for_context(self, destination, context, event_id=None): def get_state_for_context(self, destination, context, event_id):
"""Requests all of the `current` state PDUs for a given context from """Requests all of the `current` state PDUs for a given context from
a remote home server. a remote home server.
Args: Args:
destination (str): The remote homeserver to query for the state. destination (str): The remote homeserver to query for the state.
context (str): The context we're interested in. context (str): The context we're interested in.
event_id (str): The id of the event we want the state at.
Returns: Returns:
Deferred: Results in a list of PDUs. Deferred: Results in a list of PDUs.
""" """
transaction_data = yield self.transport_layer.get_context_state( result = yield self.transport_layer.get_context_state(
destination, destination,
context, context,
event_id=event_id, event_id=event_id,
) )
transaction = Transaction(**transaction_data)
pdus = [ pdus = [
self.event_from_pdu_json(p, outlier=True) self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
for p in transaction.pdus
] ]
defer.returnValue(pdus) auth_chain = [
self.event_from_pdu_json(p, outlier=True)
for p in result.get("auth_chain", [])
]
defer.returnValue((pdus, auth_chain))
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -334,7 +339,7 @@ class ReplicationLayer(object):
defer.returnValue(response) defer.returnValue(response)
return return
logger.debug("[%s] Transacition is new", transaction.transaction_id) logger.debug("[%s] Transaction is new", transaction.transaction_id)
with PreserveLoggingContext(): with PreserveLoggingContext():
dl = [] dl = []
@ -382,10 +387,16 @@ class ReplicationLayer(object):
context, context,
event_id, event_id,
) )
auth_chain = yield self.store.get_auth_chain(
[pdu.event_id for pdu in pdus]
)
else: else:
raise NotImplementedError("Specify an event") raise NotImplementedError("Specify an event")
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) defer.returnValue((200, {
"pdus": [pdu.get_pdu_json() for pdu in pdus],
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
}))
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -438,7 +449,9 @@ class ReplicationLayer(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_send_join_request(self, origin, content): def on_send_join_request(self, origin, content):
logger.debug("on_send_join_request: content: %s", content)
pdu = self.event_from_pdu_json(content) pdu = self.event_from_pdu_json(content)
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
res_pdus = yield self.handler.on_send_join_request(origin, pdu) res_pdus = yield self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
defer.returnValue((200, { defer.returnValue((200, {
@ -557,13 +570,21 @@ class ReplicationLayer(object):
origin, pdu.event_id, do_auth=False origin, pdu.event_id, do_auth=False
) )
if existing and (not existing.outlier or pdu.outlier): already_seen = (
existing and (
not existing.internal_metadata.is_outlier()
or pdu.internal_metadata.is_outlier()
)
)
if already_seen:
logger.debug("Already seen pdu %s", pdu.event_id) logger.debug("Already seen pdu %s", pdu.event_id)
defer.returnValue({}) defer.returnValue({})
return return
state = None state = None
auth_chain = []
# We need to make sure we have all the auth events. # We need to make sure we have all the auth events.
# for e_id, _ in pdu.auth_events: # for e_id, _ in pdu.auth_events:
# exists = yield self._get_persisted_pdu( # exists = yield self._get_persisted_pdu(
@ -595,7 +616,7 @@ class ReplicationLayer(object):
# ) # )
# Get missing pdus if necessary. # Get missing pdus if necessary.
if not pdu.outlier: if not pdu.internal_metadata.is_outlier():
# We only backfill backwards to the min depth. # We only backfill backwards to the min depth.
min_depth = yield self.handler.get_min_depth_for_context( min_depth = yield self.handler.get_min_depth_for_context(
pdu.room_id pdu.room_id
@ -636,7 +657,7 @@ class ReplicationLayer(object):
"_handle_new_pdu getting state for %s", "_handle_new_pdu getting state for %s",
pdu.room_id pdu.room_id
) )
state = yield self.get_state_for_context( state, auth_chain = yield self.get_state_for_context(
origin, pdu.room_id, pdu.event_id, origin, pdu.room_id, pdu.event_id,
) )
@ -646,6 +667,7 @@ class ReplicationLayer(object):
pdu, pdu,
backfilled=backfilled, backfilled=backfilled,
state=state, state=state,
auth_chain=auth_chain,
) )
else: else:
ret = None ret = None
@ -658,19 +680,14 @@ class ReplicationLayer(object):
return "<ReplicationLayer(%s)>" % self.server_name return "<ReplicationLayer(%s)>" % self.server_name
def event_from_pdu_json(self, pdu_json, outlier=False): def event_from_pdu_json(self, pdu_json, outlier=False):
#TODO: Check we have all the PDU keys here event = FrozenEvent(
pdu_json.setdefault("hashes", {}) pdu_json
pdu_json.setdefault("signatures", {})
sender = pdu_json.pop("sender", None)
if sender is not None:
pdu_json["user_id"] = sender
state_hash = pdu_json.get("unsigned", {}).pop("state_hash", None)
if state_hash is not None:
pdu_json["state_hash"] = state_hash
return self.event_factory.create_event(
pdu_json["type"], outlier=outlier, **pdu_json
) )
event.internal_metadata.outlier = outlier
return event
class _TransactionQueue(object): class _TransactionQueue(object):
"""This class makes sure we only have one transaction in flight at """This class makes sure we only have one transaction in flight at
@ -685,6 +702,7 @@ class _TransactionQueue(object):
self.transport_layer = transport_layer self.transport_layer = transport_layer
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.store = hs.get_datastore()
# Is a mapping from destinations -> deferreds. Used to keep track # Is a mapping from destinations -> deferreds. Used to keep track
# of which destinations have transactions in flight and when they are # of which destinations have transactions in flight and when they are
@ -705,15 +723,13 @@ class _TransactionQueue(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def enqueue_pdu(self, pdu, order): def enqueue_pdu(self, pdu, destinations, order):
# We loop through all destinations to see whether we already have # We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus # a transaction in progress. If we do, stick it in the pending_pdus
# table and we'll get back to it later. # table and we'll get back to it later.
destinations = set([ destinations = set(destinations)
d for d in pdu.destinations destinations.discard(self.server_name)
if d != self.server_name
])
logger.debug("Sending to: %s", str(destinations)) logger.debug("Sending to: %s", str(destinations))
@ -728,8 +744,14 @@ class _TransactionQueue(object):
(pdu, deferred, order) (pdu, deferred, order)
) )
def eb(failure):
if not deferred.called:
deferred.errback(failure)
else:
logger.warn("Failed to send pdu", failure)
with PreserveLoggingContext(): with PreserveLoggingContext():
self._attempt_new_transaction(destination) self._attempt_new_transaction(destination).addErrback(eb)
deferreds.append(deferred) deferreds.append(deferred)
@ -739,6 +761,9 @@ class _TransactionQueue(object):
def enqueue_edu(self, edu): def enqueue_edu(self, edu):
destination = edu.destination destination = edu.destination
if destination == self.server_name:
return
deferred = defer.Deferred() deferred = defer.Deferred()
self.pending_edus_by_dest.setdefault(destination, []).append( self.pending_edus_by_dest.setdefault(destination, []).append(
(edu, deferred) (edu, deferred)
@ -748,7 +773,7 @@ class _TransactionQueue(object):
if not deferred.called: if not deferred.called:
deferred.errback(failure) deferred.errback(failure)
else: else:
logger.exception("Failed to send edu", failure) logger.warn("Failed to send edu", failure)
with PreserveLoggingContext(): with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(eb) self._attempt_new_transaction(destination).addErrback(eb)
@ -770,10 +795,33 @@ class _TransactionQueue(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def _attempt_new_transaction(self, destination): def _attempt_new_transaction(self, destination):
(retry_last_ts, retry_interval) = (0, 0)
retry_timings = yield self.store.get_destination_retry_timings(
destination
)
if retry_timings:
(retry_last_ts, retry_interval) = (
retry_timings.retry_last_ts, retry_timings.retry_interval
)
if retry_last_ts + retry_interval > int(self._clock.time_msec()):
logger.info(
"TX [%s] not ready for retry yet - "
"dropping transaction for now",
destination,
)
return
else:
logger.info("TX [%s] is ready for retry", destination)
if destination in self.pending_transactions: if destination in self.pending_transactions:
# XXX: pending_transactions can get stuck on by a never-ending
# request at which point pending_pdus_by_dest just keeps growing.
# we need application-layer timeouts of some flavour of these
# requests
return return
# list of (pending_pdu, deferred, order) # list of (pending_pdu, deferred, order)
pending_pdus = self.pending_pdus_by_dest.pop(destination, []) pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, []) pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_failures = self.pending_failures_by_dest.pop(destination, []) pending_failures = self.pending_failures_by_dest.pop(destination, [])
@ -781,7 +829,14 @@ class _TransactionQueue(object):
if not pending_pdus and not pending_edus and not pending_failures: if not pending_pdus and not pending_edus and not pending_failures:
return return
logger.debug("TX [%s] Attempting new transaction", destination) logger.debug(
"TX [%s] Attempting new transaction "
"(pdus: %d, edus: %d, failures: %d)",
destination,
len(pending_pdus),
len(pending_edus),
len(pending_failures)
)
# Sort based on the order field # Sort based on the order field
pending_pdus.sort(key=lambda t: t[2]) pending_pdus.sort(key=lambda t: t[2])
@ -814,7 +869,11 @@ class _TransactionQueue(object):
yield self.transaction_actions.prepare_to_send(transaction) yield self.transaction_actions.prepare_to_send(transaction)
logger.debug("TX [%s] Persisted transaction", destination) logger.debug("TX [%s] Persisted transaction", destination)
logger.debug("TX [%s] Sending transaction...", destination) logger.info(
"TX [%s] Sending transaction [%s]",
destination,
transaction.transaction_id,
)
# Actually send the transaction # Actually send the transaction
@ -835,6 +894,8 @@ class _TransactionQueue(object):
transaction, json_data_cb transaction, json_data_cb
) )
logger.info("TX [%s] got %d response", destination, code)
logger.debug("TX [%s] Sent transaction", destination) logger.debug("TX [%s] Sent transaction", destination)
logger.debug("TX [%s] Marking as delivered...", destination) logger.debug("TX [%s] Marking as delivered...", destination)
@ -847,8 +908,14 @@ class _TransactionQueue(object):
for deferred in deferreds: for deferred in deferreds:
if code == 200: if code == 200:
if retry_last_ts:
# this host is alive! reset retry schedule
yield self.store.set_destination_retry_timings(
destination, 0, 0
)
deferred.callback(None) deferred.callback(None)
else: else:
self.set_retrying(destination, retry_interval)
deferred.errback(RuntimeError("Got status %d" % code)) deferred.errback(RuntimeError("Got status %d" % code))
# Ensures we don't continue until all callbacks on that # Ensures we don't continue until all callbacks on that
@ -861,11 +928,15 @@ class _TransactionQueue(object):
logger.debug("TX [%s] Yielded to callbacks", destination) logger.debug("TX [%s] Yielded to callbacks", destination)
except Exception as e: except Exception as e:
logger.error("TX Problem in _attempt_transaction")
# We capture this here as there as nothing actually listens # We capture this here as there as nothing actually listens
# for this finishing functions deferred. # for this finishing functions deferred.
logger.exception(e) logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
self.set_retrying(destination, retry_interval)
for deferred in deferreds: for deferred in deferreds:
if not deferred.called: if not deferred.called:
@ -877,3 +948,22 @@ class _TransactionQueue(object):
# Check to see if there is anything else to send. # Check to see if there is anything else to send.
self._attempt_new_transaction(destination) self._attempt_new_transaction(destination)
@defer.inlineCallbacks
def set_retrying(self, destination, retry_interval):
# track that this destination is having problems and we should
# give it a chance to recover before trying it again
if retry_interval:
retry_interval *= 2
# plateau at hourly retries for now
if retry_interval >= 60 * 60 * 1000:
retry_interval = 60 * 60 * 1000
else:
retry_interval = 2000 # try again at first after 2 seconds
yield self.store.set_destination_retry_timings(
destination,
int(self._clock.time_msec()),
retry_interval
)

View File

@ -155,7 +155,7 @@ class TransportLayer(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def send_transaction(self, transaction, json_data_callback=None): def send_transaction(self, transaction, json_data_callback=None):
""" Sends the given Transaction to it's destination """ Sends the given Transaction to its destination
Args: Args:
transaction (Transaction) transaction (Transaction)

View File

@ -15,11 +15,10 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import LimitExceededError from synapse.api.errors import LimitExceededError, SynapseError
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.api.events.room import RoomMemberEvent from synapse.api.constants import Membership, EventTypes
from synapse.api.constants import Membership
import logging import logging
@ -31,10 +30,8 @@ class BaseHandler(object):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.event_factory = hs.get_event_factory()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.room_lock = hs.get_room_lock_manager()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
self.distributor = hs.get_distributor() self.distributor = hs.get_distributor()
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
@ -44,6 +41,8 @@ class BaseHandler(object):
self.signing_key = hs.config.signing_key[0] self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname self.server_name = hs.hostname
self.event_builder_factory = hs.get_event_builder_factory()
def ratelimit(self, user_id): def ratelimit(self, user_id):
time_now = self.clock.time() time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.send_message( allowed, time_allowed = self.ratelimiter.send_message(
@ -57,62 +56,95 @@ class BaseHandler(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _on_new_room_event(self, event, snapshot, extra_destinations=[], def _create_new_client_event(self, builder):
extra_users=[], suppress_auth=False,
do_invite_host=None):
yield run_on_reactor() yield run_on_reactor()
snapshot.fill_out_prev_events(event) latest_ret = yield self.store.get_latest_events_in_room(
builder.room_id,
yield self.state_handler.annotate_event_with_state(event)
yield self.auth.add_auth_events(event)
logger.debug("Signing event...")
add_hashes_and_signatures(
event, self.server_name, self.signing_key
) )
logger.debug("Signed event.") if latest_ret:
depth = max([d for _, _, d in latest_ret]) + 1
else:
depth = 1
prev_events = [(e, h) for e, h, _ in latest_ret]
builder.prev_events = prev_events
builder.depth = depth
state_handler = self.state_handler
context = yield state_handler.compute_event_context(builder)
if builder.is_state():
builder.prev_state = context.prev_state_events
yield self.auth.add_auth_events(builder, context)
add_hashes_and_signatures(
builder, self.server_name, self.signing_key
)
event = builder.build()
logger.debug(
"Created event %s with auth_events: %s, current state: %s",
event.event_id, context.auth_events, context.current_state,
)
defer.returnValue(
(event, context,)
)
@defer.inlineCallbacks
def handle_new_client_event(self, event, context, extra_destinations=[],
extra_users=[], suppress_auth=False):
yield run_on_reactor()
# We now need to go and hit out to wherever we need to hit out to.
if not suppress_auth: if not suppress_auth:
logger.debug("Authing...") self.auth.check(event, auth_events=context.auth_events)
self.auth.check(event, auth_events=event.old_state_events)
logger.debug("Authed")
else:
logger.debug("Suppressed auth.")
if do_invite_host: yield self.store.persist_event(event, context=context)
federation_handler = self.hs.get_handlers().federation_handler
invite_event = yield federation_handler.send_invite(
do_invite_host,
event
)
# FIXME: We need to check if the remote changed anything else federation_handler = self.hs.get_handlers().federation_handler
event.signatures = invite_event.signatures
yield self.store.persist_event(event) if event.type == EventTypes.Member:
if event.content["membership"] == Membership.INVITE:
invitee = self.hs.parse_userid(event.state_key)
if not self.hs.is_mine(invitee):
# TODO: Can we add signature from remote server in a nicer
# way? If we have been invited by a remote server, we need
# to get them to sign the event.
returned_invite = yield federation_handler.send_invite(
invitee.domain,
event,
)
# TODO: Make sure the signatures actually are correct.
event.signatures.update(
returned_invite.signatures
)
destinations = set(extra_destinations) destinations = set(extra_destinations)
# Send a PDU to all hosts who have joined the room. for k, s in context.current_state.items():
for k, s in event.state_events.items():
try: try:
if k[0] == RoomMemberEvent.TYPE: if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.JOIN: if s.content["membership"] == Membership.JOIN:
destinations.add( destinations.add(
self.hs.parse_userid(s.state_key).domain self.hs.parse_userid(s.state_key).domain
) )
except: except SynapseError:
logger.warn( logger.warn(
"Failed to get destination from event %s", s.event_id "Failed to get destination from event %s", s.event_id
) )
event.destinations = list(destinations)
yield self.notifier.on_new_room_event(event, extra_users=extra_users) yield self.notifier.on_new_room_event(event, extra_users=extra_users)
federation_handler = self.hs.get_handlers().federation_handler yield federation_handler.handle_new_event(
yield federation_handler.handle_new_event(event, snapshot) event,
None,
destinations=destinations,
)

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.errors import SynapseError, Codes, CodeMessageException from synapse.api.errors import SynapseError, Codes, CodeMessageException
from synapse.api.events.room import RoomAliasesEvent from synapse.api.constants import EventTypes
import logging import logging
@ -40,7 +40,7 @@ class DirectoryHandler(BaseHandler):
# TODO(erikj): Do auth. # TODO(erikj): Do auth.
if not room_alias.is_mine: if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local") raise SynapseError(400, "Room alias must be local")
# TODO(erikj): Change this. # TODO(erikj): Change this.
@ -64,7 +64,7 @@ class DirectoryHandler(BaseHandler):
def delete_association(self, user_id, room_alias): def delete_association(self, user_id, room_alias):
# TODO Check if server admin # TODO Check if server admin
if not room_alias.is_mine: if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local") raise SynapseError(400, "Room alias must be local")
room_id = yield self.store.delete_room_alias(room_alias) room_id = yield self.store.delete_room_alias(room_alias)
@ -75,7 +75,7 @@ class DirectoryHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_association(self, room_alias): def get_association(self, room_alias):
room_id = None room_id = None
if room_alias.is_mine: if self.hs.is_mine(room_alias):
result = yield self.store.get_association_from_room_alias( result = yield self.store.get_association_from_room_alias(
room_alias room_alias
) )
@ -123,7 +123,7 @@ class DirectoryHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_directory_query(self, args): def on_directory_query(self, args):
room_alias = self.hs.parse_roomalias(args["room_alias"]) room_alias = self.hs.parse_roomalias(args["room_alias"])
if not room_alias.is_mine: if not self.hs.is_mine(room_alias):
raise SynapseError( raise SynapseError(
400, "Room Alias is not hosted on this Home Server" 400, "Room Alias is not hosted on this Home Server"
) )
@ -148,16 +148,11 @@ class DirectoryHandler(BaseHandler):
def send_room_alias_update_event(self, user_id, room_id): def send_room_alias_update_event(self, user_id, room_id):
aliases = yield self.store.get_aliases_for_room(room_id) aliases = yield self.store.get_aliases_for_room(room_id)
event = self.event_factory.create_event( msg_handler = self.hs.get_handlers().message_handler
etype=RoomAliasesEvent.TYPE, yield msg_handler.create_and_send_event({
state_key=self.hs.hostname, "type": EventTypes.Aliases,
room_id=room_id, "state_key": self.hs.hostname,
user_id=user_id, "room_id": room_id,
content={"aliases": aliases}, "sender": user_id,
) "content": {"aliases": aliases},
})
snapshot = yield self.store.snapshot_room(event)
yield self._on_new_room_event(
event, snapshot, extra_users=[user_id], suppress_auth=True
)

View File

@ -17,12 +17,11 @@
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, FederationError, SynapseError, StoreError, AuthError, FederationError, SynapseError, StoreError,
) )
from synapse.api.events.room import RoomMemberEvent, RoomCreateEvent from synapse.api.constants import EventTypes, Membership
from synapse.api.constants import Membership
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.crypto.event_signing import ( from synapse.crypto.event_signing import (
@ -76,7 +75,7 @@ class FederationHandler(BaseHandler):
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_new_event(self, event, snapshot): def handle_new_event(self, event, snapshot, destinations):
""" Takes in an event from the client to server side, that has already """ Takes in an event from the client to server side, that has already
been authed and handled by the state module, and sends it to any been authed and handled by the state module, and sends it to any
remote home servers that may be interested. remote home servers that may be interested.
@ -92,16 +91,12 @@ class FederationHandler(BaseHandler):
yield run_on_reactor() yield run_on_reactor()
pdu = event self.replication_layer.send_pdu(event, destinations)
if not hasattr(pdu, "destinations") or not pdu.destinations:
pdu.destinations = []
yield self.replication_layer.send_pdu(pdu)
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def on_receive_pdu(self, origin, pdu, backfilled, state=None): def on_receive_pdu(self, origin, pdu, backfilled, state=None,
auth_chain=None):
""" Called by the ReplicationLayer when we have a new pdu. We need to """ Called by the ReplicationLayer when we have a new pdu. We need to
do auth checks and put it through the StateHandler. do auth checks and put it through the StateHandler.
""" """
@ -140,7 +135,7 @@ class FederationHandler(BaseHandler):
if not check_event_content_hash(event): if not check_event_content_hash(event):
logger.warn( logger.warn(
"Event content has been tampered, redacting %s, %s", "Event content has been tampered, redacting %s, %s",
event.event_id, encode_canonical_json(event.get_full_dict()) event.event_id, encode_canonical_json(event.get_dict())
) )
event = redacted_event event = redacted_event
@ -153,43 +148,44 @@ class FederationHandler(BaseHandler):
event.room_id, event.room_id,
self.server_name self.server_name
) )
if not is_in_room and not event.outlier: if not is_in_room and not event.internal_metadata.outlier:
logger.debug("Got event for room we're not in.") logger.debug("Got event for room we're not in.")
replication_layer = self.replication_layer replication = self.replication_layer
auth_chain = yield replication_layer.get_event_auth(
origin,
context=event.room_id,
event_id=event.event_id,
)
for e in auth_chain:
e.outlier = True
try:
yield self._handle_new_event(e, fetch_missing=False)
except:
logger.exception(
"Failed to parse auth event %s",
e.event_id,
)
if not state: if not state:
state = yield replication_layer.get_state_for_context( state, auth_chain = yield replication.get_state_for_context(
origin, context=event.room_id, event_id=event.event_id,
)
if not auth_chain:
auth_chain = yield replication.get_event_auth(
origin, origin,
context=event.room_id, context=event.room_id,
event_id=event.event_id, event_id=event.event_id,
) )
for e in auth_chain:
e.internal_metadata.outlier = True
try:
yield self._handle_new_event(e, fetch_auth_from=origin)
except:
logger.exception(
"Failed to handle auth event %s",
e.event_id,
)
current_state = state current_state = state
if state: if state:
for e in state: for e in state:
e.outlier = True logging.info("A :) %r", e)
e.internal_metadata.outlier = True
try: try:
yield self._handle_new_event(e) yield self._handle_new_event(e)
except: except:
logger.exception( logger.exception(
"Failed to parse state event %s", "Failed to handle state event %s",
e.event_id, e.event_id,
) )
@ -208,6 +204,13 @@ class FederationHandler(BaseHandler):
affected=event.event_id, affected=event.event_id,
) )
# if we're receiving valid events from an origin,
# it's probably a good idea to mark it as not in retry-state
# for sending (although this is a bit of a leap)
retry_timings = yield self.store.get_destination_retry_timings(origin)
if (retry_timings and retry_timings.retry_last_ts):
self.store.set_destination_retry_timings(origin, 0, 0)
room = yield self.store.get_room(event.room_id) room = yield self.store.get_room(event.room_id)
if not room: if not room:
@ -222,7 +225,7 @@ class FederationHandler(BaseHandler):
if not backfilled: if not backfilled:
extra_users = [] extra_users = []
if event.type == RoomMemberEvent.TYPE: if event.type == EventTypes.Member:
target_user_id = event.state_key target_user_id = event.state_key
target_user = self.hs.parse_userid(target_user_id) target_user = self.hs.parse_userid(target_user_id)
extra_users.append(target_user) extra_users.append(target_user)
@ -231,7 +234,7 @@ class FederationHandler(BaseHandler):
event, extra_users=extra_users event, extra_users=extra_users
) )
if event.type == RoomMemberEvent.TYPE: if event.type == EventTypes.Member:
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
user = self.hs.parse_userid(event.state_key) user = self.hs.parse_userid(event.state_key)
yield self.distributor.fire( yield self.distributor.fire(
@ -258,11 +261,15 @@ class FederationHandler(BaseHandler):
event = pdu event = pdu
# FIXME (erikj): Not sure this actually works :/ # FIXME (erikj): Not sure this actually works :/
yield self.state_handler.annotate_event_with_state(event) context = yield self.state_handler.compute_event_context(event)
events.append(event) events.append((event, context))
yield self.store.persist_event(event, backfilled=True) yield self.store.persist_event(
event,
context=context,
backfilled=True
)
defer.returnValue(events) defer.returnValue(events)
@ -279,13 +286,11 @@ class FederationHandler(BaseHandler):
pdu=event pdu=event
) )
defer.returnValue(pdu) defer.returnValue(pdu)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_event_auth(self, event_id): def on_event_auth(self, event_id):
auth = yield self.store.get_auth_chain(event_id) auth = yield self.store.get_auth_chain([event_id])
for event in auth: for event in auth:
event.signatures.update( event.signatures.update(
@ -325,42 +330,55 @@ class FederationHandler(BaseHandler):
event = pdu event = pdu
# We should assert some things. # We should assert some things.
assert(event.type == RoomMemberEvent.TYPE) # FIXME: Do this in a nicer way
assert(event.type == EventTypes.Member)
assert(event.user_id == joinee) assert(event.user_id == joinee)
assert(event.state_key == joinee) assert(event.state_key == joinee)
assert(event.room_id == room_id) assert(event.room_id == room_id)
event.outlier = False event.internal_metadata.outlier = False
self.room_queues[room_id] = [] self.room_queues[room_id] = []
builder = self.event_builder_factory.new(
event.get_pdu_json()
)
handled_events = set()
try: try:
event.event_id = self.event_factory.create_event_id() builder.event_id = self.event_builder_factory.create_event_id()
event.origin = self.hs.hostname builder.origin = self.hs.hostname
event.content = content builder.content = content
if not hasattr(event, "signatures"): if not hasattr(event, "signatures"):
event.signatures = {} builder.signatures = {}
add_hashes_and_signatures( add_hashes_and_signatures(
event, builder,
self.hs.hostname, self.hs.hostname,
self.hs.config.signing_key[0], self.hs.config.signing_key[0],
) )
new_event = builder.build()
ret = yield self.replication_layer.send_join( ret = yield self.replication_layer.send_join(
target_host, target_host,
event new_event
) )
state = ret["state"] state = ret["state"]
auth_chain = ret["auth_chain"] auth_chain = ret["auth_chain"]
auth_chain.sort(key=lambda e: e.depth) auth_chain.sort(key=lambda e: e.depth)
handled_events.update([s.event_id for s in state])
handled_events.update([a.event_id for a in auth_chain])
handled_events.add(new_event.event_id)
logger.debug("do_invite_join auth_chain: %s", auth_chain) logger.debug("do_invite_join auth_chain: %s", auth_chain)
logger.debug("do_invite_join state: %s", state) logger.debug("do_invite_join state: %s", state)
logger.debug("do_invite_join event: %s", event) logger.debug("do_invite_join event: %s", new_event)
try: try:
yield self.store.store_room( yield self.store.store_room(
@ -373,37 +391,36 @@ class FederationHandler(BaseHandler):
pass pass
for e in auth_chain: for e in auth_chain:
e.outlier = True e.internal_metadata.outlier = True
try: try:
yield self._handle_new_event(e, fetch_missing=False) yield self._handle_new_event(e)
except: except:
logger.exception( logger.exception(
"Failed to parse auth event %s", "Failed to handle auth event %s",
e.event_id, e.event_id,
) )
for e in state: for e in state:
# FIXME: Auth these. # FIXME: Auth these.
e.outlier = True e.internal_metadata.outlier = True
try: try:
yield self._handle_new_event( yield self._handle_new_event(
e, e, fetch_auth_from=target_host
fetch_missing=True
) )
except: except:
logger.exception( logger.exception(
"Failed to parse state event %s", "Failed to handle state event %s",
e.event_id, e.event_id,
) )
yield self._handle_new_event( yield self._handle_new_event(
event, new_event,
state=state, state=state,
current_state=state, current_state=state,
) )
yield self.notifier.on_new_room_event( yield self.notifier.on_new_room_event(
event, extra_users=[joinee] new_event, extra_users=[joinee]
) )
logger.debug("Finished joining %s to %s", joinee, room_id) logger.debug("Finished joining %s to %s", joinee, room_id)
@ -412,6 +429,9 @@ class FederationHandler(BaseHandler):
del self.room_queues[room_id] del self.room_queues[room_id]
for p, origin in room_queue: for p, origin in room_queue:
if p.event_id in handled_events:
continue
try: try:
self.on_receive_pdu(origin, p, backfilled=False) self.on_receive_pdu(origin, p, backfilled=False)
except: except:
@ -421,25 +441,24 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_make_join_request(self, context, user_id): def on_make_join_request(self, room_id, user_id):
""" We've received a /make_join/ request, so we create a partial """ We've received a /make_join/ request, so we create a partial
join event for the room and return that. We don *not* persist or join event for the room and return that. We don *not* persist or
process it until the other server has signed it and sent it back. process it until the other server has signed it and sent it back.
""" """
event = self.event_factory.create_event( builder = self.event_builder_factory.new({
etype=RoomMemberEvent.TYPE, "type": EventTypes.Member,
content={"membership": Membership.JOIN}, "content": {"membership": Membership.JOIN},
room_id=context, "room_id": room_id,
user_id=user_id, "sender": user_id,
state_key=user_id, "state_key": user_id,
})
event, context = yield self._create_new_client_event(
builder=builder,
) )
snapshot = yield self.store.snapshot_room(event) self.auth.check(event, auth_events=context.auth_events)
snapshot.fill_out_prev_events(event)
yield self.state_handler.annotate_event_with_state(event)
yield self.auth.add_auth_events(event)
self.auth.check(event, auth_events=event.old_state_events)
pdu = event pdu = event
@ -453,12 +472,24 @@ class FederationHandler(BaseHandler):
""" """
event = pdu event = pdu
event.outlier = False logger.debug(
"on_send_join_request: Got event: %s, signatures: %s",
event.event_id,
event.signatures,
)
yield self._handle_new_event(event) event.internal_metadata.outlier = False
context = yield self._handle_new_event(event)
logger.debug(
"on_send_join_request: After _handle_new_event: %s, sigs: %s",
event.event_id,
event.signatures,
)
extra_users = [] extra_users = []
if event.type == RoomMemberEvent.TYPE: if event.type == EventTypes.Member:
target_user_id = event.state_key target_user_id = event.state_key
target_user = self.hs.parse_userid(target_user_id) target_user = self.hs.parse_userid(target_user_id)
extra_users.append(target_user) extra_users.append(target_user)
@ -467,7 +498,7 @@ class FederationHandler(BaseHandler):
event, extra_users=extra_users event, extra_users=extra_users
) )
if event.type == RoomMemberEvent.TYPE: if event.type == EventTypes.Member:
if event.content["membership"] == Membership.JOIN: if event.content["membership"] == Membership.JOIN:
user = self.hs.parse_userid(event.state_key) user = self.hs.parse_userid(event.state_key)
yield self.distributor.fire( yield self.distributor.fire(
@ -478,9 +509,9 @@ class FederationHandler(BaseHandler):
destinations = set() destinations = set()
for k, s in event.state_events.items(): for k, s in context.current_state.items():
try: try:
if k[0] == RoomMemberEvent.TYPE: if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.JOIN: if s.content["membership"] == Membership.JOIN:
destinations.add( destinations.add(
self.hs.parse_userid(s.state_key).domain self.hs.parse_userid(s.state_key).domain
@ -490,14 +521,21 @@ class FederationHandler(BaseHandler):
"Failed to get destination from event %s", s.event_id "Failed to get destination from event %s", s.event_id
) )
new_pdu.destinations = list(destinations) logger.debug(
"on_send_join_request: Sending event: %s, signatures: %s",
event.event_id,
event.signatures,
)
yield self.replication_layer.send_pdu(new_pdu) self.replication_layer.send_pdu(new_pdu, destinations)
auth_chain = yield self.store.get_auth_chain(event.event_id) state_ids = [e.event_id for e in context.current_state.values()]
auth_chain = yield self.store.get_auth_chain(set(
[event.event_id] + state_ids
))
defer.returnValue({ defer.returnValue({
"state": event.state_events.values(), "state": context.current_state.values(),
"auth_chain": auth_chain, "auth_chain": auth_chain,
}) })
@ -509,7 +547,7 @@ class FederationHandler(BaseHandler):
""" """
event = pdu event = pdu
event.outlier = True event.internal_metadata.outlier = True
event.signatures.update( event.signatures.update(
compute_event_signature( compute_event_signature(
@ -519,10 +557,11 @@ class FederationHandler(BaseHandler):
) )
) )
yield self.state_handler.annotate_event_with_state(event) context = yield self.state_handler.compute_event_context(event)
yield self.store.persist_event( yield self.store.persist_event(
event, event,
context=context,
backfilled=False, backfilled=False,
) )
@ -552,13 +591,13 @@ class FederationHandler(BaseHandler):
} }
event = yield self.store.get_event(event_id) event = yield self.store.get_event(event_id)
if hasattr(event, "state_key"): if event and event.is_state():
# Get previous state # Get previous state
if hasattr(event, "replaces_state") and event.replaces_state: if "replaces_state" in event.unsigned:
prev_event = yield self.store.get_event( prev_id = event.unsigned["replaces_state"]
event.replaces_state if prev_id != event.event_id:
) prev_event = yield self.store.get_event(prev_id)
results[(event.type, event.state_key)] = prev_event results[(event.type, event.state_key)] = prev_event
else: else:
del results[(event.type, event.state_key)] del results[(event.type, event.state_key)]
@ -643,75 +682,88 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _handle_new_event(self, event, state=None, backfilled=False, def _handle_new_event(self, event, state=None, backfilled=False,
current_state=None, fetch_missing=True): current_state=None, fetch_auth_from=None):
is_new_state = yield self.state_handler.annotate_event_with_state(
event, logger.debug(
old_state=state "_handle_new_event: Before annotate: %s, sigs: %s",
event.event_id, event.signatures,
) )
if event.old_state_events: context = yield self.state_handler.compute_event_context(
known_ids = set( event, old_state=state
[s.event_id for s in event.old_state_events.values()] )
)
for e_id, _ in event.auth_events: logger.debug(
if e_id not in known_ids: "_handle_new_event: Before auth fetch: %s, sigs: %s",
e = yield self.store.get_event( event.event_id, event.signatures,
e_id, )
allow_none=True,
is_new_state = not event.internal_metadata.is_outlier()
known_ids = set(
[s.event_id for s in context.auth_events.values()]
)
for e_id, _ in event.auth_events:
if e_id not in known_ids:
e = yield self.store.get_event(e_id, allow_none=True)
if not e and fetch_auth_from is not None:
# Grab the auth_chain over federation if we are missing
# auth events.
auth_chain = yield self.replication_layer.get_event_auth(
fetch_auth_from, event.event_id, event.room_id
) )
for auth_event in auth_chain:
if not e: yield self._handle_new_event(auth_event)
# TODO: Do some conflict res to make sure that we're e = yield self.store.get_event(e_id, allow_none=True)
# not the ones who are wrong.
logger.info(
"Rejecting %s as %s not in %s",
event.event_id, e_id, known_ids,
)
raise AuthError(403, "Auth events are stale")
auth_events = event.old_state_events
else:
# We need to get the auth events from somewhere.
# TODO: Don't just hit the DBs?
auth_events = {}
for e_id, _ in event.auth_events:
e = yield self.store.get_event(
e_id,
allow_none=True,
)
if not e: if not e:
e = yield self.replication_layer.get_pdu( # TODO: Do some conflict res to make sure that we're
event.origin, e_id, outlier=True # not the ones who are wrong.
logger.info(
"Rejecting %s as %s not in db or %s",
event.event_id, e_id, known_ids,
) )
# FIXME: How does raising AuthError work with federation?
raise AuthError(403, "Cannot find auth event")
if e and fetch_missing: context.auth_events[(e.type, e.state_key)] = e
try:
yield self.on_receive_pdu(event.origin, e, False)
except:
logger.exception(
"Failed to parse auth event %s",
e_id,
)
if not e: logger.debug(
logger.warn("Can't find auth event %s.", e_id) "_handle_new_event: Before hack: %s, sigs: %s",
event.event_id, event.signatures,
)
auth_events[(e.type, e.state_key)] = e if event.type == EventTypes.Member and not event.auth_events:
if len(event.prev_events) == 1:
c = yield self.store.get_event(event.prev_events[0][0])
if c.type == EventTypes.Create:
context.auth_events[(c.type, c.state_key)] = c
if event.type == RoomMemberEvent.TYPE and not event.auth_events: logger.debug(
if len(event.prev_events) == 1: "_handle_new_event: Before auth check: %s, sigs: %s",
c = yield self.store.get_event(event.prev_events[0][0]) event.event_id, event.signatures,
if c.type == RoomCreateEvent.TYPE: )
auth_events[(c.type, c.state_key)] = c
self.auth.check(event, auth_events=auth_events) self.auth.check(event, auth_events=context.auth_events)
logger.debug(
"_handle_new_event: Before persist_event: %s, sigs: %s",
event.event_id, event.signatures,
)
yield self.store.persist_event( yield self.store.persist_event(
event, event,
context=context,
backfilled=backfilled, backfilled=backfilled,
is_new_state=(is_new_state and not backfilled), is_new_state=(is_new_state and not backfilled),
current_state=current_state, current_state=current_state,
) )
logger.debug(
"_handle_new_event: After persist_event: %s, sigs: %s",
event.event_id, event.signatures,
)
defer.returnValue(context)

View File

@ -15,10 +15,11 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import RoomError from synapse.api.errors import RoomError
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.util.logcontext import PreserveLoggingContext from synapse.events.validator import EventValidator
from ._base import BaseHandler from ._base import BaseHandler
import logging import logging
@ -32,7 +33,7 @@ class MessageHandler(BaseHandler):
super(MessageHandler, self).__init__(hs) super(MessageHandler, self).__init__(hs)
self.hs = hs self.hs = hs
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.event_factory = hs.get_event_factory() self.validator = EventValidator()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_message(self, msg_id=None, room_id=None, sender_id=None, def get_message(self, msg_id=None, room_id=None, sender_id=None,
@ -63,35 +64,6 @@ class MessageHandler(BaseHandler):
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks
def send_message(self, event=None, suppress_auth=False):
""" Send a message.
Args:
event : The message event to store.
suppress_auth (bool) : True to suppress auth for this message. This
is primarily so the home server can inject messages into rooms at
will.
Raises:
SynapseError if something went wrong.
"""
self.ratelimit(event.user_id)
# TODO(paul): Why does 'event' not have a 'user' object?
user = self.hs.parse_userid(event.user_id)
assert user.is_mine, "User must be our own: %s" % (user,)
snapshot = yield self.store.snapshot_room(event)
yield self._on_new_room_event(
event, snapshot, suppress_auth=suppress_auth
)
with PreserveLoggingContext():
self.hs.get_handlers().presence_handler.bump_presence_active_time(
user
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_messages(self, user_id=None, room_id=None, pagin_config=None, def get_messages(self, user_id=None, room_id=None, pagin_config=None,
feedback=False): feedback=False):
@ -134,19 +106,53 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk) defer.returnValue(chunk)
@defer.inlineCallbacks @defer.inlineCallbacks
def store_room_data(self, event=None): def create_and_send_event(self, event_dict):
""" Stores data for a room. """ Given a dict from a client, create and handle a new event.
Creates an FrozenEvent object, filling out auth_events, prev_events,
etc.
Adds display names to Join membership events.
Persists and notifies local clients and federation.
Args: Args:
event : The room path event event_dict (dict): An entire event
stamp_event (bool) : True to stamp event content with server keys.
Raises:
SynapseError if something went wrong.
""" """
builder = self.event_builder_factory.new(event_dict)
snapshot = yield self.store.snapshot_room(event) self.validator.validate_new(builder)
yield self._on_new_room_event(event, snapshot) self.ratelimit(builder.user_id)
# TODO(paul): Why does 'event' not have a 'user' object?
user = self.hs.parse_userid(builder.user_id)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if builder.type == EventTypes.Member:
membership = builder.content.get("membership", None)
if membership == Membership.JOIN:
joinee = self.hs.parse_userid(builder.state_key)
# If event doesn't include a display name, add one.
yield self.distributor.fire(
"collect_presencelike_data",
joinee,
builder.content
)
event, context = yield self._create_new_client_event(
builder=builder,
)
if event.type == EventTypes.Member:
member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.change_membership(event, context)
else:
yield self.handle_new_client_event(
event=event,
context=context,
)
defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_data(self, user_id=None, room_id=None, def get_room_data(self, user_id=None, room_id=None,
@ -180,13 +186,6 @@ class MessageHandler(BaseHandler):
defer.returnValue(fb) defer.returnValue(fb)
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks
def send_feedback(self, event):
snapshot = yield self.store.snapshot_room(event)
# store message in db
yield self._on_new_room_event(event, snapshot)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_events(self, user_id, room_id): def get_state_events(self, user_id, room_id):
"""Retrieve all state events for a given room. """Retrieve all state events for a given room.

View File

@ -147,7 +147,7 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def is_presence_visible(self, observer_user, observed_user): def is_presence_visible(self, observer_user, observed_user):
assert(observed_user.is_mine) assert(self.hs.is_mine(observed_user))
if observer_user == observed_user: if observer_user == observed_user:
defer.returnValue(True) defer.returnValue(True)
@ -165,7 +165,7 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state(self, target_user, auth_user, as_event=False): def get_state(self, target_user, auth_user, as_event=False):
if target_user.is_mine: if self.hs.is_mine(target_user):
visible = yield self.is_presence_visible( visible = yield self.is_presence_visible(
observer_user=auth_user, observer_user=auth_user,
observed_user=target_user observed_user=target_user
@ -212,7 +212,7 @@ class PresenceHandler(BaseHandler):
# TODO (erikj): Turn this back on. Why did we end up sending EDUs # TODO (erikj): Turn this back on. Why did we end up sending EDUs
# everywhere? # everywhere?
if not target_user.is_mine: if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != auth_user: if target_user != auth_user:
@ -291,7 +291,7 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def user_joined_room(self, user, room_id): def user_joined_room(self, user, room_id):
if user.is_mine: if self.hs.is_mine(user):
statuscache = self._get_or_make_usercache(user) statuscache = self._get_or_make_usercache(user)
# No actual update but we need to bump the serial anyway for the # No actual update but we need to bump the serial anyway for the
@ -309,7 +309,7 @@ class PresenceHandler(BaseHandler):
rm_handler = self.homeserver.get_handlers().room_member_handler rm_handler = self.homeserver.get_handlers().room_member_handler
curr_users = yield rm_handler.get_room_members(room_id) curr_users = yield rm_handler.get_room_members(room_id)
for local_user in [c for c in curr_users if c.is_mine]: for local_user in [c for c in curr_users if self.hs.is_mine(c)]:
self.push_update_to_local_and_remote( self.push_update_to_local_and_remote(
observed_user=local_user, observed_user=local_user,
users_to_push=[user], users_to_push=[user],
@ -318,14 +318,14 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def send_invite(self, observer_user, observed_user): def send_invite(self, observer_user, observed_user):
if not observer_user.is_mine: if not self.hs.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
yield self.store.add_presence_list_pending( yield self.store.add_presence_list_pending(
observer_user.localpart, observed_user.to_string() observer_user.localpart, observed_user.to_string()
) )
if observed_user.is_mine: if self.hs.is_mine(observed_user):
yield self.invite_presence(observed_user, observer_user) yield self.invite_presence(observed_user, observer_user)
else: else:
yield self.federation.send_edu( yield self.federation.send_edu(
@ -339,7 +339,7 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _should_accept_invite(self, observed_user, observer_user): def _should_accept_invite(self, observed_user, observer_user):
if not observed_user.is_mine: if not self.hs.is_mine(observed_user):
defer.returnValue(False) defer.returnValue(False)
row = yield self.store.has_presence_state(observed_user.localpart) row = yield self.store.has_presence_state(observed_user.localpart)
@ -359,7 +359,7 @@ class PresenceHandler(BaseHandler):
observed_user.localpart, observer_user.to_string() observed_user.localpart, observer_user.to_string()
) )
if observer_user.is_mine: if self.hs.is_mine(observer_user):
if accept: if accept:
yield self.accept_presence(observed_user, observer_user) yield self.accept_presence(observed_user, observer_user)
else: else:
@ -396,7 +396,7 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def drop(self, observed_user, observer_user): def drop(self, observed_user, observer_user):
if not observer_user.is_mine: if not self.hs.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
yield self.store.del_presence_list( yield self.store.del_presence_list(
@ -410,7 +410,7 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_presence_list(self, observer_user, accepted=None): def get_presence_list(self, observer_user, accepted=None):
if not observer_user.is_mine: if not self.hs.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
presence = yield self.store.get_presence_list( presence = yield self.store.get_presence_list(
@ -465,7 +465,7 @@ class PresenceHandler(BaseHandler):
) )
for target_user in target_users: for target_user in target_users:
if target_user.is_mine: if self.hs.is_mine(target_user):
self._start_polling_local(user, target_user) self._start_polling_local(user, target_user)
# We want to tell the person that just came online # We want to tell the person that just came online
@ -477,7 +477,7 @@ class PresenceHandler(BaseHandler):
) )
deferreds = [] deferreds = []
remote_users = [u for u in target_users if not u.is_mine] remote_users = [u for u in target_users if not self.hs.is_mine(u)]
remoteusers_by_domain = partition(remote_users, lambda u: u.domain) remoteusers_by_domain = partition(remote_users, lambda u: u.domain)
# Only poll for people in our get_presence_list # Only poll for people in our get_presence_list
for domain in remoteusers_by_domain: for domain in remoteusers_by_domain:
@ -520,7 +520,7 @@ class PresenceHandler(BaseHandler):
def stop_polling_presence(self, user, target_user=None): def stop_polling_presence(self, user, target_user=None):
logger.debug("Stop polling for presence from %s", user) logger.debug("Stop polling for presence from %s", user)
if not target_user or target_user.is_mine: if not target_user or self.hs.is_mine(target_user):
self._stop_polling_local(user, target_user=target_user) self._stop_polling_local(user, target_user=target_user)
deferreds = [] deferreds = []
@ -579,7 +579,7 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def push_presence(self, user, statuscache): def push_presence(self, user, statuscache):
assert(user.is_mine) assert(self.hs.is_mine(user))
logger.debug("Pushing presence update from %s", user) logger.debug("Pushing presence update from %s", user)
@ -659,10 +659,6 @@ class PresenceHandler(BaseHandler):
if room_ids: if room_ids:
logger.debug(" | %d interested room IDs %r", len(room_ids), room_ids) logger.debug(" | %d interested room IDs %r", len(room_ids), room_ids)
if not observers and not room_ids:
logger.debug(" | no interested observers or room IDs")
continue
state = dict(push) state = dict(push)
del state["user_id"] del state["user_id"]
@ -683,6 +679,10 @@ class PresenceHandler(BaseHandler):
self._user_cachemap_latest_serial += 1 self._user_cachemap_latest_serial += 1
statuscache.update(state, serial=self._user_cachemap_latest_serial) statuscache.update(state, serial=self._user_cachemap_latest_serial)
if not observers and not room_ids:
logger.debug(" | no interested observers or room IDs")
continue
self.push_update_to_clients( self.push_update_to_clients(
observed_user=user, observed_user=user,
users_to_push=observers, users_to_push=observers,
@ -696,7 +696,7 @@ class PresenceHandler(BaseHandler):
for poll in content.get("poll", []): for poll in content.get("poll", []):
user = self.hs.parse_userid(poll) user = self.hs.parse_userid(poll)
if not user.is_mine: if not self.hs.is_mine(user):
continue continue
# TODO(paul) permissions checks # TODO(paul) permissions checks
@ -711,7 +711,7 @@ class PresenceHandler(BaseHandler):
for unpoll in content.get("unpoll", []): for unpoll in content.get("unpoll", []):
user = self.hs.parse_userid(unpoll) user = self.hs.parse_userid(unpoll)
if not user.is_mine: if not self.hs.is_mine(user):
continue continue
if user in self._remote_sendmap: if user in self._remote_sendmap:
@ -730,7 +730,7 @@ class PresenceHandler(BaseHandler):
localusers, remoteusers = partitionbool( localusers, remoteusers = partitionbool(
users_to_push, users_to_push,
lambda u: u.is_mine lambda u: self.hs.is_mine(u)
) )
localusers = set(localusers) localusers = set(localusers)
@ -788,7 +788,7 @@ class PresenceEventSource(object):
[u.to_string() for u in observer_user, observed_user])): [u.to_string() for u in observer_user, observed_user])):
defer.returnValue(True) defer.returnValue(True)
if observed_user.is_mine: if self.hs.is_mine(observed_user):
pushmap = presence._local_pushmap pushmap = presence._local_pushmap
defer.returnValue( defer.returnValue(
@ -804,6 +804,7 @@ class PresenceEventSource(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function
def get_new_events_for_user(self, user, from_key, limit): def get_new_events_for_user(self, user, from_key, limit):
from_key = int(from_key) from_key = int(from_key)
@ -816,7 +817,8 @@ class PresenceEventSource(object):
# TODO(paul): use a DeferredList ? How to limit concurrency. # TODO(paul): use a DeferredList ? How to limit concurrency.
for observed_user in cachemap.keys(): for observed_user in cachemap.keys():
cached = cachemap[observed_user] cached = cachemap[observed_user]
if not (from_key < cached.serial):
if cached.serial <= from_key:
continue continue
if (yield self.is_visible(observer_user, observed_user)): if (yield self.is_visible(observer_user, observed_user)):

View File

@ -51,7 +51,7 @@ class ProfileHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_displayname(self, target_user): def get_displayname(self, target_user):
if target_user.is_mine: if self.hs.is_mine(target_user):
displayname = yield self.store.get_profile_displayname( displayname = yield self.store.get_profile_displayname(
target_user.localpart target_user.localpart
) )
@ -81,7 +81,7 @@ class ProfileHandler(BaseHandler):
def set_displayname(self, target_user, auth_user, new_displayname): def set_displayname(self, target_user, auth_user, new_displayname):
"""target_user is the user whose displayname is to be changed; """target_user is the user whose displayname is to be changed;
auth_user is the user attempting to make this change.""" auth_user is the user attempting to make this change."""
if not target_user.is_mine: if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != auth_user: if target_user != auth_user:
@ -101,7 +101,7 @@ class ProfileHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_avatar_url(self, target_user): def get_avatar_url(self, target_user):
if target_user.is_mine: if self.hs.is_mine(target_user):
avatar_url = yield self.store.get_profile_avatar_url( avatar_url = yield self.store.get_profile_avatar_url(
target_user.localpart target_user.localpart
) )
@ -130,7 +130,7 @@ class ProfileHandler(BaseHandler):
def set_avatar_url(self, target_user, auth_user, new_avatar_url): def set_avatar_url(self, target_user, auth_user, new_avatar_url):
"""target_user is the user whose avatar_url is to be changed; """target_user is the user whose avatar_url is to be changed;
auth_user is the user attempting to make this change.""" auth_user is the user attempting to make this change."""
if not target_user.is_mine: if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != auth_user: if target_user != auth_user:
@ -150,7 +150,7 @@ class ProfileHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def collect_presencelike_data(self, user, state): def collect_presencelike_data(self, user, state):
if not user.is_mine: if not self.hs.is_mine(user):
defer.returnValue(None) defer.returnValue(None)
with PreserveLoggingContext(): with PreserveLoggingContext():
@ -170,7 +170,7 @@ class ProfileHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_profile_query(self, args): def on_profile_query(self, args):
user = self.hs.parse_userid(args["user_id"]) user = self.hs.parse_userid(args["user_id"])
if not user.is_mine: if not self.hs.is_mine(user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
just_field = args.get("field", None) just_field = args.get("field", None)
@ -191,7 +191,7 @@ class ProfileHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _update_join_states(self, user): def _update_join_states(self, user):
if not user.is_mine: if not self.hs.is_mine(user):
return return
joins = yield self.store.get_rooms_for_user_where_membership_is( joins = yield self.store.get_rooms_for_user_where_membership_is(
@ -200,8 +200,6 @@ class ProfileHandler(BaseHandler):
) )
for j in joins: for j in joins:
snapshot = yield self.store.snapshot_room(j)
content = { content = {
"membership": j.content["membership"], "membership": j.content["membership"],
} }
@ -210,14 +208,11 @@ class ProfileHandler(BaseHandler):
"collect_presencelike_data", user, content "collect_presencelike_data", user, content
) )
new_event = self.event_factory.create_event( msg_handler = self.hs.get_handlers().message_handler
etype=j.type, yield msg_handler.create_and_send_event({
room_id=j.room_id, "type": j.type,
state_key=j.state_key, "room_id": j.room_id,
content=content, "state_key": j.state_key,
user_id=j.state_key, "content": content,
) "sender": j.state_key,
})
yield self._on_new_room_event(
new_event, snapshot, suppress_auth=True
)

View File

@ -22,6 +22,7 @@ from synapse.api.errors import (
) )
from ._base import BaseHandler from ._base import BaseHandler
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
from synapse.util.async import run_on_reactor
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
from synapse.http.client import CaptchaServerHttpClient from synapse.http.client import CaptchaServerHttpClient
@ -54,12 +55,13 @@ class RegistrationHandler(BaseHandler):
Raises: Raises:
RegistrationError if there was a problem registering. RegistrationError if there was a problem registering.
""" """
yield run_on_reactor()
password_hash = None password_hash = None
if password: if password:
password_hash = bcrypt.hashpw(password, bcrypt.gensalt()) password_hash = bcrypt.hashpw(password, bcrypt.gensalt())
if localpart: if localpart:
user = UserID(localpart, self.hs.hostname, True) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
token = self._generate_token(user_id) token = self._generate_token(user_id)
@ -78,7 +80,7 @@ class RegistrationHandler(BaseHandler):
while not user_id and not token: while not user_id and not token:
try: try:
localpart = self._generate_user_id() localpart = self._generate_user_id()
user = UserID(localpart, self.hs.hostname, True) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
token = self._generate_token(user_id) token = self._generate_token(user_id)

View File

@ -17,12 +17,8 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.types import UserID, RoomAlias, RoomID from synapse.types import UserID, RoomAlias, RoomID
from synapse.api.constants import Membership, JoinRules from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import StoreError, SynapseError from synapse.api.errors import StoreError, SynapseError
from synapse.api.events.room import (
RoomMemberEvent, RoomCreateEvent, RoomPowerLevelsEvent,
RoomTopicEvent, RoomNameEvent, RoomJoinRulesEvent,
)
from synapse.util import stringutils from synapse.util import stringutils
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from ._base import BaseHandler from ._base import BaseHandler
@ -52,9 +48,9 @@ class RoomCreationHandler(BaseHandler):
self.ratelimit(user_id) self.ratelimit(user_id)
if "room_alias_name" in config: if "room_alias_name" in config:
room_alias = RoomAlias.create_local( room_alias = RoomAlias.create(
config["room_alias_name"], config["room_alias_name"],
self.hs self.hs.hostname,
) )
mapping = yield self.store.get_association_from_room_alias( mapping = yield self.store.get_association_from_room_alias(
room_alias room_alias
@ -76,8 +72,8 @@ class RoomCreationHandler(BaseHandler):
if room_id: if room_id:
# Ensure room_id is the correct type # Ensure room_id is the correct type
room_id_obj = RoomID.from_string(room_id, self.hs) room_id_obj = RoomID.from_string(room_id)
if not room_id_obj.is_mine: if not self.hs.is_mine(room_id_obj):
raise SynapseError(400, "Room id must be local") raise SynapseError(400, "Room id must be local")
yield self.store.store_room( yield self.store.store_room(
@ -93,7 +89,10 @@ class RoomCreationHandler(BaseHandler):
while attempts < 5: while attempts < 5:
try: try:
random_string = stringutils.random_string(18) random_string = stringutils.random_string(18)
gen_room_id = RoomID.create_local(random_string, self.hs) gen_room_id = RoomID.create(
random_string,
self.hs.hostname,
)
yield self.store.store_room( yield self.store.store_room(
room_id=gen_room_id.to_string(), room_id=gen_room_id.to_string(),
room_creator_user_id=user_id, room_creator_user_id=user_id,
@ -120,59 +119,37 @@ class RoomCreationHandler(BaseHandler):
user, room_id, is_public=is_public user, room_id, is_public=is_public
) )
room_member_handler = self.hs.get_handlers().room_member_handler msg_handler = self.hs.get_handlers().message_handler
@defer.inlineCallbacks
def handle_event(event):
snapshot = yield self.store.snapshot_room(event)
logger.debug("Event: %s", event)
if event.type == RoomMemberEvent.TYPE:
yield room_member_handler.change_membership(
event,
do_auth=True
)
else:
yield self._on_new_room_event(
event, snapshot, extra_users=[user], suppress_auth=True
)
for event in creation_events: for event in creation_events:
yield handle_event(event) yield msg_handler.create_and_send_event(event)
if "name" in config: if "name" in config:
name = config["name"] name = config["name"]
name_event = self.event_factory.create_event( yield msg_handler.create_and_send_event({
etype=RoomNameEvent.TYPE, "type": EventTypes.Name,
room_id=room_id, "room_id": room_id,
user_id=user_id, "sender": user_id,
content={"name": name}, "content": {"name": name},
) })
yield handle_event(name_event)
if "topic" in config: if "topic" in config:
topic = config["topic"] topic = config["topic"]
topic_event = self.event_factory.create_event( yield msg_handler.create_and_send_event({
etype=RoomTopicEvent.TYPE, "type": EventTypes.Topic,
room_id=room_id, "room_id": room_id,
user_id=user_id, "sender": user_id,
content={"topic": topic}, "content": {"topic": topic},
) })
yield handle_event(topic_event)
content = {"membership": Membership.INVITE}
for invitee in invite_list: for invitee in invite_list:
invite_event = self.event_factory.create_event( yield msg_handler.create_and_send_event({
etype=RoomMemberEvent.TYPE, "type": EventTypes.Member,
state_key=invitee, "state_key": invitee,
room_id=room_id, "room_id": room_id,
user_id=user_id, "sender": user_id,
content=content "content": {"membership": Membership.INVITE},
) })
yield handle_event(invite_event)
result = {"room_id": room_id} result = {"room_id": room_id}
@ -189,40 +166,44 @@ class RoomCreationHandler(BaseHandler):
event_keys = { event_keys = {
"room_id": room_id, "room_id": room_id,
"user_id": creator_id, "sender": creator_id,
"state_key": "",
} }
def create(etype, **content): def create(etype, content, **kwargs):
return self.event_factory.create_event( e = {
etype=etype, "type": etype,
content=content, "content": content,
**event_keys }
)
e.update(event_keys)
e.update(kwargs)
return e
creation_event = create( creation_event = create(
etype=RoomCreateEvent.TYPE, etype=EventTypes.Create,
creator=creator.to_string(), content={"creator": creator.to_string()},
) )
join_event = self.event_factory.create_event( join_event = create(
etype=RoomMemberEvent.TYPE, etype=EventTypes.Member,
state_key=creator_id, state_key=creator_id,
content={ content={
"membership": Membership.JOIN, "membership": Membership.JOIN,
}, },
**event_keys
) )
power_levels_event = self.event_factory.create_event( power_levels_event = create(
etype=RoomPowerLevelsEvent.TYPE, etype=EventTypes.PowerLevels,
content={ content={
"users": { "users": {
creator.to_string(): 100, creator.to_string(): 100,
}, },
"users_default": 0, "users_default": 0,
"events": { "events": {
RoomNameEvent.TYPE: 100, EventTypes.Name: 100,
RoomPowerLevelsEvent.TYPE: 100, EventTypes.PowerLevels: 100,
}, },
"events_default": 0, "events_default": 0,
"state_default": 50, "state_default": 50,
@ -230,13 +211,12 @@ class RoomCreationHandler(BaseHandler):
"kick": 50, "kick": 50,
"redact": 50 "redact": 50
}, },
**event_keys
) )
join_rule = JoinRules.PUBLIC if is_public else JoinRules.INVITE join_rule = JoinRules.PUBLIC if is_public else JoinRules.INVITE
join_rules_event = create( join_rules_event = create(
etype=RoomJoinRulesEvent.TYPE, etype=EventTypes.JoinRules,
join_rule=join_rule, content={"join_rule": join_rule},
) )
return [ return [
@ -260,6 +240,7 @@ class RoomMemberHandler(BaseHandler):
self.distributor = hs.get_distributor() self.distributor = hs.get_distributor()
self.distributor.declare("user_joined_room") self.distributor.declare("user_joined_room")
self.distributor.declare("user_left_room")
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_members(self, room_id, membership=Membership.JOIN): def get_room_members(self, room_id, membership=Membership.JOIN):
@ -287,7 +268,7 @@ class RoomMemberHandler(BaseHandler):
if ignore_user is not None and member == ignore_user: if ignore_user is not None and member == ignore_user:
continue continue
if member.is_mine: if self.hs.is_mine(member):
if localusers is not None: if localusers is not None:
localusers.add(member) localusers.add(member)
else: else:
@ -348,7 +329,7 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue(member) defer.returnValue(member)
@defer.inlineCallbacks @defer.inlineCallbacks
def change_membership(self, event=None, do_auth=True): def change_membership(self, event, context, do_auth=True):
""" Change the membership status of a user in a room. """ Change the membership status of a user in a room.
Args: Args:
@ -358,11 +339,9 @@ class RoomMemberHandler(BaseHandler):
""" """
target_user_id = event.state_key target_user_id = event.state_key
snapshot = yield self.store.snapshot_room(event) prev_state = context.current_state.get(
(EventTypes.Member, target_user_id),
## TODO(markjh): get prev state from snapshot. None
prev_state = yield self.store.get_room_member(
target_user_id, event.room_id
) )
room_id = event.room_id room_id = event.room_id
@ -371,10 +350,11 @@ class RoomMemberHandler(BaseHandler):
# if this HS is not currently in the room, i.e. we have to do the # if this HS is not currently in the room, i.e. we have to do the
# invite/join dance. # invite/join dance.
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
yield self._do_join(event, snapshot, do_auth=do_auth) yield self._do_join(event, context, do_auth=do_auth)
else: else:
# This is not a JOIN, so we can handle it normally. # This is not a JOIN, so we can handle it normally.
# FIXME: This isn't idempotency.
if prev_state and prev_state.membership == event.membership: if prev_state and prev_state.membership == event.membership:
# double same action, treat this event as a NOOP. # double same action, treat this event as a NOOP.
defer.returnValue({}) defer.returnValue({})
@ -383,10 +363,16 @@ class RoomMemberHandler(BaseHandler):
yield self._do_local_membership_update( yield self._do_local_membership_update(
event, event,
membership=event.content["membership"], membership=event.content["membership"],
snapshot=snapshot, context=context,
do_auth=do_auth, do_auth=do_auth,
) )
if prev_state and prev_state.membership == Membership.JOIN:
user = self.hs.parse_userid(event.user_id)
self.distributor.fire(
"user_left_room", user=user, room_id=event.room_id
)
defer.returnValue({"room_id": room_id}) defer.returnValue({"room_id": room_id})
@defer.inlineCallbacks @defer.inlineCallbacks
@ -404,33 +390,32 @@ class RoomMemberHandler(BaseHandler):
host = hosts[0] host = hosts[0]
content.update({"membership": Membership.JOIN}) # If event doesn't include a display name, add one.
new_event = self.event_factory.create_event( yield self.distributor.fire(
etype=RoomMemberEvent.TYPE, "collect_presencelike_data", joinee, content
state_key=joinee.to_string(),
room_id=room_id,
user_id=joinee.to_string(),
membership=Membership.JOIN,
content=content,
) )
snapshot = yield self.store.snapshot_room(new_event) content.update({"membership": Membership.JOIN})
builder = self.event_builder_factory.new({
"type": EventTypes.Member,
"state_key": joinee.to_string(),
"room_id": room_id,
"sender": joinee.to_string(),
"membership": Membership.JOIN,
"content": content,
})
event, context = yield self._create_new_client_event(builder)
yield self._do_join(new_event, snapshot, room_host=host, do_auth=True) yield self._do_join(event, context, room_host=host, do_auth=True)
defer.returnValue({"room_id": room_id}) defer.returnValue({"room_id": room_id})
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_join(self, event, snapshot, room_host=None, do_auth=True): def _do_join(self, event, context, room_host=None, do_auth=True):
joinee = self.hs.parse_userid(event.state_key) joinee = self.hs.parse_userid(event.state_key)
# room_id = RoomID.from_string(event.room_id, self.hs) # room_id = RoomID.from_string(event.room_id, self.hs)
room_id = event.room_id room_id = event.room_id
# If event doesn't include a display name, add one.
yield self.distributor.fire(
"collect_presencelike_data", joinee, event.content
)
# XXX: We don't do an auth check if we are doing an invite # XXX: We don't do an auth check if we are doing an invite
# join dance for now, since we're kinda implicitly checking # join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we # that we are allowed to join when we decide whether or not we
@ -452,31 +437,29 @@ class RoomMemberHandler(BaseHandler):
) )
if prev_state and prev_state.membership == Membership.INVITE: if prev_state and prev_state.membership == Membership.INVITE:
room = yield self.store.get_room(room_id) inviter = UserID.from_string(prev_state.user_id)
inviter = UserID.from_string(
prev_state.user_id, self.hs
)
should_do_dance = not inviter.is_mine and not room should_do_dance = not self.hs.is_mine(inviter)
room_host = inviter.domain room_host = inviter.domain
else: else:
should_do_dance = False should_do_dance = False
have_joined = False
if should_do_dance: if should_do_dance:
handler = self.hs.get_handlers().federation_handler handler = self.hs.get_handlers().federation_handler
have_joined = yield handler.do_invite_join( yield handler.do_invite_join(
room_host, room_id, event.user_id, event.content, snapshot room_host,
room_id,
event.user_id,
event.get_dict()["content"], # FIXME To get a non-frozen dict
context
) )
else:
# We want to do the _do_update inside the room lock.
if not have_joined:
logger.debug("Doing normal join") logger.debug("Doing normal join")
yield self._do_local_membership_update( yield self._do_local_membership_update(
event, event,
membership=event.content["membership"], membership=event.content["membership"],
snapshot=snapshot, context=context,
do_auth=do_auth, do_auth=do_auth,
) )
@ -501,10 +484,10 @@ class RoomMemberHandler(BaseHandler):
if prev_state and prev_state.membership == Membership.INVITE: if prev_state and prev_state.membership == Membership.INVITE:
room = yield self.store.get_room(room_id) room = yield self.store.get_room(room_id)
inviter = UserID.from_string( inviter = UserID.from_string(
prev_state.sender, self.hs prev_state.sender
) )
is_remote_invite_join = not inviter.is_mine and not room is_remote_invite_join = not self.hs.is_mine(inviter) and not room
room_host = inviter.domain room_host = inviter.domain
else: else:
is_remote_invite_join = False is_remote_invite_join = False
@ -526,25 +509,17 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue(room_ids) defer.returnValue(room_ids)
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_local_membership_update(self, event, membership, snapshot, def _do_local_membership_update(self, event, membership, context,
do_auth): do_auth):
yield run_on_reactor() yield run_on_reactor()
# If we're inviting someone, then we should also send it to that target_user = self.hs.parse_userid(event.state_key)
# HS.
target_user_id = event.state_key
target_user = self.hs.parse_userid(target_user_id)
if membership == Membership.INVITE and not target_user.is_mine:
do_invite_host = target_user.domain
else:
do_invite_host = None
yield self._on_new_room_event( yield self.handle_new_client_event(
event, event,
snapshot, context,
extra_users=[target_user], extra_users=[target_user],
suppress_auth=(not do_auth), suppress_auth=(not do_auth),
do_invite_host=do_invite_host,
) )

View File

@ -43,22 +43,50 @@ class TypingNotificationHandler(BaseHandler):
self.federation.register_edu_handler("m.typing", self._recv_edu) self.federation.register_edu_handler("m.typing", self._recv_edu)
self._member_typing_until = {} hs.get_distributor().observe("user_left_room", self.user_left_room)
self._member_typing_until = {} # clock time we expect to stop
self._member_typing_timer = {} # deferreds to manage theabove
# map room IDs to serial numbers
self._room_serials = {}
self._latest_room_serial = 0
# map room IDs to sets of users currently typing
self._room_typing = {}
def tearDown(self):
"""Cancels all the pending timers.
Normally this shouldn't be needed, but it's required from unit tests
to avoid a "Reactor was unclean" warning."""
for t in self._member_typing_timer.values():
self.clock.cancel_call_later(t)
@defer.inlineCallbacks @defer.inlineCallbacks
def started_typing(self, target_user, auth_user, room_id, timeout): def started_typing(self, target_user, auth_user, room_id, timeout):
if not target_user.is_mine: if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != auth_user: if target_user != auth_user:
raise AuthError(400, "Cannot set another user's typing state") raise AuthError(400, "Cannot set another user's typing state")
yield self.auth.check_joined_room(room_id, target_user.to_string())
logger.debug(
"%s has started typing in %s", target_user.to_string(), room_id
)
until = self.clock.time_msec() + timeout until = self.clock.time_msec() + timeout
member = RoomMember(room_id=room_id, user=target_user) member = RoomMember(room_id=room_id, user=target_user)
was_present = member in self._member_typing_until was_present = member in self._member_typing_until
if member in self._member_typing_timer:
self.clock.cancel_call_later(self._member_typing_timer[member])
self._member_typing_until[member] = until self._member_typing_until[member] = until
self._member_typing_timer[member] = self.clock.call_later(
timeout / 1000, lambda: self._stopped_typing(member)
)
if was_present: if was_present:
# No point sending another notification # No point sending another notification
@ -72,24 +100,45 @@ class TypingNotificationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def stopped_typing(self, target_user, auth_user, room_id): def stopped_typing(self, target_user, auth_user, room_id):
if not target_user.is_mine: if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != auth_user: if target_user != auth_user:
raise AuthError(400, "Cannot set another user's typing state") raise AuthError(400, "Cannot set another user's typing state")
yield self.auth.check_joined_room(room_id, target_user.to_string())
logger.debug(
"%s has stopped typing in %s", target_user.to_string(), room_id
)
member = RoomMember(room_id=room_id, user=target_user) member = RoomMember(room_id=room_id, user=target_user)
yield self._stopped_typing(member)
@defer.inlineCallbacks
def user_left_room(self, user, room_id):
if self.hs.is_mine(user):
member = RoomMember(room_id=room_id, user=user)
yield self._stopped_typing(member)
@defer.inlineCallbacks
def _stopped_typing(self, member):
if member not in self._member_typing_until: if member not in self._member_typing_until:
# No point # No point
defer.returnValue(None) defer.returnValue(None)
yield self._push_update( yield self._push_update(
room_id=room_id, room_id=member.room_id,
user=target_user, user=member.user,
typing=False, typing=False,
) )
del self._member_typing_until[member]
self.clock.cancel_call_later(self._member_typing_timer[member])
del self._member_typing_timer[member]
@defer.inlineCallbacks @defer.inlineCallbacks
def _push_update(self, room_id, user, typing): def _push_update(self, room_id, user, typing):
localusers = set() localusers = set()
@ -97,16 +146,14 @@ class TypingNotificationHandler(BaseHandler):
rm_handler = self.homeserver.get_handlers().room_member_handler rm_handler = self.homeserver.get_handlers().room_member_handler
yield rm_handler.fetch_room_distributions_into( yield rm_handler.fetch_room_distributions_into(
room_id, localusers=localusers, remotedomains=remotedomains, room_id, localusers=localusers, remotedomains=remotedomains
ignore_user=user
) )
for u in localusers: if localusers:
self.push_update_to_clients( self._push_update_local(
room_id=room_id, room_id=room_id,
observer_user=u, user=user,
observed_user=user, typing=typing
typing=typing,
) )
deferreds = [] deferreds = []
@ -135,29 +182,67 @@ class TypingNotificationHandler(BaseHandler):
room_id, localusers=localusers room_id, localusers=localusers
) )
for u in localusers: if localusers:
self.push_update_to_clients( self._push_update_local(
room_id=room_id, room_id=room_id,
observer_user=u, user=user,
observed_user=user,
typing=content["typing"] typing=content["typing"]
) )
def push_update_to_clients(self, room_id, observer_user, observed_user, def _push_update_local(self, room_id, user, typing):
typing): if room_id not in self._room_serials:
# TODO(paul) steal this from presence.py self._room_serials[room_id] = 0
pass self._room_typing[room_id] = set()
room_set = self._room_typing[room_id]
if typing:
room_set.add(user)
elif user in room_set:
room_set.remove(user)
self._latest_room_serial += 1
self._room_serials[room_id] = self._latest_room_serial
self.notifier.on_new_user_event(rooms=[room_id])
class TypingNotificationEventSource(object): class TypingNotificationEventSource(object):
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
self._handler = None
def handler(self):
# Avoid cyclic dependency in handler setup
if not self._handler:
self._handler = self.hs.get_handlers().typing_notification_handler
return self._handler
def _make_event_for(self, room_id):
typing = self.handler()._room_typing[room_id]
return {
"type": "m.typing",
"room_id": room_id,
"content": {
"user_ids": [u.to_string() for u in typing],
},
}
def get_new_events_for_user(self, user, from_key, limit): def get_new_events_for_user(self, user, from_key, limit):
return ([], from_key) from_key = int(from_key)
handler = self.handler()
events = []
for room_id in handler._room_serials:
if handler._room_serials[room_id] <= from_key:
continue
# TODO: check if user is in room
events.append(self._make_event_for(room_id))
return (events, handler._latest_room_serial)
def get_current_key(self): def get_current_key(self):
return 0 return self.handler()._latest_room_serial
def get_pagination_rows(self, user, pagination_config, key): def get_pagination_rows(self, user, pagination_config, key):
return ([], pagination_config.from_key) return ([], pagination_config.from_key)

View File

@ -14,10 +14,11 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer, reactor from twisted.internet import defer, reactor, protocol
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.web.client import readBody, _AgentBase, _URI from twisted.web.client import readBody, _AgentBase, _URI
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web._newclient import ResponseDone
from synapse.http.endpoint import matrix_federation_endpoint from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util.async import sleep from synapse.util.async import sleep
@ -25,7 +26,7 @@ from synapse.util.logcontext import PreserveLoggingContext
from syutil.jsonutil import encode_canonical_json from syutil.jsonutil import encode_canonical_json
from synapse.api.errors import CodeMessageException, SynapseError from synapse.api.errors import CodeMessageException, SynapseError, Codes
from syutil.crypto.jsonsign import sign_json from syutil.crypto.jsonsign import sign_json
@ -89,8 +90,8 @@ class MatrixFederationHttpClient(object):
("", "", path_bytes, param_bytes, query_bytes, "",) ("", "", path_bytes, param_bytes, query_bytes, "",)
) )
logger.debug("Sending request to %s: %s %s", logger.info("Sending request to %s: %s %s",
destination, method, url_bytes) destination, method, url_bytes)
logger.debug( logger.debug(
"Types: %s", "Types: %s",
@ -101,6 +102,8 @@ class MatrixFederationHttpClient(object):
] ]
) )
# XXX: Would be much nicer to retry only at the transaction-layer
# (once we have reliable transactions in place)
retries_left = 5 retries_left = 5
endpoint = self._getEndpoint(reactor, destination) endpoint = self._getEndpoint(reactor, destination)
@ -127,11 +130,20 @@ class MatrixFederationHttpClient(object):
break break
except Exception as e: except Exception as e:
if not retry_on_dns_fail and isinstance(e, DNSLookupError): if not retry_on_dns_fail and isinstance(e, DNSLookupError):
logger.warn("DNS Lookup failed to %s with %s", destination, logger.warn(
e) "DNS Lookup failed to %s with %s",
destination,
e
)
raise SynapseError(400, "Domain specified not found.") raise SynapseError(400, "Domain specified not found.")
logger.exception("Got error in _create_request") logger.warn(
"Sending request failed to %s: %s %s : %s",
destination,
method,
url_bytes,
e
)
_print_ex(e) _print_ex(e)
if retries_left: if retries_left:
@ -140,15 +152,21 @@ class MatrixFederationHttpClient(object):
else: else:
raise raise
logger.info(
"Received response %d %s for %s: %s %s",
response.code,
response.phrase,
destination,
method,
url_bytes
)
if 200 <= response.code < 300: if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent? # We need to update the transactions table to say it was sent?
pass pass
else: else:
# :'( # :'(
# Update transactions table? # Update transactions table?
logger.error(
"Got response %d %s", response.code, response.phrase
)
raise CodeMessageException( raise CodeMessageException(
response.code, response.phrase response.code, response.phrase
) )
@ -227,7 +245,7 @@ class MatrixFederationHttpClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_json(self, destination, path, args={}, retry_on_dns_fail=True): def get_json(self, destination, path, args={}, retry_on_dns_fail=True):
""" Get's some json from the given host homeserver and path """ GETs some json from the given host homeserver and path
Args: Args:
destination (str): The remote server to send the HTTP request destination (str): The remote server to send the HTTP request
@ -235,9 +253,6 @@ class MatrixFederationHttpClient(object):
path (str): The HTTP path. path (str): The HTTP path.
args (dict): A dictionary used to create query strings, defaults to args (dict): A dictionary used to create query strings, defaults to
None. None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
Returns: Returns:
Deferred: Succeeds when we get *any* HTTP response. Deferred: Succeeds when we get *any* HTTP response.
@ -272,6 +287,52 @@ class MatrixFederationHttpClient(object):
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks
def get_file(self, destination, path, output_stream, args={},
retry_on_dns_fail=True, max_size=None):
"""GETs a file from a given homeserver
Args:
destination (str): The remote server to send the HTTP request to.
path (str): The HTTP path to GET.
output_stream (file): File to write the response body to.
args (dict): Optional dictionary used to create the query string.
Returns:
A (int,dict) tuple of the file length and a dict of the response
headers.
"""
encoded_args = {}
for k, vs in args.items():
if isinstance(vs, basestring):
vs = [vs]
encoded_args[k] = [v.encode("UTF-8") for v in vs]
query_bytes = urllib.urlencode(encoded_args, True)
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
def body_callback(method, url_bytes, headers_dict):
self.sign_request(destination, method, url_bytes, headers_dict)
return None
response = yield self._create_request(
destination.encode("ascii"),
"GET",
path.encode("ascii"),
query_bytes=query_bytes,
body_callback=body_callback,
retry_on_dns_fail=retry_on_dns_fail
)
headers = dict(response.headers.getAllRawHeaders())
try:
length = yield _readBodyToFile(response, output_stream, max_size)
except:
logger.exception("Failed to download body")
raise
defer.returnValue((length, headers))
def _getEndpoint(self, reactor, destination): def _getEndpoint(self, reactor, destination):
return matrix_federation_endpoint( return matrix_federation_endpoint(
reactor, destination, timeout=10, reactor, destination, timeout=10,
@ -279,12 +340,44 @@ class MatrixFederationHttpClient(object):
) )
class _ReadBodyToFileProtocol(protocol.Protocol):
def __init__(self, stream, deferred, max_size):
self.stream = stream
self.deferred = deferred
self.length = 0
self.max_size = max_size
def dataReceived(self, data):
self.stream.write(data)
self.length += len(data)
if self.max_size is not None and self.length >= self.max_size:
self.deferred.errback(SynapseError(
502,
"Requested file is too large > %r bytes" % (self.max_size,),
Codes.TOO_LARGE,
))
self.deferred = defer.Deferred()
self.transport.loseConnection()
def connectionLost(self, reason):
if reason.check(ResponseDone):
self.deferred.callback(self.length)
else:
self.deferred.errback(reason)
def _readBodyToFile(response, stream, max_size):
d = defer.Deferred()
response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
return d
def _print_ex(e): def _print_ex(e):
if hasattr(e, "reasons") and e.reasons: if hasattr(e, "reasons") and e.reasons:
for ex in e.reasons: for ex in e.reasons:
_print_ex(ex) _print_ex(ex)
else: else:
logger.exception(e) logger.warn(e)
class _JsonProducer(object): class _JsonProducer(object):

View File

@ -29,6 +29,7 @@ from twisted.web.util import redirectTo
import collections import collections
import logging import logging
import urllib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -122,9 +123,14 @@ class JsonResource(HttpServer, resource.Resource):
# We found a match! Trigger callback and then return the # We found a match! Trigger callback and then return the
# returned response. We pass both the request and any # returned response. We pass both the request and any
# matched groups from the regex to the callback. # matched groups from the regex to the callback.
args = [
urllib.unquote(u).decode("UTF-8") for u in m.groups()
]
code, response = yield path_entry.callback( code, response = yield path_entry.callback(
request, request,
*m.groups() *args
) )
self._send_response(request, code, response) self._send_response(request, code, response)
@ -166,14 +172,10 @@ class JsonResource(HttpServer, resource.Resource):
request) request)
return return
if not self._request_user_agent_is_curl(request):
json_bytes = encode_canonical_json(response_json_object)
else:
json_bytes = encode_pretty_printed_json(response_json_object)
# TODO: Only enable CORS for the requests that need it. # TODO: Only enable CORS for the requests that need it.
respond_with_json_bytes(request, code, json_bytes, send_cors=True, respond_with_json(request, code, response_json_object, send_cors=True,
response_code_message=response_code_message) response_code_message=response_code_message,
pretty_print=self._request_user_agent_is_curl)
@staticmethod @staticmethod
def _request_user_agent_is_curl(request): def _request_user_agent_is_curl(request):
@ -202,6 +204,17 @@ class RootRedirect(resource.Resource):
return resource.Resource.getChild(self, name, request) return resource.Resource.getChild(self, name, request)
def respond_with_json(request, code, json_object, send_cors=False,
response_code_message=None, pretty_print=False):
if not pretty_print:
json_bytes = encode_pretty_printed_json(json_object)
else:
json_bytes = encode_canonical_json(json_object)
return respond_with_json_bytes(request, code, json_bytes, send_cors,
response_code_message=response_code_message)
def respond_with_json_bytes(request, code, json_bytes, send_cors=False, def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
response_code_message=None): response_code_message=None):
"""Sends encoded JSON in response to the given request. """Sends encoded JSON in response to the given request.

View File

View File

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .server import respond_with_json_bytes from synapse.http.server import respond_with_json_bytes
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from synapse.api.errors import ( from synapse.api.errors import (

View File

View File

@ -0,0 +1,369 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .thumbnailer import Thumbnailer
from synapse.http.server import respond_with_json
from synapse.util.stringutils import random_string
from synapse.api.errors import (
cs_exception, CodeMessageException, cs_error, Codes, SynapseError
)
from twisted.internet import defer
from twisted.web.resource import Resource
from twisted.protocols.basic import FileSender
import os
import logging
logger = logging.getLogger(__name__)
class BaseMediaResource(Resource):
isLeaf = True
def __init__(self, hs, filepaths):
Resource.__init__(self)
self.auth = hs.get_auth()
self.client = hs.get_http_client()
self.clock = hs.get_clock()
self.server_name = hs.hostname
self.store = hs.get_datastore()
self.max_upload_size = hs.config.max_upload_size
self.max_image_pixels = hs.config.max_image_pixels
self.filepaths = filepaths
self.downloads = {}
@staticmethod
def catch_errors(request_handler):
@defer.inlineCallbacks
def wrapped_request_handler(self, request):
try:
yield request_handler(self, request)
except CodeMessageException as e:
logger.exception(e)
respond_with_json(
request, e.code, cs_exception(e), send_cors=True
)
except:
logger.exception(
"Failed handle request %s.%s on %r",
request_handler.__module__,
request_handler.__name__,
self,
)
respond_with_json(
request,
500,
{"error": "Internal server error"},
send_cors=True
)
return wrapped_request_handler
@staticmethod
def _parse_media_id(request):
try:
server_name, media_id = request.postpath
return (server_name, media_id)
except:
raise SynapseError(
404,
"Invalid media id token %r" % (request.postpath,),
Codes.UNKKOWN,
)
@staticmethod
def _parse_integer(request, arg_name, default=None):
try:
if default is None:
return int(request.args[arg_name][0])
else:
return int(request.args.get(arg_name, [default])[0])
except:
raise SynapseError(
400,
"Missing integer argument %r" % (arg_name,),
Codes.UNKNOWN,
)
@staticmethod
def _parse_string(request, arg_name, default=None):
try:
if default is None:
return request.args[arg_name][0]
else:
return request.args.get(arg_name, [default])[0]
except:
raise SynapseError(
400,
"Missing string argument %r" % (arg_name,),
Codes.UNKNOWN,
)
def _respond_404(self, request):
respond_with_json(
request, 404,
cs_error(
"Not found %r" % (request.postpath,),
code=Codes.NOT_FOUND,
),
send_cors=True
)
@staticmethod
def _makedirs(filepath):
dirname = os.path.dirname(filepath)
if not os.path.exists(dirname):
os.makedirs(dirname)
def _get_remote_media(self, server_name, media_id):
key = (server_name, media_id)
download = self.downloads.get(key)
if download is None:
download = self._get_remote_media_impl(server_name, media_id)
self.downloads[key] = download
@download.addBoth
def callback(media_info):
del self.downloads[key]
return download
@defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id):
media_info = yield self.store.get_cached_remote_media(
server_name, media_id
)
if not media_info:
media_info = yield self._download_remote_file(
server_name, media_id
)
defer.returnValue(media_info)
@defer.inlineCallbacks
def _download_remote_file(self, server_name, media_id):
file_id = random_string(24)
fname = self.filepaths.remote_media_filepath(
server_name, file_id
)
self._makedirs(fname)
try:
with open(fname, "wb") as f:
request_path = "/".join((
"/_matrix/media/v1/download", server_name, media_id,
))
length, headers = yield self.client.get_file(
server_name, request_path, output_stream=f,
max_size=self.max_upload_size,
)
media_type = headers["Content-Type"][0]
time_now_ms = self.clock.time_msec()
yield self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
upload_name=None,
media_length=length,
filesystem_id=file_id,
)
except:
os.remove(fname)
raise
media_info = {
"media_type": media_type,
"media_length": length,
"upload_name": None,
"created_ts": time_now_ms,
"filesystem_id": file_id,
}
yield self._generate_remote_thumbnails(
server_name, media_id, media_info
)
defer.returnValue(media_info)
@defer.inlineCallbacks
def _respond_with_file(self, request, media_type, file_path):
logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path):
request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
# cache for at least a day.
# XXX: we might want to turn this off for data we don't want to
# recommend caching as it's sensitive or private - or at least
# select private. don't bother setting Expires as all our
# clients are smart enough to be happy with Cache-Control
request.setHeader(
b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
)
with open(file_path, "rb") as f:
yield FileSender().beginFileTransfer(f, request)
request.finish()
else:
self._respond_404()
def _get_thumbnail_requirements(self, media_type):
if media_type == "image/jpeg":
return (
(32, 32, "crop", "image/jpeg"),
(96, 96, "crop", "image/jpeg"),
(320, 240, "scale", "image/jpeg"),
(640, 480, "scale", "image/jpeg"),
)
elif (media_type == "image/png") or (media_type == "image/gif"):
return (
(32, 32, "crop", "image/png"),
(96, 96, "crop", "image/png"),
(320, 240, "scale", "image/png"),
(640, 480, "scale", "image/png"),
)
else:
return ()
@defer.inlineCallbacks
def _generate_local_thumbnails(self, media_id, media_info):
media_type = media_info["media_type"]
requirements = self._get_thumbnail_requirements(media_type)
if not requirements:
return
input_path = self.filepaths.local_media_filepath(media_id)
thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
m_height = thumbnailer.height
if m_width * m_height >= self.max_image_pixels:
logger.info(
"Image too large to thumbnail %r x %r > %r",
m_width, m_height, self.max_image_pixels
)
return
scales = set()
crops = set()
for r_width, r_height, r_method, r_type in requirements:
if r_method == "scale":
t_width, t_height = thumbnailer.aspect(r_width, r_height)
scales.add((
min(m_width, t_width), min(m_height, t_height), r_type,
))
elif r_method == "crop":
crops.add((r_width, r_height, r_type))
for t_width, t_height, t_type in scales:
t_method = "scale"
t_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
yield self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
for t_width, t_height, t_type in crops:
if (t_width, t_height, t_type) in scales:
# If the aspect ratio of the cropped thumbnail matches a purely
# scaled one then there is no point in calculating a separate
# thumbnail.
continue
t_method = "crop"
t_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
yield self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
defer.returnValue({
"width": m_width,
"height": m_height,
})
@defer.inlineCallbacks
def _generate_remote_thumbnails(self, server_name, media_id, media_info):
media_type = media_info["media_type"]
file_id = media_info["filesystem_id"]
requirements = self._get_thumbnail_requirements(media_type)
if not requirements:
return
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
m_height = thumbnailer.height
if m_width * m_height >= self.max_image_pixels:
logger.info(
"Image too large to thumbnail %r x %r > %r",
m_width, m_height, self.max_image_pixels
)
return
scales = set()
crops = set()
for r_width, r_height, r_method, r_type in requirements:
if r_method == "scale":
t_width, t_height = thumbnailer.aspect(r_width, r_height)
scales.add((
min(m_width, t_width), min(m_height, t_height), r_type,
))
elif r_method == "crop":
crops.add((r_width, r_height, r_type))
for t_width, t_height, t_type in scales:
t_method = "scale"
t_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
yield self.store.store_remote_media_thumbnail(
server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len
)
for t_width, t_height, t_type in crops:
if (t_width, t_height, t_type) in scales:
# If the aspect ratio of the cropped thumbnail matches a purely
# scaled one then there is no point in calculating a separate
# thumbnail.
continue
t_method = "crop"
t_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
yield self.store.store_remote_media_thumbnail(
server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len
)
defer.returnValue({
"width": m_width,
"height": m_height,
})

View File

@ -0,0 +1,68 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base_resource import BaseMediaResource
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
import logging
logger = logging.getLogger(__name__)
class DownloadResource(BaseMediaResource):
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
@BaseMediaResource.catch_errors
@defer.inlineCallbacks
def _async_render_GET(self, request):
try:
server_name, media_id = request.postpath
except:
self._respond_404(request)
return
if server_name == self.server_name:
yield self._respond_local_file(request, media_id)
else:
yield self._respond_remote_file(request, server_name, media_id)
@defer.inlineCallbacks
def _respond_local_file(self, request, media_id):
media_info = yield self.store.get_local_media(media_id)
if not media_info:
self._respond_404()
return
media_type = media_info["media_type"]
file_path = self.filepaths.local_media_filepath(media_id)
yield self._respond_with_file(request, media_type, file_path)
@defer.inlineCallbacks
def _respond_remote_file(self, request, server_name, media_id):
media_info = yield self._get_remote_media(server_name, media_id)
media_type = media_info["media_type"]
filesystem_id = media_info["filesystem_id"]
file_path = self.filepaths.remote_media_filepath(
server_name, filesystem_id
)
yield self._respond_with_file(request, media_type, file_path)

View File

@ -0,0 +1,67 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
class MediaFilePaths(object):
def __init__(self, base_path):
self.base_path = base_path
def default_thumbnail(self, default_top_level, default_sub_type, width,
height, content_type, method):
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (
width, height, top_level_type, sub_type, method
)
return os.path.join(
self.base_path, "default_thumbnails", default_top_level,
default_sub_type, file_name
)
def local_media_filepath(self, media_id):
return os.path.join(
self.base_path, "local_content",
media_id[0:2], media_id[2:4], media_id[4:]
)
def local_media_thumbnail(self, media_id, width, height, content_type,
method):
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (
width, height, top_level_type, sub_type, method
)
return os.path.join(
self.base_path, "local_thumbnails",
media_id[0:2], media_id[2:4], media_id[4:],
file_name
)
def remote_media_filepath(self, server_name, file_id):
return os.path.join(
self.base_path, "remote_content", server_name,
file_id[0:2], file_id[2:4], file_id[4:]
)
def remote_media_thumbnail(self, server_name, file_id, width, height,
content_type, method):
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
return os.path.join(
self.base_path, "remote_thumbnail", server_name,
file_id[0:2], file_id[2:4], file_id[4:],
file_name
)

View File

@ -0,0 +1,77 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .upload_resource import UploadResource
from .download_resource import DownloadResource
from .thumbnail_resource import ThumbnailResource
from .filepath import MediaFilePaths
from twisted.web.resource import Resource
import logging
logger = logging.getLogger(__name__)
class MediaRepositoryResource(Resource):
"""File uploading and downloading.
Uploads are POSTed to a resource which returns a token which is used to GET
the download::
=> POST /_matrix/media/v1/upload HTTP/1.1
Content-Type: <media-type>
<media>
<= HTTP/1.1 200 OK
Content-Type: application/json
{ "content_uri": "mxc://<server-name>/<media-id>" }
=> GET /_matrix/media/v1/download/<server-name>/<media-id> HTTP/1.1
<= HTTP/1.1 200 OK
Content-Type: <media-type>
Content-Disposition: attachment;filename=<upload-filename>
<media>
Clients can get thumbnails by supplying a desired width and height and
thumbnailing method::
=> GET /_matrix/media/v1/thumbnail/<server_name>
/<media-id>?width=<w>&height=<h>&method=<m> HTTP/1.1
<= HTTP/1.1 200 OK
Content-Type: image/jpeg or image/png
<thumbnail>
The thumbnail methods are "crop" and "scale". "scale" trys to return an
image where either the width or the height is smaller than the requested
size. The client should then scale and letterbox the image if it needs to
fit within a given rectangle. "crop" trys to return an image where the
width and height are close to the requested size and the aspect matches
the requested size. The client should scale the image if it needs to fit
within a given rectangle.
"""
def __init__(self, hs):
Resource.__init__(self)
filepaths = MediaFilePaths(hs.config.media_store_path)
self.putChild("upload", UploadResource(hs, filepaths))
self.putChild("download", DownloadResource(hs, filepaths))
self.putChild("thumbnail", ThumbnailResource(hs, filepaths))

View File

@ -0,0 +1,191 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base_resource import BaseMediaResource
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
import logging
logger = logging.getLogger(__name__)
class ThumbnailResource(BaseMediaResource):
isLeaf = True
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
@BaseMediaResource.catch_errors
@defer.inlineCallbacks
def _async_render_GET(self, request):
server_name, media_id = self._parse_media_id(request)
width = self._parse_integer(request, "width")
height = self._parse_integer(request, "height")
method = self._parse_string(request, "method", "scale")
m_type = self._parse_string(request, "type", "image/png")
if server_name == self.server_name:
yield self._respond_local_thumbnail(
request, media_id, width, height, method, m_type
)
else:
yield self._respond_remote_thumbnail(
request, server_name, media_id,
width, height, method, m_type
)
@defer.inlineCallbacks
def _respond_local_thumbnail(self, request, media_id, width, height,
method, m_type):
media_info = yield self.store.get_local_media(media_id)
if not media_info:
self._respond_404(request)
return
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
if thumbnail_infos:
thumbnail_info = self._select_thumbnail(
width, height, method, m_type, thumbnail_infos
)
t_width = thumbnail_info["thumbnail_width"]
t_height = thumbnail_info["thumbnail_height"]
t_type = thumbnail_info["thumbnail_type"]
t_method = thumbnail_info["thumbnail_method"]
file_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method,
)
yield self._respond_with_file(request, t_type, file_path)
else:
yield self._respond_default_thumbnail(
request, media_info, width, height, method, m_type,
)
@defer.inlineCallbacks
def _respond_remote_thumbnail(self, request, server_name, media_id, width,
height, method, m_type):
# TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead.
media_info = yield self._get_remote_media(server_name, media_id)
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
server_name, media_id,
)
if thumbnail_infos:
thumbnail_info = self._select_thumbnail(
width, height, method, m_type, thumbnail_infos
)
t_width = thumbnail_info["thumbnail_width"]
t_height = thumbnail_info["thumbnail_height"]
t_type = thumbnail_info["thumbnail_type"]
t_method = thumbnail_info["thumbnail_method"]
file_id = thumbnail_info["filesystem_id"]
file_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, t_width, t_height, t_type, t_method,
)
yield self._respond_with_file(request, t_type, file_path)
else:
yield self._respond_default_thumbnail(
request, media_info, width, height, method, m_type,
)
@defer.inlineCallbacks
def _respond_default_thumbnail(self, request, media_info, width, height,
method, m_type):
media_type = media_info["media_type"]
top_level_type = media_type.split("/")[0]
sub_type = media_type.split("/")[-1].split(";")[0]
thumbnail_infos = yield self.store.get_default_thumbnails(
top_level_type, sub_type,
)
if not thumbnail_infos:
thumbnail_infos = yield self.store.get_default_thumbnails(
top_level_type, "_default",
)
if not thumbnail_infos:
thumbnail_infos = yield self.store.get_default_thumbnails(
"_default", "_default",
)
if not thumbnail_infos:
self._respond_404(request)
return
thumbnail_info = self._select_thumbnail(
width, height, "crop", m_type, thumbnail_infos
)
t_width = thumbnail_info["thumbnail_width"]
t_height = thumbnail_info["thumbnail_height"]
t_type = thumbnail_info["thumbnail_type"]
t_method = thumbnail_info["thumbnail_method"]
file_path = self.filepaths.default_thumbnail(
top_level_type, sub_type, t_width, t_height, t_type, t_method,
)
yield self.respond_with_file(request, t_type, file_path)
def _select_thumbnail(self, desired_width, desired_height, desired_method,
desired_type, thumbnail_infos):
d_w = desired_width
d_h = desired_height
if desired_method.lower() == "crop":
info_list = []
for info in thumbnail_infos:
t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"]
t_method = info["thumbnail_method"]
if t_method == "scale" or t_method == "crop":
aspect_quality = abs(d_w * t_h - d_h * t_w)
size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"]
info_list.append((
aspect_quality, size_quality, type_quality,
length_quality, info
))
if info_list:
return min(info_list)[-1]
else:
info_list = []
info_list2 = []
for info in thumbnail_infos:
t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"]
t_method = info["thumbnail_method"]
size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"]
if t_method == "scale" and (t_w >= d_w or t_h >= d_h):
info_list.append((
size_quality, type_quality, length_quality, info
))
elif t_method == "scale":
info_list2.append((
size_quality, type_quality, length_quality, info
))
if info_list:
return min(info_list)[-1]
else:
return min(info_list2)[-1]

View File

@ -0,0 +1,89 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import PIL.Image as Image
from io import BytesIO
class Thumbnailer(object):
FORMATS = {
"image/jpeg": "JPEG",
"image/png": "PNG",
}
def __init__(self, input_path):
self.image = Image.open(input_path)
self.width, self.height = self.image.size
def aspect(self, max_width, max_height):
"""Calculate the largest size that preserves aspect ratio which
fits within the given rectangle::
(w_in / h_in) = (w_out / h_out)
w_out = min(w_max, h_max * (w_in / h_in))
h_out = min(h_max, w_max * (h_in / w_in))
Args:
max_width: The largest possible width.
max_height: The larget possible height.
"""
if max_width * self.height < max_height * self.width:
return (max_width, (max_width * self.height) // self.width)
else:
return ((max_height * self.width) // self.height, max_height)
def scale(self, output_path, width, height, output_type):
"""Rescales the image to the given dimensions"""
scaled = self.image.resize((width, height), Image.BILINEAR)
return self.save_image(scaled, output_type, output_path)
def crop(self, output_path, width, height, output_type):
"""Rescales and crops the image to the given dimensions preserving
aspect::
(w_in / h_in) = (w_scaled / h_scaled)
w_scaled = max(w_out, h_out * (w_in / h_in))
h_scaled = max(h_out, w_out * (h_in / w_in))
Args:
max_width: The largest possible width.
max_height: The larget possible height.
"""
if width * self.height > height * self.width:
scaled_height = (width * self.height) // self.width
scaled_image = self.image.resize(
(width, scaled_height), Image.BILINEAR
)
crop_top = (scaled_height - height) // 2
crop_bottom = height + crop_top
cropped = scaled_image.crop((0, crop_top, width, crop_bottom))
else:
scaled_width = (height * self.width) // self.height
scaled_image = self.image.resize(
(scaled_width, height), Image.BILINEAR
)
crop_left = (scaled_width - width) // 2
crop_right = width + crop_left
cropped = scaled_image.crop((crop_left, 0, crop_right, height))
return self.save_image(cropped, output_type, output_path)
def save_image(self, output_image, output_type, output_path):
output_bytes_io = BytesIO()
output_image.save(output_bytes_io, self.FORMATS[output_type])
output_bytes = output_bytes_io.getvalue()
with open(output_path, "wb") as output_file:
output_file.write(output_bytes)
return len(output_bytes)

View File

@ -0,0 +1,113 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.http.server import respond_with_json
from synapse.util.stringutils import random_string
from synapse.api.errors import (
cs_exception, SynapseError, CodeMessageException
)
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
from .base_resource import BaseMediaResource
import logging
logger = logging.getLogger(__name__)
class UploadResource(BaseMediaResource):
def render_POST(self, request):
self._async_render_POST(request)
return NOT_DONE_YET
def render_OPTIONS(self, request):
respond_with_json(request, 200, {}, send_cors=True)
return NOT_DONE_YET
@defer.inlineCallbacks
def _async_render_POST(self, request):
try:
auth_user = yield self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point
content_length = request.getHeader("Content-Length")
if content_length is None:
raise SynapseError(
msg="Request must specify a Content-Length", code=400
)
if int(content_length) > self.max_upload_size:
raise SynapseError(
msg="Upload request body is too large",
code=413,
)
headers = request.requestHeaders
if headers.hasHeader("Content-Type"):
media_type = headers.getRawHeaders("Content-Type")[0]
else:
raise SynapseError(
msg="Upload request missing 'Content-Type'",
code=400,
)
#if headers.hasHeader("Content-Disposition"):
# disposition = headers.getRawHeaders("Content-Disposition")[0]
# TODO(markjh): parse content-dispostion
media_id = random_string(24)
fname = self.filepaths.local_media_filepath(media_id)
self._makedirs(fname)
# This shouldn't block for very long because the content will have
# already been uploaded at this point.
with open(fname, "wb") as f:
f.write(request.content.read())
yield self.store.store_local_media(
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
upload_name=None,
media_length=content_length,
user_id=auth_user,
)
media_info = {
"media_type": media_type,
"media_length": content_length,
}
yield self._generate_local_thumbnails(media_id, media_info)
content_uri = "mxc://%s/%s" % (self.server_name, media_id)
respond_with_json(
request, 200, {"content_uri": content_uri}, send_cors=True
)
except CodeMessageException as e:
logger.exception(e)
respond_with_json(request, e.code, cs_exception(e), send_cors=True)
except:
logger.exception("Failed to store file")
respond_with_json(
request,
500,
{"error": "Internal server error"},
send_cors=True
)

View File

@ -146,7 +146,11 @@ class Notifier(object):
Will wake up all listeners for the given users and rooms. Will wake up all listeners for the given users and rooms.
""" """
yield run_on_reactor() yield run_on_reactor()
# TODO(paul): This is horrible, having to manually list every event
# source here individually
presence_source = self.event_sources.sources["presence"] presence_source = self.event_sources.sources["presence"]
typing_source = self.event_sources.sources["typing"]
listeners = set() listeners = set()
@ -158,19 +162,33 @@ class Notifier(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def notify(listener): def notify(listener):
events, end_key = yield presence_source.get_new_events_for_user( presence_events, presence_end_key = (
listener.user, yield presence_source.get_new_events_for_user(
listener.from_token.presence_key, listener.user,
listener.limit, listener.from_token.presence_key,
listener.limit,
)
)
typing_events, typing_end_key = (
yield typing_source.get_new_events_for_user(
listener.user,
listener.from_token.typing_key,
listener.limit,
)
) )
if events: if presence_events or typing_events:
end_token = listener.from_token.copy_and_replace( end_token = listener.from_token.copy_and_replace(
"presence_key", end_key "presence_key", presence_end_key
).copy_and_replace(
"typing_key", typing_end_key
) )
listener.notify( listener.notify(
self, events, listener.from_token, end_token self,
presence_events + typing_events,
listener.from_token,
end_token
) )
def eb(failure): def eb(failure):

View File

@ -28,7 +28,7 @@ class RestServletFactory(object):
speaking, they serve as wrappers around events and the handlers that speaking, they serve as wrappers around events and the handlers that
process them. process them.
See synapse.api.events for information on synapse events. See synapse.events for information on synapse events.
""" """
def __init__(self, hs): def __init__(self, hs):

View File

@ -35,7 +35,7 @@ class WhoisRestServlet(RestServlet):
if not is_admin and target_user != auth_user: if not is_admin and target_user != auth_user:
raise AuthError(403, "You are not a server admin") raise AuthError(403, "You are not a server admin")
if not target_user.is_mine: if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only whois a local user") raise SynapseError(400, "Can only whois a local user")
ret = yield self.handlers.admin_handler.get_whois(target_user) ret = yield self.handlers.admin_handler.get_whois(target_user)

View File

@ -63,12 +63,10 @@ class RestServlet(object):
self.hs = hs self.hs = hs
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.event_factory = hs.get_event_factory() self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.txns = HttpTransactionStore() self.txns = HttpTransactionStore()
self.validator = hs.get_event_validator()
def register(self, http_server): def register(self, http_server):
""" Register this servlet with the given HTTP server. """ """ Register this servlet with the given HTTP server. """
if hasattr(self, "PATTERN"): if hasattr(self, "PATTERN"):

View File

@ -21,7 +21,6 @@ from base import RestServlet, client_path_pattern
import json import json
import logging import logging
import urllib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,9 +35,7 @@ class ClientDirectoryServer(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_alias): def on_GET(self, request, room_alias):
room_alias = self.hs.parse_roomalias( room_alias = self.hs.parse_roomalias(room_alias)
urllib.unquote(room_alias).decode("utf-8")
)
dir_handler = self.handlers.directory_handler dir_handler = self.handlers.directory_handler
res = yield dir_handler.get_association(room_alias) res = yield dir_handler.get_association(room_alias)
@ -56,9 +53,7 @@ class ClientDirectoryServer(RestServlet):
logger.debug("Got content: %s", content) logger.debug("Got content: %s", content)
room_alias = self.hs.parse_roomalias( room_alias = self.hs.parse_roomalias(room_alias)
urllib.unquote(room_alias).decode("utf-8")
)
logger.debug("Got room name: %s", room_alias.to_string()) logger.debug("Got room name: %s", room_alias.to_string())
@ -97,9 +92,7 @@ class ClientDirectoryServer(RestServlet):
dir_handler = self.handlers.directory_handler dir_handler = self.handlers.directory_handler
room_alias = self.hs.parse_roomalias( room_alias = self.hs.parse_roomalias(room_alias)
urllib.unquote(room_alias).decode("utf-8")
)
yield dir_handler.delete_association( yield dir_handler.delete_association(
user.to_string(), room_alias user.to_string(), room_alias

View File

@ -47,8 +47,8 @@ class LoginRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def do_password_login(self, login_submission): def do_password_login(self, login_submission):
if not login_submission["user"].startswith('@'): if not login_submission["user"].startswith('@'):
login_submission["user"] = UserID.create_local( login_submission["user"] = UserID.create(
login_submission["user"], self.hs).to_string() login_submission["user"], self.hs.hostname).to_string()
handler = self.handlers.login_handler handler = self.handlers.login_handler
token = yield handler.login( token = yield handler.login(

View File

@ -22,7 +22,6 @@ from base import RestServlet, client_path_pattern
import json import json
import logging import logging
import urllib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -33,7 +32,6 @@ class PresenceStatusRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request) auth_user = yield self.auth.get_user_by_req(request)
user_id = urllib.unquote(user_id)
user = self.hs.parse_userid(user_id) user = self.hs.parse_userid(user_id)
state = yield self.handlers.presence_handler.get_state( state = yield self.handlers.presence_handler.get_state(
@ -44,7 +42,6 @@ class PresenceStatusRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id): def on_PUT(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request) auth_user = yield self.auth.get_user_by_req(request)
user_id = urllib.unquote(user_id)
user = self.hs.parse_userid(user_id) user = self.hs.parse_userid(user_id)
state = {} state = {}
@ -80,10 +77,9 @@ class PresenceListRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request) auth_user = yield self.auth.get_user_by_req(request)
user_id = urllib.unquote(user_id)
user = self.hs.parse_userid(user_id) user = self.hs.parse_userid(user_id)
if not user.is_mine: if not self.hs.is_mine(user):
raise SynapseError(400, "User not hosted on this Home Server") raise SynapseError(400, "User not hosted on this Home Server")
if auth_user != user: if auth_user != user:
@ -101,10 +97,9 @@ class PresenceListRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id): def on_POST(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request) auth_user = yield self.auth.get_user_by_req(request)
user_id = urllib.unquote(user_id)
user = self.hs.parse_userid(user_id) user = self.hs.parse_userid(user_id)
if not user.is_mine: if not self.hs.is_mine(user):
raise SynapseError(400, "User not hosted on this Home Server") raise SynapseError(400, "User not hosted on this Home Server")
if auth_user != user: if auth_user != user:

View File

@ -19,7 +19,6 @@ from twisted.internet import defer
from base import RestServlet, client_path_pattern from base import RestServlet, client_path_pattern
import json import json
import urllib
class ProfileDisplaynameRestServlet(RestServlet): class ProfileDisplaynameRestServlet(RestServlet):
@ -27,7 +26,6 @@ class ProfileDisplaynameRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
user_id = urllib.unquote(user_id)
user = self.hs.parse_userid(user_id) user = self.hs.parse_userid(user_id)
displayname = yield self.handlers.profile_handler.get_displayname( displayname = yield self.handlers.profile_handler.get_displayname(
@ -39,7 +37,6 @@ class ProfileDisplaynameRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id): def on_PUT(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request) auth_user = yield self.auth.get_user_by_req(request)
user_id = urllib.unquote(user_id)
user = self.hs.parse_userid(user_id) user = self.hs.parse_userid(user_id)
try: try:
@ -62,7 +59,6 @@ class ProfileAvatarURLRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
user_id = urllib.unquote(user_id)
user = self.hs.parse_userid(user_id) user = self.hs.parse_userid(user_id)
avatar_url = yield self.handlers.profile_handler.get_avatar_url( avatar_url = yield self.handlers.profile_handler.get_avatar_url(
@ -74,7 +70,6 @@ class ProfileAvatarURLRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id): def on_PUT(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request) auth_user = yield self.auth.get_user_by_req(request)
user_id = urllib.unquote(user_id)
user = self.hs.parse_userid(user_id) user = self.hs.parse_userid(user_id)
try: try:
@ -97,7 +92,6 @@ class ProfileRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
user_id = urllib.unquote(user_id)
user = self.hs.parse_userid(user_id) user = self.hs.parse_userid(user_id)
displayname = yield self.handlers.profile_handler.get_displayname( displayname = yield self.handlers.profile_handler.get_displayname(

View File

@ -21,6 +21,8 @@ from synapse.api.constants import LoginType
from base import RestServlet, client_path_pattern from base import RestServlet, client_path_pattern
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
from synapse.util.async import run_on_reactor
from hashlib import sha1 from hashlib import sha1
import hmac import hmac
import json import json
@ -233,7 +235,7 @@ class RegisterRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_password(self, request, register_json, session): def _do_password(self, request, register_json, session):
yield yield run_on_reactor()
if (self.hs.config.enable_registration_captcha and if (self.hs.config.enable_registration_captcha and
not session[LoginType.RECAPTCHA]): not session[LoginType.RECAPTCHA]):
# captcha should've been done by this stage! # captcha should've been done by this stage!

View File

@ -19,8 +19,7 @@ from twisted.internet import defer
from base import RestServlet, client_path_pattern from base import RestServlet, client_path_pattern
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.api.events.room import RoomMemberEvent, RoomRedactionEvent from synapse.api.constants import EventTypes, Membership
from synapse.api.constants import Membership
import json import json
import logging import logging
@ -129,9 +128,9 @@ class RoomStateEventRestServlet(RestServlet):
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
data = yield msg_handler.get_room_data( data = yield msg_handler.get_room_data(
user_id=user.to_string(), user_id=user.to_string(),
room_id=urllib.unquote(room_id), room_id=room_id,
event_type=urllib.unquote(event_type), event_type=event_type,
state_key=urllib.unquote(state_key), state_key=state_key,
) )
if not data: if not data:
@ -143,32 +142,23 @@ class RoomStateEventRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, state_key): def on_PUT(self, request, room_id, event_type, state_key):
user = yield self.auth.get_user_by_req(request) user = yield self.auth.get_user_by_req(request)
event_type = urllib.unquote(event_type)
content = _parse_json(request) content = _parse_json(request)
event = self.event_factory.create_event( event_dict = {
etype=event_type, # already urldecoded "type": event_type,
content=content, "content": content,
room_id=urllib.unquote(room_id), "room_id": room_id,
user_id=user.to_string(), "sender": user.to_string(),
state_key=urllib.unquote(state_key) }
)
self.validator.validate(event) if state_key is not None:
event_dict["state_key"] = state_key
if event_type == RoomMemberEvent.TYPE: msg_handler = self.handlers.message_handler
# membership events are special yield msg_handler.create_and_send_event(event_dict)
handler = self.handlers.room_member_handler
yield handler.change_membership(event) defer.returnValue((200, {}))
defer.returnValue((200, {}))
else:
# store random bits of state
msg_handler = self.handlers.message_handler
yield msg_handler.store_room_data(
event=event
)
defer.returnValue((200, {}))
# TODO: Needs unit testing for generic events + feedback # TODO: Needs unit testing for generic events + feedback
@ -184,17 +174,15 @@ class RoomSendEventRestServlet(RestServlet):
user = yield self.auth.get_user_by_req(request) user = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
event = self.event_factory.create_event(
etype=urllib.unquote(event_type),
room_id=urllib.unquote(room_id),
user_id=user.to_string(),
content=content
)
self.validator.validate(event)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
yield msg_handler.send_message(event) event = yield msg_handler.create_and_send_event(
{
"type": event_type,
"content": content,
"room_id": room_id,
"sender": user.to_string(),
}
)
defer.returnValue((200, {"event_id": event.event_id})) defer.returnValue((200, {"event_id": event.event_id}))
@ -235,14 +223,10 @@ class JoinRoomAliasServlet(RestServlet):
identifier = None identifier = None
is_room_alias = False is_room_alias = False
try: try:
identifier = self.hs.parse_roomalias( identifier = self.hs.parse_roomalias(room_identifier)
urllib.unquote(room_identifier)
)
is_room_alias = True is_room_alias = True
except SynapseError: except SynapseError:
identifier = self.hs.parse_roomid( identifier = self.hs.parse_roomid(room_identifier)
urllib.unquote(room_identifier)
)
# TODO: Support for specifying the home server to join with? # TODO: Support for specifying the home server to join with?
@ -251,18 +235,17 @@ class JoinRoomAliasServlet(RestServlet):
ret_dict = yield handler.join_room_alias(user, identifier) ret_dict = yield handler.join_room_alias(user, identifier)
defer.returnValue((200, ret_dict)) defer.returnValue((200, ret_dict))
else: # room id else: # room id
event = self.event_factory.create_event( msg_handler = self.handlers.message_handler
etype=RoomMemberEvent.TYPE, yield msg_handler.create_and_send_event(
content={"membership": Membership.JOIN}, {
room_id=urllib.unquote(identifier.to_string()), "type": EventTypes.Member,
user_id=user.to_string(), "content": {"membership": Membership.JOIN},
state_key=user.to_string() "room_id": identifier.to_string(),
"sender": user.to_string(),
"state_key": user.to_string(),
}
) )
self.validator.validate(event)
handler = self.handlers.room_member_handler
yield handler.change_membership(event)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -301,7 +284,7 @@ class RoomMemberListRestServlet(RestServlet):
user = yield self.auth.get_user_by_req(request) user = yield self.auth.get_user_by_req(request)
handler = self.handlers.room_member_handler handler = self.handlers.room_member_handler
members = yield handler.get_room_members_as_pagination_chunk( members = yield handler.get_room_members_as_pagination_chunk(
room_id=urllib.unquote(room_id), room_id=room_id,
user_id=user.to_string()) user_id=user.to_string())
for event in members["chunk"]: for event in members["chunk"]:
@ -327,13 +310,13 @@ class RoomMessageListRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
user = yield self.auth.get_user_by_req(request) user = yield self.auth.get_user_by_req(request)
pagination_config = PaginationConfig.from_request(request, pagination_config = PaginationConfig.from_request(
default_limit=10, request, default_limit=10,
) )
with_feedback = "feedback" in request.args with_feedback = "feedback" in request.args
handler = self.handlers.message_handler handler = self.handlers.message_handler
msgs = yield handler.get_messages( msgs = yield handler.get_messages(
room_id=urllib.unquote(room_id), room_id=room_id,
user_id=user.to_string(), user_id=user.to_string(),
pagin_config=pagination_config, pagin_config=pagination_config,
feedback=with_feedback) feedback=with_feedback)
@ -351,7 +334,7 @@ class RoomStateRestServlet(RestServlet):
handler = self.handlers.message_handler handler = self.handlers.message_handler
# Get all the current state for this room # Get all the current state for this room
events = yield handler.get_state_events( events = yield handler.get_state_events(
room_id=urllib.unquote(room_id), room_id=room_id,
user_id=user.to_string(), user_id=user.to_string(),
) )
defer.returnValue((200, events)) defer.returnValue((200, events))
@ -366,7 +349,7 @@ class RoomInitialSyncRestServlet(RestServlet):
user = yield self.auth.get_user_by_req(request) user = yield self.auth.get_user_by_req(request)
pagination_config = PaginationConfig.from_request(request) pagination_config = PaginationConfig.from_request(request)
content = yield self.handlers.message_handler.room_initial_sync( content = yield self.handlers.message_handler.room_initial_sync(
room_id=urllib.unquote(room_id), room_id=room_id,
user_id=user.to_string(), user_id=user.to_string(),
pagin_config=pagination_config, pagin_config=pagination_config,
) )
@ -378,8 +361,10 @@ class RoomTriggerBackfill(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
remote_server = urllib.unquote(request.args["remote"][0]) remote_server = urllib.unquote(
room_id = urllib.unquote(room_id) request.args["remote"][0]
).decode("UTF-8")
limit = int(request.args["limit"][0]) limit = int(request.args["limit"][0])
handler = self.handlers.federation_handler handler = self.handlers.federation_handler
@ -414,18 +399,17 @@ class RoomMembershipRestServlet(RestServlet):
if membership_action == "kick": if membership_action == "kick":
membership_action = "leave" membership_action = "leave"
event = self.event_factory.create_event( msg_handler = self.handlers.message_handler
etype=RoomMemberEvent.TYPE, yield msg_handler.create_and_send_event(
content={"membership": unicode(membership_action)}, {
room_id=urllib.unquote(room_id), "type": EventTypes.Member,
user_id=user.to_string(), "content": {"membership": unicode(membership_action)},
state_key=state_key "room_id": room_id,
"sender": user.to_string(),
"state_key": state_key,
}
) )
self.validator.validate(event)
handler = self.handlers.room_member_handler
yield handler.change_membership(event)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -453,18 +437,16 @@ class RoomRedactEventRestServlet(RestServlet):
user = yield self.auth.get_user_by_req(request) user = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
event = self.event_factory.create_event(
etype=RoomRedactionEvent.TYPE,
room_id=urllib.unquote(room_id),
user_id=user.to_string(),
content=content,
redacts=urllib.unquote(event_id),
)
self.validator.validate(event)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
yield msg_handler.send_message(event) event = yield msg_handler.create_and_send_event(
{
"type": EventTypes.Redaction,
"content": content,
"room_id": room_id,
"sender": user.to_string(),
"redacts": event_id,
}
)
defer.returnValue((200, {"event_id": event.event_id})) defer.returnValue((200, {"event_id": event.event_id}))
@ -483,6 +465,39 @@ class RoomRedactEventRestServlet(RestServlet):
defer.returnValue(response) defer.returnValue(response)
class RoomTypingRestServlet(RestServlet):
PATTERN = client_path_pattern(
"/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$"
)
@defer.inlineCallbacks
def on_PUT(self, request, room_id, user_id):
auth_user = yield self.auth.get_user_by_req(request)
room_id = urllib.unquote(room_id)
target_user = self.hs.parse_userid(urllib.unquote(user_id))
content = _parse_json(request)
typing_handler = self.handlers.typing_notification_handler
if content["typing"]:
yield typing_handler.started_typing(
target_user=target_user,
auth_user=auth_user,
room_id=room_id,
timeout=content.get("timeout", 30000),
)
else:
yield typing_handler.stopped_typing(
target_user=target_user,
auth_user=auth_user,
room_id=room_id,
)
defer.returnValue((200, {}))
def _parse_json(request): def _parse_json(request):
try: try:
content = json.loads(request.content.read()) content = json.loads(request.content.read())
@ -538,3 +553,4 @@ def register_servlets(hs, http_server):
RoomStateRestServlet(hs).register(http_server) RoomStateRestServlet(hs).register(http_server)
RoomInitialSyncRestServlet(hs).register(http_server) RoomInitialSyncRestServlet(hs).register(http_server)
RoomRedactEventRestServlet(hs).register(http_server) RoomRedactEventRestServlet(hs).register(http_server)
RoomTypingRestServlet(hs).register(http_server)

View File

@ -20,6 +20,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# FIXME: elsewhere we use FooStore to indicate something in the storage layer...
class HttpTransactionStore(object): class HttpTransactionStore(object):
def __init__(self): def __init__(self):

View File

@ -20,9 +20,7 @@
# Imports required for the default HomeServer() implementation # Imports required for the default HomeServer() implementation
from synapse.federation import initialize_http_replication from synapse.federation import initialize_http_replication
from synapse.api.events import serialize_event from synapse.events.utils import serialize_event
from synapse.api.events.factory import EventFactory
from synapse.api.events.validator import EventValidator
from synapse.notifier import Notifier from synapse.notifier import Notifier
from synapse.api.auth import Auth from synapse.api.auth import Auth
from synapse.handlers import Handlers from synapse.handlers import Handlers
@ -36,6 +34,7 @@ from synapse.util.lockutils import LockManager
from synapse.streams.events import EventSources from synapse.streams.events import EventSources
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.crypto.keyring import Keyring from synapse.crypto.keyring import Keyring
from synapse.events.builder import EventBuilderFactory
class BaseHomeServer(object): class BaseHomeServer(object):
@ -65,7 +64,6 @@ class BaseHomeServer(object):
'persistence_service', 'persistence_service',
'replication_layer', 'replication_layer',
'datastore', 'datastore',
'event_factory',
'handlers', 'handlers',
'auth', 'auth',
'rest_servlet_factory', 'rest_servlet_factory',
@ -78,10 +76,11 @@ class BaseHomeServer(object):
'resource_for_web_client', 'resource_for_web_client',
'resource_for_content_repo', 'resource_for_content_repo',
'resource_for_server_key', 'resource_for_server_key',
'resource_for_media_repository',
'event_sources', 'event_sources',
'ratelimiter', 'ratelimiter',
'keyring', 'keyring',
'event_validator', 'event_builder_factory',
] ]
def __init__(self, hostname, **kwargs): def __init__(self, hostname, **kwargs):
@ -133,22 +132,22 @@ class BaseHomeServer(object):
def parse_userid(self, s): def parse_userid(self, s):
"""Parse the string given by 's' as a User ID and return a UserID """Parse the string given by 's' as a User ID and return a UserID
object.""" object."""
return UserID.from_string(s, hs=self) return UserID.from_string(s)
def parse_roomalias(self, s): def parse_roomalias(self, s):
"""Parse the string given by 's' as a Room Alias and return a RoomAlias """Parse the string given by 's' as a Room Alias and return a RoomAlias
object.""" object."""
return RoomAlias.from_string(s, hs=self) return RoomAlias.from_string(s)
def parse_roomid(self, s): def parse_roomid(self, s):
"""Parse the string given by 's' as a Room ID and return a RoomID """Parse the string given by 's' as a Room ID and return a RoomID
object.""" object."""
return RoomID.from_string(s, hs=self) return RoomID.from_string(s)
def parse_eventid(self, s): def parse_eventid(self, s):
"""Parse the string given by 's' as a Event ID and return a EventID """Parse the string given by 's' as a Event ID and return a EventID
object.""" object."""
return EventID.from_string(s, hs=self) return EventID.from_string(s)
def serialize_event(self, e): def serialize_event(self, e):
return serialize_event(self, e) return serialize_event(self, e)
@ -165,6 +164,9 @@ class BaseHomeServer(object):
return ip_addr return ip_addr
def is_mine(self, domain_specific_string):
return domain_specific_string.domain == self.hostname
# Build magic accessors for every dependency # Build magic accessors for every dependency
for depname in BaseHomeServer.DEPENDENCIES: for depname in BaseHomeServer.DEPENDENCIES:
BaseHomeServer._make_dependency_method(depname) BaseHomeServer._make_dependency_method(depname)
@ -192,9 +194,6 @@ class HomeServer(BaseHomeServer):
def build_datastore(self): def build_datastore(self):
return DataStore(self) return DataStore(self)
def build_event_factory(self):
return EventFactory(self)
def build_handlers(self): def build_handlers(self):
return Handlers(self) return Handlers(self)
@ -225,8 +224,11 @@ class HomeServer(BaseHomeServer):
def build_keyring(self): def build_keyring(self):
return Keyring(self) return Keyring(self)
def build_event_validator(self): def build_event_builder_factory(self):
return EventValidator(self) return EventBuilderFactory(
clock=self.get_clock(),
hostname=self.hostname,
)
def register_servlets(self): def register_servlets(self):
""" Register all servlets associated with this HomeServer. """ Register all servlets associated with this HomeServer.

View File

@ -18,11 +18,11 @@ from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.api.events.room import RoomPowerLevelsEvent from synapse.api.constants import EventTypes
from synapse.events.snapshot import EventContext
from collections import namedtuple from collections import namedtuple
import copy
import logging import logging
import hashlib import hashlib
@ -43,71 +43,6 @@ class StateHandler(object):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks
@log_function
def annotate_event_with_state(self, event, old_state=None):
""" Annotates the event with the current state events as of that event.
This method adds three new attributes to the event:
* `state_events`: The state up to and including the event. Encoded
as a dict mapping tuple (type, state_key) -> event.
* `old_state_events`: The state up to, but excluding, the event.
Encoded similarly as `state_events`.
* `state_group`: If there is an existing state group that can be
used, then return that. Otherwise return `None`. See state
storage for more information.
If the argument `old_state` is given (in the form of a list of
events), then they are used as a the values for `old_state_events` and
the value for `state_events` is generated from it. `state_group` is
set to None.
This needs to be called before persisting the event.
"""
yield run_on_reactor()
if old_state:
event.state_group = None
event.old_state_events = {
(s.type, s.state_key): s for s in old_state
}
event.state_events = event.old_state_events
if hasattr(event, "state_key"):
event.state_events[(event.type, event.state_key)] = event
defer.returnValue(False)
return
if hasattr(event, "outlier") and event.outlier:
event.state_group = None
event.old_state_events = None
event.state_events = None
defer.returnValue(False)
return
ids = [e for e, _ in event.prev_events]
ret = yield self.resolve_state_groups(ids)
state_group, new_state = ret
event.old_state_events = copy.deepcopy(new_state)
if hasattr(event, "state_key"):
key = (event.type, event.state_key)
if key in new_state:
event.replaces_state = new_state[key].event_id
new_state[key] = event
elif state_group:
event.state_group = state_group
event.state_events = new_state
defer.returnValue(False)
event.state_group = None
event.state_events = new_state
defer.returnValue(hasattr(event, "state_key"))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""): def get_current_state(self, room_id, event_type=None, state_key=""):
""" Returns the current state for the room as a list. This is done by """ Returns the current state for the room as a list. This is done by
@ -135,9 +70,92 @@ class StateHandler(object):
defer.returnValue(res[1].values()) defer.returnValue(res[1].values())
@defer.inlineCallbacks
def compute_event_context(self, event, old_state=None):
""" Fills out the context with the `current state` of the graph. The
`current state` here is defined to be the state of the event graph
just before the event - i.e. it never includes `event`
If `event` has `auth_events` then this will also fill out the
`auth_events` field on `context` from the `current_state`.
Args:
event (EventBase)
Returns:
an EventContext
"""
context = EventContext()
yield run_on_reactor()
if old_state:
context.current_state = {
(s.type, s.state_key): s for s in old_state
}
context.state_group = None
if hasattr(event, "auth_events") and event.auth_events:
auth_ids = zip(*event.auth_events)[0]
context.auth_events = {
k: v
for k, v in context.current_state.items()
if v.event_id in auth_ids
}
else:
context.auth_events = {}
if event.is_state():
key = (event.type, event.state_key)
if key in context.current_state:
replaces = context.current_state[key]
if replaces.event_id != event.event_id: # Paranoia check
event.unsigned["replaces_state"] = replaces.event_id
context.prev_state_events = []
defer.returnValue(context)
if event.is_state():
ret = yield self.resolve_state_groups(
[e for e, _ in event.prev_events],
event_type=event.type,
state_key=event.state_key,
)
else:
ret = yield self.resolve_state_groups(
[e for e, _ in event.prev_events],
)
group, curr_state, prev_state = ret
context.current_state = curr_state
context.state_group = group if not event.is_state() else None
prev_state = yield self.store.add_event_hashes(
prev_state
)
if event.is_state():
key = (event.type, event.state_key)
if key in context.current_state:
replaces = context.current_state[key]
event.unsigned["replaces_state"] = replaces.event_id
if hasattr(event, "auth_events") and event.auth_events:
auth_ids = zip(*event.auth_events)[0]
context.auth_events = {
k: v
for k, v in context.current_state.items()
if v.event_id in auth_ids
}
else:
context.auth_events = {}
context.prev_state_events = prev_state
defer.returnValue(context)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def resolve_state_groups(self, event_ids): def resolve_state_groups(self, event_ids, event_type=None, state_key=""):
""" Given a list of event_ids this method fetches the state at each """ Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them. event, resolves conflicts between them and returns them.
@ -156,7 +174,14 @@ class StateHandler(object):
(e.type, e.state_key): e (e.type, e.state_key): e
for e in state_list for e in state_list
} }
defer.returnValue((name, state)) prev_state = state.get((event_type, state_key), None)
if prev_state:
prev_state = prev_state.event_id
prev_states = [prev_state]
else:
prev_states = []
defer.returnValue((name, state, prev_states))
state = {} state = {}
for group, g_state in state_groups.items(): for group, g_state in state_groups.items():
@ -177,6 +202,14 @@ class StateHandler(object):
if len(v.values()) > 1 if len(v.values()) > 1
} }
if event_type:
prev_states_events = conflicted_state.get(
(event_type, state_key), []
)
prev_states = [s.event_id for s in prev_states_events]
else:
prev_states = []
try: try:
new_state = {} new_state = {}
new_state.update(unconflicted_state) new_state.update(unconflicted_state)
@ -186,11 +219,11 @@ class StateHandler(object):
logger.exception("Failed to resolve state") logger.exception("Failed to resolve state")
raise raise
defer.returnValue((None, new_state)) defer.returnValue((None, new_state, prev_states))
def _get_power_level_from_event_state(self, event, user_id): def _get_power_level_from_event_state(self, event, user_id):
if hasattr(event, "old_state_events") and event.old_state_events: if hasattr(event, "old_state_events") and event.old_state_events:
key = (RoomPowerLevelsEvent.TYPE, "", ) key = (EventTypes.PowerLevels, "", )
power_level_event = event.old_state_events.get(key) power_level_event = event.old_state_events.get(key)
level = None level = None
if power_level_event: if power_level_event:

View File

@ -15,12 +15,8 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.events.room import (
RoomMemberEvent, RoomTopicEvent, FeedbackEvent, RoomNameEvent,
RoomRedactionEvent,
)
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes
from .directory import DirectoryStore from .directory import DirectoryStore
from .feedback import FeedbackStore from .feedback import FeedbackStore
@ -33,11 +29,13 @@ from .stream import StreamStore
from .transactions import TransactionStore from .transactions import TransactionStore
from .keys import KeyStore from .keys import KeyStore
from .event_federation import EventFederationStore from .event_federation import EventFederationStore
from .media_repository import MediaRepositoryStore
from .state import StateStore from .state import StateStore
from .signatures import SignatureStore from .signatures import SignatureStore
from syutil.base64util import decode_base64 from syutil.base64util import decode_base64
from syutil.jsonutil import encode_canonical_json
from synapse.crypto.event_signing import compute_event_reference_hash from synapse.crypto.event_signing import compute_event_reference_hash
@ -62,12 +60,13 @@ SCHEMAS = [
"state", "state",
"event_edges", "event_edges",
"event_signatures", "event_signatures",
"media_repository",
] ]
# Remember to update this number every time an incompatible change is made to # Remember to update this number every time an incompatible change is made to
# database schema files, so the users will be informed on server restarts. # database schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 8 SCHEMA_VERSION = 10
class _RollbackButIsFineException(Exception): class _RollbackButIsFineException(Exception):
@ -81,11 +80,12 @@ class DataStore(RoomMemberStore, RoomStore,
RegistrationStore, StreamStore, ProfileStore, FeedbackStore, RegistrationStore, StreamStore, ProfileStore, FeedbackStore,
PresenceStore, TransactionStore, PresenceStore, TransactionStore,
DirectoryStore, KeyStore, StateStore, SignatureStore, DirectoryStore, KeyStore, StateStore, SignatureStore,
EventFederationStore, ): EventFederationStore,
MediaRepositoryStore,
):
def __init__(self, hs): def __init__(self, hs):
super(DataStore, self).__init__(hs) super(DataStore, self).__init__(hs)
self.event_factory = hs.get_event_factory()
self.hs = hs self.hs = hs
self.min_token_deferred = self._get_min_token() self.min_token_deferred = self._get_min_token()
@ -93,8 +93,8 @@ class DataStore(RoomMemberStore, RoomStore,
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def persist_event(self, event, backfilled=False, is_new_state=True, def persist_event(self, event, context, backfilled=False,
current_state=None): is_new_state=True, current_state=None):
stream_ordering = None stream_ordering = None
if backfilled: if backfilled:
if not self.min_token_deferred.called: if not self.min_token_deferred.called:
@ -107,6 +107,7 @@ class DataStore(RoomMemberStore, RoomStore,
"persist_event", "persist_event",
self._persist_event_txn, self._persist_event_txn,
event=event, event=event,
context=context,
backfilled=backfilled, backfilled=backfilled,
stream_ordering=stream_ordering, stream_ordering=stream_ordering,
is_new_state=is_new_state, is_new_state=is_new_state,
@ -117,50 +118,64 @@ class DataStore(RoomMemberStore, RoomStore,
@defer.inlineCallbacks @defer.inlineCallbacks
def get_event(self, event_id, allow_none=False): def get_event(self, event_id, allow_none=False):
events_dict = yield self._simple_select_one( events = yield self._get_events([event_id])
"events",
{"event_id": event_id},
[
"event_id",
"type",
"room_id",
"content",
"unrecognized_keys",
"depth",
],
allow_none=allow_none,
)
if not events_dict: if not events:
defer.returnValue(None) if allow_none:
defer.returnValue(None)
else:
raise RuntimeError("Could not find event %s" % (event_id,))
event = yield self._parse_events([events_dict]) defer.returnValue(events[0])
defer.returnValue(event[0])
@log_function @log_function
def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None, def _persist_event_txn(self, txn, event, context, backfilled,
is_new_state=True, current_state=None): stream_ordering=None, is_new_state=True,
if event.type == RoomMemberEvent.TYPE: current_state=None):
if event.type == EventTypes.Member:
self._store_room_member_txn(txn, event) self._store_room_member_txn(txn, event)
elif event.type == FeedbackEvent.TYPE: elif event.type == EventTypes.Feedback:
self._store_feedback_txn(txn, event) self._store_feedback_txn(txn, event)
elif event.type == RoomNameEvent.TYPE: elif event.type == EventTypes.Name:
self._store_room_name_txn(txn, event) self._store_room_name_txn(txn, event)
elif event.type == RoomTopicEvent.TYPE: elif event.type == EventTypes.Topic:
self._store_room_topic_txn(txn, event) self._store_room_topic_txn(txn, event)
elif event.type == RoomRedactionEvent.TYPE: elif event.type == EventTypes.Redaction:
self._store_redaction(txn, event) self._store_redaction(txn, event)
outlier = False outlier = event.internal_metadata.is_outlier()
if hasattr(event, "outlier"):
outlier = event.outlier event_dict = {
k: v
for k, v in event.get_dict().items()
if k not in [
"redacted",
"redacted_because",
]
}
metadata_json = encode_canonical_json(
event.internal_metadata.get_dict()
)
self._simple_insert_txn(
txn,
table="event_json",
values={
"event_id": event.event_id,
"room_id": event.room_id,
"internal_metadata": metadata_json.decode("UTF-8"),
"json": encode_canonical_json(event_dict).decode("UTF-8"),
},
or_replace=True,
)
vals = { vals = {
"topological_ordering": event.depth, "topological_ordering": event.depth,
"event_id": event.event_id, "event_id": event.event_id,
"type": event.type, "type": event.type,
"room_id": event.room_id, "room_id": event.room_id,
"content": json.dumps(event.content), "content": json.dumps(event.get_dict()["content"]),
"processed": True, "processed": True,
"outlier": outlier, "outlier": outlier,
"depth": event.depth, "depth": event.depth,
@ -171,7 +186,7 @@ class DataStore(RoomMemberStore, RoomStore,
unrec = { unrec = {
k: v k: v
for k, v in event.get_full_dict().items() for k, v in event.get_dict().items()
if k not in vals.keys() and k not in [ if k not in vals.keys() and k not in [
"redacted", "redacted",
"redacted_because", "redacted_because",
@ -206,7 +221,8 @@ class DataStore(RoomMemberStore, RoomStore,
room_id=event.room_id, room_id=event.room_id,
) )
self._store_state_groups_txn(txn, event) if not outlier:
self._store_state_groups_txn(txn, event, context)
if current_state: if current_state:
txn.execute( txn.execute(
@ -300,16 +316,6 @@ class DataStore(RoomMemberStore, RoomStore,
txn, event.event_id, hash_alg, hash_bytes, txn, event.event_id, hash_alg, hash_bytes,
) )
if hasattr(event, "signatures"):
logger.debug("sigs: %s", event.signatures)
for name, sigs in event.signatures.items():
for key_id, signature_base64 in sigs.items():
signature_bytes = decode_base64(signature_base64)
self._store_event_signature_txn(
txn, event.event_id, name, key_id,
signature_bytes,
)
for prev_event_id, prev_hashes in event.prev_events: for prev_event_id, prev_hashes in event.prev_events:
for alg, hash_base64 in prev_hashes.items(): for alg, hash_base64 in prev_hashes.items():
hash_bytes = decode_base64(hash_base64) hash_bytes = decode_base64(hash_base64)
@ -411,86 +417,6 @@ class DataStore(RoomMemberStore, RoomStore,
], ],
) )
def snapshot_room(self, event):
"""Snapshot the room for an update by a user
Args:
room_id (synapse.types.RoomId): The room to snapshot.
user_id (synapse.types.UserId): The user to snapshot the room for.
state_type (str): Optional state type to snapshot.
state_key (str): Optional state key to snapshot.
Returns:
synapse.storage.Snapshot: A snapshot of the state of the room.
"""
def _snapshot(txn):
prev_events = self._get_latest_events_in_room(
txn,
event.room_id
)
prev_state = None
state_key = None
if hasattr(event, "state_key"):
state_key = event.state_key
prev_state = self._get_latest_state_in_room(
txn,
event.room_id,
type=event.type,
state_key=state_key,
)
return Snapshot(
store=self,
room_id=event.room_id,
user_id=event.user_id,
prev_events=prev_events,
prev_state=prev_state,
state_type=event.type,
state_key=state_key,
)
return self.runInteraction("snapshot_room", _snapshot)
class Snapshot(object):
"""Snapshot of the state of a room
Args:
store (DataStore): The datastore.
room_id (RoomId): The room of the snapshot.
user_id (UserId): The user this snapshot is for.
prev_events (list): The list of event ids this snapshot is after.
membership_state (RoomMemberEvent): The current state of the user in
the room.
state_type (str, optional): State type captured by the snapshot
state_key (str, optional): State key captured by the snapshot
prev_state_pdu (PduEntry, optional): pdu id of
the previous value of the state type and key in the room.
"""
def __init__(self, store, room_id, user_id, prev_events,
prev_state, state_type=None, state_key=None):
self.store = store
self.room_id = room_id
self.user_id = user_id
self.prev_events = prev_events
self.prev_state = prev_state
self.state_type = state_type
self.state_key = state_key
def fill_out_prev_events(self, event):
if not hasattr(event, "prev_events"):
event.prev_events = [
(event_id, hashes)
for event_id, hashes, _ in self.prev_events
]
if self.prev_events:
event.depth = max([int(v) for _, _, v in self.prev_events]) + 1
else:
event.depth = 0
if not hasattr(event, "prev_state") and self.prev_state is not None:
event.prev_state = self.prev_state
def schema_path(schema): def schema_path(schema):
""" Get a filesystem path for the named database schema """ Get a filesystem path for the named database schema
@ -518,6 +444,14 @@ def read_schema(schema):
return schema_file.read() return schema_file.read()
class PrepareDatabaseException(Exception):
pass
class UpgradeDatabaseException(PrepareDatabaseException):
pass
def prepare_database(db_conn): def prepare_database(db_conn):
""" Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we """ Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we
don't have to worry about overwriting existing content. don't have to worry about overwriting existing content.
@ -542,6 +476,10 @@ def prepare_database(db_conn):
# Run every version since after the current version. # Run every version since after the current version.
for v in range(user_version + 1, SCHEMA_VERSION + 1): for v in range(user_version + 1, SCHEMA_VERSION + 1):
if v == 10:
raise UpgradeDatabaseException(
"No delta for version 10"
)
sql_script = read_schema("delta/v%d" % (v)) sql_script = read_schema("delta/v%d" % (v))
c.executescript(sql_script) c.executescript(sql_script)

View File

@ -15,15 +15,14 @@
import logging import logging
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.api.events.utils import prune_event from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext, LoggingContext from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
from syutil.base64util import encode_base64
from twisted.internet import defer from twisted.internet import defer
import collections import collections
import copy
import json import json
import sys import sys
import time import time
@ -84,7 +83,6 @@ class SQLBaseStore(object):
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
self._db_pool = hs.get_db_pool() self._db_pool = hs.get_db_pool()
self.event_factory = hs.get_event_factory()
self._clock = hs.get_clock() self._clock = hs.get_clock()
@defer.inlineCallbacks @defer.inlineCallbacks
@ -436,42 +434,69 @@ class SQLBaseStore(object):
return self.runInteraction("_simple_max_id", func) return self.runInteraction("_simple_max_id", func)
def _parse_event_from_row(self, row_dict): def _get_events(self, event_ids):
d = copy.deepcopy({k: v for k, v in row_dict.items()}) return self.runInteraction(
"_get_events", self._get_events_txn, event_ids
d.pop("stream_ordering", None)
d.pop("topological_ordering", None)
d.pop("processed", None)
d["origin_server_ts"] = d.pop("ts", 0)
replaces_state = d.pop("prev_state", None)
if replaces_state:
d["replaces_state"] = replaces_state
d.update(json.loads(row_dict["unrecognized_keys"]))
d["content"] = json.loads(d["content"])
del d["unrecognized_keys"]
if "age_ts" not in d:
# For compatibility
d["age_ts"] = d.get("origin_server_ts", 0)
return self.event_factory.create_event(
etype=d["type"],
**d
) )
def _get_events_txn(self, txn, event_ids): def _get_events_txn(self, txn, event_ids):
# FIXME (erikj): This should be batched? events = []
sql = "SELECT * FROM events WHERE event_id = ? ORDER BY rowid asc"
event_rows = []
for e_id in event_ids: for e_id in event_ids:
c = txn.execute(sql, (e_id,)) ev = self._get_event_txn(txn, e_id)
event_rows.extend(self.cursor_to_dict(c))
return self._parse_events_txn(txn, event_rows) if ev:
events.append(ev)
return events
def _get_event_txn(self, txn, event_id, check_redacted=True,
get_prev_content=True):
sql = (
"SELECT internal_metadata, json, r.event_id FROM event_json as e "
"LEFT JOIN redactions as r ON e.event_id = r.redacts "
"WHERE e.event_id = ? "
"LIMIT 1 "
)
txn.execute(sql, (event_id,))
res = txn.fetchone()
if not res:
return None
internal_metadata, js, redacted = res
d = json.loads(js)
internal_metadata = json.loads(internal_metadata)
ev = FrozenEvent(d, internal_metadata_dict=internal_metadata)
if check_redacted and redacted:
ev = prune_event(ev)
ev.unsigned["redacted_by"] = redacted
# Get the redaction event.
because = self._get_event_txn(
txn,
redacted,
check_redacted=False
)
if because:
ev.unsigned["redacted_because"] = because
if get_prev_content and "replaces_state" in ev.unsigned:
prev = self._get_event_txn(
txn,
ev.unsigned["replaces_state"],
get_prev_content=False,
)
if prev:
ev.unsigned["prev_content"] = prev.get_dict()["content"]
return ev
def _parse_events(self, rows): def _parse_events(self, rows):
return self.runInteraction( return self.runInteraction(
@ -479,80 +504,9 @@ class SQLBaseStore(object):
) )
def _parse_events_txn(self, txn, rows): def _parse_events_txn(self, txn, rows):
events = [self._parse_event_from_row(r) for r in rows] event_ids = [r["event_id"] for r in rows]
select_event_sql = ( return self._get_events_txn(txn, event_ids)
"SELECT * FROM events WHERE event_id = ? ORDER BY rowid asc"
)
for i, ev in enumerate(events):
signatures = self._get_event_signatures_txn(
txn, ev.event_id,
)
ev.signatures = {
n: {
k: encode_base64(v) for k, v in s.items()
}
for n, s in signatures.items()
}
hashes = self._get_event_content_hashes_txn(
txn, ev.event_id,
)
ev.hashes = {
k: encode_base64(v) for k, v in hashes.items()
}
prevs = self._get_prev_events_and_state(txn, ev.event_id)
ev.prev_events = [
(e_id, h)
for e_id, h, is_state in prevs
if is_state == 0
]
ev.auth_events = self._get_auth_events(txn, ev.event_id)
if hasattr(ev, "state_key"):
ev.prev_state = [
(e_id, h)
for e_id, h, is_state in prevs
if is_state == 1
]
if hasattr(ev, "replaces_state"):
# Load previous state_content.
# FIXME (erikj): Handle multiple prev_states.
cursor = txn.execute(
select_event_sql,
(ev.replaces_state,)
)
prevs = self.cursor_to_dict(cursor)
if prevs:
prev = self._parse_event_from_row(prevs[0])
ev.prev_content = prev.content
if not hasattr(ev, "redacted"):
logger.debug("Doesn't have redacted key: %s", ev)
ev.redacted = self._has_been_redacted_txn(txn, ev)
if ev.redacted:
# Get the redaction event.
select_event_sql = "SELECT * FROM events WHERE event_id = ?"
txn.execute(select_event_sql, (ev.redacted,))
del_evs = self._parse_events_txn(
txn, self.cursor_to_dict(txn)
)
if del_evs:
ev = prune_event(ev)
events[i] = ev
ev.redacted_because = del_evs[0]
return events
def _has_been_redacted_txn(self, txn, event): def _has_been_redacted_txn(self, txn, event):
sql = "SELECT event_id FROM redactions WHERE redacts = ?" sql = "SELECT event_id FROM redactions WHERE redacts = ?"
@ -650,7 +604,7 @@ class JoinHelper(object):
to dump the results into. to dump the results into.
Attributes: Attributes:
taples (list): List of `Table` classes tables (list): List of `Table` classes
EntryType (type) EntryType (type)
""" """

View File

@ -32,39 +32,33 @@ class EventFederationStore(SQLBaseStore):
and backfilling from another server respectively. and backfilling from another server respectively.
""" """
def get_auth_chain(self, event_id): def get_auth_chain(self, event_ids):
return self.runInteraction( return self.runInteraction(
"get_auth_chain", "get_auth_chain",
self._get_auth_chain_txn, self._get_auth_chain_txn,
event_id event_ids
) )
def _get_auth_chain_txn(self, txn, event_id): def _get_auth_chain_txn(self, txn, event_ids):
results = self._get_auth_chain_ids_txn(txn, event_id) results = self._get_auth_chain_ids_txn(txn, event_ids)
sql = "SELECT * FROM events WHERE event_id = ?" return self._get_events_txn(txn, results)
rows = []
for ev_id in results:
c = txn.execute(sql, (ev_id,))
rows.extend(self.cursor_to_dict(c))
return self._parse_events_txn(txn, rows) def get_auth_chain_ids(self, event_ids):
def get_auth_chain_ids(self, event_id):
return self.runInteraction( return self.runInteraction(
"get_auth_chain_ids", "get_auth_chain_ids",
self._get_auth_chain_ids_txn, self._get_auth_chain_ids_txn,
event_id event_ids
) )
def _get_auth_chain_ids_txn(self, txn, event_id): def _get_auth_chain_ids_txn(self, txn, event_ids):
results = set() results = set()
base_sql = ( base_sql = (
"SELECT auth_id FROM event_auth WHERE %s" "SELECT auth_id FROM event_auth WHERE %s"
) )
front = set([event_id]) front = set(event_ids)
while front: while front:
sql = base_sql % ( sql = base_sql % (
" OR ".join(["event_id=?"] * len(front)), " OR ".join(["event_id=?"] * len(front)),
@ -177,14 +171,15 @@ class EventFederationStore(SQLBaseStore):
retcols=["prev_event_id", "is_state"], retcols=["prev_event_id", "is_state"],
) )
hashes = self._get_prev_event_hashes_txn(txn, event_id)
results = [] results = []
for d in res: for d in res:
hashes = self._get_event_reference_hashes_txn( edge_hash = self._get_event_reference_hashes_txn(txn, d["prev_event_id"])
txn, edge_hash.update(hashes.get(d["prev_event_id"], {}))
d["prev_event_id"]
)
prev_hashes = { prev_hashes = {
k: encode_base64(v) for k, v in hashes.items() k: encode_base64(v)
for k, v in edge_hash.items()
if k == "sha256" if k == "sha256"
} }
results.append((d["prev_event_id"], prev_hashes, d["is_state"])) results.append((d["prev_event_id"], prev_hashes, d["is_state"]))

View File

@ -0,0 +1,129 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from _base import SQLBaseStore
class MediaRepositoryStore(SQLBaseStore):
"""Persistence for attachments and avatars"""
def get_default_thumbnails(self, top_level_type, sub_type):
return []
def get_local_media(self, media_id):
"""Get the metadata for a local piece of media
Returns:
None if the meia_id doesn't exist.
"""
return self._simple_select_one(
"local_media_repository",
{"media_id": media_id},
("media_type", "media_length", "upload_name", "created_ts"),
allow_none=True,
)
def store_local_media(self, media_id, media_type, time_now_ms, upload_name,
media_length, user_id):
return self._simple_insert(
"local_media_repository",
{
"media_id": media_id,
"media_type": media_type,
"created_ts": time_now_ms,
"upload_name": upload_name,
"media_length": media_length,
"user_id": user_id.to_string(),
}
)
def get_local_media_thumbnails(self, media_id):
return self._simple_select_list(
"local_media_repository_thumbnails",
{"media_id": media_id},
(
"thumbnail_width", "thumbnail_height", "thumbnail_method",
"thumbnail_type", "thumbnail_length",
)
)
def store_local_thumbnail(self, media_id, thumbnail_width,
thumbnail_height, thumbnail_type,
thumbnail_method, thumbnail_length):
return self._simple_insert(
"local_media_repository_thumbnails",
{
"media_id": media_id,
"thumbnail_width": thumbnail_width,
"thumbnail_height": thumbnail_height,
"thumbnail_method": thumbnail_method,
"thumbnail_type": thumbnail_type,
"thumbnail_length": thumbnail_length,
}
)
def get_cached_remote_media(self, origin, media_id):
return self._simple_select_one(
"remote_media_cache",
{"media_origin": origin, "media_id": media_id},
(
"media_type", "media_length", "upload_name", "created_ts",
"filesystem_id",
),
allow_none=True,
)
def store_cached_remote_media(self, origin, media_id, media_type,
media_length, time_now_ms, upload_name,
filesystem_id):
return self._simple_insert(
"remote_media_cache",
{
"media_origin": origin,
"media_id": media_id,
"media_type": media_type,
"media_length": media_length,
"created_ts": time_now_ms,
"upload_name": upload_name,
"filesystem_id": filesystem_id,
}
)
def get_remote_media_thumbnails(self, origin, media_id):
return self._simple_select_list(
"remote_media_cache_thumbnails",
{"media_origin": origin, "media_id": media_id},
(
"thumbnail_width", "thumbnail_height", "thumbnail_method",
"thumbnail_type", "thumbnail_length", "filesystem_id",
)
)
def store_remote_media_thumbnail(self, origin, media_id, filesystem_id,
thumbnail_width, thumbnail_height,
thumbnail_type, thumbnail_method,
thumbnail_length):
return self._simple_insert(
"remote_media_cache_thumbnails",
{
"media_origin": origin,
"media_id": media_id,
"thumbnail_width": thumbnail_width,
"thumbnail_height": thumbnail_height,
"thumbnail_method": thumbnail_method,
"thumbnail_type": thumbnail_type,
"thumbnail_length": thumbnail_length,
"filesystem_id": filesystem_id,
}
)

View File

@ -0,0 +1,79 @@
/* Copyright 2014 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- To track destination health
CREATE TABLE IF NOT EXISTS destinations(
destination TEXT PRIMARY KEY,
retry_last_ts INTEGER,
retry_interval INTEGER
);
CREATE TABLE IF NOT EXISTS local_media_repository (
media_id TEXT, -- The id used to refer to the media.
media_type TEXT, -- The MIME-type of the media.
media_length INTEGER, -- Length of the media in bytes.
created_ts INTEGER, -- When the content was uploaded in ms.
upload_name TEXT, -- The name the media was uploaded with.
user_id TEXT, -- The user who uploaded the file.
CONSTRAINT uniqueness UNIQUE (media_id)
);
CREATE TABLE IF NOT EXISTS local_media_repository_thumbnails (
media_id TEXT, -- The id used to refer to the media.
thumbnail_width INTEGER, -- The width of the thumbnail in pixels.
thumbnail_height INTEGER, -- The height of the thumbnail in pixels.
thumbnail_type TEXT, -- The MIME-type of the thumbnail.
thumbnail_method TEXT, -- The method used to make the thumbnail.
thumbnail_length INTEGER, -- The length of the thumbnail in bytes.
CONSTRAINT uniqueness UNIQUE (
media_id, thumbnail_width, thumbnail_height, thumbnail_type
)
);
CREATE INDEX IF NOT EXISTS local_media_repository_thumbnails_media_id
ON local_media_repository_thumbnails (media_id);
CREATE TABLE IF NOT EXISTS remote_media_cache (
media_origin TEXT, -- The remote HS the media came from.
media_id TEXT, -- The id used to refer to the media on that server.
media_type TEXT, -- The MIME-type of the media.
created_ts INTEGER, -- When the content was uploaded in ms.
upload_name TEXT, -- The name the media was uploaded with.
media_length INTEGER, -- Length of the media in bytes.
filesystem_id TEXT, -- The name used to store the media on disk.
CONSTRAINT uniqueness UNIQUE (media_origin, media_id)
);
CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails (
media_origin TEXT, -- The remote HS the media came from.
media_id TEXT, -- The id used to refer to the media.
thumbnail_width INTEGER, -- The width of the thumbnail in pixels.
thumbnail_height INTEGER, -- The height of the thumbnail in pixels.
thumbnail_method TEXT, -- The method used to make the thumbnail
thumbnail_type TEXT, -- The MIME-type of the thumbnail.
thumbnail_length INTEGER, -- The length of the thumbnail in bytes.
filesystem_id TEXT, -- The name used to store the media on disk.
CONSTRAINT uniqueness UNIQUE (
media_origin, media_id, thumbnail_width, thumbnail_height,
thumbnail_type, thumbnail_type
)
);
CREATE INDEX IF NOT EXISTS remote_media_cache_thumbnails_media_id
ON local_media_repository_thumbnails (media_id);
PRAGMA user_version = 9;

View File

@ -32,6 +32,19 @@ CREATE INDEX IF NOT EXISTS events_stream_ordering ON events (stream_ordering);
CREATE INDEX IF NOT EXISTS events_topological_ordering ON events (topological_ordering); CREATE INDEX IF NOT EXISTS events_topological_ordering ON events (topological_ordering);
CREATE INDEX IF NOT EXISTS events_room_id ON events (room_id); CREATE INDEX IF NOT EXISTS events_room_id ON events (room_id);
CREATE TABLE IF NOT EXISTS event_json(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
internal_metadata NOT NULL,
json BLOB NOT NULL,
CONSTRAINT ev_j_uniq UNIQUE (event_id)
);
CREATE INDEX IF NOT EXISTS event_json_id ON event_json(event_id);
CREATE INDEX IF NOT EXISTS event_json_room_id ON event_json(room_id);
CREATE TABLE IF NOT EXISTS state_events( CREATE TABLE IF NOT EXISTS state_events(
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
room_id TEXT NOT NULL, room_id TEXT NOT NULL,

View File

@ -0,0 +1,68 @@
/* Copyright 2014 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS local_media_repository (
media_id TEXT, -- The id used to refer to the media.
media_type TEXT, -- The MIME-type of the media.
media_length INTEGER, -- Length of the media in bytes.
created_ts INTEGER, -- When the content was uploaded in ms.
upload_name TEXT, -- The name the media was uploaded with.
user_id TEXT, -- The user who uploaded the file.
CONSTRAINT uniqueness UNIQUE (media_id)
);
CREATE TABLE IF NOT EXISTS local_media_repository_thumbnails (
media_id TEXT, -- The id used to refer to the media.
thumbnail_width INTEGER, -- The width of the thumbnail in pixels.
thumbnail_height INTEGER, -- The height of the thumbnail in pixels.
thumbnail_type TEXT, -- The MIME-type of the thumbnail.
thumbnail_method TEXT, -- The method used to make the thumbnail.
thumbnail_length INTEGER, -- The length of the thumbnail in bytes.
CONSTRAINT uniqueness UNIQUE (
media_id, thumbnail_width, thumbnail_height, thumbnail_type
)
);
CREATE INDEX IF NOT EXISTS local_media_repository_thumbnails_media_id
ON local_media_repository_thumbnails (media_id);
CREATE TABLE IF NOT EXISTS remote_media_cache (
media_origin TEXT, -- The remote HS the media came from.
media_id TEXT, -- The id used to refer to the media on that server.
media_type TEXT, -- The MIME-type of the media.
created_ts INTEGER, -- When the content was uploaded in ms.
upload_name TEXT, -- The name the media was uploaded with.
media_length INTEGER, -- Length of the media in bytes.
filesystem_id TEXT, -- The name used to store the media on disk.
CONSTRAINT uniqueness UNIQUE (media_origin, media_id)
);
CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails (
media_origin TEXT, -- The remote HS the media came from.
media_id TEXT, -- The id used to refer to the media.
thumbnail_width INTEGER, -- The width of the thumbnail in pixels.
thumbnail_height INTEGER, -- The height of the thumbnail in pixels.
thumbnail_method TEXT, -- The method used to make the thumbnail
thumbnail_type TEXT, -- The MIME-type of the thumbnail.
thumbnail_length INTEGER, -- The length of the thumbnail in bytes.
filesystem_id TEXT, -- The name used to store the media on disk.
CONSTRAINT uniqueness UNIQUE (
media_origin, media_id, thumbnail_width, thumbnail_height,
thumbnail_type, thumbnail_type
)
);
CREATE INDEX IF NOT EXISTS remote_media_cache_thumbnails_media_id
ON local_media_repository_thumbnails (media_id);

View File

@ -29,7 +29,8 @@ CREATE TABLE IF NOT EXISTS state_groups_state(
CREATE TABLE IF NOT EXISTS event_to_state_groups( CREATE TABLE IF NOT EXISTS event_to_state_groups(
event_id TEXT NOT NULL, event_id TEXT NOT NULL,
state_group INTEGER NOT NULL state_group INTEGER NOT NULL,
CONSTRAINT event_to_state_groups_uniq UNIQUE (event_id)
); );
CREATE INDEX IF NOT EXISTS state_groups_id ON state_groups(id); CREATE INDEX IF NOT EXISTS state_groups_id ON state_groups(id);

View File

@ -59,3 +59,9 @@ CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_tx ON transaction_id_to_pdu(tra
CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_dest ON transaction_id_to_pdu(destination); CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_dest ON transaction_id_to_pdu(destination);
CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_index ON transaction_id_to_pdu(transaction_id, destination); CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_index ON transaction_id_to_pdu(transaction_id, destination);
-- To track destination health
CREATE TABLE IF NOT EXISTS destinations(
destination TEXT PRIMARY KEY,
retry_last_ts INTEGER,
retry_interval INTEGER
);

View File

@ -13,8 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from _base import SQLBaseStore from _base import SQLBaseStore
from syutil.base64util import encode_base64
class SignatureStore(SQLBaseStore): class SignatureStore(SQLBaseStore):
"""Persistence for event signatures and hashes""" """Persistence for event signatures and hashes"""
@ -67,6 +71,21 @@ class SignatureStore(SQLBaseStore):
f f
) )
@defer.inlineCallbacks
def add_event_hashes(self, event_ids):
hashes = yield self.get_event_reference_hashes(
event_ids
)
hashes = [
{
k: encode_base64(v) for k, v in h.items()
if k == "sha256"
}
for h in hashes
]
defer.returnValue(zip(event_ids, hashes))
def _get_event_reference_hashes_txn(self, txn, event_id): def _get_event_reference_hashes_txn(self, txn, event_id):
"""Get all the hashes for a given PDU. """Get all the hashes for a given PDU.
Args: Args:

View File

@ -86,11 +86,16 @@ class StateStore(SQLBaseStore):
self._store_state_groups_txn, event self._store_state_groups_txn, event
) )
def _store_state_groups_txn(self, txn, event): def _store_state_groups_txn(self, txn, event, context):
if event.state_events is None: if context.current_state is None:
return return
state_group = event.state_group state_events = context.current_state
if event.is_state():
state_events[(event.type, event.state_key)] = event
state_group = context.state_group
if not state_group: if not state_group:
state_group = self._simple_insert_txn( state_group = self._simple_insert_txn(
txn, txn,
@ -102,7 +107,7 @@ class StateStore(SQLBaseStore):
or_ignore=True, or_ignore=True,
) )
for state in event.state_events.values(): for state in state_events.values():
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
table="state_groups_state", table="state_groups_state",

View File

@ -17,6 +17,8 @@ from ._base import SQLBaseStore, Table
from collections import namedtuple from collections import namedtuple
from twisted.internet import defer
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -26,6 +28,10 @@ class TransactionStore(SQLBaseStore):
"""A collection of queries for handling PDUs. """A collection of queries for handling PDUs.
""" """
# a write-through cache of DestinationsTable.EntryType indexed by
# destination string
destination_retry_cache = {}
def get_received_txn_response(self, transaction_id, origin): def get_received_txn_response(self, transaction_id, origin):
"""For an incoming transaction from a given origin, check if we have """For an incoming transaction from a given origin, check if we have
already responded to it. If so, return the response code and response already responded to it. If so, return the response code and response
@ -114,7 +120,7 @@ class TransactionStore(SQLBaseStore):
def _prep_send_transaction(self, txn, transaction_id, destination, def _prep_send_transaction(self, txn, transaction_id, destination,
origin_server_ts): origin_server_ts):
# First we find out what the prev_txs should be. # First we find out what the prev_txns should be.
# Since we know that we are only sending one transaction at a time, # Since we know that we are only sending one transaction at a time,
# we can simply take the last one. # we can simply take the last one.
query = "%s ORDER BY id DESC LIMIT 1" % ( query = "%s ORDER BY id DESC LIMIT 1" % (
@ -205,6 +211,92 @@ class TransactionStore(SQLBaseStore):
return ReceivedTransactionsTable.decode_results(txn.fetchall()) return ReceivedTransactionsTable.decode_results(txn.fetchall())
def get_destination_retry_timings(self, destination):
"""Gets the current retry timings (if any) for a given destination.
Args:
destination (str)
Returns:
None if not retrying
Otherwise a DestinationsTable.EntryType for the retry scheme
"""
if destination in self.destination_retry_cache:
return defer.succeed(self.destination_retry_cache[destination])
return self.runInteraction(
"get_destination_retry_timings",
self._get_destination_retry_timings, destination)
def _get_destination_retry_timings(cls, txn, destination):
query = DestinationsTable.select_statement("destination = ?")
txn.execute(query, (destination,))
result = txn.fetchall()
if result:
result = DestinationsTable.decode_single_result(result)
if result.retry_last_ts > 0:
return result
else:
return None
def set_destination_retry_timings(self, destination,
retry_last_ts, retry_interval):
"""Sets the current retry timings for a given destination.
Both timings should be zero if retrying is no longer occuring.
Args:
destination (str)
retry_last_ts (int) - time of last retry attempt in unix epoch ms
retry_interval (int) - how long until next retry in ms
"""
self.destination_retry_cache[destination] = (
DestinationsTable.EntryType(
destination,
retry_last_ts,
retry_interval
)
)
# XXX: we could chose to not bother persisting this if our cache thinks
# this is a NOOP
return self.runInteraction(
"set_destination_retry_timings",
self._set_destination_retry_timings,
destination,
retry_last_ts,
retry_interval,
)
def _set_destination_retry_timings(cls, txn, destination,
retry_last_ts, retry_interval):
query = (
"INSERT OR REPLACE INTO %s "
"(destination, retry_last_ts, retry_interval) "
"VALUES (?, ?, ?) "
) % DestinationsTable.table_name
txn.execute(query, (destination, retry_last_ts, retry_interval))
def get_destinations_needing_retry(self):
"""Get all destinations which are due a retry for sending a transaction.
Returns:
list: A list of `DestinationsTable.EntryType`
"""
return self.runInteraction(
"get_destinations_needing_retry",
self._get_destinations_needing_retry
)
def _get_destinations_needing_retry(cls, txn):
where = "retry_last_ts > 0 and retry_next_ts < now()"
query = DestinationsTable.select_statement(where)
txn.execute(query)
return DestinationsTable.decode_results(txn.fetchall())
class ReceivedTransactionsTable(Table): class ReceivedTransactionsTable(Table):
table_name = "received_transactions" table_name = "received_transactions"
@ -247,3 +339,15 @@ class TransactionsToPduTable(Table):
] ]
EntryType = namedtuple("TransactionsToPduEntry", fields) EntryType = namedtuple("TransactionsToPduEntry", fields)
class DestinationsTable(Table):
table_name = "destinations"
fields = [
"destination",
"retry_last_ts",
"retry_interval",
]
EntryType = namedtuple("DestinationsEntry", fields)

View File

@ -19,7 +19,7 @@ from collections import namedtuple
class DomainSpecificString( class DomainSpecificString(
namedtuple("DomainSpecificString", ("localpart", "domain", "is_mine")) namedtuple("DomainSpecificString", ("localpart", "domain"))
): ):
"""Common base class among ID/name strings that have a local part and a """Common base class among ID/name strings that have a local part and a
domain name, prefixed with a sigil. domain name, prefixed with a sigil.
@ -28,15 +28,13 @@ class DomainSpecificString(
'localpart' : The local part of the name (without the leading sigil) 'localpart' : The local part of the name (without the leading sigil)
'domain' : The domain part of the name 'domain' : The domain part of the name
'is_mine' : Boolean indicating if the domain name is recognised by the
HomeServer as being its own
""" """
# Deny iteration because it will bite you if you try to create a singleton # Deny iteration because it will bite you if you try to create a singleton
# set by: # set by:
# users = set(user) # users = set(user)
def __iter__(self): def __iter__(self):
raise ValueError("Attempted to iterate a %s" % (type(self).__name__)) raise ValueError("Attempted to iterate a %s" % (type(self).__name__,))
# Because this class is a namedtuple of strings and booleans, it is deeply # Because this class is a namedtuple of strings and booleans, it is deeply
# immutable. # immutable.
@ -47,7 +45,7 @@ class DomainSpecificString(
return self return self
@classmethod @classmethod
def from_string(cls, s, hs): def from_string(cls, s):
"""Parse the string given by 's' into a structure object.""" """Parse the string given by 's' into a structure object."""
if s[0] != cls.SIGIL: if s[0] != cls.SIGIL:
raise SynapseError(400, "Expected %s string to start with '%s'" % ( raise SynapseError(400, "Expected %s string to start with '%s'" % (
@ -66,22 +64,15 @@ class DomainSpecificString(
# This code will need changing if we want to support multiple domain # This code will need changing if we want to support multiple domain
# names on one HS # names on one HS
is_mine = domain == hs.hostname return cls(localpart=parts[0], domain=domain)
return cls(localpart=parts[0], domain=domain, is_mine=is_mine)
def to_string(self): def to_string(self):
"""Return a string encoding the fields of the structure object.""" """Return a string encoding the fields of the structure object."""
return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain) return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain)
@classmethod @classmethod
def create_local(cls, localpart, hs): def create(cls, localpart, domain,):
"""Create a structure on the local domain""" return cls(localpart=localpart, domain=domain)
return cls(localpart=localpart, domain=hs.hostname, is_mine=True)
@classmethod
def create(cls, localpart, domain, hs):
is_mine = domain == hs.hostname
return cls(localpart=localpart, domain=domain, is_mine=is_mine)
class UserID(DomainSpecificString): class UserID(DomainSpecificString):

View File

@ -115,10 +115,10 @@ class Signal(object):
failure.value, failure.value,
failure.getTracebackObject())) failure.getTracebackObject()))
if not self.suppress_failures: if not self.suppress_failures:
raise failure failure.raiseException()
deferreds.append(d.addErrback(eb)) deferreds.append(d.addErrback(eb))
results = []
result = yield defer.DeferredList( for deferred in deferreds:
deferreds, fireOnOneErrback=not self.suppress_failures result = yield deferred
) results.append(result)
defer.returnValue(result) defer.returnValue(results)

View File

@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from frozendict import frozendict
def freeze(o):
if isinstance(o, dict) or isinstance(o, frozendict):
return frozendict({k: freeze(v) for k, v in o.items()})
if isinstance(o, basestring):
return o
try:
return tuple([freeze(i) for i in o])
except TypeError:
pass
return o
def unfreeze(o):
if isinstance(o, frozendict) or isinstance(o, dict):
return dict({k: unfreeze(v) for k, v in o.items()})
if isinstance(o, basestring):
return o
try:
return [unfreeze(i) for i in o]
except TypeError:
pass
return o

View File

@ -1,217 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.events import SynapseEvent
from synapse.api.events.validator import EventValidator
from synapse.api.errors import SynapseError
from tests import unittest
class SynapseTemplateCheckTestCase(unittest.TestCase):
def setUp(self):
self.validator = EventValidator(None)
def tearDown(self):
pass
def test_top_level_keys(self):
template = {
"person": {},
"friends": ["string"]
}
content = {
"person": {"name": "bob"},
"friends": ["jill", "mike"]
}
event = MockSynapseEvent(template)
event.content = content
self.assertTrue(self.validator.validate(event))
content = {
"person": {"name": "bob"},
"friends": ["jill"],
"enemies": ["mike"]
}
event.content = content
self.assertTrue(self.validator.validate(event))
content = {
"person": {"name": "bob"},
# missing friends
"enemies": ["mike", "jill"]
}
event.content = content
self.assertRaises(
SynapseError,
self.validator.validate,
event
)
def test_lists(self):
template = {
"person": {},
"friends": [{"name":"string"}]
}
content = {
"person": {"name": "bob"},
"friends": ["jill", "mike"] # should be in objects
}
event = MockSynapseEvent(template)
event.content = content
self.assertRaises(
SynapseError,
self.validator.validate,
event
)
content = {
"person": {"name": "bob"},
"friends": [{"name": "jill"}, {"name": "mike"}]
}
event.content = content
self.assertTrue(self.validator.validate(event))
def test_nested_lists(self):
template = {
"results": {
"families": [
{
"name": "string",
"members": [
{}
]
}
]
}
}
content = {
"results": {
"families": [
{
"name": "Smith",
"members": [
"Alice", "Bob" # wrong types
]
}
]
}
}
event = MockSynapseEvent(template)
event.content = content
self.assertRaises(
SynapseError,
self.validator.validate,
event
)
content = {
"results": {
"families": [
{
"name": "Smith",
"members": [
{"name": "Alice"}, {"name": "Bob"}
]
}
]
}
}
event.content = content
self.assertTrue(self.validator.validate(event))
def test_nested_keys(self):
template = {
"person": {
"attributes": {
"hair": "string",
"eye": "string"
},
"age": 0,
"fav_books": ["string"]
}
}
event = MockSynapseEvent(template)
content = {
"person": {
"attributes": {
"hair": "brown",
"eye": "green",
"skin": "purple"
},
"age": 33,
"fav_books": ["lotr", "hobbit"],
"fav_music": ["abba", "beatles"]
}
}
event.content = content
self.assertTrue(self.validator.validate(event))
content = {
"person": {
"attributes": {
"hair": "brown"
# missing eye
},
"age": 33,
"fav_books": ["lotr", "hobbit"],
"fav_music": ["abba", "beatles"]
}
}
event.content = content
self.assertRaises(
SynapseError,
self.validator.validate,
event
)
content = {
"person": {
"attributes": {
"hair": "brown",
"eye": "green",
"skin": "purple"
},
"age": 33,
"fav_books": "nothing", # should be a list
}
}
event.content = content
self.assertRaises(
SynapseError,
self.validator.validate,
event
)
class MockSynapseEvent(SynapseEvent):
def __init__(self, template):
self.template = template
def get_content_template(self):
return self.template

View File

@ -23,24 +23,20 @@ from ..utils import MockHttpResource, MockClock, MockKey
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.federation import initialize_http_replication from synapse.federation import initialize_http_replication
from synapse.api.events import SynapseEvent from synapse.events import FrozenEvent
from synapse.storage.transactions import DestinationsTable
def make_pdu(prev_pdus=[], **kwargs): def make_pdu(prev_pdus=[], **kwargs):
"""Provide some default fields for making a PduTuple.""" """Provide some default fields for making a PduTuple."""
pdu_fields = { pdu_fields = {
"is_state": False,
"unrecognized_keys": [],
"outlier": False,
"have_processed": True,
"state_key": None, "state_key": None,
"power_level": None, "prev_events": prev_pdus,
"prev_state_id": None,
"prev_state_origin": None,
} }
pdu_fields.update(kwargs) pdu_fields.update(kwargs)
return SynapseEvent(prev_pdus=prev_pdus, **pdu_fields) return FrozenEvent(pdu_fields)
class FederationTestCase(unittest.TestCase): class FederationTestCase(unittest.TestCase):
@ -55,10 +51,16 @@ class FederationTestCase(unittest.TestCase):
"delivered_txn", "delivered_txn",
"get_received_txn_response", "get_received_txn_response",
"set_received_txn_response", "set_received_txn_response",
"get_destination_retry_timings",
"get_auth_chain",
]) ])
self.mock_persistence.get_received_txn_response.return_value = ( self.mock_persistence.get_received_txn_response.return_value = (
defer.succeed(None) defer.succeed(None)
) )
self.mock_persistence.get_destination_retry_timings.return_value = (
defer.succeed(DestinationsTable.EntryType("", 0, 0))
)
self.mock_persistence.get_auth_chain.return_value = []
self.mock_config = Mock() self.mock_config = Mock()
self.mock_config.signing_key = [MockKey()] self.mock_config.signing_key = [MockKey()]
self.clock = MockClock() self.clock = MockClock()
@ -171,7 +173,7 @@ class FederationTestCase(unittest.TestCase):
(200, "OK") (200, "OK")
) )
pdu = SynapseEvent( pdu = make_pdu(
event_id="abc123def456", event_id="abc123def456",
origin="red", origin="red",
user_id="@a:red", user_id="@a:red",
@ -180,10 +182,9 @@ class FederationTestCase(unittest.TestCase):
origin_server_ts=123456789001, origin_server_ts=123456789001,
depth=1, depth=1,
content={"text": "Here is the message"}, content={"text": "Here is the message"},
destinations=["remote"],
) )
yield self.federation.send_pdu(pdu) yield self.federation.send_pdu(pdu, ["remote"])
self.mock_http_client.put_json.assert_called_with( self.mock_http_client.put_json.assert_called_with(
"remote", "remote",

View File

@ -16,11 +16,8 @@
from twisted.internet import defer from twisted.internet import defer
from tests import unittest from tests import unittest
from synapse.api.events.room import ( from synapse.api.constants import EventTypes
MessageEvent, from synapse.events import FrozenEvent
)
from synapse.api.events import SynapseEvent
from synapse.handlers.federation import FederationHandler from synapse.handlers.federation import FederationHandler
from synapse.server import HomeServer from synapse.server import HomeServer
@ -37,7 +34,7 @@ class FederationTestCase(unittest.TestCase):
self.mock_config.signing_key = [MockKey()] self.mock_config.signing_key = [MockKey()]
self.state_handler = NonCallableMock(spec_set=[ self.state_handler = NonCallableMock(spec_set=[
"annotate_event_with_state", "compute_event_context",
]) ])
self.auth = NonCallableMock(spec_set=[ self.auth = NonCallableMock(spec_set=[
@ -53,6 +50,8 @@ class FederationTestCase(unittest.TestCase):
"persist_event", "persist_event",
"store_room", "store_room",
"get_room", "get_room",
"get_destination_retry_timings",
"set_destination_retry_timings",
]), ]),
resource_for_federation=NonCallableMock(), resource_for_federation=NonCallableMock(),
http_client=NonCallableMock(spec_set=[]), http_client=NonCallableMock(spec_set=[]),
@ -76,43 +75,47 @@ class FederationTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_msg(self): def test_msg(self):
pdu = SynapseEvent( pdu = FrozenEvent({
type=MessageEvent.TYPE, "type": EventTypes.Message,
room_id="foo", "room_id": "foo",
content={"msgtype": u"fooo"}, "content": {"msgtype": u"fooo"},
origin_server_ts=0, "origin_server_ts": 0,
event_id="$a:b", "event_id": "$a:b",
user_id="@a:b", "user_id":"@a:b",
origin="b", "origin": "b",
auth_events=[], "auth_events": [],
hashes={"sha256":"AcLrgtUIqqwaGoHhrEvYG1YLDIsVPYJdSRGhkp3jJp8"}, "hashes": {"sha256":"AcLrgtUIqqwaGoHhrEvYG1YLDIsVPYJdSRGhkp3jJp8"},
) })
self.datastore.persist_event.return_value = defer.succeed(None) self.datastore.persist_event.return_value = defer.succeed(None)
self.datastore.get_room.return_value = defer.succeed(True) self.datastore.get_room.return_value = defer.succeed(True)
self.auth.check_host_in_room.return_value = defer.succeed(True) self.auth.check_host_in_room.return_value = defer.succeed(True)
def annotate(ev, old_state=None): def annotate(ev, old_state=None):
ev.old_state_events = [] context = Mock()
return defer.succeed(False) context.current_state = {}
self.state_handler.annotate_event_with_state.side_effect = annotate context.auth_events = {}
return defer.succeed(context)
self.state_handler.compute_event_context.side_effect = annotate
yield self.handlers.federation_handler.on_receive_pdu( yield self.handlers.federation_handler.on_receive_pdu(
"fo", pdu, False "fo", pdu, False
) )
self.datastore.persist_event.assert_called_once_with( self.datastore.persist_event.assert_called_once_with(
ANY, is_new_state=False, backfilled=False, current_state=None ANY,
is_new_state=True,
backfilled=False,
current_state=None,
context=ANY,
) )
self.state_handler.annotate_event_with_state.assert_called_once_with( self.state_handler.compute_event_context.assert_called_once_with(
ANY, ANY, old_state=None,
old_state=None,
) )
self.auth.check.assert_called_once_with(ANY, auth_events={}) self.auth.check.assert_called_once_with(ANY, auth_events={})
self.notifier.on_new_room_event.assert_called_once_with( self.notifier.on_new_room_event.assert_called_once_with(
ANY, ANY, extra_users=[]
extra_users=[]
) )

View File

@ -30,7 +30,7 @@ from synapse.api.constants import PresenceState
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.handlers.presence import PresenceHandler, UserPresenceCache from synapse.handlers.presence import PresenceHandler, UserPresenceCache
from synapse.streams.config import SourcePaginationConfig from synapse.streams.config import SourcePaginationConfig
from synapse.storage.transactions import DestinationsTable
OFFLINE = PresenceState.OFFLINE OFFLINE = PresenceState.OFFLINE
UNAVAILABLE = PresenceState.UNAVAILABLE UNAVAILABLE = PresenceState.UNAVAILABLE
@ -528,6 +528,7 @@ class PresencePushTestCase(unittest.TestCase):
"delivered_txn", "delivered_txn",
"get_received_txn_response", "get_received_txn_response",
"set_received_txn_response", "set_received_txn_response",
"get_destination_retry_timings",
]), ]),
handlers=None, handlers=None,
resource_for_client=Mock(), resource_for_client=Mock(),
@ -539,6 +540,9 @@ class PresencePushTestCase(unittest.TestCase):
hs.handlers = JustPresenceHandlers(hs) hs.handlers = JustPresenceHandlers(hs)
self.datastore = hs.get_datastore() self.datastore = hs.get_datastore()
self.datastore.get_destination_retry_timings.return_value = (
defer.succeed(DestinationsTable.EntryType("", 0, 0))
)
def get_received_txn_response(*args): def get_received_txn_response(*args):
return defer.succeed(None) return defer.succeed(None)
@ -1037,6 +1041,7 @@ class PresencePollingTestCase(unittest.TestCase):
"delivered_txn", "delivered_txn",
"get_received_txn_response", "get_received_txn_response",
"set_received_txn_response", "set_received_txn_response",
"get_destination_retry_timings",
]), ]),
handlers=None, handlers=None,
resource_for_client=Mock(), resource_for_client=Mock(),
@ -1048,6 +1053,9 @@ class PresencePollingTestCase(unittest.TestCase):
hs.handlers = JustPresenceHandlers(hs) hs.handlers = JustPresenceHandlers(hs)
self.datastore = hs.get_datastore() self.datastore = hs.get_datastore()
self.datastore.get_destination_retry_timings.return_value = (
defer.succeed(DestinationsTable.EntryType("", 0, 0))
)
def get_received_txn_response(*args): def get_received_txn_response(*args):
return defer.succeed(None) return defer.succeed(None)

View File

@ -17,10 +17,7 @@
from twisted.internet import defer from twisted.internet import defer
from tests import unittest from tests import unittest
from synapse.api.events.room import ( from synapse.api.constants import EventTypes, Membership
RoomMemberEvent,
)
from synapse.api.constants import Membership
from synapse.handlers.room import RoomMemberHandler, RoomCreationHandler from synapse.handlers.room import RoomMemberHandler, RoomCreationHandler
from synapse.handlers.profile import ProfileHandler from synapse.handlers.profile import ProfileHandler
from synapse.server import HomeServer from synapse.server import HomeServer
@ -47,7 +44,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
"get_room_member", "get_room_member",
"get_room", "get_room",
"store_room", "store_room",
"snapshot_room", "get_latest_events_in_room",
]), ]),
resource_for_federation=NonCallableMock(), resource_for_federation=NonCallableMock(),
http_client=NonCallableMock(spec_set=[]), http_client=NonCallableMock(spec_set=[]),
@ -63,7 +60,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
"check_host_in_room", "check_host_in_room",
]), ]),
state_handler=NonCallableMock(spec_set=[ state_handler=NonCallableMock(spec_set=[
"annotate_event_with_state", "compute_event_context",
"get_current_state", "get_current_state",
]), ]),
config=self.mock_config, config=self.mock_config,
@ -91,9 +88,6 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
self.handlers.profile_handler = ProfileHandler(self.hs) self.handlers.profile_handler = ProfileHandler(self.hs)
self.room_member_handler = self.handlers.room_member_handler self.room_member_handler = self.handlers.room_member_handler
self.snapshot = Mock()
self.datastore.snapshot_room.return_value = self.snapshot
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0) self.ratelimiter.send_message.return_value = (True, 0)
@ -104,50 +98,70 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
target_user_id = "@red:blue" target_user_id = "@red:blue"
content = {"membership": Membership.INVITE} content = {"membership": Membership.INVITE}
event = self.hs.get_event_factory().create_event( builder = self.hs.get_event_builder_factory().new({
etype=RoomMemberEvent.TYPE, "type": EventTypes.Member,
user_id=user_id, "sender": user_id,
state_key=target_user_id, "state_key": target_user_id,
room_id=room_id, "room_id": room_id,
membership=Membership.INVITE, "content": content,
content=content, })
self.datastore.get_latest_events_in_room.return_value = (
defer.succeed([])
) )
self.auth.check_host_in_room.return_value = defer.succeed(True) def annotate(_):
ctx = Mock()
ctx.current_state = {
(EventTypes.Member, "@alice:green"): self._create_member(
user_id="@alice:green",
room_id=room_id,
),
(EventTypes.Member, "@bob:red"): self._create_member(
user_id="@bob:red",
room_id=room_id,
),
}
ctx.prev_state_events = []
store_id = "store_id_fooo" return defer.succeed(ctx)
self.datastore.persist_event.return_value = defer.succeed(store_id)
self.datastore.get_room_member.return_value = defer.succeed(None) self.state_handler.compute_event_context.side_effect = annotate
event.old_state_events = { def add_auth(_, ctx):
(RoomMemberEvent.TYPE, "@alice:green"): self._create_member( ctx.auth_events = ctx.current_state[
user_id="@alice:green", (EventTypes.Member, "@bob:red")
room_id=room_id, ]
),
(RoomMemberEvent.TYPE, "@bob:red"): self._create_member(
user_id="@bob:red",
room_id=room_id,
),
}
event.state_events = event.old_state_events return defer.succeed(True)
event.state_events[(RoomMemberEvent.TYPE, target_user_id)] = event self.auth.add_auth_events.side_effect = add_auth
# Actual invocation def send_invite(domain, event):
yield self.room_member_handler.change_membership(event) return defer.succeed(event)
self.federation.handle_new_event.assert_called_once_with( self.federation.send_invite.side_effect = send_invite
event, self.snapshot,
room_handler = self.room_member_handler
event, context = yield room_handler._create_new_client_event(
builder
) )
self.assertEquals( yield room_handler.change_membership(event, context)
set(["red", "green"]),
set(event.destinations) self.state_handler.compute_event_context.assert_called_once_with(
builder
)
self.auth.add_auth_events.assert_called_once_with(
builder, context
)
self.federation.send_invite.assert_called_once_with(
"blue", event,
) )
self.datastore.persist_event.assert_called_once_with( self.datastore.persist_event.assert_called_once_with(
event event, context=context,
) )
self.notifier.on_new_room_event.assert_called_once_with( self.notifier.on_new_room_event.assert_called_once_with(
event, extra_users=[self.hs.parse_userid(target_user_id)] event, extra_users=[self.hs.parse_userid(target_user_id)]
@ -162,57 +176,58 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
user_id = "@bob:red" user_id = "@bob:red"
user = self.hs.parse_userid(user_id) user = self.hs.parse_userid(user_id)
event = self._create_member(
user_id=user_id,
room_id=room_id,
)
self.auth.check_host_in_room.return_value = defer.succeed(True)
store_id = "store_id_fooo"
self.datastore.persist_event.return_value = defer.succeed(store_id)
self.datastore.get_room.return_value = defer.succeed(1) # Not None.
prev_state = NonCallableMock()
prev_state.membership = Membership.INVITE
prev_state.sender = "@foo:red"
self.datastore.get_room_member.return_value = defer.succeed(prev_state)
join_signal_observer = Mock() join_signal_observer = Mock()
self.distributor.observe("user_joined_room", join_signal_observer) self.distributor.observe("user_joined_room", join_signal_observer)
event.state_events = { builder = self.hs.get_event_builder_factory().new({
(RoomMemberEvent.TYPE, "@alice:green"): self._create_member( "type": EventTypes.Member,
user_id="@alice:green", "sender": user_id,
room_id=room_id, "state_key": user_id,
), "room_id": room_id,
(RoomMemberEvent.TYPE, user_id): event, "content": {"membership": Membership.JOIN},
} })
event.old_state_events = { self.datastore.get_latest_events_in_room.return_value = (
(RoomMemberEvent.TYPE, "@alice:green"): self._create_member( defer.succeed([])
user_id="@alice:green",
room_id=room_id,
),
}
event.state_events = event.old_state_events
event.state_events[(RoomMemberEvent.TYPE, user_id)] = event
# Actual invocation
yield self.room_member_handler.change_membership(event)
self.federation.handle_new_event.assert_called_once_with(
event, self.snapshot
) )
self.assertEquals( def annotate(_):
set(["red", "green"]), ctx = Mock()
set(event.destinations) ctx.current_state = {
(EventTypes.Member, "@bob:red"): self._create_member(
user_id="@bob:red",
room_id=room_id,
membership=Membership.INVITE
),
}
ctx.prev_state_events = []
return defer.succeed(ctx)
self.state_handler.compute_event_context.side_effect = annotate
def add_auth(_, ctx):
ctx.auth_events = ctx.current_state[
(EventTypes.Member, "@bob:red")
]
return defer.succeed(True)
self.auth.add_auth_events.side_effect = add_auth
room_handler = self.room_member_handler
event, context = yield room_handler._create_new_client_event(
builder
)
# Actual invocation
yield room_handler.change_membership(event, context)
self.federation.handle_new_event.assert_called_once_with(
event, None, destinations=set()
) )
self.datastore.persist_event.assert_called_once_with( self.datastore.persist_event.assert_called_once_with(
event event, context=context
) )
self.notifier.on_new_room_event.assert_called_once_with( self.notifier.on_new_room_event.assert_called_once_with(
event, extra_users=[user] event, extra_users=[user]
@ -222,14 +237,82 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
user=user, room_id=room_id user=user, room_id=room_id
) )
def _create_member(self, user_id, room_id): def _create_member(self, user_id, room_id, membership=Membership.JOIN):
return self.hs.get_event_factory().create_event( builder = self.hs.get_event_builder_factory().new({
etype=RoomMemberEvent.TYPE, "type": EventTypes.Member,
user_id=user_id, "sender": user_id,
state_key=user_id, "state_key": user_id,
room_id=room_id, "room_id": room_id,
membership=Membership.JOIN, "content": {"membership": membership},
content={"membership": Membership.JOIN}, })
return builder.build()
@defer.inlineCallbacks
def test_simple_leave(self):
room_id = "!foo:red"
user_id = "@bob:red"
user = self.hs.parse_userid(user_id)
builder = self.hs.get_event_builder_factory().new({
"type": EventTypes.Member,
"sender": user_id,
"state_key": user_id,
"room_id": room_id,
"content": {"membership": Membership.LEAVE},
})
self.datastore.get_latest_events_in_room.return_value = (
defer.succeed([])
)
def annotate(_):
ctx = Mock()
ctx.current_state = {
(EventTypes.Member, "@bob:red"): self._create_member(
user_id="@bob:red",
room_id=room_id,
membership=Membership.JOIN
),
}
ctx.prev_state_events = []
return defer.succeed(ctx)
self.state_handler.compute_event_context.side_effect = annotate
def add_auth(_, ctx):
ctx.auth_events = ctx.current_state[
(EventTypes.Member, "@bob:red")
]
return defer.succeed(True)
self.auth.add_auth_events.side_effect = add_auth
room_handler = self.room_member_handler
event, context = yield room_handler._create_new_client_event(
builder
)
leave_signal_observer = Mock()
self.distributor.observe("user_left_room", leave_signal_observer)
# Actual invocation
yield room_handler.change_membership(event, context)
self.federation.handle_new_event.assert_called_once_with(
event, None, destinations=set(['red'])
)
self.datastore.persist_event.assert_called_once_with(
event, context=context
)
self.notifier.on_new_room_event.assert_called_once_with(
event, extra_users=[user]
)
leave_signal_observer.assert_called_with(
user=user, room_id=room_id
) )
@ -254,13 +337,9 @@ class RoomCreationTest(unittest.TestCase):
notifier=NonCallableMock(spec_set=["on_new_room_event"]), notifier=NonCallableMock(spec_set=["on_new_room_event"]),
handlers=NonCallableMock(spec_set=[ handlers=NonCallableMock(spec_set=[
"room_creation_handler", "room_creation_handler",
"room_member_handler", "message_handler",
"federation_handler",
]), ]),
auth=NonCallableMock(spec_set=["check", "add_auth_events"]), auth=NonCallableMock(spec_set=["check", "add_auth_events"]),
state_handler=NonCallableMock(spec_set=[
"annotate_event_with_state",
]),
ratelimiter=NonCallableMock(spec_set=[ ratelimiter=NonCallableMock(spec_set=[
"send_message", "send_message",
]), ]),
@ -271,30 +350,12 @@ class RoomCreationTest(unittest.TestCase):
"handle_new_event", "handle_new_event",
]) ])
self.datastore = hs.get_datastore()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.notifier = hs.get_notifier()
self.state_handler = hs.get_state_handler()
self.hs = hs
self.handlers.federation_handler = self.federation self.handlers.room_creation_handler = RoomCreationHandler(hs)
self.handlers.room_creation_handler = RoomCreationHandler(self.hs)
self.room_creation_handler = self.handlers.room_creation_handler self.room_creation_handler = self.handlers.room_creation_handler
self.handlers.room_member_handler = NonCallableMock(spec_set=[ self.message_handler = self.handlers.message_handler
"change_membership"
])
self.room_member_handler = self.handlers.room_member_handler
def annotate(event):
event.state_events = {}
return defer.succeed(None)
self.state_handler.annotate_event_with_state.side_effect = annotate
def hosts(room):
return defer.succeed([])
self.datastore.get_joined_hosts_for_room.side_effect = hosts
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0) self.ratelimiter.send_message.return_value = (True, 0)
@ -311,14 +372,37 @@ class RoomCreationTest(unittest.TestCase):
config=config, config=config,
) )
self.assertTrue(self.room_member_handler.change_membership.called) self.assertTrue(self.message_handler.create_and_send_event.called)
join_event = self.room_member_handler.change_membership.call_args[0][0]
self.assertEquals(RoomMemberEvent.TYPE, join_event.type) event_dicts = [
self.assertEquals(room_id, join_event.room_id) e[0][0]
self.assertEquals(user_id, join_event.user_id) for e in self.message_handler.create_and_send_event.call_args_list
self.assertEquals(user_id, join_event.state_key) ]
self.assertTrue(self.state_handler.annotate_event_with_state.called) self.assertTrue(len(event_dicts) > 3)
self.assertTrue(self.federation.handle_new_event.called) self.assertDictContainsSubset(
{
"type": EventTypes.Create,
"sender": user_id,
"room_id": room_id,
},
event_dicts[0]
)
self.assertEqual(user_id, event_dicts[0]["content"]["creator"])
self.assertDictContainsSubset(
{
"type": EventTypes.Member,
"sender": user_id,
"room_id": room_id,
"state_key": user_id,
},
event_dicts[1]
)
self.assertEqual(
Membership.JOIN,
event_dicts[1]["content"]["membership"]
)

View File

@ -22,9 +22,12 @@ import json
from ..utils import MockHttpResource, MockClock, DeferredMockCallable, MockKey from ..utils import MockHttpResource, MockClock, DeferredMockCallable, MockKey
from synapse.api.errors import AuthError
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.handlers.typing import TypingNotificationHandler from synapse.handlers.typing import TypingNotificationHandler
from synapse.storage.transactions import DestinationsTable
def _expect_edu(destination, edu_type, content, origin="test"): def _expect_edu(destination, edu_type, content, origin="test"):
return { return {
@ -63,7 +66,13 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.mock_config = Mock() self.mock_config = Mock()
self.mock_config.signing_key = [MockKey()] self.mock_config.signing_key = [MockKey()]
mock_notifier = Mock(spec=["on_new_user_event"])
self.on_new_user_event = mock_notifier.on_new_user_event
self.auth = Mock(spec=[])
hs = HomeServer("test", hs = HomeServer("test",
auth=self.auth,
clock=self.clock, clock=self.clock,
db_pool=None, db_pool=None,
datastore=Mock(spec=[ datastore=Mock(spec=[
@ -72,8 +81,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
"delivered_txn", "delivered_txn",
"get_received_txn_response", "get_received_txn_response",
"set_received_txn_response", "set_received_txn_response",
"get_destination_retry_timings",
]), ]),
handlers=None, handlers=None,
notifier=mock_notifier,
resource_for_client=Mock(), resource_for_client=Mock(),
resource_for_federation=self.mock_federation_resource, resource_for_federation=self.mock_federation_resource,
http_client=self.mock_http_client, http_client=self.mock_http_client,
@ -82,13 +93,14 @@ class TypingNotificationsTestCase(unittest.TestCase):
) )
hs.handlers = JustTypingNotificationHandlers(hs) hs.handlers = JustTypingNotificationHandlers(hs)
self.mock_update_client = Mock()
self.mock_update_client.return_value = defer.succeed(None)
self.handler = hs.get_handlers().typing_notification_handler self.handler = hs.get_handlers().typing_notification_handler
self.handler.push_update_to_clients = self.mock_update_client
self.event_source = hs.get_event_sources().sources["typing"]
self.datastore = hs.get_datastore() self.datastore = hs.get_datastore()
self.datastore.get_destination_retry_timings.return_value = (
defer.succeed(DestinationsTable.EntryType("", 0, 0))
)
def get_received_txn_response(*args): def get_received_txn_response(*args):
return defer.succeed(None) return defer.succeed(None)
@ -125,7 +137,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
if ignore_user is not None and member == ignore_user: if ignore_user is not None and member == ignore_user:
continue continue
if member.is_mine: if hs.is_mine(member):
if localusers is not None: if localusers is not None:
localusers.add(member) localusers.add(member)
else: else:
@ -134,6 +146,12 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.room_member_handler.fetch_room_distributions_into = ( self.room_member_handler.fetch_room_distributions_into = (
fetch_room_distributions_into) fetch_room_distributions_into)
def check_joined_room(room_id, user_id):
if user_id not in [u.to_string() for u in self.room_members]:
raise AuthError(401, "User is not in the room")
self.auth.check_joined_room = check_joined_room
# Some local users to test with # Some local users to test with
self.u_apple = hs.parse_userid("@apple:test") self.u_apple = hs.parse_userid("@apple:test")
self.u_banana = hs.parse_userid("@banana:test") self.u_banana = hs.parse_userid("@banana:test")
@ -145,6 +163,8 @@ class TypingNotificationsTestCase(unittest.TestCase):
def test_started_typing_local(self): def test_started_typing_local(self):
self.room_members = [self.u_apple, self.u_banana] self.room_members = [self.u_apple, self.u_banana]
self.assertEquals(self.event_source.get_current_key(), 0)
yield self.handler.started_typing( yield self.handler.started_typing(
target_user=self.u_apple, target_user=self.u_apple,
auth_user=self.u_apple, auth_user=self.u_apple,
@ -152,13 +172,22 @@ class TypingNotificationsTestCase(unittest.TestCase):
timeout=20000, timeout=20000,
) )
self.mock_update_client.assert_has_calls([ self.on_new_user_event.assert_has_calls([
call(observer_user=self.u_banana, call(rooms=[self.room_id]),
observed_user=self.u_apple,
room_id=self.room_id,
typing=True),
]) ])
self.assertEquals(self.event_source.get_current_key(), 1)
self.assertEquals(
self.event_source.get_new_events_for_user(self.u_apple, 0, None)[0],
[
{"type": "m.typing",
"room_id": self.room_id,
"content": {
"user_ids": [self.u_apple.to_string()],
}},
]
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_started_typing_remote_send(self): def test_started_typing_remote_send(self):
self.room_members = [self.u_apple, self.u_onion] self.room_members = [self.u_apple, self.u_onion]
@ -192,6 +221,8 @@ class TypingNotificationsTestCase(unittest.TestCase):
def test_started_typing_remote_recv(self): def test_started_typing_remote_recv(self):
self.room_members = [self.u_apple, self.u_onion] self.room_members = [self.u_apple, self.u_onion]
self.assertEquals(self.event_source.get_current_key(), 0)
yield self.mock_federation_resource.trigger("PUT", yield self.mock_federation_resource.trigger("PUT",
"/_matrix/federation/v1/send/1000000/", "/_matrix/federation/v1/send/1000000/",
_make_edu_json("farm", "m.typing", _make_edu_json("farm", "m.typing",
@ -203,13 +234,22 @@ class TypingNotificationsTestCase(unittest.TestCase):
) )
) )
self.mock_update_client.assert_has_calls([ self.on_new_user_event.assert_has_calls([
call(observer_user=self.u_apple, call(rooms=[self.room_id]),
observed_user=self.u_onion,
room_id=self.room_id,
typing=True),
]) ])
self.assertEquals(self.event_source.get_current_key(), 1)
self.assertEquals(
self.event_source.get_new_events_for_user(self.u_apple, 0, None)[0],
[
{"type": "m.typing",
"room_id": self.room_id,
"content": {
"user_ids": [self.u_onion.to_string()],
}},
]
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_stopped_typing(self): def test_stopped_typing(self):
self.room_members = [self.u_apple, self.u_banana, self.u_onion] self.room_members = [self.u_apple, self.u_banana, self.u_onion]
@ -232,9 +272,14 @@ class TypingNotificationsTestCase(unittest.TestCase):
# Gut-wrenching # Gut-wrenching
from synapse.handlers.typing import RoomMember from synapse.handlers.typing import RoomMember
self.handler._member_typing_until[ member = RoomMember(self.room_id, self.u_apple)
RoomMember(self.room_id, self.u_apple) self.handler._member_typing_until[member] = 1002000
] = 1002000 self.handler._member_typing_timer[member] = (
self.clock.call_later(1002, lambda: 0)
)
self.handler._room_typing[self.room_id] = set((self.u_apple,))
self.assertEquals(self.event_source.get_current_key(), 0)
yield self.handler.stopped_typing( yield self.handler.stopped_typing(
target_user=self.u_apple, target_user=self.u_apple,
@ -242,11 +287,68 @@ class TypingNotificationsTestCase(unittest.TestCase):
room_id=self.room_id, room_id=self.room_id,
) )
self.mock_update_client.assert_has_calls([ self.on_new_user_event.assert_has_calls([
call(observer_user=self.u_banana, call(rooms=[self.room_id]),
observed_user=self.u_apple,
room_id=self.room_id,
typing=False),
]) ])
yield put_json.await_calls() yield put_json.await_calls()
self.assertEquals(self.event_source.get_current_key(), 1)
self.assertEquals(
self.event_source.get_new_events_for_user(self.u_apple, 0, None)[0],
[
{"type": "m.typing",
"room_id": self.room_id,
"content": {
"user_ids": [],
}},
]
)
@defer.inlineCallbacks
def test_typing_timeout(self):
self.room_members = [self.u_apple, self.u_banana]
self.assertEquals(self.event_source.get_current_key(), 0)
yield self.handler.started_typing(
target_user=self.u_apple,
auth_user=self.u_apple,
room_id=self.room_id,
timeout=10000,
)
self.on_new_user_event.assert_has_calls([
call(rooms=[self.room_id]),
])
self.on_new_user_event.reset_mock()
self.assertEquals(self.event_source.get_current_key(), 1)
self.assertEquals(
self.event_source.get_new_events_for_user(self.u_apple, 0, None)[0],
[
{"type": "m.typing",
"room_id": self.room_id,
"content": {
"user_ids": [self.u_apple.to_string()],
}},
]
)
self.clock.advance_time(11)
self.on_new_user_event.assert_has_calls([
call(rooms=[self.room_id]),
])
self.assertEquals(self.event_source.get_current_key(), 2)
self.assertEquals(
self.event_source.get_new_events_for_user(self.u_apple, 1, None)[0],
[
{"type": "m.typing",
"room_id": self.room_id,
"content": {
"user_ids": [],
}},
]
)

View File

@ -113,9 +113,6 @@ class EventStreamPermissionsTestCase(RestTestCase):
def setUp(self): def setUp(self):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX) self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
persistence_service.get_latest_pdus_in_context.return_value = []
self.mock_config = NonCallableMock() self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()] self.mock_config.signing_key = [MockKey()]
@ -127,7 +124,6 @@ class EventStreamPermissionsTestCase(RestTestCase):
db_pool=db_pool, db_pool=db_pool,
http_client=None, http_client=None,
replication_layer=Mock(), replication_layer=Mock(),
persistence_service=persistence_service,
clock=Mock(spec=[ clock=Mock(spec=[
"call_later", "call_later",
"cancel_call_later", "cancel_call_later",

View File

@ -503,7 +503,7 @@ class RoomsMemberListTestCase(RestTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_member_list_mixed_memberships(self): def test_get_member_list_mixed_memberships(self):
room_creator = "@some_other_guy:blue" room_creator = "@some_other_guy:red"
room_id = yield self.create_room_as(room_creator) room_id = yield self.create_room_as(room_creator)
room_path = "/rooms/%s/members" % room_id room_path = "/rooms/%s/members" % room_id
yield self.invite(room=room_id, src=room_creator, yield self.invite(room=room_id, src=room_creator,
@ -1066,7 +1066,3 @@ class RoomInitialSyncTestCase(RestTestCase):
} }
self.assertTrue(self.user_id in presence_by_user) self.assertTrue(self.user_id in presence_by_user)
self.assertEquals("m.presence", presence_by_user[self.user_id]["type"]) self.assertEquals("m.presence", presence_by_user[self.user_id]["type"])
# (code, response) = yield self.mock_resource.trigger("GET", path, None)
# self.assertEquals(200, code, msg=str(response))
# self.assert_dict(json.loads(content), response)

115
tests/rest/test_typing.py Normal file
View File

@ -0,0 +1,115 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests REST events for /rooms paths."""
# twisted imports
from twisted.internet import defer
import synapse.rest.room
from synapse.server import HomeServer
from ..utils import MockHttpResource, SQLiteMemoryDbPool, MockKey
from .utils import RestTestCase
from mock import Mock, NonCallableMock
PATH_PREFIX = "/_matrix/client/api/v1"
class RoomTypingTestCase(RestTestCase):
""" Tests /rooms/$room_id/typing/$user_id REST API. """
user_id = "@sid:red"
@defer.inlineCallbacks
def setUp(self):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
self.auth_user_id = self.user_id
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer(
"red",
db_pool=db_pool,
http_client=None,
replication_layer=Mock(),
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
config=self.mock_config,
)
self.hs = hs
self.event_source = hs.get_event_sources().sources["typing"]
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
hs.get_handlers().federation_handler = Mock()
def _get_user_by_token(token=None):
return {
"user": hs.parse_userid(self.auth_user_id),
"admin": False,
"device_id": None,
}
hs.get_auth().get_user_by_token = _get_user_by_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
hs.get_datastore().insert_client_ip = _insert_client_ip
synapse.rest.room.register_servlets(hs, self.mock_resource)
self.room_id = yield self.create_room_as(self.user_id)
# Need another user to make notifications actually work
yield self.join(self.room_id, user="@jim:red")
def tearDown(self):
self.hs.get_handlers().typing_notification_handler.tearDown()
@defer.inlineCallbacks
def test_set_typing(self):
(code, _) = yield self.mock_resource.trigger("PUT",
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
'{"typing": true, "timeout": 30000}'
)
self.assertEquals(200, code)
self.assertEquals(self.event_source.get_current_key(), 1)
self.assertEquals(
self.event_source.get_new_events_for_user(self.user_id, 0, None)[0],
[
{"type": "m.typing",
"room_id": self.room_id,
"content": {
"user_ids": [self.user_id],
}},
]
)
@defer.inlineCallbacks
def test_set_not_typing(self):
(code, _) = yield self.mock_resource.trigger("PUT",
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
'{"typing": false}'
)
self.assertEquals(200, code)

View File

@ -18,12 +18,11 @@ from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.api.constants import Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.events.room import (
RoomMemberEvent, MessageEvent, RoomRedactionEvent,
)
from tests.utils import SQLiteMemoryDbPool from tests.utils import SQLiteMemoryDbPool, MockKey
from mock import Mock
class RedactionTestCase(unittest.TestCase): class RedactionTestCase(unittest.TestCase):
@ -33,13 +32,21 @@ class RedactionTestCase(unittest.TestCase):
db_pool = SQLiteMemoryDbPool() db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare() yield db_pool.prepare()
self.mock_config = Mock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer( hs = HomeServer(
"test", "test",
db_pool=db_pool, db_pool=db_pool,
config=self.mock_config,
resource_for_federation=Mock(),
http_client=None,
) )
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.event_factory = hs.get_event_factory() self.event_builder_factory = hs.get_event_builder_factory()
self.handlers = hs.get_handlers()
self.message_handler = self.handlers.message_handler
self.u_alice = hs.parse_userid("@alice:test") self.u_alice = hs.parse_userid("@alice:test")
self.u_bob = hs.parse_userid("@bob:test") self.u_bob = hs.parse_userid("@bob:test")
@ -49,35 +56,23 @@ class RedactionTestCase(unittest.TestCase):
self.depth = 1 self.depth = 1
@defer.inlineCallbacks @defer.inlineCallbacks
def inject_room_member(self, room, user, membership, prev_state=None, def inject_room_member(self, room, user, membership, replaces_state=None,
extra_content={}): extra_content={}):
self.depth += 1 content = {"membership": membership}
content.update(extra_content)
builder = self.event_builder_factory.new({
"type": EventTypes.Member,
"sender": user.to_string(),
"state_key": user.to_string(),
"room_id": room.to_string(),
"content": content,
})
event = self.event_factory.create_event( event, context = yield self.message_handler._create_new_client_event(
etype=RoomMemberEvent.TYPE, builder
user_id=user.to_string(),
state_key=user.to_string(),
room_id=room.to_string(),
membership=membership,
content={"membership": membership},
depth=self.depth,
prev_events=[],
) )
event.content.update(extra_content) yield self.store.persist_event(event, context)
if prev_state:
event.prev_state = prev_state
event.state_events = None
event.hashes = {}
event.prev_state = []
event.auth_events = []
# Have to create a join event using the eventfactory
yield self.store.persist_event(
event
)
defer.returnValue(event) defer.returnValue(event)
@ -85,46 +80,38 @@ class RedactionTestCase(unittest.TestCase):
def inject_message(self, room, user, body): def inject_message(self, room, user, body):
self.depth += 1 self.depth += 1
event = self.event_factory.create_event( builder = self.event_builder_factory.new({
etype=MessageEvent.TYPE, "type": EventTypes.Message,
user_id=user.to_string(), "sender": user.to_string(),
room_id=room.to_string(), "state_key": user.to_string(),
content={"body": body, "msgtype": u"message"}, "room_id": room.to_string(),
depth=self.depth, "content": {"body": body, "msgtype": u"message"},
prev_events=[], })
event, context = yield self.message_handler._create_new_client_event(
builder
) )
event.state_events = None yield self.store.persist_event(event, context)
event.hashes = {}
event.auth_events = []
yield self.store.persist_event(
event
)
defer.returnValue(event) defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks
def inject_redaction(self, room, event_id, user, reason): def inject_redaction(self, room, event_id, user, reason):
event = self.event_factory.create_event( builder = self.event_builder_factory.new({
etype=RoomRedactionEvent.TYPE, "type": EventTypes.Redaction,
user_id=user.to_string(), "sender": user.to_string(),
room_id=room.to_string(), "state_key": user.to_string(),
content={"reason": reason}, "room_id": room.to_string(),
depth=self.depth, "content": {"reason": reason},
redacts=event_id, "redacts": event_id,
prev_events=[], })
event, context = yield self.message_handler._create_new_client_event(
builder
) )
event.state_events = None yield self.store.persist_event(event, context)
event.hashes = {}
event.auth_events = []
yield self.store.persist_event(
event
)
defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_redact(self): def test_redact(self):
@ -152,14 +139,14 @@ class RedactionTestCase(unittest.TestCase):
self.assertObjectHasAttributes( self.assertObjectHasAttributes(
{ {
"type": MessageEvent.TYPE, "type": EventTypes.Message,
"user_id": self.u_alice.to_string(), "user_id": self.u_alice.to_string(),
"content": {"body": "t", "msgtype": "message"}, "content": {"body": "t", "msgtype": "message"},
}, },
event, event,
) )
self.assertFalse(hasattr(event, "redacted_because")) self.assertFalse("redacted_because" in event.unsigned)
# Redact event # Redact event
reason = "Because I said so" reason = "Because I said so"
@ -180,24 +167,26 @@ class RedactionTestCase(unittest.TestCase):
event = results[0] event = results[0]
self.assertEqual(msg_event.event_id, event.event_id)
self.assertTrue("redacted_because" in event.unsigned)
self.assertObjectHasAttributes( self.assertObjectHasAttributes(
{ {
"type": MessageEvent.TYPE, "type": EventTypes.Message,
"user_id": self.u_alice.to_string(), "user_id": self.u_alice.to_string(),
"content": {}, "content": {},
}, },
event, event,
) )
self.assertTrue(hasattr(event, "redacted_because"))
self.assertObjectHasAttributes( self.assertObjectHasAttributes(
{ {
"type": RoomRedactionEvent.TYPE, "type": EventTypes.Redaction,
"user_id": self.u_alice.to_string(), "user_id": self.u_alice.to_string(),
"content": {"reason": reason}, "content": {"reason": reason},
}, },
event.redacted_because, event.unsigned["redacted_because"],
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -229,7 +218,7 @@ class RedactionTestCase(unittest.TestCase):
self.assertObjectHasAttributes( self.assertObjectHasAttributes(
{ {
"type": RoomMemberEvent.TYPE, "type": EventTypes.Member,
"user_id": self.u_bob.to_string(), "user_id": self.u_bob.to_string(),
"content": {"membership": Membership.JOIN, "blue": "red"}, "content": {"membership": Membership.JOIN, "blue": "red"},
}, },
@ -257,22 +246,22 @@ class RedactionTestCase(unittest.TestCase):
event = results[0] event = results[0]
self.assertTrue("redacted_because" in event.unsigned)
self.assertObjectHasAttributes( self.assertObjectHasAttributes(
{ {
"type": RoomMemberEvent.TYPE, "type": EventTypes.Member,
"user_id": self.u_bob.to_string(), "user_id": self.u_bob.to_string(),
"content": {"membership": Membership.JOIN}, "content": {"membership": Membership.JOIN},
}, },
event, event,
) )
self.assertTrue(hasattr(event, "redacted_because"))
self.assertObjectHasAttributes( self.assertObjectHasAttributes(
{ {
"type": RoomRedactionEvent.TYPE, "type": EventTypes.Redaction,
"user_id": self.u_alice.to_string(), "user_id": self.u_alice.to_string(),
"content": {"reason": reason}, "content": {"reason": reason},
}, },
event.redacted_because, event.unsigned["redacted_because"],
) )

View File

@ -18,9 +18,7 @@ from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.api.events.room import ( from synapse.api.constants import EventTypes
RoomNameEvent, RoomTopicEvent
)
from tests.utils import SQLiteMemoryDbPool from tests.utils import SQLiteMemoryDbPool
@ -131,7 +129,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
name = u"A-Room-Name" name = u"A-Room-Name"
yield self.inject_room_event( yield self.inject_room_event(
etype=RoomNameEvent.TYPE, etype=EventTypes.Name,
name=name, name=name,
content={"name": name}, content={"name": name},
depth=1, depth=1,
@ -154,7 +152,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
topic = u"A place for things" topic = u"A place for things"
yield self.inject_room_event( yield self.inject_room_event(
etype=RoomTopicEvent.TYPE, etype=EventTypes.Topic,
topic=topic, topic=topic,
content={"topic": topic}, content={"topic": topic},
depth=1, depth=1,

View File

@ -18,10 +18,11 @@ from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.api.constants import Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.events.room import RoomMemberEvent
from tests.utils import SQLiteMemoryDbPool from tests.utils import SQLiteMemoryDbPool, MockKey
from mock import Mock
class RoomMemberStoreTestCase(unittest.TestCase): class RoomMemberStoreTestCase(unittest.TestCase):
@ -31,14 +32,22 @@ class RoomMemberStoreTestCase(unittest.TestCase):
db_pool = SQLiteMemoryDbPool() db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare() yield db_pool.prepare()
hs = HomeServer("test", self.mock_config = Mock()
db_pool=db_pool, self.mock_config.signing_key = [MockKey()]
)
hs = HomeServer(
"test",
db_pool=db_pool,
config=self.mock_config,
resource_for_federation=Mock(),
http_client=None,
)
# We can't test the RoomMemberStore on its own without the other event # We can't test the RoomMemberStore on its own without the other event
# storage logic # storage logic
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.event_factory = hs.get_event_factory() self.event_builder_factory = hs.get_event_builder_factory()
self.handlers = hs.get_handlers()
self.message_handler = self.handlers.message_handler
self.u_alice = hs.parse_userid("@alice:test") self.u_alice = hs.parse_userid("@alice:test")
self.u_bob = hs.parse_userid("@bob:test") self.u_bob = hs.parse_userid("@bob:test")
@ -49,27 +58,22 @@ class RoomMemberStoreTestCase(unittest.TestCase):
self.room = hs.parse_roomid("!abc123:test") self.room = hs.parse_roomid("!abc123:test")
@defer.inlineCallbacks @defer.inlineCallbacks
def inject_room_member(self, room, user, membership): def inject_room_member(self, room, user, membership, replaces_state=None):
# Have to create a join event using the eventfactory builder = self.event_builder_factory.new({
event = self.event_factory.create_event( "type": EventTypes.Member,
etype=RoomMemberEvent.TYPE, "sender": user.to_string(),
user_id=user.to_string(), "state_key": user.to_string(),
state_key=user.to_string(), "room_id": room.to_string(),
room_id=room.to_string(), "content": {"membership": membership},
membership=membership, })
content={"membership": membership},
depth=1, event, context = yield self.message_handler._create_new_client_event(
prev_events=[], builder
) )
event.state_events = None yield self.store.persist_event(event, context)
event.hashes = {}
event.prev_state = {}
event.auth_events = {}
yield self.store.persist_event( defer.returnValue(event)
event
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_one_member(self): def test_one_member(self):

View File

@ -18,10 +18,11 @@ from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.api.constants import Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.events.room import RoomMemberEvent, MessageEvent
from tests.utils import SQLiteMemoryDbPool from tests.utils import SQLiteMemoryDbPool, MockKey
from mock import Mock
class StreamStoreTestCase(unittest.TestCase): class StreamStoreTestCase(unittest.TestCase):
@ -31,13 +32,21 @@ class StreamStoreTestCase(unittest.TestCase):
db_pool = SQLiteMemoryDbPool() db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare() yield db_pool.prepare()
self.mock_config = Mock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer( hs = HomeServer(
"test", "test",
db_pool=db_pool, db_pool=db_pool,
config=self.mock_config,
resource_for_federation=Mock(),
http_client=None,
) )
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.event_factory = hs.get_event_factory() self.event_builder_factory = hs.get_event_builder_factory()
self.handlers = hs.get_handlers()
self.message_handler = self.handlers.message_handler
self.u_alice = hs.parse_userid("@alice:test") self.u_alice = hs.parse_userid("@alice:test")
self.u_bob = hs.parse_userid("@bob:test") self.u_bob = hs.parse_userid("@bob:test")
@ -48,33 +57,22 @@ class StreamStoreTestCase(unittest.TestCase):
self.depth = 1 self.depth = 1
@defer.inlineCallbacks @defer.inlineCallbacks
def inject_room_member(self, room, user, membership, replaces_state=None): def inject_room_member(self, room, user, membership):
self.depth += 1 self.depth += 1
event = self.event_factory.create_event( builder = self.event_builder_factory.new({
etype=RoomMemberEvent.TYPE, "type": EventTypes.Member,
user_id=user.to_string(), "sender": user.to_string(),
state_key=user.to_string(), "state_key": user.to_string(),
room_id=room.to_string(), "room_id": room.to_string(),
membership=membership, "content": {"membership": membership},
content={"membership": membership}, })
depth=self.depth,
prev_events=[], event, context = yield self.message_handler._create_new_client_event(
builder
) )
event.state_events = None yield self.store.persist_event(event, context)
event.hashes = {}
event.prev_state = []
event.auth_events = []
if replaces_state:
event.prev_state = [(replaces_state, "hash")]
event.replaces_state = replaces_state
# Have to create a join event using the eventfactory
yield self.store.persist_event(
event
)
defer.returnValue(event) defer.returnValue(event)
@ -82,23 +80,19 @@ class StreamStoreTestCase(unittest.TestCase):
def inject_message(self, room, user, body): def inject_message(self, room, user, body):
self.depth += 1 self.depth += 1
event = self.event_factory.create_event( builder = self.event_builder_factory.new({
etype=MessageEvent.TYPE, "type": EventTypes.Message,
user_id=user.to_string(), "sender": user.to_string(),
room_id=room.to_string(), "state_key": user.to_string(),
content={"body": body, "msgtype": u"message"}, "room_id": room.to_string(),
depth=self.depth, "content": {"body": body, "msgtype": u"message"},
prev_events=[], })
event, context = yield self.message_handler._create_new_client_event(
builder
) )
event.state_events = None yield self.store.persist_event(event, context)
event.hashes = {}
event.auth_events = []
# Have to create a join event using the eventfactory
yield self.store.persist_event(
event
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_event_stream_get_other(self): def test_event_stream_get_other(self):
@ -130,7 +124,7 @@ class StreamStoreTestCase(unittest.TestCase):
self.assertObjectHasAttributes( self.assertObjectHasAttributes(
{ {
"type": MessageEvent.TYPE, "type": EventTypes.Message,
"user_id": self.u_alice.to_string(), "user_id": self.u_alice.to_string(),
"content": {"body": "test", "msgtype": "message"}, "content": {"body": "test", "msgtype": "message"},
}, },
@ -167,7 +161,7 @@ class StreamStoreTestCase(unittest.TestCase):
self.assertObjectHasAttributes( self.assertObjectHasAttributes(
{ {
"type": MessageEvent.TYPE, "type": EventTypes.Message,
"user_id": self.u_alice.to_string(), "user_id": self.u_alice.to_string(),
"content": {"body": "test", "msgtype": "message"}, "content": {"body": "test", "msgtype": "message"},
}, },
@ -220,7 +214,6 @@ class StreamStoreTestCase(unittest.TestCase):
event2 = yield self.inject_room_member( event2 = yield self.inject_room_member(
self.room1, self.u_alice, Membership.JOIN, self.room1, self.u_alice, Membership.JOIN,
replaces_state=event1.event_id,
) )
end = yield self.store.get_room_events_max_id() end = yield self.store.get_room_events_max_id()
@ -238,6 +231,6 @@ class StreamStoreTestCase(unittest.TestCase):
event = results[0] event = results[0]
self.assertTrue( self.assertTrue(
hasattr(event, "prev_content"), "prev_content" in event.unsigned,
msg="No prev_content key" msg="No prev_content key"
) )

View File

@ -13,12 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from tests import unittest from . import unittest
from twisted.internet import defer from twisted.internet import defer
from mock import Mock, patch from mock import Mock, patch
from synapse.util.distributor import Distributor from synapse.util.distributor import Distributor
from synapse.util.async import run_on_reactor
class DistributorTestCase(unittest.TestCase): class DistributorTestCase(unittest.TestCase):
@ -26,6 +27,7 @@ class DistributorTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.dist = Distributor() self.dist = Distributor()
@defer.inlineCallbacks
def test_signal_dispatch(self): def test_signal_dispatch(self):
self.dist.declare("alert") self.dist.declare("alert")
@ -33,10 +35,11 @@ class DistributorTestCase(unittest.TestCase):
self.dist.observe("alert", observer) self.dist.observe("alert", observer)
d = self.dist.fire("alert", 1, 2, 3) d = self.dist.fire("alert", 1, 2, 3)
yield d
self.assertTrue(d.called) self.assertTrue(d.called)
observer.assert_called_with(1, 2, 3) observer.assert_called_with(1, 2, 3)
@defer.inlineCallbacks
def test_signal_dispatch_deferred(self): def test_signal_dispatch_deferred(self):
self.dist.declare("whine") self.dist.declare("whine")
@ -50,8 +53,10 @@ class DistributorTestCase(unittest.TestCase):
self.assertFalse(d_outer.called) self.assertFalse(d_outer.called)
d_inner.callback(None) d_inner.callback(None)
yield d_outer
self.assertTrue(d_outer.called) self.assertTrue(d_outer.called)
@defer.inlineCallbacks
def test_signal_catch(self): def test_signal_catch(self):
self.dist.declare("alarm") self.dist.declare("alarm")
@ -65,6 +70,7 @@ class DistributorTestCase(unittest.TestCase):
spec=["warning"] spec=["warning"]
) as mock_logger: ) as mock_logger:
d = self.dist.fire("alarm", "Go") d = self.dist.fire("alarm", "Go")
yield d
self.assertTrue(d.called) self.assertTrue(d.called)
observers[0].assert_called_once("Go") observers[0].assert_called_once("Go")
@ -81,23 +87,28 @@ class DistributorTestCase(unittest.TestCase):
self.dist.declare("whail") self.dist.declare("whail")
observer = Mock() class MyException(Exception):
observer.return_value = defer.fail( pass
Exception("Oopsie")
) @defer.inlineCallbacks
def observer():
yield run_on_reactor()
raise MyException("Oopsie")
self.dist.observe("whail", observer) self.dist.observe("whail", observer)
d = self.dist.fire("whail") d = self.dist.fire("whail")
yield self.assertFailure(d, Exception) yield self.assertFailure(d, MyException)
self.dist.suppress_failures = True
@defer.inlineCallbacks
def test_signal_prereg(self): def test_signal_prereg(self):
observer = Mock() observer = Mock()
self.dist.observe("flare", observer) self.dist.observe("flare", observer)
self.dist.declare("flare") self.dist.declare("flare")
self.dist.fire("flare", 4, 5) yield self.dist.fire("flare", 4, 5)
observer.assert_called_with(4, 5) observer.assert_called_with(4, 5)

View File

@ -26,6 +26,7 @@ class StateTestCase(unittest.TestCase):
self.store = Mock( self.store = Mock(
spec_set=[ spec_set=[
"get_state_groups", "get_state_groups",
"add_event_hashes",
] ]
) )
hs = Mock(spec=["get_datastore"]) hs = Mock(spec=["get_datastore"])
@ -44,17 +45,20 @@ class StateTestCase(unittest.TestCase):
self.create_event(type="test2", state_key=""), self.create_event(type="test2", state_key=""),
] ]
yield self.state.annotate_event_with_state(event, old_state=old_state) context = yield self.state.compute_event_context(
event, old_state=old_state
)
for k, v in event.old_state_events.items(): for k, v in context.current_state.items():
type, state_key = k type, state_key = k
self.assertEqual(type, v.type) self.assertEqual(type, v.type)
self.assertEqual(state_key, v.state_key) self.assertEqual(state_key, v.state_key)
self.assertEqual(set(old_state), set(event.old_state_events.values())) self.assertEqual(
self.assertDictEqual(event.old_state_events, event.state_events) set(old_state), set(context.current_state.values())
)
self.assertIsNone(event.state_group) self.assertIsNone(context.state_group)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_annotate_with_old_state(self): def test_annotate_with_old_state(self):
@ -66,21 +70,21 @@ class StateTestCase(unittest.TestCase):
self.create_event(type="test2", state_key=""), self.create_event(type="test2", state_key=""),
] ]
yield self.state.annotate_event_with_state(event, old_state=old_state) context = yield self.state.compute_event_context(
event, old_state=old_state
)
for k, v in event.old_state_events.items(): for k, v in context.current_state.items():
type, state_key = k type, state_key = k
self.assertEqual(type, v.type) self.assertEqual(type, v.type)
self.assertEqual(state_key, v.state_key) self.assertEqual(state_key, v.state_key)
self.assertEqual( self.assertEqual(
set(old_state + [event]), set(old_state),
set(event.old_state_events.values()) set(context.current_state.values())
) )
self.assertDictEqual(event.old_state_events, event.state_events) self.assertIsNone(context.state_group)
self.assertIsNone(event.state_group)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_trivial_annotate_message(self): def test_trivial_annotate_message(self):
@ -99,30 +103,19 @@ class StateTestCase(unittest.TestCase):
group_name: old_state, group_name: old_state,
} }
yield self.state.annotate_event_with_state(event) context = yield self.state.compute_event_context(event)
for k, v in event.old_state_events.items(): for k, v in context.current_state.items():
type, state_key = k type, state_key = k
self.assertEqual(type, v.type) self.assertEqual(type, v.type)
self.assertEqual(state_key, v.state_key) self.assertEqual(state_key, v.state_key)
self.assertEqual( self.assertEqual(
set([e.event_id for e in old_state]), set([e.event_id for e in old_state]),
set([e.event_id for e in event.old_state_events.values()]) set([e.event_id for e in context.current_state.values()])
) )
self.assertDictEqual( self.assertEqual(group_name, context.state_group)
{
k: v.event_id
for k, v in event.old_state_events.items()
},
{
k: v.event_id
for k, v in event.state_events.items()
}
)
self.assertEqual(group_name, event.state_group)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_trivial_annotate_state(self): def test_trivial_annotate_state(self):
@ -141,38 +134,19 @@ class StateTestCase(unittest.TestCase):
group_name: old_state, group_name: old_state,
} }
yield self.state.annotate_event_with_state(event) context = yield self.state.compute_event_context(event)
for k, v in event.old_state_events.items(): for k, v in context.current_state.items():
type, state_key = k type, state_key = k
self.assertEqual(type, v.type) self.assertEqual(type, v.type)
self.assertEqual(state_key, v.state_key) self.assertEqual(state_key, v.state_key)
self.assertEqual( self.assertEqual(
set([e.event_id for e in old_state]), set([e.event_id for e in old_state]),
set([e.event_id for e in event.old_state_events.values()]) set([e.event_id for e in context.current_state.values()])
) )
self.assertEqual( self.assertIsNone(context.state_group)
set([e.event_id for e in old_state] + [event.event_id]),
set([e.event_id for e in event.state_events.values()])
)
new_state = {
k: v.event_id
for k, v in event.state_events.items()
}
old_state = {
k: v.event_id
for k, v in event.old_state_events.items()
}
old_state[(event.type, event.state_key)] = event.event_id
self.assertDictEqual(
old_state,
new_state
)
self.assertIsNone(event.state_group)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_resolve_message_conflict(self): def test_resolve_message_conflict(self):
@ -199,16 +173,11 @@ class StateTestCase(unittest.TestCase):
group_name_2: old_state_2, group_name_2: old_state_2,
} }
yield self.state.annotate_event_with_state(event) context = yield self.state.compute_event_context(event)
self.assertEqual(len(event.old_state_events), 5) self.assertEqual(len(context.current_state), 5)
self.assertEqual( self.assertIsNone(context.state_group)
set([e.event_id for e in event.state_events.values()]),
set([e.event_id for e in event.old_state_events.values()])
)
self.assertIsNone(event.state_group)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_resolve_state_conflict(self): def test_resolve_state_conflict(self):
@ -235,19 +204,11 @@ class StateTestCase(unittest.TestCase):
group_name_2: old_state_2, group_name_2: old_state_2,
} }
yield self.state.annotate_event_with_state(event) context = yield self.state.compute_event_context(event)
self.assertEqual(len(event.old_state_events), 5) self.assertEqual(len(context.current_state), 5)
expected_new = event.old_state_events self.assertIsNone(context.state_group)
expected_new[(event.type, event.state_key)] = event
self.assertEqual(
set([e.event_id for e in expected_new.values()]),
set([e.event_id for e in event.state_events.values()]),
)
self.assertIsNone(event.state_group)
def create_event(self, name=None, type=None, state_key=None): def create_event(self, name=None, type=None, state_key=None):
self.event_id += 1 self.event_id += 1
@ -266,6 +227,9 @@ class StateTestCase(unittest.TestCase):
event.state_key = state_key event.state_key = state_key
event.event_id = event_id event.event_id = event_id
event.is_state = lambda: (state_key is not None)
event.unsigned = {}
event.user_id = "@user_id:example.com" event.user_id = "@user_id:example.com"
event.room_id = "!room_id:example.com" event.room_id = "!room_id:example.com"

70
tests/test_test_utils.py Normal file
View File

@ -0,0 +1,70 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from tests.utils import MockClock
class MockClockTestCase(unittest.TestCase):
def setUp(self):
self.clock = MockClock()
def test_advance_time(self):
start_time = self.clock.time()
self.clock.advance_time(20)
self.assertEquals(20, self.clock.time() - start_time)
def test_later(self):
invoked = [0, 0]
def _cb0():
invoked[0] = 1
self.clock.call_later(10, _cb0)
def _cb1():
invoked[1] = 1
self.clock.call_later(20, _cb1)
self.assertFalse(invoked[0])
self.clock.advance_time(15)
self.assertTrue(invoked[0])
self.assertFalse(invoked[1])
self.clock.advance_time(5)
self.assertTrue(invoked[1])
def test_cancel_later(self):
invoked = [0, 0]
def _cb0():
invoked[0] = 1
t0 = self.clock.call_later(10, _cb0)
def _cb1():
invoked[1] = 1
t1 = self.clock.call_later(20, _cb1)
self.clock.cancel_call_later(t0)
self.clock.advance_time(30)
self.assertFalse(invoked[0])
self.assertTrue(invoked[1])

View File

@ -23,21 +23,21 @@ mock_homeserver = BaseHomeServer(hostname="my.domain")
class UserIDTestCase(unittest.TestCase): class UserIDTestCase(unittest.TestCase):
def test_parse(self): def test_parse(self):
user = UserID.from_string("@1234abcd:my.domain", hs=mock_homeserver) user = UserID.from_string("@1234abcd:my.domain")
self.assertEquals("1234abcd", user.localpart) self.assertEquals("1234abcd", user.localpart)
self.assertEquals("my.domain", user.domain) self.assertEquals("my.domain", user.domain)
self.assertEquals(True, user.is_mine) self.assertEquals(True, mock_homeserver.is_mine(user))
def test_build(self): def test_build(self):
user = UserID("5678efgh", "my.domain", True) user = UserID("5678efgh", "my.domain")
self.assertEquals(user.to_string(), "@5678efgh:my.domain") self.assertEquals(user.to_string(), "@5678efgh:my.domain")
def test_compare(self): def test_compare(self):
userA = UserID.from_string("@userA:my.domain", hs=mock_homeserver) userA = UserID.from_string("@userA:my.domain")
userAagain = UserID.from_string("@userA:my.domain", hs=mock_homeserver) userAagain = UserID.from_string("@userA:my.domain")
userB = UserID.from_string("@userB:my.domain", hs=mock_homeserver) userB = UserID.from_string("@userB:my.domain")
self.assertTrue(userA == userAagain) self.assertTrue(userA == userAagain)
self.assertTrue(userA != userB) self.assertTrue(userA != userB)
@ -52,14 +52,14 @@ class UserIDTestCase(unittest.TestCase):
class RoomAliasTestCase(unittest.TestCase): class RoomAliasTestCase(unittest.TestCase):
def test_parse(self): def test_parse(self):
room = RoomAlias.from_string("#channel:my.domain", hs=mock_homeserver) room = RoomAlias.from_string("#channel:my.domain")
self.assertEquals("channel", room.localpart) self.assertEquals("channel", room.localpart)
self.assertEquals("my.domain", room.domain) self.assertEquals("my.domain", room.domain)
self.assertEquals(True, room.is_mine) self.assertEquals(True, mock_homeserver.is_mine(room))
def test_build(self): def test_build(self):
room = RoomAlias("channel", "my.domain", True) room = RoomAlias("channel", "my.domain")
self.assertEquals(room.to_string(), "#channel:my.domain") self.assertEquals(room.to_string(), "#channel:my.domain")

Some files were not shown because too many files have changed in this diff Show More