mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2024-10-01 11:49:51 -04:00
Merge branch 'release-v0.6.0'
This commit is contained in:
commit
45a6869cb4
2
.gitignore
vendored
2
.gitignore
vendored
@ -38,3 +38,5 @@ graph/*.dot
|
||||
**/webclient/test/environment-protractor.js
|
||||
|
||||
uploads
|
||||
|
||||
.idea/
|
||||
|
@ -1,3 +1,12 @@
|
||||
Changes in synapse 0.6.0 (2014-12-16)
|
||||
=====================================
|
||||
|
||||
* Add new API for media upload and download that supports thumbnailing.
|
||||
* Implement typing notifications.
|
||||
* Fix bugs where we sent events with invalid signatures due to bugs where
|
||||
we incorrectly persisted events.
|
||||
* Improve performance of database queries involving retrieving events.
|
||||
|
||||
Changes in synapse 0.5.4a (2014-12-13)
|
||||
======================================
|
||||
|
||||
|
31
README.rst
31
README.rst
@ -133,6 +133,37 @@ failing, e.g.::
|
||||
On OSX, if you encounter clang: error: unknown argument: '-mno-fused-madd' you
|
||||
will need to export CFLAGS=-Qunused-arguments.
|
||||
|
||||
Windows Install
|
||||
---------------
|
||||
Synapse can be installed on Cygwin. It requires the following Cygwin packages:
|
||||
|
||||
- gcc
|
||||
- git
|
||||
- libffi-devel
|
||||
- openssl (and openssl-devel, python-openssl)
|
||||
- python
|
||||
- python-setuptools
|
||||
|
||||
The content repository requires additional packages and will be unable to process
|
||||
uploads without them:
|
||||
- libjpeg8
|
||||
- libjpeg8-devel
|
||||
- zlib
|
||||
If you choose to install Synapse without these packages, you will need to reinstall
|
||||
``pillow`` for changes to be applied, e.g. ``pip uninstall pillow`` ``pip install
|
||||
pillow --user``
|
||||
|
||||
Troubleshooting:
|
||||
|
||||
- You may need to upgrade ``setuptools`` to get this to work correctly:
|
||||
``pip install setuptools --upgrade``.
|
||||
- You may encounter errors indicating that ``ffi.h`` is missing, even with
|
||||
``libffi-devel`` installed. If you do, copy the ``.h`` files:
|
||||
``cp /usr/lib/libffi-3.0.13/include/*.h /usr/include``
|
||||
- You may need to install libsodium from source in order to install PyNacl. If
|
||||
you do, you may need to create a symlink to ``libsodium.a`` so ``ld`` can find
|
||||
it: ``ln -s /usr/local/lib/libsodium.a /usr/lib/libsodium.a``
|
||||
|
||||
Running Your Homeserver
|
||||
=======================
|
||||
|
||||
|
20
UPGRADE.rst
20
UPGRADE.rst
@ -1,3 +1,23 @@
|
||||
Upgrading to v0.6.0
|
||||
===================
|
||||
|
||||
To pull in new dependencies, run::
|
||||
|
||||
python setup.py develop --user
|
||||
|
||||
This update includes a change to the database schema. To upgrade you first need
|
||||
to upgrade the database by running::
|
||||
|
||||
python scripts/upgrade_db_to_v0.6.0.py <db> <server_name> <signing_key>
|
||||
|
||||
Where `<db>` is the location of the database, `<server_name>` is the
|
||||
server name as specified in the synapse configuration, and `<signing_key>` is
|
||||
the location of the signing key as specified in the synapse configuration.
|
||||
|
||||
This may take some time to complete. Failures of signatures and content hashes
|
||||
can safely be ignored.
|
||||
|
||||
|
||||
Upgrading to v0.5.1
|
||||
===================
|
||||
|
||||
|
@ -1,10 +1,14 @@
|
||||
Basically, PEP8
|
||||
|
||||
- Max line width: 80 chars.
|
||||
- NEVER tabs. 4 spaces to indent.
|
||||
- Max line width: 79 chars (with flexibility to overflow by a "few chars" if
|
||||
the overflowing content is not semantically significant and avoids an
|
||||
explosion of vertical whitespace).
|
||||
- Use camel case for class and type names
|
||||
- Use underscores for functions and variables.
|
||||
- Use double quotes.
|
||||
- Use parentheses instead of '\' for line continuation where ever possible (which is pretty much everywhere)
|
||||
- Use parentheses instead of '\\' for line continuation where ever possible
|
||||
(which is pretty much everywhere)
|
||||
- There should be max a single new line between:
|
||||
- statements
|
||||
- functions in a class
|
||||
@ -14,5 +18,32 @@ Basically, PEP8
|
||||
- a single space after a comma
|
||||
- a single space before and after for '=' when used as assignment
|
||||
- no spaces before and after for '=' for default values and keyword arguments.
|
||||
- Indenting must follow PEP8; either hanging indent or multiline-visual indent
|
||||
depending on the size and shape of the arguments and what makes more sense to
|
||||
the author. In other words, both this::
|
||||
|
||||
Comments should follow the google code style. This is so that we can generate documentation with sphinx (http://sphinxcontrib-napoleon.readthedocs.org/en/latest/)
|
||||
print("I am a fish %s" % "moo")
|
||||
|
||||
and this::
|
||||
|
||||
print("I am a fish %s" %
|
||||
"moo")
|
||||
|
||||
and this::
|
||||
|
||||
print(
|
||||
"I am a fish %s" %
|
||||
"moo"
|
||||
)
|
||||
|
||||
...are valid, although given each one takes up 2x more vertical space than
|
||||
the previous, it's up to the author's discretion as to which layout makes most
|
||||
sense for their function invocation. (e.g. if they want to add comments
|
||||
per-argument, or put expressions in the arguments, or group related arguments
|
||||
together, or want to deliberately extend or preserve vertical/horizontal
|
||||
space)
|
||||
|
||||
Comments should follow the google code style. This is so that we can generate
|
||||
documentation with sphinx (http://sphinxcontrib-napoleon.readthedocs.org/en/latest/)
|
||||
|
||||
Code should pass pep8 --max-line-length=100 without any warnings.
|
||||
|
25
docs/media_repository.rst
Normal file
25
docs/media_repository.rst
Normal file
@ -0,0 +1,25 @@
|
||||
Media Repository
|
||||
================
|
||||
|
||||
The media repository is where attachments and avatar photos are stored.
|
||||
It stores attachment content and thumbnails for media uploaded by local users.
|
||||
It caches attachment content and thumbnails for media uploaded by remote users.
|
||||
|
||||
Storage
|
||||
-------
|
||||
|
||||
Each item of media is assigned a ``media_id`` when it is uploaded.
|
||||
The ``media_id`` is a randomly chosen, URL safe 24 character string.
|
||||
Metadata such as the MIME type, upload time and length are stored in the
|
||||
sqlite3 database indexed by ``media_id``.
|
||||
Content is stored on the filesystem under a ``"local_content"`` directory.
|
||||
Thumbnails are stored under a ``"local_thumbnails"`` directory.
|
||||
The item with ``media_id`` ``"aabbccccccccdddddddddddd"`` is stored under
|
||||
``"local_content/aa/bb/ccccccccdddddddddddd"``. Its thumbnail with width
|
||||
``128`` and height ``96`` and type ``"image/jpeg"`` is stored under
|
||||
``"local_thumbnails/aa/bb/ccccccccdddddddddddd/128-96-image-jpeg"``
|
||||
Remote content is cached under ``"remote_content"`` directory. Each item of
|
||||
remote content is assigned a local "``filesystem_id``" to ensure that the
|
||||
directory structure ``"remote_content/server_name/aa/bb/ccccccccdddddddddddd"``
|
||||
is appropriate. Thumbnails for remote content are stored under
|
||||
``"remote_thumbnails/server_name/..."``
|
138
graph/graph2.py
Normal file
138
graph/graph2.py
Normal file
@ -0,0 +1,138 @@
|
||||
# Copyright 2014 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import sqlite3
|
||||
import pydot
|
||||
import cgi
|
||||
import json
|
||||
import datetime
|
||||
import argparse
|
||||
|
||||
from synapse.events import FrozenEvent
|
||||
|
||||
|
||||
def make_graph(db_name, room_id, file_prefix):
|
||||
conn = sqlite3.connect(db_name)
|
||||
|
||||
c = conn.execute(
|
||||
"SELECT json FROM event_json where room_id = ?",
|
||||
(room_id,)
|
||||
)
|
||||
|
||||
events = [FrozenEvent(json.loads(e[0])) for e in c.fetchall()]
|
||||
|
||||
events.sort(key=lambda e: e.depth)
|
||||
|
||||
node_map = {}
|
||||
state_groups = {}
|
||||
|
||||
graph = pydot.Dot(graph_name="Test")
|
||||
|
||||
for event in events:
|
||||
c = conn.execute(
|
||||
"SELECT state_group FROM event_to_state_groups "
|
||||
"WHERE event_id = ?",
|
||||
(event.event_id,)
|
||||
)
|
||||
|
||||
res = c.fetchone()
|
||||
state_group = res[0] if res else None
|
||||
|
||||
if state_group is not None:
|
||||
state_groups.setdefault(state_group, []).append(event.event_id)
|
||||
|
||||
t = datetime.datetime.fromtimestamp(
|
||||
float(event.origin_server_ts) / 1000
|
||||
).strftime('%Y-%m-%d %H:%M:%S,%f')
|
||||
|
||||
content = json.dumps(event.get_dict()["content"])
|
||||
|
||||
label = (
|
||||
"<"
|
||||
"<b>%(name)s </b><br/>"
|
||||
"Type: <b>%(type)s </b><br/>"
|
||||
"State key: <b>%(state_key)s </b><br/>"
|
||||
"Content: <b>%(content)s </b><br/>"
|
||||
"Time: <b>%(time)s </b><br/>"
|
||||
"Depth: <b>%(depth)s </b><br/>"
|
||||
"State group: %(state_group)s<br/>"
|
||||
">"
|
||||
) % {
|
||||
"name": event.event_id,
|
||||
"type": event.type,
|
||||
"state_key": event.get("state_key", None),
|
||||
"content": cgi.escape(content, quote=True),
|
||||
"time": t,
|
||||
"depth": event.depth,
|
||||
"state_group": state_group,
|
||||
}
|
||||
|
||||
node = pydot.Node(
|
||||
name=event.event_id,
|
||||
label=label,
|
||||
)
|
||||
|
||||
node_map[event.event_id] = node
|
||||
graph.add_node(node)
|
||||
|
||||
for event in events:
|
||||
for prev_id, _ in event.prev_events:
|
||||
try:
|
||||
end_node = node_map[prev_id]
|
||||
except:
|
||||
end_node = pydot.Node(
|
||||
name=prev_id,
|
||||
label="<<b>%s</b>>" % (prev_id,),
|
||||
)
|
||||
|
||||
node_map[prev_id] = end_node
|
||||
graph.add_node(end_node)
|
||||
|
||||
edge = pydot.Edge(node_map[event.event_id], end_node)
|
||||
graph.add_edge(edge)
|
||||
|
||||
for group, event_ids in state_groups.items():
|
||||
if len(event_ids) <= 1:
|
||||
continue
|
||||
|
||||
cluster = pydot.Cluster(
|
||||
str(group),
|
||||
label="<State Group: %s>" % (str(group),)
|
||||
)
|
||||
|
||||
for event_id in event_ids:
|
||||
cluster.add_node(node_map[event_id])
|
||||
|
||||
graph.add_subgraph(cluster)
|
||||
|
||||
graph.write('%s.dot' % file_prefix, format='raw', prog='dot')
|
||||
graph.write_svg("%s.svg" % file_prefix, prog='dot')
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate a PDU graph for a given room by talking "
|
||||
"to the given homeserver to get the list of PDUs. \n"
|
||||
"Requires pydot."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p", "--prefix", dest="prefix",
|
||||
help="String to prefix output files with"
|
||||
)
|
||||
parser.add_argument('db')
|
||||
parser.add_argument('room')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
make_graph(args.db, args.room, args.prefix)
|
@ -18,6 +18,9 @@ class dictobj(dict):
|
||||
def get_full_dict(self):
|
||||
return dict(self)
|
||||
|
||||
def get_pdu_json(self):
|
||||
return dict(self)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
143
scripts/federation_client.py
Normal file
143
scripts/federation_client.py
Normal file
@ -0,0 +1,143 @@
|
||||
import nacl.signing
|
||||
import json
|
||||
import base64
|
||||
import requests
|
||||
import sys
|
||||
import srvlookup
|
||||
|
||||
|
||||
def encode_base64(input_bytes):
|
||||
"""Encode bytes as a base64 string without any padding."""
|
||||
|
||||
input_len = len(input_bytes)
|
||||
output_len = 4 * ((input_len + 2) // 3) + (input_len + 2) % 3 - 2
|
||||
output_bytes = base64.b64encode(input_bytes)
|
||||
output_string = output_bytes[:output_len].decode("ascii")
|
||||
return output_string
|
||||
|
||||
|
||||
def decode_base64(input_string):
|
||||
"""Decode a base64 string to bytes inferring padding from the length of the
|
||||
string."""
|
||||
|
||||
input_bytes = input_string.encode("ascii")
|
||||
input_len = len(input_bytes)
|
||||
padding = b"=" * (3 - ((input_len + 3) % 4))
|
||||
output_len = 3 * ((input_len + 2) // 4) + (input_len + 2) % 4 - 2
|
||||
output_bytes = base64.b64decode(input_bytes + padding)
|
||||
return output_bytes[:output_len]
|
||||
|
||||
|
||||
def encode_canonical_json(value):
|
||||
return json.dumps(
|
||||
value,
|
||||
# Encode code-points outside of ASCII as UTF-8 rather than \u escapes
|
||||
ensure_ascii=False,
|
||||
# Remove unecessary white space.
|
||||
separators=(',',':'),
|
||||
# Sort the keys of dictionaries.
|
||||
sort_keys=True,
|
||||
# Encode the resulting unicode as UTF-8 bytes.
|
||||
).encode("UTF-8")
|
||||
|
||||
|
||||
def sign_json(json_object, signing_key, signing_name):
|
||||
signatures = json_object.pop("signatures", {})
|
||||
unsigned = json_object.pop("unsigned", None)
|
||||
|
||||
signed = signing_key.sign(encode_canonical_json(json_object))
|
||||
signature_base64 = encode_base64(signed.signature)
|
||||
|
||||
key_id = "%s:%s" % (signing_key.alg, signing_key.version)
|
||||
signatures.setdefault(signing_name, {})[key_id] = signature_base64
|
||||
|
||||
json_object["signatures"] = signatures
|
||||
if unsigned is not None:
|
||||
json_object["unsigned"] = unsigned
|
||||
|
||||
return json_object
|
||||
|
||||
|
||||
NACL_ED25519 = "ed25519"
|
||||
|
||||
def decode_signing_key_base64(algorithm, version, key_base64):
|
||||
"""Decode a base64 encoded signing key
|
||||
Args:
|
||||
algorithm (str): The algorithm the key is for (currently "ed25519").
|
||||
version (str): Identifies this key out of the keys for this entity.
|
||||
key_base64 (str): Base64 encoded bytes of the key.
|
||||
Returns:
|
||||
A SigningKey object.
|
||||
"""
|
||||
if algorithm == NACL_ED25519:
|
||||
key_bytes = decode_base64(key_base64)
|
||||
key = nacl.signing.SigningKey(key_bytes)
|
||||
key.version = version
|
||||
key.alg = NACL_ED25519
|
||||
return key
|
||||
else:
|
||||
raise ValueError("Unsupported algorithm %s" % (algorithm,))
|
||||
|
||||
|
||||
def read_signing_keys(stream):
|
||||
"""Reads a list of keys from a stream
|
||||
Args:
|
||||
stream : A stream to iterate for keys.
|
||||
Returns:
|
||||
list of SigningKey objects.
|
||||
"""
|
||||
keys = []
|
||||
for line in stream:
|
||||
algorithm, version, key_base64 = line.split()
|
||||
keys.append(decode_signing_key_base64(algorithm, version, key_base64))
|
||||
return keys
|
||||
|
||||
|
||||
def lookup(destination, path):
|
||||
if ":" in destination:
|
||||
return "https://%s%s" % (destination, path)
|
||||
else:
|
||||
srv = srvlookup.lookup("matrix", "tcp", destination)[0]
|
||||
return "https://%s:%d%s" % (srv.host, srv.port, path)
|
||||
|
||||
def get_json(origin_name, origin_key, destination, path):
|
||||
request_json = {
|
||||
"method": "GET",
|
||||
"uri": path,
|
||||
"origin": origin_name,
|
||||
"destination": destination,
|
||||
}
|
||||
|
||||
signed_json = sign_json(request_json, origin_key, origin_name)
|
||||
|
||||
authorization_headers = []
|
||||
|
||||
for key, sig in signed_json["signatures"][origin_name].items():
|
||||
authorization_headers.append(bytes(
|
||||
"X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
|
||||
origin_name, key, sig,
|
||||
)
|
||||
))
|
||||
|
||||
result = requests.get(
|
||||
lookup(destination, path),
|
||||
headers={"Authorization": authorization_headers[0]},
|
||||
verify=False,
|
||||
)
|
||||
return result.json()
|
||||
|
||||
|
||||
def main():
|
||||
origin_name, keyfile, destination, path = sys.argv[1:]
|
||||
|
||||
with open(keyfile) as f:
|
||||
key = read_signing_keys(f)[0]
|
||||
|
||||
result = get_json(
|
||||
origin_name, key, destination, "/_matrix/federation/v1/" + path
|
||||
)
|
||||
|
||||
json.dump(result, sys.stdout)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
331
scripts/upgrade_db_to_v0.6.0.py
Normal file
331
scripts/upgrade_db_to_v0.6.0.py
Normal file
@ -0,0 +1,331 @@
|
||||
|
||||
from synapse.storage import SCHEMA_VERSION, read_schema
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.signatures import SignatureStore
|
||||
from synapse.storage.event_federation import EventFederationStore
|
||||
|
||||
from syutil.base64util import encode_base64, decode_base64
|
||||
|
||||
from synapse.crypto.event_signing import compute_event_signature
|
||||
|
||||
from synapse.events.builder import EventBuilder
|
||||
from synapse.events.utils import prune_event
|
||||
|
||||
from synapse.crypto.event_signing import check_event_content_hash
|
||||
|
||||
from syutil.crypto.jsonsign import (
|
||||
verify_signed_json, SignatureVerifyException,
|
||||
)
|
||||
from syutil.crypto.signing_key import decode_verify_key_bytes
|
||||
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
|
||||
import argparse
|
||||
# import dns.resolver
|
||||
import hashlib
|
||||
import httplib
|
||||
import json
|
||||
import sqlite3
|
||||
import syutil
|
||||
import urllib2
|
||||
|
||||
|
||||
delta_sql = """
|
||||
CREATE TABLE IF NOT EXISTS event_json(
|
||||
event_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
internal_metadata NOT NULL,
|
||||
json BLOB NOT NULL,
|
||||
CONSTRAINT ev_j_uniq UNIQUE (event_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS event_json_id ON event_json(event_id);
|
||||
CREATE INDEX IF NOT EXISTS event_json_room_id ON event_json(room_id);
|
||||
|
||||
PRAGMA user_version = 10;
|
||||
"""
|
||||
|
||||
|
||||
class Store(object):
|
||||
_get_event_signatures_txn = SignatureStore.__dict__["_get_event_signatures_txn"]
|
||||
_get_event_content_hashes_txn = SignatureStore.__dict__["_get_event_content_hashes_txn"]
|
||||
_get_event_reference_hashes_txn = SignatureStore.__dict__["_get_event_reference_hashes_txn"]
|
||||
_get_prev_event_hashes_txn = SignatureStore.__dict__["_get_prev_event_hashes_txn"]
|
||||
_get_prev_events_and_state = EventFederationStore.__dict__["_get_prev_events_and_state"]
|
||||
_get_auth_events = EventFederationStore.__dict__["_get_auth_events"]
|
||||
cursor_to_dict = SQLBaseStore.__dict__["cursor_to_dict"]
|
||||
_simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
|
||||
_simple_select_list_txn = SQLBaseStore.__dict__["_simple_select_list_txn"]
|
||||
_simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"]
|
||||
|
||||
def _generate_event_json(self, txn, rows):
|
||||
events = []
|
||||
for row in rows:
|
||||
d = dict(row)
|
||||
|
||||
d.pop("stream_ordering", None)
|
||||
d.pop("topological_ordering", None)
|
||||
d.pop("processed", None)
|
||||
|
||||
if "origin_server_ts" not in d:
|
||||
d["origin_server_ts"] = d.pop("ts", 0)
|
||||
else:
|
||||
d.pop("ts", 0)
|
||||
|
||||
d.pop("prev_state", None)
|
||||
d.update(json.loads(d.pop("unrecognized_keys")))
|
||||
|
||||
d["sender"] = d.pop("user_id")
|
||||
|
||||
d["content"] = json.loads(d["content"])
|
||||
|
||||
if "age_ts" not in d:
|
||||
# For compatibility
|
||||
d["age_ts"] = d.get("origin_server_ts", 0)
|
||||
|
||||
d.setdefault("unsigned", {})["age_ts"] = d.pop("age_ts")
|
||||
|
||||
outlier = d.pop("outlier", False)
|
||||
|
||||
# d.pop("membership", None)
|
||||
|
||||
d.pop("state_hash", None)
|
||||
|
||||
d.pop("replaces_state", None)
|
||||
|
||||
b = EventBuilder(d)
|
||||
b.internal_metadata.outlier = outlier
|
||||
|
||||
events.append(b)
|
||||
|
||||
for i, ev in enumerate(events):
|
||||
signatures = self._get_event_signatures_txn(
|
||||
txn, ev.event_id,
|
||||
)
|
||||
|
||||
ev.signatures = {
|
||||
n: {
|
||||
k: encode_base64(v) for k, v in s.items()
|
||||
}
|
||||
for n, s in signatures.items()
|
||||
}
|
||||
|
||||
hashes = self._get_event_content_hashes_txn(
|
||||
txn, ev.event_id,
|
||||
)
|
||||
|
||||
ev.hashes = {
|
||||
k: encode_base64(v) for k, v in hashes.items()
|
||||
}
|
||||
|
||||
prevs = self._get_prev_events_and_state(txn, ev.event_id)
|
||||
|
||||
ev.prev_events = [
|
||||
(e_id, h)
|
||||
for e_id, h, is_state in prevs
|
||||
if is_state == 0
|
||||
]
|
||||
|
||||
# ev.auth_events = self._get_auth_events(txn, ev.event_id)
|
||||
|
||||
hashes = dict(ev.auth_events)
|
||||
|
||||
for e_id, hash in ev.prev_events:
|
||||
if e_id in hashes and not hash:
|
||||
hash.update(hashes[e_id])
|
||||
#
|
||||
# if hasattr(ev, "state_key"):
|
||||
# ev.prev_state = [
|
||||
# (e_id, h)
|
||||
# for e_id, h, is_state in prevs
|
||||
# if is_state == 1
|
||||
# ]
|
||||
|
||||
return [e.build() for e in events]
|
||||
|
||||
|
||||
store = Store()
|
||||
|
||||
|
||||
# def get_key(server_name):
|
||||
# print "Getting keys for: %s" % (server_name,)
|
||||
# targets = []
|
||||
# if ":" in server_name:
|
||||
# target, port = server_name.split(":")
|
||||
# targets.append((target, int(port)))
|
||||
# try:
|
||||
# answers = dns.resolver.query("_matrix._tcp." + server_name, "SRV")
|
||||
# for srv in answers:
|
||||
# targets.append((srv.target, srv.port))
|
||||
# except dns.resolver.NXDOMAIN:
|
||||
# targets.append((server_name, 8448))
|
||||
# except:
|
||||
# print "Failed to lookup keys for %s" % (server_name,)
|
||||
# return {}
|
||||
#
|
||||
# for target, port in targets:
|
||||
# url = "https://%s:%i/_matrix/key/v1" % (target, port)
|
||||
# try:
|
||||
# keys = json.load(urllib2.urlopen(url, timeout=2))
|
||||
# verify_keys = {}
|
||||
# for key_id, key_base64 in keys["verify_keys"].items():
|
||||
# verify_key = decode_verify_key_bytes(
|
||||
# key_id, decode_base64(key_base64)
|
||||
# )
|
||||
# verify_signed_json(keys, server_name, verify_key)
|
||||
# verify_keys[key_id] = verify_key
|
||||
# print "Got keys for: %s" % (server_name,)
|
||||
# return verify_keys
|
||||
# except urllib2.URLError:
|
||||
# pass
|
||||
# except urllib2.HTTPError:
|
||||
# pass
|
||||
# except httplib.HTTPException:
|
||||
# pass
|
||||
#
|
||||
# print "Failed to get keys for %s" % (server_name,)
|
||||
# return {}
|
||||
|
||||
|
||||
def reinsert_events(cursor, server_name, signing_key):
|
||||
print "Running delta: v10"
|
||||
|
||||
cursor.executescript(delta_sql)
|
||||
|
||||
cursor.execute(
|
||||
"SELECT * FROM events ORDER BY rowid ASC"
|
||||
)
|
||||
|
||||
print "Getting events..."
|
||||
|
||||
rows = store.cursor_to_dict(cursor)
|
||||
|
||||
events = store._generate_event_json(cursor, rows)
|
||||
|
||||
print "Got events from DB."
|
||||
|
||||
algorithms = {
|
||||
"sha256": hashlib.sha256,
|
||||
}
|
||||
|
||||
key_id = "%s:%s" % (signing_key.alg, signing_key.version)
|
||||
verify_key = signing_key.verify_key
|
||||
verify_key.alg = signing_key.alg
|
||||
verify_key.version = signing_key.version
|
||||
|
||||
server_keys = {
|
||||
server_name: {
|
||||
key_id: verify_key
|
||||
}
|
||||
}
|
||||
|
||||
i = 0
|
||||
N = len(events)
|
||||
|
||||
for event in events:
|
||||
if i % 100 == 0:
|
||||
print "Processed: %d/%d events" % (i,N,)
|
||||
i += 1
|
||||
|
||||
# for alg_name in event.hashes:
|
||||
# if check_event_content_hash(event, algorithms[alg_name]):
|
||||
# pass
|
||||
# else:
|
||||
# pass
|
||||
# print "FAIL content hash %s %s" % (alg_name, event.event_id, )
|
||||
|
||||
have_own_correctly_signed = False
|
||||
for host, sigs in event.signatures.items():
|
||||
pruned = prune_event(event)
|
||||
|
||||
for key_id in sigs:
|
||||
if host not in server_keys:
|
||||
server_keys[host] = {} # get_key(host)
|
||||
if key_id in server_keys[host]:
|
||||
try:
|
||||
verify_signed_json(
|
||||
pruned.get_pdu_json(),
|
||||
host,
|
||||
server_keys[host][key_id]
|
||||
)
|
||||
|
||||
if host == server_name:
|
||||
have_own_correctly_signed = True
|
||||
except SignatureVerifyException:
|
||||
print "FAIL signature check %s %s" % (
|
||||
key_id, event.event_id
|
||||
)
|
||||
|
||||
# TODO: Re sign with our own server key
|
||||
if not have_own_correctly_signed:
|
||||
sigs = compute_event_signature(event, server_name, signing_key)
|
||||
event.signatures.update(sigs)
|
||||
|
||||
pruned = prune_event(event)
|
||||
|
||||
for key_id in event.signatures[server_name]:
|
||||
verify_signed_json(
|
||||
pruned.get_pdu_json(),
|
||||
server_name,
|
||||
server_keys[server_name][key_id]
|
||||
)
|
||||
|
||||
event_json = encode_canonical_json(
|
||||
event.get_dict()
|
||||
).decode("UTF-8")
|
||||
|
||||
metadata_json = encode_canonical_json(
|
||||
event.internal_metadata.get_dict()
|
||||
).decode("UTF-8")
|
||||
|
||||
store._simple_insert_txn(
|
||||
cursor,
|
||||
table="event_json",
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"internal_metadata": metadata_json,
|
||||
"json": event_json,
|
||||
},
|
||||
or_replace=True,
|
||||
)
|
||||
|
||||
|
||||
def main(database, server_name, signing_key):
|
||||
conn = sqlite3.connect(database)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Do other deltas:
|
||||
cursor.execute("PRAGMA user_version")
|
||||
row = cursor.fetchone()
|
||||
|
||||
if row and row[0]:
|
||||
user_version = row[0]
|
||||
# Run every version since after the current version.
|
||||
for v in range(user_version + 1, 10):
|
||||
print "Running delta: %d" % (v,)
|
||||
sql_script = read_schema("delta/v%d" % (v,))
|
||||
cursor.executescript(sql_script)
|
||||
|
||||
reinsert_events(cursor, server_name, signing_key)
|
||||
|
||||
conn.commit()
|
||||
|
||||
print "Success!"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("database")
|
||||
parser.add_argument("server_name")
|
||||
parser.add_argument(
|
||||
"signing_key", type=argparse.FileType('r'),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
signing_key = syutil.crypto.signing_key.read_signing_keys(
|
||||
args.signing_key
|
||||
)
|
||||
|
||||
main(args.database, args.server_name, signing_key[0])
|
8
setup.py
8
setup.py
@ -32,7 +32,7 @@ setup(
|
||||
description="Reference Synapse Home Server",
|
||||
install_requires=[
|
||||
"syutil==0.0.2",
|
||||
"matrix_angular_sdk==0.5.1",
|
||||
"matrix_angular_sdk==0.6.0",
|
||||
"Twisted>=14.0.0",
|
||||
"service_identity>=1.0.0",
|
||||
"pyopenssl>=0.14",
|
||||
@ -41,11 +41,13 @@ setup(
|
||||
"pynacl",
|
||||
"daemonize",
|
||||
"py-bcrypt",
|
||||
"frozendict>=0.4",
|
||||
"pillow",
|
||||
],
|
||||
dependency_links=[
|
||||
"https://github.com/matrix-org/syutil/tarball/v0.0.2#egg=syutil-0.0.2",
|
||||
"https://github.com/pyca/pynacl/tarball/d4d3175589b892f6ea7c22f466e0e223853516fa#egg=pynacl-0.3.0",
|
||||
"https://github.com/matrix-org/matrix-angular-sdk/tarball/v0.5.1/#egg=matrix_angular_sdk-0.5.1",
|
||||
"https://github.com/matrix-org/matrix-angular-sdk/tarball/v0.6.0/#egg=matrix_angular_sdk-0.6.0",
|
||||
],
|
||||
setup_requires=[
|
||||
"setuptools_trial",
|
||||
@ -59,6 +61,6 @@ setup(
|
||||
entry_points="""
|
||||
[console_scripts]
|
||||
synctl=synapse.app.synctl:main
|
||||
synapse-homeserver=synapse.app.homeserver:run
|
||||
synapse-homeserver=synapse.app.homeserver:main
|
||||
"""
|
||||
)
|
||||
|
@ -16,4 +16,4 @@
|
||||
""" This is a reference implementation of a synapse home server.
|
||||
"""
|
||||
|
||||
__version__ = "0.5.4a"
|
||||
__version__ = "0.6.0"
|
||||
|
@ -17,14 +17,10 @@
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import Membership, JoinRules
|
||||
from synapse.api.constants import EventTypes, Membership, JoinRules
|
||||
from synapse.api.errors import AuthError, StoreError, Codes, SynapseError
|
||||
from synapse.api.events.room import (
|
||||
RoomMemberEvent, RoomPowerLevelsEvent, RoomRedactionEvent,
|
||||
RoomJoinRulesEvent, RoomCreateEvent, RoomAliasesEvent,
|
||||
)
|
||||
from synapse.util.logutils import log_function
|
||||
from syutil.base64util import encode_base64
|
||||
from synapse.util.async import run_on_reactor
|
||||
|
||||
import logging
|
||||
|
||||
@ -53,15 +49,17 @@ class Auth(object):
|
||||
logger.warn("Trusting event: %s", event.event_id)
|
||||
return True
|
||||
|
||||
if event.type == RoomCreateEvent.TYPE:
|
||||
if event.type == EventTypes.Create:
|
||||
# FIXME
|
||||
return True
|
||||
|
||||
# FIXME: Temp hack
|
||||
if event.type == RoomAliasesEvent.TYPE:
|
||||
if event.type == EventTypes.Aliases:
|
||||
return True
|
||||
|
||||
if event.type == RoomMemberEvent.TYPE:
|
||||
logger.debug("Auth events: %s", auth_events)
|
||||
|
||||
if event.type == EventTypes.Member:
|
||||
allowed = self.is_membership_change_allowed(
|
||||
event, auth_events
|
||||
)
|
||||
@ -74,10 +72,10 @@ class Auth(object):
|
||||
self.check_event_sender_in_room(event, auth_events)
|
||||
self._can_send_event(event, auth_events)
|
||||
|
||||
if event.type == RoomPowerLevelsEvent.TYPE:
|
||||
if event.type == EventTypes.PowerLevels:
|
||||
self._check_power_levels(event, auth_events)
|
||||
|
||||
if event.type == RoomRedactionEvent.TYPE:
|
||||
if event.type == EventTypes.Redaction:
|
||||
self._check_redaction(event, auth_events)
|
||||
|
||||
logger.debug("Allowing! %s", event)
|
||||
@ -93,7 +91,7 @@ class Auth(object):
|
||||
def check_joined_room(self, room_id, user_id):
|
||||
member = yield self.state.get_current_state(
|
||||
room_id=room_id,
|
||||
event_type=RoomMemberEvent.TYPE,
|
||||
event_type=EventTypes.Member,
|
||||
state_key=user_id
|
||||
)
|
||||
self._check_joined_room(member, user_id, room_id)
|
||||
@ -104,7 +102,7 @@ class Auth(object):
|
||||
curr_state = yield self.state.get_current_state(room_id)
|
||||
|
||||
for event in curr_state:
|
||||
if event.type == RoomMemberEvent.TYPE:
|
||||
if event.type == EventTypes.Member:
|
||||
try:
|
||||
if self.hs.parse_userid(event.state_key).domain != host:
|
||||
continue
|
||||
@ -118,7 +116,7 @@ class Auth(object):
|
||||
defer.returnValue(False)
|
||||
|
||||
def check_event_sender_in_room(self, event, auth_events):
|
||||
key = (RoomMemberEvent.TYPE, event.user_id, )
|
||||
key = (EventTypes.Member, event.user_id, )
|
||||
member_event = auth_events.get(key)
|
||||
|
||||
return self._check_joined_room(
|
||||
@ -140,7 +138,7 @@ class Auth(object):
|
||||
# Check if this is the room creator joining:
|
||||
if len(event.prev_events) == 1 and Membership.JOIN == membership:
|
||||
# Get room creation event:
|
||||
key = (RoomCreateEvent.TYPE, "", )
|
||||
key = (EventTypes.Create, "", )
|
||||
create = auth_events.get(key)
|
||||
if create and event.prev_events[0][0] == create.event_id:
|
||||
if create.content["creator"] == event.state_key:
|
||||
@ -149,19 +147,19 @@ class Auth(object):
|
||||
target_user_id = event.state_key
|
||||
|
||||
# get info about the caller
|
||||
key = (RoomMemberEvent.TYPE, event.user_id, )
|
||||
key = (EventTypes.Member, event.user_id, )
|
||||
caller = auth_events.get(key)
|
||||
|
||||
caller_in_room = caller and caller.membership == Membership.JOIN
|
||||
caller_invited = caller and caller.membership == Membership.INVITE
|
||||
|
||||
# get info about the target
|
||||
key = (RoomMemberEvent.TYPE, target_user_id, )
|
||||
key = (EventTypes.Member, target_user_id, )
|
||||
target = auth_events.get(key)
|
||||
|
||||
target_in_room = target and target.membership == Membership.JOIN
|
||||
|
||||
key = (RoomJoinRulesEvent.TYPE, "", )
|
||||
key = (EventTypes.JoinRules, "", )
|
||||
join_rule_event = auth_events.get(key)
|
||||
if join_rule_event:
|
||||
join_rule = join_rule_event.content.get(
|
||||
@ -256,7 +254,7 @@ class Auth(object):
|
||||
return True
|
||||
|
||||
def _get_power_level_from_event_state(self, event, user_id, auth_events):
|
||||
key = (RoomPowerLevelsEvent.TYPE, "", )
|
||||
key = (EventTypes.PowerLevels, "", )
|
||||
power_level_event = auth_events.get(key)
|
||||
level = None
|
||||
if power_level_event:
|
||||
@ -264,7 +262,7 @@ class Auth(object):
|
||||
if not level:
|
||||
level = power_level_event.content.get("users_default", 0)
|
||||
else:
|
||||
key = (RoomCreateEvent.TYPE, "", )
|
||||
key = (EventTypes.Create, "", )
|
||||
create_event = auth_events.get(key)
|
||||
if (create_event is not None and
|
||||
create_event.content["creator"] == user_id):
|
||||
@ -273,7 +271,7 @@ class Auth(object):
|
||||
return level
|
||||
|
||||
def _get_ops_level_from_event_state(self, event, auth_events):
|
||||
key = (RoomPowerLevelsEvent.TYPE, "", )
|
||||
key = (EventTypes.PowerLevels, "", )
|
||||
power_level_event = auth_events.get(key)
|
||||
|
||||
if power_level_event:
|
||||
@ -351,29 +349,31 @@ class Auth(object):
|
||||
return self.store.is_server_admin(user)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_auth_events(self, event):
|
||||
if event.type == RoomCreateEvent.TYPE:
|
||||
event.auth_events = []
|
||||
def add_auth_events(self, builder, context):
|
||||
yield run_on_reactor()
|
||||
|
||||
if builder.type == EventTypes.Create:
|
||||
builder.auth_events = []
|
||||
return
|
||||
|
||||
auth_events = []
|
||||
auth_ids = []
|
||||
|
||||
key = (RoomPowerLevelsEvent.TYPE, "", )
|
||||
power_level_event = event.old_state_events.get(key)
|
||||
key = (EventTypes.PowerLevels, "", )
|
||||
power_level_event = context.current_state.get(key)
|
||||
|
||||
if power_level_event:
|
||||
auth_events.append(power_level_event.event_id)
|
||||
auth_ids.append(power_level_event.event_id)
|
||||
|
||||
key = (RoomJoinRulesEvent.TYPE, "", )
|
||||
join_rule_event = event.old_state_events.get(key)
|
||||
key = (EventTypes.JoinRules, "", )
|
||||
join_rule_event = context.current_state.get(key)
|
||||
|
||||
key = (RoomMemberEvent.TYPE, event.user_id, )
|
||||
member_event = event.old_state_events.get(key)
|
||||
key = (EventTypes.Member, builder.user_id, )
|
||||
member_event = context.current_state.get(key)
|
||||
|
||||
key = (RoomCreateEvent.TYPE, "", )
|
||||
create_event = event.old_state_events.get(key)
|
||||
key = (EventTypes.Create, "", )
|
||||
create_event = context.current_state.get(key)
|
||||
if create_event:
|
||||
auth_events.append(create_event.event_id)
|
||||
auth_ids.append(create_event.event_id)
|
||||
|
||||
if join_rule_event:
|
||||
join_rule = join_rule_event.content.get("join_rule")
|
||||
@ -381,33 +381,37 @@ class Auth(object):
|
||||
else:
|
||||
is_public = False
|
||||
|
||||
if event.type == RoomMemberEvent.TYPE:
|
||||
e_type = event.content["membership"]
|
||||
if builder.type == EventTypes.Member:
|
||||
e_type = builder.content["membership"]
|
||||
if e_type in [Membership.JOIN, Membership.INVITE]:
|
||||
if join_rule_event:
|
||||
auth_events.append(join_rule_event.event_id)
|
||||
auth_ids.append(join_rule_event.event_id)
|
||||
|
||||
if e_type == Membership.JOIN:
|
||||
if member_event and not is_public:
|
||||
auth_events.append(member_event.event_id)
|
||||
auth_ids.append(member_event.event_id)
|
||||
else:
|
||||
if member_event:
|
||||
auth_ids.append(member_event.event_id)
|
||||
elif member_event:
|
||||
if member_event.content["membership"] == Membership.JOIN:
|
||||
auth_events.append(member_event.event_id)
|
||||
auth_ids.append(member_event.event_id)
|
||||
|
||||
hashes = yield self.store.get_event_reference_hashes(
|
||||
auth_events
|
||||
auth_events_entries = yield self.store.add_event_hashes(
|
||||
auth_ids
|
||||
)
|
||||
hashes = [
|
||||
{
|
||||
k: encode_base64(v) for k, v in h.items()
|
||||
if k == "sha256"
|
||||
}
|
||||
for h in hashes
|
||||
]
|
||||
event.auth_events = zip(auth_events, hashes)
|
||||
|
||||
builder.auth_events = auth_events_entries
|
||||
|
||||
context.auth_events = {
|
||||
k: v
|
||||
for k, v in context.current_state.items()
|
||||
if v.event_id in auth_ids
|
||||
}
|
||||
|
||||
@log_function
|
||||
def _can_send_event(self, event, auth_events):
|
||||
key = (RoomPowerLevelsEvent.TYPE, "", )
|
||||
key = (EventTypes.PowerLevels, "", )
|
||||
send_level_event = auth_events.get(key)
|
||||
send_level = None
|
||||
if send_level_event:
|
||||
|
@ -59,3 +59,18 @@ class LoginType(object):
|
||||
EMAIL_URL = u"m.login.email.url"
|
||||
EMAIL_IDENTITY = u"m.login.email.identity"
|
||||
RECAPTCHA = u"m.login.recaptcha"
|
||||
|
||||
|
||||
class EventTypes(object):
|
||||
Member = "m.room.member"
|
||||
Create = "m.room.create"
|
||||
JoinRules = "m.room.join_rules"
|
||||
PowerLevels = "m.room.power_levels"
|
||||
Aliases = "m.room.aliases"
|
||||
Redaction = "m.room.redaction"
|
||||
Feedback = "m.room.message.feedback"
|
||||
|
||||
# These are used for validation
|
||||
Message = "m.room.message"
|
||||
Topic = "m.room.topic"
|
||||
Name = "m.room.name"
|
||||
|
@ -34,6 +34,7 @@ class Codes(object):
|
||||
LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED"
|
||||
CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED"
|
||||
CAPTCHA_INVALID = "M_CAPTCHA_INVALID"
|
||||
TOO_LARGE = "M_TOO_LARGE"
|
||||
|
||||
|
||||
class CodeMessageException(Exception):
|
||||
|
@ -1,148 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.util.jsonobject import JsonEncodedObject
|
||||
|
||||
|
||||
def serialize_event(hs, e):
|
||||
# FIXME(erikj): To handle the case of presence events and the like
|
||||
if not isinstance(e, SynapseEvent):
|
||||
return e
|
||||
|
||||
# Should this strip out None's?
|
||||
d = {k: v for k, v in e.get_dict().items()}
|
||||
if "age_ts" in d:
|
||||
d["age"] = int(hs.get_clock().time_msec()) - d["age_ts"]
|
||||
del d["age_ts"]
|
||||
|
||||
return d
|
||||
|
||||
|
||||
class SynapseEvent(JsonEncodedObject):
|
||||
|
||||
"""Base class for Synapse events. These are JSON objects which must abide
|
||||
by a certain well-defined structure.
|
||||
"""
|
||||
|
||||
# Attributes that are currently assumed by the federation side:
|
||||
# Mandatory:
|
||||
# - event_id
|
||||
# - room_id
|
||||
# - type
|
||||
# - is_state
|
||||
#
|
||||
# Optional:
|
||||
# - state_key (mandatory when is_state is True)
|
||||
# - prev_events (these can be filled out by the federation layer itself.)
|
||||
# - prev_state
|
||||
|
||||
valid_keys = [
|
||||
"event_id",
|
||||
"type",
|
||||
"room_id",
|
||||
"user_id", # sender/initiator
|
||||
"content", # HTTP body, JSON
|
||||
"state_key",
|
||||
"age_ts",
|
||||
"prev_content",
|
||||
"replaces_state",
|
||||
"redacted_because",
|
||||
"origin_server_ts",
|
||||
]
|
||||
|
||||
internal_keys = [
|
||||
"is_state",
|
||||
"depth",
|
||||
"destinations",
|
||||
"origin",
|
||||
"outlier",
|
||||
"redacted",
|
||||
"prev_events",
|
||||
"hashes",
|
||||
"signatures",
|
||||
"prev_state",
|
||||
"auth_events",
|
||||
"state_hash",
|
||||
]
|
||||
|
||||
required_keys = [
|
||||
"event_id",
|
||||
"room_id",
|
||||
"content",
|
||||
]
|
||||
|
||||
outlier = False
|
||||
|
||||
def __init__(self, raises=True, **kwargs):
|
||||
super(SynapseEvent, self).__init__(**kwargs)
|
||||
# if "content" in kwargs:
|
||||
# self.check_json(self.content, raises=raises)
|
||||
|
||||
def get_content_template(self):
|
||||
""" Retrieve the JSON template for this event as a dict.
|
||||
|
||||
The template must be a dict representing the JSON to match. Only
|
||||
required keys should be present. The values of the keys in the template
|
||||
are checked via type() to the values of the same keys in the actual
|
||||
event JSON.
|
||||
|
||||
NB: If loading content via json.loads, you MUST define strings as
|
||||
unicode.
|
||||
|
||||
For example:
|
||||
Content:
|
||||
{
|
||||
"name": u"bob",
|
||||
"age": 18,
|
||||
"friends": [u"mike", u"jill"]
|
||||
}
|
||||
Template:
|
||||
{
|
||||
"name": u"string",
|
||||
"age": 0,
|
||||
"friends": [u"string"]
|
||||
}
|
||||
The values "string" and 0 could be anything, so long as the types
|
||||
are the same as the content.
|
||||
"""
|
||||
raise NotImplementedError("get_content_template not implemented.")
|
||||
|
||||
def get_pdu_json(self, time_now=None):
|
||||
pdu_json = self.get_full_dict()
|
||||
pdu_json.pop("destinations", None)
|
||||
pdu_json.pop("outlier", None)
|
||||
pdu_json.pop("replaces_state", None)
|
||||
pdu_json.pop("redacted", None)
|
||||
pdu_json.pop("prev_content", None)
|
||||
state_hash = pdu_json.pop("state_hash", None)
|
||||
if state_hash is not None:
|
||||
pdu_json.setdefault("unsigned", {})["state_hash"] = state_hash
|
||||
content = pdu_json.get("content", {})
|
||||
content.pop("prev", None)
|
||||
if time_now is not None and "age_ts" in pdu_json:
|
||||
age = time_now - pdu_json["age_ts"]
|
||||
pdu_json.setdefault("unsigned", {})["age"] = int(age)
|
||||
del pdu_json["age_ts"]
|
||||
user_id = pdu_json.pop("user_id")
|
||||
pdu_json["sender"] = user_id
|
||||
return pdu_json
|
||||
|
||||
|
||||
class SynapseStateEvent(SynapseEvent):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
if "state_key" not in kwargs:
|
||||
kwargs["state_key"] = ""
|
||||
super(SynapseStateEvent, self).__init__(**kwargs)
|
@ -1,90 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.api.events.room import (
|
||||
RoomTopicEvent, MessageEvent, RoomMemberEvent, FeedbackEvent,
|
||||
InviteJoinEvent, RoomConfigEvent, RoomNameEvent, GenericEvent,
|
||||
RoomPowerLevelsEvent, RoomJoinRulesEvent,
|
||||
RoomCreateEvent,
|
||||
RoomRedactionEvent,
|
||||
)
|
||||
|
||||
from synapse.types import EventID
|
||||
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
|
||||
class EventFactory(object):
|
||||
|
||||
_event_classes = [
|
||||
RoomTopicEvent,
|
||||
RoomNameEvent,
|
||||
MessageEvent,
|
||||
RoomMemberEvent,
|
||||
FeedbackEvent,
|
||||
InviteJoinEvent,
|
||||
RoomConfigEvent,
|
||||
RoomPowerLevelsEvent,
|
||||
RoomJoinRulesEvent,
|
||||
RoomCreateEvent,
|
||||
RoomRedactionEvent,
|
||||
]
|
||||
|
||||
def __init__(self, hs):
|
||||
self._event_list = {} # dict of TYPE to event class
|
||||
for event_class in EventFactory._event_classes:
|
||||
self._event_list[event_class.TYPE] = event_class
|
||||
|
||||
self.clock = hs.get_clock()
|
||||
self.hs = hs
|
||||
|
||||
self.event_id_count = 0
|
||||
|
||||
def create_event_id(self):
|
||||
i = str(self.event_id_count)
|
||||
self.event_id_count += 1
|
||||
|
||||
local_part = str(int(self.clock.time())) + i + random_string(5)
|
||||
|
||||
e_id = EventID.create_local(local_part, self.hs)
|
||||
|
||||
return e_id.to_string()
|
||||
|
||||
def create_event(self, etype=None, **kwargs):
|
||||
kwargs["type"] = etype
|
||||
if "event_id" not in kwargs:
|
||||
kwargs["event_id"] = self.create_event_id()
|
||||
kwargs["origin"] = self.hs.hostname
|
||||
else:
|
||||
ev_id = self.hs.parse_eventid(kwargs["event_id"])
|
||||
kwargs["origin"] = ev_id.domain
|
||||
|
||||
if "origin_server_ts" not in kwargs:
|
||||
kwargs["origin_server_ts"] = int(self.clock.time_msec())
|
||||
|
||||
# The "age" key is a delta timestamp that should be converted into an
|
||||
# absolute timestamp the minute we see it.
|
||||
if "age" in kwargs:
|
||||
kwargs["age_ts"] = int(self.clock.time_msec()) - int(kwargs["age"])
|
||||
del kwargs["age"]
|
||||
elif "age_ts" not in kwargs:
|
||||
kwargs["age_ts"] = int(self.clock.time_msec())
|
||||
|
||||
if etype in self._event_list:
|
||||
handler = self._event_list[etype]
|
||||
else:
|
||||
handler = GenericEvent
|
||||
|
||||
return handler(**kwargs)
|
@ -1,170 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.api.constants import Feedback, Membership
|
||||
from synapse.api.errors import SynapseError
|
||||
from . import SynapseEvent, SynapseStateEvent
|
||||
|
||||
|
||||
class GenericEvent(SynapseEvent):
|
||||
def get_content_template(self):
|
||||
return {}
|
||||
|
||||
|
||||
class RoomTopicEvent(SynapseEvent):
|
||||
TYPE = "m.room.topic"
|
||||
|
||||
internal_keys = SynapseEvent.internal_keys + [
|
||||
"topic",
|
||||
]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["state_key"] = ""
|
||||
if "topic" in kwargs["content"]:
|
||||
kwargs["topic"] = kwargs["content"]["topic"]
|
||||
super(RoomTopicEvent, self).__init__(**kwargs)
|
||||
|
||||
def get_content_template(self):
|
||||
return {"topic": u"string"}
|
||||
|
||||
|
||||
class RoomNameEvent(SynapseEvent):
|
||||
TYPE = "m.room.name"
|
||||
|
||||
internal_keys = SynapseEvent.internal_keys + [
|
||||
"name",
|
||||
]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["state_key"] = ""
|
||||
if "name" in kwargs["content"]:
|
||||
kwargs["name"] = kwargs["content"]["name"]
|
||||
super(RoomNameEvent, self).__init__(**kwargs)
|
||||
|
||||
def get_content_template(self):
|
||||
return {"name": u"string"}
|
||||
|
||||
|
||||
class RoomMemberEvent(SynapseEvent):
|
||||
TYPE = "m.room.member"
|
||||
|
||||
valid_keys = SynapseEvent.valid_keys + [
|
||||
# target is the state_key
|
||||
"membership", # action
|
||||
]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
if "membership" not in kwargs:
|
||||
kwargs["membership"] = kwargs.get("content", {}).get("membership")
|
||||
if not kwargs["membership"] in Membership.LIST:
|
||||
raise SynapseError(400, "Bad membership value.")
|
||||
super(RoomMemberEvent, self).__init__(**kwargs)
|
||||
|
||||
def get_content_template(self):
|
||||
return {"membership": u"string"}
|
||||
|
||||
|
||||
class MessageEvent(SynapseEvent):
|
||||
TYPE = "m.room.message"
|
||||
|
||||
valid_keys = SynapseEvent.valid_keys + [
|
||||
"msg_id", # unique per room + user combo
|
||||
]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(MessageEvent, self).__init__(**kwargs)
|
||||
|
||||
def get_content_template(self):
|
||||
return {"msgtype": u"string"}
|
||||
|
||||
|
||||
class FeedbackEvent(SynapseEvent):
|
||||
TYPE = "m.room.message.feedback"
|
||||
|
||||
valid_keys = SynapseEvent.valid_keys
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(FeedbackEvent, self).__init__(**kwargs)
|
||||
if not kwargs["content"]["type"] in Feedback.LIST:
|
||||
raise SynapseError(400, "Bad feedback value.")
|
||||
|
||||
def get_content_template(self):
|
||||
return {
|
||||
"type": u"string",
|
||||
"target_event_id": u"string"
|
||||
}
|
||||
|
||||
|
||||
class InviteJoinEvent(SynapseEvent):
|
||||
TYPE = "m.room.invite_join"
|
||||
|
||||
valid_keys = SynapseEvent.valid_keys + [
|
||||
# target_user_id is the state_key
|
||||
"target_host",
|
||||
]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(InviteJoinEvent, self).__init__(**kwargs)
|
||||
|
||||
def get_content_template(self):
|
||||
return {}
|
||||
|
||||
|
||||
class RoomConfigEvent(SynapseEvent):
|
||||
TYPE = "m.room.config"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["state_key"] = ""
|
||||
super(RoomConfigEvent, self).__init__(**kwargs)
|
||||
|
||||
def get_content_template(self):
|
||||
return {}
|
||||
|
||||
|
||||
class RoomCreateEvent(SynapseStateEvent):
|
||||
TYPE = "m.room.create"
|
||||
|
||||
def get_content_template(self):
|
||||
return {}
|
||||
|
||||
|
||||
class RoomJoinRulesEvent(SynapseStateEvent):
|
||||
TYPE = "m.room.join_rules"
|
||||
|
||||
def get_content_template(self):
|
||||
return {}
|
||||
|
||||
|
||||
class RoomPowerLevelsEvent(SynapseStateEvent):
|
||||
TYPE = "m.room.power_levels"
|
||||
|
||||
def get_content_template(self):
|
||||
return {}
|
||||
|
||||
|
||||
class RoomAliasesEvent(SynapseStateEvent):
|
||||
TYPE = "m.room.aliases"
|
||||
|
||||
def get_content_template(self):
|
||||
return {}
|
||||
|
||||
|
||||
class RoomRedactionEvent(SynapseEvent):
|
||||
TYPE = "m.room.redaction"
|
||||
|
||||
valid_keys = SynapseEvent.valid_keys + ["redacts"]
|
||||
|
||||
def get_content_template(self):
|
||||
return {}
|
@ -1,87 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.api.errors import SynapseError, Codes
|
||||
|
||||
|
||||
class EventValidator(object):
|
||||
def __init__(self, hs):
|
||||
pass
|
||||
|
||||
def validate(self, event):
|
||||
"""Checks the given JSON content abides by the rules of the template.
|
||||
|
||||
Args:
|
||||
content : A JSON object to check.
|
||||
raises: True to raise a SynapseError if the check fails.
|
||||
Returns:
|
||||
True if the content passes the template. Returns False if the check
|
||||
fails and raises=False.
|
||||
Raises:
|
||||
SynapseError if the check fails and raises=True.
|
||||
"""
|
||||
# recursively call to inspect each layer
|
||||
err_msg = self._check_json_template(
|
||||
event.content,
|
||||
event.get_content_template()
|
||||
)
|
||||
if err_msg:
|
||||
raise SynapseError(400, err_msg, Codes.BAD_JSON)
|
||||
else:
|
||||
return True
|
||||
|
||||
def _check_json_template(self, content, template):
|
||||
"""Check content and template matches.
|
||||
|
||||
If the template is a dict, each key in the dict will be validated with
|
||||
the content, else it will just compare the types of content and
|
||||
template. This basic type check is required because this function will
|
||||
be recursively called and could be called with just strs or ints.
|
||||
|
||||
Args:
|
||||
content: The content to validate.
|
||||
template: The validation template.
|
||||
Returns:
|
||||
str: An error message if the validation fails, else None.
|
||||
"""
|
||||
if type(content) != type(template):
|
||||
return "Mismatched types: %s" % template
|
||||
|
||||
if type(template) == dict:
|
||||
for key in template:
|
||||
if key not in content:
|
||||
return "Missing %s key" % key
|
||||
|
||||
if type(content[key]) != type(template[key]):
|
||||
return "Key %s is of the wrong type (got %s, want %s)" % (
|
||||
key, type(content[key]), type(template[key]))
|
||||
|
||||
if type(content[key]) == dict:
|
||||
# we must go deeper
|
||||
msg = self._check_json_template(
|
||||
content[key],
|
||||
template[key]
|
||||
)
|
||||
if msg:
|
||||
return msg
|
||||
elif type(content[key]) == list:
|
||||
# make sure each item type in content matches the template
|
||||
for entry in content[key]:
|
||||
msg = self._check_json_template(
|
||||
entry,
|
||||
template[key][0]
|
||||
)
|
||||
if msg:
|
||||
return msg
|
@ -20,3 +20,4 @@ FEDERATION_PREFIX = "/_matrix/federation/v1"
|
||||
WEB_CLIENT_PREFIX = "/_matrix/client"
|
||||
CONTENT_REPO_PREFIX = "/_matrix/content"
|
||||
SERVER_KEY_PREFIX = "/_matrix/key/v1"
|
||||
MEDIA_PREFIX = "/_matrix/media/v1"
|
||||
|
@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.storage import prepare_database
|
||||
from synapse.storage import prepare_database, UpgradeDatabaseException
|
||||
|
||||
from synapse.server import HomeServer
|
||||
|
||||
@ -24,12 +24,13 @@ from twisted.web.resource import Resource
|
||||
from twisted.web.static import File
|
||||
from twisted.web.server import Site
|
||||
from synapse.http.server import JsonResource, RootRedirect
|
||||
from synapse.http.content_repository import ContentRepoResource
|
||||
from synapse.media.v0.content_repository import ContentRepoResource
|
||||
from synapse.media.v1.media_repository import MediaRepositoryResource
|
||||
from synapse.http.server_key_resource import LocalKey
|
||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||
from synapse.api.urls import (
|
||||
CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
|
||||
SERVER_KEY_PREFIX,
|
||||
SERVER_KEY_PREFIX, MEDIA_PREFIX
|
||||
)
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.crypto import context_factory
|
||||
@ -69,6 +70,9 @@ class SynapseHomeServer(HomeServer):
|
||||
self, self.upload_dir, self.auth, self.content_addr
|
||||
)
|
||||
|
||||
def build_resource_for_media_repository(self):
|
||||
return MediaRepositoryResource(self)
|
||||
|
||||
def build_resource_for_server_key(self):
|
||||
return LocalKey(self)
|
||||
|
||||
@ -99,6 +103,7 @@ class SynapseHomeServer(HomeServer):
|
||||
(FEDERATION_PREFIX, self.get_resource_for_federation()),
|
||||
(CONTENT_REPO_PREFIX, self.get_resource_for_content_repo()),
|
||||
(SERVER_KEY_PREFIX, self.get_resource_for_server_key()),
|
||||
(MEDIA_PREFIX, self.get_resource_for_media_repository()),
|
||||
]
|
||||
if web_client:
|
||||
logger.info("Adding the web client.")
|
||||
@ -223,8 +228,15 @@ def setup():
|
||||
|
||||
logger.info("Preparing database: %s...", db_name)
|
||||
|
||||
with sqlite3.connect(db_name) as db_conn:
|
||||
prepare_database(db_conn)
|
||||
try:
|
||||
with sqlite3.connect(db_name) as db_conn:
|
||||
prepare_database(db_conn)
|
||||
except UpgradeDatabaseException:
|
||||
sys.stderr.write(
|
||||
"\nFailed to upgrade database.\n"
|
||||
"Have you followed any instructions in UPGRADES.rst?\n"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
logger.info("Database prepared in %s.", db_name)
|
||||
|
||||
|
@ -50,12 +50,26 @@ class Config(object):
|
||||
)
|
||||
return cls.abspath(file_path)
|
||||
|
||||
@staticmethod
|
||||
def ensure_directory(dir_path):
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
if not os.path.isdir(dir_path):
|
||||
raise ConfigError(
|
||||
"%s is not a directory" % (dir_path,)
|
||||
)
|
||||
return dir_path
|
||||
|
||||
@classmethod
|
||||
def read_file(cls, file_path, config_name):
|
||||
cls.check_file(file_path, config_name)
|
||||
with open(file_path) as file_stream:
|
||||
return file_stream.read()
|
||||
|
||||
@staticmethod
|
||||
def default_path(name):
|
||||
return os.path.abspath(os.path.join(os.path.curdir, name))
|
||||
|
||||
@staticmethod
|
||||
def read_config_file(file_path):
|
||||
with open(file_path) as file_stream:
|
||||
|
@ -36,7 +36,7 @@ class LoggingConfig(Config):
|
||||
help="The verbosity level."
|
||||
)
|
||||
logging_group.add_argument(
|
||||
'-f', '--log-file', dest="log_file", default=None,
|
||||
'-f', '--log-file', dest="log_file", default="homeserver.log",
|
||||
help="File to log to."
|
||||
)
|
||||
logging_group.add_argument(
|
||||
|
@ -20,6 +20,8 @@ class ContentRepositoryConfig(Config):
|
||||
def __init__(self, args):
|
||||
super(ContentRepositoryConfig, self).__init__(args)
|
||||
self.max_upload_size = self.parse_size(args.max_upload_size)
|
||||
self.max_image_pixels = self.parse_size(args.max_image_pixels)
|
||||
self.media_store_path = self.ensure_directory(args.media_store_path)
|
||||
|
||||
def parse_size(self, string):
|
||||
sizes = {"K": 1024, "M": 1024 * 1024}
|
||||
@ -37,3 +39,10 @@ class ContentRepositoryConfig(Config):
|
||||
db_group.add_argument(
|
||||
"--max-upload-size", default="1M"
|
||||
)
|
||||
db_group.add_argument(
|
||||
"--media-store-path", default=cls.default_path("media_store")
|
||||
)
|
||||
db_group.add_argument(
|
||||
"--max-image-pixels", default="32M",
|
||||
help="Maximum number of pixels that will be thumbnailed"
|
||||
)
|
||||
|
@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from synapse.api.events.utils import prune_event
|
||||
from synapse.events.utils import prune_event
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
from syutil.base64util import encode_base64, decode_base64
|
||||
from syutil.crypto.jsonsign import sign_json
|
||||
@ -29,17 +29,17 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
|
||||
"""Check whether the hash for this PDU matches the contents"""
|
||||
computed_hash = _compute_content_hash(event, hash_algorithm)
|
||||
logger.debug("Expecting hash: %s", encode_base64(computed_hash.digest()))
|
||||
if computed_hash.name not in event.hashes:
|
||||
name, expected_hash = compute_content_hash(event, hash_algorithm)
|
||||
logger.debug("Expecting hash: %s", encode_base64(expected_hash))
|
||||
if name not in event.hashes:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Algorithm %s not in hashes %s" % (
|
||||
computed_hash.name, list(event.hashes),
|
||||
name, list(event.hashes),
|
||||
),
|
||||
Codes.UNAUTHORIZED,
|
||||
)
|
||||
message_hash_base64 = event.hashes[computed_hash.name]
|
||||
message_hash_base64 = event.hashes[name]
|
||||
try:
|
||||
message_hash_bytes = decode_base64(message_hash_base64)
|
||||
except:
|
||||
@ -48,10 +48,10 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
|
||||
"Invalid base64: %s" % (message_hash_base64,),
|
||||
Codes.UNAUTHORIZED,
|
||||
)
|
||||
return message_hash_bytes == computed_hash.digest()
|
||||
return message_hash_bytes == expected_hash
|
||||
|
||||
|
||||
def _compute_content_hash(event, hash_algorithm):
|
||||
def compute_content_hash(event, hash_algorithm):
|
||||
event_json = event.get_pdu_json()
|
||||
event_json.pop("age_ts", None)
|
||||
event_json.pop("unsigned", None)
|
||||
@ -59,8 +59,11 @@ def _compute_content_hash(event, hash_algorithm):
|
||||
event_json.pop("hashes", None)
|
||||
event_json.pop("outlier", None)
|
||||
event_json.pop("destinations", None)
|
||||
|
||||
event_json_bytes = encode_canonical_json(event_json)
|
||||
return hash_algorithm(event_json_bytes)
|
||||
|
||||
hashed = hash_algorithm(event_json_bytes)
|
||||
return (hashed.name, hashed.digest())
|
||||
|
||||
|
||||
def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256):
|
||||
@ -79,27 +82,28 @@ def compute_event_signature(event, signature_name, signing_key):
|
||||
redact_json = tmp_event.get_pdu_json()
|
||||
redact_json.pop("age_ts", None)
|
||||
redact_json.pop("unsigned", None)
|
||||
logger.debug("Signing event: %s", redact_json)
|
||||
logger.debug("Signing event: %s", encode_canonical_json(redact_json))
|
||||
redact_json = sign_json(redact_json, signature_name, signing_key)
|
||||
logger.debug("Signed event: %s", encode_canonical_json(redact_json))
|
||||
return redact_json["signatures"]
|
||||
|
||||
|
||||
def add_hashes_and_signatures(event, signature_name, signing_key,
|
||||
hash_algorithm=hashlib.sha256):
|
||||
if hasattr(event, "old_state_events"):
|
||||
state_json_bytes = encode_canonical_json(
|
||||
[e.event_id for e in event.old_state_events.values()]
|
||||
)
|
||||
hashed = hash_algorithm(state_json_bytes)
|
||||
event.state_hash = {
|
||||
hashed.name: encode_base64(hashed.digest())
|
||||
}
|
||||
# if hasattr(event, "old_state_events"):
|
||||
# state_json_bytes = encode_canonical_json(
|
||||
# [e.event_id for e in event.old_state_events.values()]
|
||||
# )
|
||||
# hashed = hash_algorithm(state_json_bytes)
|
||||
# event.state_hash = {
|
||||
# hashed.name: encode_base64(hashed.digest())
|
||||
# }
|
||||
|
||||
hashed = _compute_content_hash(event, hash_algorithm=hash_algorithm)
|
||||
name, digest = compute_content_hash(event, hash_algorithm=hash_algorithm)
|
||||
|
||||
if not hasattr(event, "hashes"):
|
||||
event.hashes = {}
|
||||
event.hashes[hashed.name] = encode_base64(hashed.digest())
|
||||
event.hashes[name] = encode_base64(digest)
|
||||
|
||||
event.signatures = compute_event_signature(
|
||||
event,
|
||||
|
149
synapse/events/__init__.py
Normal file
149
synapse/events/__init__.py
Normal file
@ -0,0 +1,149 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.util.frozenutils import freeze, unfreeze
|
||||
|
||||
import copy
|
||||
|
||||
|
||||
class _EventInternalMetadata(object):
|
||||
def __init__(self, internal_metadata_dict):
|
||||
self.__dict__ = copy.deepcopy(internal_metadata_dict)
|
||||
|
||||
def get_dict(self):
|
||||
return dict(self.__dict__)
|
||||
|
||||
def is_outlier(self):
|
||||
return hasattr(self, "outlier") and self.outlier
|
||||
|
||||
|
||||
def _event_dict_property(key):
|
||||
def getter(self):
|
||||
return self._event_dict[key]
|
||||
|
||||
def setter(self, v):
|
||||
self._event_dict[key] = v
|
||||
|
||||
def delete(self):
|
||||
del self._event_dict[key]
|
||||
|
||||
return property(
|
||||
getter,
|
||||
setter,
|
||||
delete,
|
||||
)
|
||||
|
||||
|
||||
class EventBase(object):
|
||||
def __init__(self, event_dict, signatures={}, unsigned={},
|
||||
internal_metadata_dict={}):
|
||||
self.signatures = copy.deepcopy(signatures)
|
||||
self.unsigned = copy.deepcopy(unsigned)
|
||||
|
||||
self._event_dict = copy.deepcopy(event_dict)
|
||||
|
||||
self.internal_metadata = _EventInternalMetadata(
|
||||
internal_metadata_dict
|
||||
)
|
||||
|
||||
auth_events = _event_dict_property("auth_events")
|
||||
depth = _event_dict_property("depth")
|
||||
content = _event_dict_property("content")
|
||||
event_id = _event_dict_property("event_id")
|
||||
hashes = _event_dict_property("hashes")
|
||||
origin = _event_dict_property("origin")
|
||||
origin_server_ts = _event_dict_property("origin_server_ts")
|
||||
prev_events = _event_dict_property("prev_events")
|
||||
prev_state = _event_dict_property("prev_state")
|
||||
redacts = _event_dict_property("redacts")
|
||||
room_id = _event_dict_property("room_id")
|
||||
sender = _event_dict_property("sender")
|
||||
state_key = _event_dict_property("state_key")
|
||||
type = _event_dict_property("type")
|
||||
user_id = _event_dict_property("sender")
|
||||
|
||||
@property
|
||||
def membership(self):
|
||||
return self.content["membership"]
|
||||
|
||||
def is_state(self):
|
||||
return hasattr(self, "state_key")
|
||||
|
||||
def get_dict(self):
|
||||
d = dict(self._event_dict)
|
||||
d.update({
|
||||
"signatures": self.signatures,
|
||||
"unsigned": self.unsigned,
|
||||
})
|
||||
|
||||
return d
|
||||
|
||||
def get(self, key, default):
|
||||
return self._event_dict.get(key, default)
|
||||
|
||||
def get_internal_metadata_dict(self):
|
||||
return self.internal_metadata.get_dict()
|
||||
|
||||
def get_pdu_json(self, time_now=None):
|
||||
pdu_json = self.get_dict()
|
||||
|
||||
if time_now is not None and "age_ts" in pdu_json["unsigned"]:
|
||||
age = time_now - pdu_json["unsigned"]["age_ts"]
|
||||
pdu_json.setdefault("unsigned", {})["age"] = int(age)
|
||||
del pdu_json["unsigned"]["age_ts"]
|
||||
|
||||
return pdu_json
|
||||
|
||||
def __set__(self, instance, value):
|
||||
raise AttributeError("Unrecognized attribute %s" % (instance,))
|
||||
|
||||
|
||||
class FrozenEvent(EventBase):
|
||||
def __init__(self, event_dict, internal_metadata_dict={}):
|
||||
event_dict = copy.deepcopy(event_dict)
|
||||
|
||||
signatures = copy.deepcopy(event_dict.pop("signatures", {}))
|
||||
unsigned = copy.deepcopy(event_dict.pop("unsigned", {}))
|
||||
|
||||
frozen_dict = freeze(event_dict)
|
||||
|
||||
super(FrozenEvent, self).__init__(
|
||||
frozen_dict,
|
||||
signatures=signatures,
|
||||
unsigned=unsigned,
|
||||
internal_metadata_dict=internal_metadata_dict,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_event(event):
|
||||
e = FrozenEvent(
|
||||
event.get_pdu_json()
|
||||
)
|
||||
|
||||
e.internal_metadata = event.internal_metadata
|
||||
|
||||
return e
|
||||
|
||||
def get_dict(self):
|
||||
# We need to unfreeze what we return
|
||||
return unfreeze(super(FrozenEvent, self).get_dict())
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def __repr__(self):
|
||||
return "<FrozenEvent event_id='%s', type='%s', state_key='%s'>" % (
|
||||
self.event_id, self.type, self.get("state_key", None),
|
||||
)
|
77
synapse/events/builder.py
Normal file
77
synapse/events/builder.py
Normal file
@ -0,0 +1,77 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import EventBase, FrozenEvent
|
||||
|
||||
from synapse.types import EventID
|
||||
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
import copy
|
||||
|
||||
|
||||
class EventBuilder(EventBase):
|
||||
def __init__(self, key_values={}):
|
||||
signatures = copy.deepcopy(key_values.pop("signatures", {}))
|
||||
unsigned = copy.deepcopy(key_values.pop("unsigned", {}))
|
||||
|
||||
super(EventBuilder, self).__init__(
|
||||
key_values,
|
||||
signatures=signatures,
|
||||
unsigned=unsigned
|
||||
)
|
||||
|
||||
def update_event_key(self, key, value):
|
||||
self._event_dict[key] = value
|
||||
|
||||
def update_event_keys(self, other_dict):
|
||||
self._event_dict.update(other_dict)
|
||||
|
||||
def build(self):
|
||||
return FrozenEvent.from_event(self)
|
||||
|
||||
|
||||
class EventBuilderFactory(object):
|
||||
def __init__(self, clock, hostname):
|
||||
self.clock = clock
|
||||
self.hostname = hostname
|
||||
|
||||
self.event_id_count = 0
|
||||
|
||||
def create_event_id(self):
|
||||
i = str(self.event_id_count)
|
||||
self.event_id_count += 1
|
||||
|
||||
local_part = str(int(self.clock.time())) + i + random_string(5)
|
||||
|
||||
e_id = EventID.create(local_part, self.hostname)
|
||||
|
||||
return e_id.to_string()
|
||||
|
||||
def new(self, key_values={}):
|
||||
key_values["event_id"] = self.create_event_id()
|
||||
|
||||
time_now = int(self.clock.time_msec())
|
||||
|
||||
key_values.setdefault("origin", self.hostname)
|
||||
key_values.setdefault("origin_server_ts", time_now)
|
||||
|
||||
key_values.setdefault("unsigned", {})
|
||||
age = key_values["unsigned"].pop("age", 0)
|
||||
key_values["unsigned"].setdefault("age_ts", time_now - age)
|
||||
|
||||
key_values["signatures"] = {}
|
||||
|
||||
return EventBuilder(key_values=key_values,)
|
@ -13,3 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
class EventContext(object):
|
||||
|
||||
def __init__(self, current_state=None, auth_events=None):
|
||||
self.current_state = current_state
|
||||
self.auth_events = auth_events
|
||||
self.state_group = None
|
@ -13,10 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .room import (
|
||||
RoomMemberEvent, RoomJoinRulesEvent, RoomPowerLevelsEvent,
|
||||
RoomAliasesEvent, RoomCreateEvent,
|
||||
)
|
||||
from synapse.api.constants import EventTypes
|
||||
from . import EventBase
|
||||
|
||||
|
||||
def prune_event(event):
|
||||
@ -31,7 +29,7 @@ def prune_event(event):
|
||||
|
||||
allowed_keys = [
|
||||
"event_id",
|
||||
"user_id",
|
||||
"sender",
|
||||
"room_id",
|
||||
"hashes",
|
||||
"signatures",
|
||||
@ -44,6 +42,7 @@ def prune_event(event):
|
||||
"auth_events",
|
||||
"origin",
|
||||
"origin_server_ts",
|
||||
"membership",
|
||||
]
|
||||
|
||||
new_content = {}
|
||||
@ -53,13 +52,13 @@ def prune_event(event):
|
||||
if field in event.content:
|
||||
new_content[field] = event.content[field]
|
||||
|
||||
if event_type == RoomMemberEvent.TYPE:
|
||||
if event_type == EventTypes.Member:
|
||||
add_fields("membership")
|
||||
elif event_type == RoomCreateEvent.TYPE:
|
||||
elif event_type == EventTypes.Create:
|
||||
add_fields("creator")
|
||||
elif event_type == RoomJoinRulesEvent.TYPE:
|
||||
elif event_type == EventTypes.JoinRules:
|
||||
add_fields("join_rule")
|
||||
elif event_type == RoomPowerLevelsEvent.TYPE:
|
||||
elif event_type == EventTypes.PowerLevels:
|
||||
add_fields(
|
||||
"users",
|
||||
"users_default",
|
||||
@ -71,15 +70,64 @@ def prune_event(event):
|
||||
"kick",
|
||||
"redact",
|
||||
)
|
||||
elif event_type == RoomAliasesEvent.TYPE:
|
||||
elif event_type == EventTypes.Aliases:
|
||||
add_fields("aliases")
|
||||
|
||||
allowed_fields = {
|
||||
k: v
|
||||
for k, v in event.get_full_dict().items()
|
||||
for k, v in event.get_dict().items()
|
||||
if k in allowed_keys
|
||||
}
|
||||
|
||||
allowed_fields["content"] = new_content
|
||||
|
||||
return type(event)(**allowed_fields)
|
||||
allowed_fields["unsigned"] = {}
|
||||
|
||||
if "age_ts" in event.unsigned:
|
||||
allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"]
|
||||
|
||||
return type(event)(allowed_fields)
|
||||
|
||||
|
||||
def serialize_event(hs, e):
|
||||
# FIXME(erikj): To handle the case of presence events and the like
|
||||
if not isinstance(e, EventBase):
|
||||
return e
|
||||
|
||||
# Should this strip out None's?
|
||||
d = {k: v for k, v in e.get_dict().items()}
|
||||
if "age_ts" in d["unsigned"]:
|
||||
now = int(hs.get_clock().time_msec())
|
||||
d["unsigned"]["age"] = now - d["unsigned"]["age_ts"]
|
||||
del d["unsigned"]["age_ts"]
|
||||
|
||||
d["user_id"] = d.pop("sender", None)
|
||||
|
||||
if "redacted_because" in e.unsigned:
|
||||
d["redacted_because"] = serialize_event(
|
||||
hs, e.unsigned["redacted_because"]
|
||||
)
|
||||
|
||||
del d["unsigned"]["redacted_because"]
|
||||
|
||||
if "redacted_by" in e.unsigned:
|
||||
d["redacted_by"] = e.unsigned["redacted_by"]
|
||||
del d["unsigned"]["redacted_by"]
|
||||
|
||||
if "replaces_state" in e.unsigned:
|
||||
d["replaces_state"] = e.unsigned["replaces_state"]
|
||||
del d["unsigned"]["replaces_state"]
|
||||
|
||||
if "prev_content" in e.unsigned:
|
||||
d["prev_content"] = e.unsigned["prev_content"]
|
||||
del d["unsigned"]["prev_content"]
|
||||
|
||||
del d["auth_events"]
|
||||
del d["prev_events"]
|
||||
del d["hashes"]
|
||||
del d["signatures"]
|
||||
d.pop("depth", None)
|
||||
d.pop("unsigned", None)
|
||||
d.pop("origin", None)
|
||||
|
||||
return d
|
92
synapse/events/validator.py
Normal file
92
synapse/events/validator.py
Normal file
@ -0,0 +1,92 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.types import EventID, RoomID, UserID
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
|
||||
|
||||
class EventValidator(object):
|
||||
|
||||
def validate(self, event):
|
||||
EventID.from_string(event.event_id)
|
||||
RoomID.from_string(event.room_id)
|
||||
|
||||
required = [
|
||||
# "auth_events",
|
||||
"content",
|
||||
# "hashes",
|
||||
"origin",
|
||||
# "prev_events",
|
||||
"sender",
|
||||
"type",
|
||||
]
|
||||
|
||||
for k in required:
|
||||
if not hasattr(event, k):
|
||||
raise SynapseError(400, "Event does not have key %s" % (k,))
|
||||
|
||||
# Check that the following keys have string values
|
||||
strings = [
|
||||
"origin",
|
||||
"sender",
|
||||
"type",
|
||||
]
|
||||
|
||||
if hasattr(event, "state_key"):
|
||||
strings.append("state_key")
|
||||
|
||||
for s in strings:
|
||||
if not isinstance(getattr(event, s), basestring):
|
||||
raise SynapseError(400, "Not '%s' a string type" % (s,))
|
||||
|
||||
if event.type == EventTypes.Member:
|
||||
if "membership" not in event.content:
|
||||
raise SynapseError(400, "Content has not membership key")
|
||||
|
||||
if event.content["membership"] not in Membership.LIST:
|
||||
raise SynapseError(400, "Invalid membership key")
|
||||
|
||||
# Check that the following keys have dictionary values
|
||||
# TODO
|
||||
|
||||
# Check that the following keys have the correct format for DAGs
|
||||
# TODO
|
||||
|
||||
def validate_new(self, event):
|
||||
self.validate(event)
|
||||
|
||||
UserID.from_string(event.sender)
|
||||
|
||||
if event.type == EventTypes.Message:
|
||||
strings = [
|
||||
"body",
|
||||
"msgtype",
|
||||
]
|
||||
|
||||
self._ensure_strings(event.content, strings)
|
||||
|
||||
elif event.type == EventTypes.Topic:
|
||||
self._ensure_strings(event.content, ["topic"])
|
||||
|
||||
elif event.type == EventTypes.Name:
|
||||
self._ensure_strings(event.content, ["name"])
|
||||
|
||||
def _ensure_strings(self, d, keys):
|
||||
for s in keys:
|
||||
if s not in d:
|
||||
raise SynapseError(400, "'%s' not in content" % (s,))
|
||||
if not isinstance(d[s], basestring):
|
||||
raise SynapseError(400, "Not '%s' a string type" % (s,))
|
@ -25,6 +25,7 @@ from .persistence import TransactionActions
|
||||
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
from synapse.events import FrozenEvent
|
||||
|
||||
import logging
|
||||
|
||||
@ -73,7 +74,7 @@ class ReplicationLayer(object):
|
||||
|
||||
self._clock = hs.get_clock()
|
||||
|
||||
self.event_factory = hs.get_event_factory()
|
||||
self.event_builder_factory = hs.get_event_builder_factory()
|
||||
|
||||
def set_handler(self, handler):
|
||||
"""Sets the handler that the replication layer will use to communicate
|
||||
@ -112,7 +113,7 @@ class ReplicationLayer(object):
|
||||
self.query_handlers[query_type] = handler
|
||||
|
||||
@log_function
|
||||
def send_pdu(self, pdu):
|
||||
def send_pdu(self, pdu, destinations):
|
||||
"""Informs the replication layer about a new PDU generated within the
|
||||
home server that should be transmitted to others.
|
||||
|
||||
@ -131,7 +132,7 @@ class ReplicationLayer(object):
|
||||
logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
|
||||
|
||||
# TODO, add errback, etc.
|
||||
self._transaction_queue.enqueue_pdu(pdu, order)
|
||||
self._transaction_queue.enqueue_pdu(pdu, destinations, order)
|
||||
|
||||
logger.debug(
|
||||
"[%s] transaction_layer.enqueue_pdu... done",
|
||||
@ -255,31 +256,35 @@ class ReplicationLayer(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_state_for_context(self, destination, context, event_id=None):
|
||||
def get_state_for_context(self, destination, context, event_id):
|
||||
"""Requests all of the `current` state PDUs for a given context from
|
||||
a remote home server.
|
||||
|
||||
Args:
|
||||
destination (str): The remote homeserver to query for the state.
|
||||
context (str): The context we're interested in.
|
||||
event_id (str): The id of the event we want the state at.
|
||||
|
||||
Returns:
|
||||
Deferred: Results in a list of PDUs.
|
||||
"""
|
||||
|
||||
transaction_data = yield self.transport_layer.get_context_state(
|
||||
result = yield self.transport_layer.get_context_state(
|
||||
destination,
|
||||
context,
|
||||
event_id=event_id,
|
||||
)
|
||||
|
||||
transaction = Transaction(**transaction_data)
|
||||
pdus = [
|
||||
self.event_from_pdu_json(p, outlier=True)
|
||||
for p in transaction.pdus
|
||||
self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
|
||||
]
|
||||
|
||||
defer.returnValue(pdus)
|
||||
auth_chain = [
|
||||
self.event_from_pdu_json(p, outlier=True)
|
||||
for p in result.get("auth_chain", [])
|
||||
]
|
||||
|
||||
defer.returnValue((pdus, auth_chain))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
@ -334,7 +339,7 @@ class ReplicationLayer(object):
|
||||
defer.returnValue(response)
|
||||
return
|
||||
|
||||
logger.debug("[%s] Transacition is new", transaction.transaction_id)
|
||||
logger.debug("[%s] Transaction is new", transaction.transaction_id)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
dl = []
|
||||
@ -382,10 +387,16 @@ class ReplicationLayer(object):
|
||||
context,
|
||||
event_id,
|
||||
)
|
||||
auth_chain = yield self.store.get_auth_chain(
|
||||
[pdu.event_id for pdu in pdus]
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Specify an event")
|
||||
|
||||
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
|
||||
defer.returnValue((200, {
|
||||
"pdus": [pdu.get_pdu_json() for pdu in pdus],
|
||||
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
|
||||
}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
@ -438,7 +449,9 @@ class ReplicationLayer(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_send_join_request(self, origin, content):
|
||||
logger.debug("on_send_join_request: content: %s", content)
|
||||
pdu = self.event_from_pdu_json(content)
|
||||
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
|
||||
res_pdus = yield self.handler.on_send_join_request(origin, pdu)
|
||||
time_now = self._clock.time_msec()
|
||||
defer.returnValue((200, {
|
||||
@ -557,13 +570,21 @@ class ReplicationLayer(object):
|
||||
origin, pdu.event_id, do_auth=False
|
||||
)
|
||||
|
||||
if existing and (not existing.outlier or pdu.outlier):
|
||||
already_seen = (
|
||||
existing and (
|
||||
not existing.internal_metadata.is_outlier()
|
||||
or pdu.internal_metadata.is_outlier()
|
||||
)
|
||||
)
|
||||
if already_seen:
|
||||
logger.debug("Already seen pdu %s", pdu.event_id)
|
||||
defer.returnValue({})
|
||||
return
|
||||
|
||||
state = None
|
||||
|
||||
auth_chain = []
|
||||
|
||||
# We need to make sure we have all the auth events.
|
||||
# for e_id, _ in pdu.auth_events:
|
||||
# exists = yield self._get_persisted_pdu(
|
||||
@ -595,7 +616,7 @@ class ReplicationLayer(object):
|
||||
# )
|
||||
|
||||
# Get missing pdus if necessary.
|
||||
if not pdu.outlier:
|
||||
if not pdu.internal_metadata.is_outlier():
|
||||
# We only backfill backwards to the min depth.
|
||||
min_depth = yield self.handler.get_min_depth_for_context(
|
||||
pdu.room_id
|
||||
@ -636,7 +657,7 @@ class ReplicationLayer(object):
|
||||
"_handle_new_pdu getting state for %s",
|
||||
pdu.room_id
|
||||
)
|
||||
state = yield self.get_state_for_context(
|
||||
state, auth_chain = yield self.get_state_for_context(
|
||||
origin, pdu.room_id, pdu.event_id,
|
||||
)
|
||||
|
||||
@ -646,6 +667,7 @@ class ReplicationLayer(object):
|
||||
pdu,
|
||||
backfilled=backfilled,
|
||||
state=state,
|
||||
auth_chain=auth_chain,
|
||||
)
|
||||
else:
|
||||
ret = None
|
||||
@ -658,19 +680,14 @@ class ReplicationLayer(object):
|
||||
return "<ReplicationLayer(%s)>" % self.server_name
|
||||
|
||||
def event_from_pdu_json(self, pdu_json, outlier=False):
|
||||
#TODO: Check we have all the PDU keys here
|
||||
pdu_json.setdefault("hashes", {})
|
||||
pdu_json.setdefault("signatures", {})
|
||||
sender = pdu_json.pop("sender", None)
|
||||
if sender is not None:
|
||||
pdu_json["user_id"] = sender
|
||||
state_hash = pdu_json.get("unsigned", {}).pop("state_hash", None)
|
||||
if state_hash is not None:
|
||||
pdu_json["state_hash"] = state_hash
|
||||
return self.event_factory.create_event(
|
||||
pdu_json["type"], outlier=outlier, **pdu_json
|
||||
event = FrozenEvent(
|
||||
pdu_json
|
||||
)
|
||||
|
||||
event.internal_metadata.outlier = outlier
|
||||
|
||||
return event
|
||||
|
||||
|
||||
class _TransactionQueue(object):
|
||||
"""This class makes sure we only have one transaction in flight at
|
||||
@ -685,6 +702,7 @@ class _TransactionQueue(object):
|
||||
self.transport_layer = transport_layer
|
||||
|
||||
self._clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
# Is a mapping from destinations -> deferreds. Used to keep track
|
||||
# of which destinations have transactions in flight and when they are
|
||||
@ -705,15 +723,13 @@ class _TransactionQueue(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def enqueue_pdu(self, pdu, order):
|
||||
def enqueue_pdu(self, pdu, destinations, order):
|
||||
# We loop through all destinations to see whether we already have
|
||||
# a transaction in progress. If we do, stick it in the pending_pdus
|
||||
# table and we'll get back to it later.
|
||||
|
||||
destinations = set([
|
||||
d for d in pdu.destinations
|
||||
if d != self.server_name
|
||||
])
|
||||
destinations = set(destinations)
|
||||
destinations.discard(self.server_name)
|
||||
|
||||
logger.debug("Sending to: %s", str(destinations))
|
||||
|
||||
@ -728,8 +744,14 @@ class _TransactionQueue(object):
|
||||
(pdu, deferred, order)
|
||||
)
|
||||
|
||||
def eb(failure):
|
||||
if not deferred.called:
|
||||
deferred.errback(failure)
|
||||
else:
|
||||
logger.warn("Failed to send pdu", failure)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
self._attempt_new_transaction(destination)
|
||||
self._attempt_new_transaction(destination).addErrback(eb)
|
||||
|
||||
deferreds.append(deferred)
|
||||
|
||||
@ -739,6 +761,9 @@ class _TransactionQueue(object):
|
||||
def enqueue_edu(self, edu):
|
||||
destination = edu.destination
|
||||
|
||||
if destination == self.server_name:
|
||||
return
|
||||
|
||||
deferred = defer.Deferred()
|
||||
self.pending_edus_by_dest.setdefault(destination, []).append(
|
||||
(edu, deferred)
|
||||
@ -748,7 +773,7 @@ class _TransactionQueue(object):
|
||||
if not deferred.called:
|
||||
deferred.errback(failure)
|
||||
else:
|
||||
logger.exception("Failed to send edu", failure)
|
||||
logger.warn("Failed to send edu", failure)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
self._attempt_new_transaction(destination).addErrback(eb)
|
||||
@ -770,10 +795,33 @@ class _TransactionQueue(object):
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def _attempt_new_transaction(self, destination):
|
||||
|
||||
(retry_last_ts, retry_interval) = (0, 0)
|
||||
retry_timings = yield self.store.get_destination_retry_timings(
|
||||
destination
|
||||
)
|
||||
if retry_timings:
|
||||
(retry_last_ts, retry_interval) = (
|
||||
retry_timings.retry_last_ts, retry_timings.retry_interval
|
||||
)
|
||||
if retry_last_ts + retry_interval > int(self._clock.time_msec()):
|
||||
logger.info(
|
||||
"TX [%s] not ready for retry yet - "
|
||||
"dropping transaction for now",
|
||||
destination,
|
||||
)
|
||||
return
|
||||
else:
|
||||
logger.info("TX [%s] is ready for retry", destination)
|
||||
|
||||
if destination in self.pending_transactions:
|
||||
# XXX: pending_transactions can get stuck on by a never-ending
|
||||
# request at which point pending_pdus_by_dest just keeps growing.
|
||||
# we need application-layer timeouts of some flavour of these
|
||||
# requests
|
||||
return
|
||||
|
||||
# list of (pending_pdu, deferred, order)
|
||||
# list of (pending_pdu, deferred, order)
|
||||
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
|
||||
pending_edus = self.pending_edus_by_dest.pop(destination, [])
|
||||
pending_failures = self.pending_failures_by_dest.pop(destination, [])
|
||||
@ -781,7 +829,14 @@ class _TransactionQueue(object):
|
||||
if not pending_pdus and not pending_edus and not pending_failures:
|
||||
return
|
||||
|
||||
logger.debug("TX [%s] Attempting new transaction", destination)
|
||||
logger.debug(
|
||||
"TX [%s] Attempting new transaction "
|
||||
"(pdus: %d, edus: %d, failures: %d)",
|
||||
destination,
|
||||
len(pending_pdus),
|
||||
len(pending_edus),
|
||||
len(pending_failures)
|
||||
)
|
||||
|
||||
# Sort based on the order field
|
||||
pending_pdus.sort(key=lambda t: t[2])
|
||||
@ -814,7 +869,11 @@ class _TransactionQueue(object):
|
||||
yield self.transaction_actions.prepare_to_send(transaction)
|
||||
|
||||
logger.debug("TX [%s] Persisted transaction", destination)
|
||||
logger.debug("TX [%s] Sending transaction...", destination)
|
||||
logger.info(
|
||||
"TX [%s] Sending transaction [%s]",
|
||||
destination,
|
||||
transaction.transaction_id,
|
||||
)
|
||||
|
||||
# Actually send the transaction
|
||||
|
||||
@ -835,6 +894,8 @@ class _TransactionQueue(object):
|
||||
transaction, json_data_cb
|
||||
)
|
||||
|
||||
logger.info("TX [%s] got %d response", destination, code)
|
||||
|
||||
logger.debug("TX [%s] Sent transaction", destination)
|
||||
logger.debug("TX [%s] Marking as delivered...", destination)
|
||||
|
||||
@ -847,8 +908,14 @@ class _TransactionQueue(object):
|
||||
|
||||
for deferred in deferreds:
|
||||
if code == 200:
|
||||
if retry_last_ts:
|
||||
# this host is alive! reset retry schedule
|
||||
yield self.store.set_destination_retry_timings(
|
||||
destination, 0, 0
|
||||
)
|
||||
deferred.callback(None)
|
||||
else:
|
||||
self.set_retrying(destination, retry_interval)
|
||||
deferred.errback(RuntimeError("Got status %d" % code))
|
||||
|
||||
# Ensures we don't continue until all callbacks on that
|
||||
@ -861,11 +928,15 @@ class _TransactionQueue(object):
|
||||
logger.debug("TX [%s] Yielded to callbacks", destination)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("TX Problem in _attempt_transaction")
|
||||
|
||||
# We capture this here as there as nothing actually listens
|
||||
# for this finishing functions deferred.
|
||||
logger.exception(e)
|
||||
logger.warn(
|
||||
"TX [%s] Problem in _attempt_transaction: %s",
|
||||
destination,
|
||||
e,
|
||||
)
|
||||
|
||||
self.set_retrying(destination, retry_interval)
|
||||
|
||||
for deferred in deferreds:
|
||||
if not deferred.called:
|
||||
@ -877,3 +948,22 @@ class _TransactionQueue(object):
|
||||
|
||||
# Check to see if there is anything else to send.
|
||||
self._attempt_new_transaction(destination)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_retrying(self, destination, retry_interval):
|
||||
# track that this destination is having problems and we should
|
||||
# give it a chance to recover before trying it again
|
||||
|
||||
if retry_interval:
|
||||
retry_interval *= 2
|
||||
# plateau at hourly retries for now
|
||||
if retry_interval >= 60 * 60 * 1000:
|
||||
retry_interval = 60 * 60 * 1000
|
||||
else:
|
||||
retry_interval = 2000 # try again at first after 2 seconds
|
||||
|
||||
yield self.store.set_destination_retry_timings(
|
||||
destination,
|
||||
int(self._clock.time_msec()),
|
||||
retry_interval
|
||||
)
|
||||
|
@ -155,7 +155,7 @@ class TransportLayer(object):
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def send_transaction(self, transaction, json_data_callback=None):
|
||||
""" Sends the given Transaction to it's destination
|
||||
""" Sends the given Transaction to its destination
|
||||
|
||||
Args:
|
||||
transaction (Transaction)
|
||||
|
@ -15,11 +15,10 @@
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import LimitExceededError
|
||||
from synapse.api.errors import LimitExceededError, SynapseError
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.crypto.event_signing import add_hashes_and_signatures
|
||||
from synapse.api.events.room import RoomMemberEvent
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.constants import Membership, EventTypes
|
||||
|
||||
import logging
|
||||
|
||||
@ -31,10 +30,8 @@ class BaseHandler(object):
|
||||
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
self.event_factory = hs.get_event_factory()
|
||||
self.auth = hs.get_auth()
|
||||
self.notifier = hs.get_notifier()
|
||||
self.room_lock = hs.get_room_lock_manager()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
self.distributor = hs.get_distributor()
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
@ -44,6 +41,8 @@ class BaseHandler(object):
|
||||
self.signing_key = hs.config.signing_key[0]
|
||||
self.server_name = hs.hostname
|
||||
|
||||
self.event_builder_factory = hs.get_event_builder_factory()
|
||||
|
||||
def ratelimit(self, user_id):
|
||||
time_now = self.clock.time()
|
||||
allowed, time_allowed = self.ratelimiter.send_message(
|
||||
@ -57,62 +56,95 @@ class BaseHandler(object):
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _on_new_room_event(self, event, snapshot, extra_destinations=[],
|
||||
extra_users=[], suppress_auth=False,
|
||||
do_invite_host=None):
|
||||
def _create_new_client_event(self, builder):
|
||||
yield run_on_reactor()
|
||||
|
||||
snapshot.fill_out_prev_events(event)
|
||||
|
||||
yield self.state_handler.annotate_event_with_state(event)
|
||||
|
||||
yield self.auth.add_auth_events(event)
|
||||
|
||||
logger.debug("Signing event...")
|
||||
|
||||
add_hashes_and_signatures(
|
||||
event, self.server_name, self.signing_key
|
||||
latest_ret = yield self.store.get_latest_events_in_room(
|
||||
builder.room_id,
|
||||
)
|
||||
|
||||
logger.debug("Signed event.")
|
||||
if latest_ret:
|
||||
depth = max([d for _, _, d in latest_ret]) + 1
|
||||
else:
|
||||
depth = 1
|
||||
|
||||
prev_events = [(e, h) for e, h, _ in latest_ret]
|
||||
|
||||
builder.prev_events = prev_events
|
||||
builder.depth = depth
|
||||
|
||||
state_handler = self.state_handler
|
||||
|
||||
context = yield state_handler.compute_event_context(builder)
|
||||
|
||||
if builder.is_state():
|
||||
builder.prev_state = context.prev_state_events
|
||||
|
||||
yield self.auth.add_auth_events(builder, context)
|
||||
|
||||
add_hashes_and_signatures(
|
||||
builder, self.server_name, self.signing_key
|
||||
)
|
||||
|
||||
event = builder.build()
|
||||
|
||||
logger.debug(
|
||||
"Created event %s with auth_events: %s, current state: %s",
|
||||
event.event_id, context.auth_events, context.current_state,
|
||||
)
|
||||
|
||||
defer.returnValue(
|
||||
(event, context,)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_new_client_event(self, event, context, extra_destinations=[],
|
||||
extra_users=[], suppress_auth=False):
|
||||
yield run_on_reactor()
|
||||
|
||||
# We now need to go and hit out to wherever we need to hit out to.
|
||||
|
||||
if not suppress_auth:
|
||||
logger.debug("Authing...")
|
||||
self.auth.check(event, auth_events=event.old_state_events)
|
||||
logger.debug("Authed")
|
||||
else:
|
||||
logger.debug("Suppressed auth.")
|
||||
self.auth.check(event, auth_events=context.auth_events)
|
||||
|
||||
if do_invite_host:
|
||||
federation_handler = self.hs.get_handlers().federation_handler
|
||||
invite_event = yield federation_handler.send_invite(
|
||||
do_invite_host,
|
||||
event
|
||||
)
|
||||
yield self.store.persist_event(event, context=context)
|
||||
|
||||
# FIXME: We need to check if the remote changed anything else
|
||||
event.signatures = invite_event.signatures
|
||||
federation_handler = self.hs.get_handlers().federation_handler
|
||||
|
||||
yield self.store.persist_event(event)
|
||||
if event.type == EventTypes.Member:
|
||||
if event.content["membership"] == Membership.INVITE:
|
||||
invitee = self.hs.parse_userid(event.state_key)
|
||||
if not self.hs.is_mine(invitee):
|
||||
# TODO: Can we add signature from remote server in a nicer
|
||||
# way? If we have been invited by a remote server, we need
|
||||
# to get them to sign the event.
|
||||
returned_invite = yield federation_handler.send_invite(
|
||||
invitee.domain,
|
||||
event,
|
||||
)
|
||||
|
||||
# TODO: Make sure the signatures actually are correct.
|
||||
event.signatures.update(
|
||||
returned_invite.signatures
|
||||
)
|
||||
|
||||
destinations = set(extra_destinations)
|
||||
# Send a PDU to all hosts who have joined the room.
|
||||
|
||||
for k, s in event.state_events.items():
|
||||
for k, s in context.current_state.items():
|
||||
try:
|
||||
if k[0] == RoomMemberEvent.TYPE:
|
||||
if k[0] == EventTypes.Member:
|
||||
if s.content["membership"] == Membership.JOIN:
|
||||
destinations.add(
|
||||
self.hs.parse_userid(s.state_key).domain
|
||||
)
|
||||
except:
|
||||
except SynapseError:
|
||||
logger.warn(
|
||||
"Failed to get destination from event %s", s.event_id
|
||||
)
|
||||
|
||||
event.destinations = list(destinations)
|
||||
|
||||
yield self.notifier.on_new_room_event(event, extra_users=extra_users)
|
||||
|
||||
federation_handler = self.hs.get_handlers().federation_handler
|
||||
yield federation_handler.handle_new_event(event, snapshot)
|
||||
yield federation_handler.handle_new_event(
|
||||
event,
|
||||
None,
|
||||
destinations=destinations,
|
||||
)
|
||||
|
@ -18,7 +18,7 @@ from twisted.internet import defer
|
||||
from ._base import BaseHandler
|
||||
|
||||
from synapse.api.errors import SynapseError, Codes, CodeMessageException
|
||||
from synapse.api.events.room import RoomAliasesEvent
|
||||
from synapse.api.constants import EventTypes
|
||||
|
||||
import logging
|
||||
|
||||
@ -40,7 +40,7 @@ class DirectoryHandler(BaseHandler):
|
||||
|
||||
# TODO(erikj): Do auth.
|
||||
|
||||
if not room_alias.is_mine:
|
||||
if not self.hs.is_mine(room_alias):
|
||||
raise SynapseError(400, "Room alias must be local")
|
||||
# TODO(erikj): Change this.
|
||||
|
||||
@ -64,7 +64,7 @@ class DirectoryHandler(BaseHandler):
|
||||
def delete_association(self, user_id, room_alias):
|
||||
# TODO Check if server admin
|
||||
|
||||
if not room_alias.is_mine:
|
||||
if not self.hs.is_mine(room_alias):
|
||||
raise SynapseError(400, "Room alias must be local")
|
||||
|
||||
room_id = yield self.store.delete_room_alias(room_alias)
|
||||
@ -75,7 +75,7 @@ class DirectoryHandler(BaseHandler):
|
||||
@defer.inlineCallbacks
|
||||
def get_association(self, room_alias):
|
||||
room_id = None
|
||||
if room_alias.is_mine:
|
||||
if self.hs.is_mine(room_alias):
|
||||
result = yield self.store.get_association_from_room_alias(
|
||||
room_alias
|
||||
)
|
||||
@ -123,7 +123,7 @@ class DirectoryHandler(BaseHandler):
|
||||
@defer.inlineCallbacks
|
||||
def on_directory_query(self, args):
|
||||
room_alias = self.hs.parse_roomalias(args["room_alias"])
|
||||
if not room_alias.is_mine:
|
||||
if not self.hs.is_mine(room_alias):
|
||||
raise SynapseError(
|
||||
400, "Room Alias is not hosted on this Home Server"
|
||||
)
|
||||
@ -148,16 +148,11 @@ class DirectoryHandler(BaseHandler):
|
||||
def send_room_alias_update_event(self, user_id, room_id):
|
||||
aliases = yield self.store.get_aliases_for_room(room_id)
|
||||
|
||||
event = self.event_factory.create_event(
|
||||
etype=RoomAliasesEvent.TYPE,
|
||||
state_key=self.hs.hostname,
|
||||
room_id=room_id,
|
||||
user_id=user_id,
|
||||
content={"aliases": aliases},
|
||||
)
|
||||
|
||||
snapshot = yield self.store.snapshot_room(event)
|
||||
|
||||
yield self._on_new_room_event(
|
||||
event, snapshot, extra_users=[user_id], suppress_auth=True
|
||||
)
|
||||
msg_handler = self.hs.get_handlers().message_handler
|
||||
yield msg_handler.create_and_send_event({
|
||||
"type": EventTypes.Aliases,
|
||||
"state_key": self.hs.hostname,
|
||||
"room_id": room_id,
|
||||
"sender": user_id,
|
||||
"content": {"aliases": aliases},
|
||||
})
|
||||
|
@ -17,12 +17,11 @@
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
from synapse.api.events.utils import prune_event
|
||||
from synapse.events.utils import prune_event
|
||||
from synapse.api.errors import (
|
||||
AuthError, FederationError, SynapseError, StoreError,
|
||||
)
|
||||
from synapse.api.events.room import RoomMemberEvent, RoomCreateEvent
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.crypto.event_signing import (
|
||||
@ -76,7 +75,7 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
@log_function
|
||||
@defer.inlineCallbacks
|
||||
def handle_new_event(self, event, snapshot):
|
||||
def handle_new_event(self, event, snapshot, destinations):
|
||||
""" Takes in an event from the client to server side, that has already
|
||||
been authed and handled by the state module, and sends it to any
|
||||
remote home servers that may be interested.
|
||||
@ -92,16 +91,12 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
yield run_on_reactor()
|
||||
|
||||
pdu = event
|
||||
|
||||
if not hasattr(pdu, "destinations") or not pdu.destinations:
|
||||
pdu.destinations = []
|
||||
|
||||
yield self.replication_layer.send_pdu(pdu)
|
||||
self.replication_layer.send_pdu(event, destinations)
|
||||
|
||||
@log_function
|
||||
@defer.inlineCallbacks
|
||||
def on_receive_pdu(self, origin, pdu, backfilled, state=None):
|
||||
def on_receive_pdu(self, origin, pdu, backfilled, state=None,
|
||||
auth_chain=None):
|
||||
""" Called by the ReplicationLayer when we have a new pdu. We need to
|
||||
do auth checks and put it through the StateHandler.
|
||||
"""
|
||||
@ -140,7 +135,7 @@ class FederationHandler(BaseHandler):
|
||||
if not check_event_content_hash(event):
|
||||
logger.warn(
|
||||
"Event content has been tampered, redacting %s, %s",
|
||||
event.event_id, encode_canonical_json(event.get_full_dict())
|
||||
event.event_id, encode_canonical_json(event.get_dict())
|
||||
)
|
||||
event = redacted_event
|
||||
|
||||
@ -153,43 +148,44 @@ class FederationHandler(BaseHandler):
|
||||
event.room_id,
|
||||
self.server_name
|
||||
)
|
||||
if not is_in_room and not event.outlier:
|
||||
if not is_in_room and not event.internal_metadata.outlier:
|
||||
logger.debug("Got event for room we're not in.")
|
||||
|
||||
replication_layer = self.replication_layer
|
||||
auth_chain = yield replication_layer.get_event_auth(
|
||||
origin,
|
||||
context=event.room_id,
|
||||
event_id=event.event_id,
|
||||
)
|
||||
|
||||
for e in auth_chain:
|
||||
e.outlier = True
|
||||
try:
|
||||
yield self._handle_new_event(e, fetch_missing=False)
|
||||
except:
|
||||
logger.exception(
|
||||
"Failed to parse auth event %s",
|
||||
e.event_id,
|
||||
)
|
||||
replication = self.replication_layer
|
||||
|
||||
if not state:
|
||||
state = yield replication_layer.get_state_for_context(
|
||||
state, auth_chain = yield replication.get_state_for_context(
|
||||
origin, context=event.room_id, event_id=event.event_id,
|
||||
)
|
||||
|
||||
if not auth_chain:
|
||||
auth_chain = yield replication.get_event_auth(
|
||||
origin,
|
||||
context=event.room_id,
|
||||
event_id=event.event_id,
|
||||
)
|
||||
|
||||
for e in auth_chain:
|
||||
e.internal_metadata.outlier = True
|
||||
try:
|
||||
yield self._handle_new_event(e, fetch_auth_from=origin)
|
||||
except:
|
||||
logger.exception(
|
||||
"Failed to handle auth event %s",
|
||||
e.event_id,
|
||||
)
|
||||
|
||||
current_state = state
|
||||
|
||||
if state:
|
||||
for e in state:
|
||||
e.outlier = True
|
||||
logging.info("A :) %r", e)
|
||||
e.internal_metadata.outlier = True
|
||||
try:
|
||||
yield self._handle_new_event(e)
|
||||
except:
|
||||
logger.exception(
|
||||
"Failed to parse state event %s",
|
||||
"Failed to handle state event %s",
|
||||
e.event_id,
|
||||
)
|
||||
|
||||
@ -208,6 +204,13 @@ class FederationHandler(BaseHandler):
|
||||
affected=event.event_id,
|
||||
)
|
||||
|
||||
# if we're receiving valid events from an origin,
|
||||
# it's probably a good idea to mark it as not in retry-state
|
||||
# for sending (although this is a bit of a leap)
|
||||
retry_timings = yield self.store.get_destination_retry_timings(origin)
|
||||
if (retry_timings and retry_timings.retry_last_ts):
|
||||
self.store.set_destination_retry_timings(origin, 0, 0)
|
||||
|
||||
room = yield self.store.get_room(event.room_id)
|
||||
|
||||
if not room:
|
||||
@ -222,7 +225,7 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
if not backfilled:
|
||||
extra_users = []
|
||||
if event.type == RoomMemberEvent.TYPE:
|
||||
if event.type == EventTypes.Member:
|
||||
target_user_id = event.state_key
|
||||
target_user = self.hs.parse_userid(target_user_id)
|
||||
extra_users.append(target_user)
|
||||
@ -231,7 +234,7 @@ class FederationHandler(BaseHandler):
|
||||
event, extra_users=extra_users
|
||||
)
|
||||
|
||||
if event.type == RoomMemberEvent.TYPE:
|
||||
if event.type == EventTypes.Member:
|
||||
if event.membership == Membership.JOIN:
|
||||
user = self.hs.parse_userid(event.state_key)
|
||||
yield self.distributor.fire(
|
||||
@ -258,11 +261,15 @@ class FederationHandler(BaseHandler):
|
||||
event = pdu
|
||||
|
||||
# FIXME (erikj): Not sure this actually works :/
|
||||
yield self.state_handler.annotate_event_with_state(event)
|
||||
context = yield self.state_handler.compute_event_context(event)
|
||||
|
||||
events.append(event)
|
||||
events.append((event, context))
|
||||
|
||||
yield self.store.persist_event(event, backfilled=True)
|
||||
yield self.store.persist_event(
|
||||
event,
|
||||
context=context,
|
||||
backfilled=True
|
||||
)
|
||||
|
||||
defer.returnValue(events)
|
||||
|
||||
@ -279,13 +286,11 @@ class FederationHandler(BaseHandler):
|
||||
pdu=event
|
||||
)
|
||||
|
||||
|
||||
|
||||
defer.returnValue(pdu)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_event_auth(self, event_id):
|
||||
auth = yield self.store.get_auth_chain(event_id)
|
||||
auth = yield self.store.get_auth_chain([event_id])
|
||||
|
||||
for event in auth:
|
||||
event.signatures.update(
|
||||
@ -325,42 +330,55 @@ class FederationHandler(BaseHandler):
|
||||
event = pdu
|
||||
|
||||
# We should assert some things.
|
||||
assert(event.type == RoomMemberEvent.TYPE)
|
||||
# FIXME: Do this in a nicer way
|
||||
assert(event.type == EventTypes.Member)
|
||||
assert(event.user_id == joinee)
|
||||
assert(event.state_key == joinee)
|
||||
assert(event.room_id == room_id)
|
||||
|
||||
event.outlier = False
|
||||
event.internal_metadata.outlier = False
|
||||
|
||||
self.room_queues[room_id] = []
|
||||
|
||||
builder = self.event_builder_factory.new(
|
||||
event.get_pdu_json()
|
||||
)
|
||||
|
||||
handled_events = set()
|
||||
|
||||
try:
|
||||
event.event_id = self.event_factory.create_event_id()
|
||||
event.origin = self.hs.hostname
|
||||
event.content = content
|
||||
builder.event_id = self.event_builder_factory.create_event_id()
|
||||
builder.origin = self.hs.hostname
|
||||
builder.content = content
|
||||
|
||||
if not hasattr(event, "signatures"):
|
||||
event.signatures = {}
|
||||
builder.signatures = {}
|
||||
|
||||
add_hashes_and_signatures(
|
||||
event,
|
||||
builder,
|
||||
self.hs.hostname,
|
||||
self.hs.config.signing_key[0],
|
||||
)
|
||||
|
||||
new_event = builder.build()
|
||||
|
||||
ret = yield self.replication_layer.send_join(
|
||||
target_host,
|
||||
event
|
||||
new_event
|
||||
)
|
||||
|
||||
state = ret["state"]
|
||||
auth_chain = ret["auth_chain"]
|
||||
auth_chain.sort(key=lambda e: e.depth)
|
||||
|
||||
handled_events.update([s.event_id for s in state])
|
||||
handled_events.update([a.event_id for a in auth_chain])
|
||||
handled_events.add(new_event.event_id)
|
||||
|
||||
logger.debug("do_invite_join auth_chain: %s", auth_chain)
|
||||
logger.debug("do_invite_join state: %s", state)
|
||||
|
||||
logger.debug("do_invite_join event: %s", event)
|
||||
logger.debug("do_invite_join event: %s", new_event)
|
||||
|
||||
try:
|
||||
yield self.store.store_room(
|
||||
@ -373,37 +391,36 @@ class FederationHandler(BaseHandler):
|
||||
pass
|
||||
|
||||
for e in auth_chain:
|
||||
e.outlier = True
|
||||
e.internal_metadata.outlier = True
|
||||
try:
|
||||
yield self._handle_new_event(e, fetch_missing=False)
|
||||
yield self._handle_new_event(e)
|
||||
except:
|
||||
logger.exception(
|
||||
"Failed to parse auth event %s",
|
||||
"Failed to handle auth event %s",
|
||||
e.event_id,
|
||||
)
|
||||
|
||||
for e in state:
|
||||
# FIXME: Auth these.
|
||||
e.outlier = True
|
||||
e.internal_metadata.outlier = True
|
||||
try:
|
||||
yield self._handle_new_event(
|
||||
e,
|
||||
fetch_missing=True
|
||||
e, fetch_auth_from=target_host
|
||||
)
|
||||
except:
|
||||
logger.exception(
|
||||
"Failed to parse state event %s",
|
||||
"Failed to handle state event %s",
|
||||
e.event_id,
|
||||
)
|
||||
|
||||
yield self._handle_new_event(
|
||||
event,
|
||||
new_event,
|
||||
state=state,
|
||||
current_state=state,
|
||||
)
|
||||
|
||||
yield self.notifier.on_new_room_event(
|
||||
event, extra_users=[joinee]
|
||||
new_event, extra_users=[joinee]
|
||||
)
|
||||
|
||||
logger.debug("Finished joining %s to %s", joinee, room_id)
|
||||
@ -412,6 +429,9 @@ class FederationHandler(BaseHandler):
|
||||
del self.room_queues[room_id]
|
||||
|
||||
for p, origin in room_queue:
|
||||
if p.event_id in handled_events:
|
||||
continue
|
||||
|
||||
try:
|
||||
self.on_receive_pdu(origin, p, backfilled=False)
|
||||
except:
|
||||
@ -421,25 +441,24 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_make_join_request(self, context, user_id):
|
||||
def on_make_join_request(self, room_id, user_id):
|
||||
""" We've received a /make_join/ request, so we create a partial
|
||||
join event for the room and return that. We don *not* persist or
|
||||
process it until the other server has signed it and sent it back.
|
||||
"""
|
||||
event = self.event_factory.create_event(
|
||||
etype=RoomMemberEvent.TYPE,
|
||||
content={"membership": Membership.JOIN},
|
||||
room_id=context,
|
||||
user_id=user_id,
|
||||
state_key=user_id,
|
||||
builder = self.event_builder_factory.new({
|
||||
"type": EventTypes.Member,
|
||||
"content": {"membership": Membership.JOIN},
|
||||
"room_id": room_id,
|
||||
"sender": user_id,
|
||||
"state_key": user_id,
|
||||
})
|
||||
|
||||
event, context = yield self._create_new_client_event(
|
||||
builder=builder,
|
||||
)
|
||||
|
||||
snapshot = yield self.store.snapshot_room(event)
|
||||
snapshot.fill_out_prev_events(event)
|
||||
|
||||
yield self.state_handler.annotate_event_with_state(event)
|
||||
yield self.auth.add_auth_events(event)
|
||||
self.auth.check(event, auth_events=event.old_state_events)
|
||||
self.auth.check(event, auth_events=context.auth_events)
|
||||
|
||||
pdu = event
|
||||
|
||||
@ -453,12 +472,24 @@ class FederationHandler(BaseHandler):
|
||||
"""
|
||||
event = pdu
|
||||
|
||||
event.outlier = False
|
||||
logger.debug(
|
||||
"on_send_join_request: Got event: %s, signatures: %s",
|
||||
event.event_id,
|
||||
event.signatures,
|
||||
)
|
||||
|
||||
yield self._handle_new_event(event)
|
||||
event.internal_metadata.outlier = False
|
||||
|
||||
context = yield self._handle_new_event(event)
|
||||
|
||||
logger.debug(
|
||||
"on_send_join_request: After _handle_new_event: %s, sigs: %s",
|
||||
event.event_id,
|
||||
event.signatures,
|
||||
)
|
||||
|
||||
extra_users = []
|
||||
if event.type == RoomMemberEvent.TYPE:
|
||||
if event.type == EventTypes.Member:
|
||||
target_user_id = event.state_key
|
||||
target_user = self.hs.parse_userid(target_user_id)
|
||||
extra_users.append(target_user)
|
||||
@ -467,7 +498,7 @@ class FederationHandler(BaseHandler):
|
||||
event, extra_users=extra_users
|
||||
)
|
||||
|
||||
if event.type == RoomMemberEvent.TYPE:
|
||||
if event.type == EventTypes.Member:
|
||||
if event.content["membership"] == Membership.JOIN:
|
||||
user = self.hs.parse_userid(event.state_key)
|
||||
yield self.distributor.fire(
|
||||
@ -478,9 +509,9 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
destinations = set()
|
||||
|
||||
for k, s in event.state_events.items():
|
||||
for k, s in context.current_state.items():
|
||||
try:
|
||||
if k[0] == RoomMemberEvent.TYPE:
|
||||
if k[0] == EventTypes.Member:
|
||||
if s.content["membership"] == Membership.JOIN:
|
||||
destinations.add(
|
||||
self.hs.parse_userid(s.state_key).domain
|
||||
@ -490,14 +521,21 @@ class FederationHandler(BaseHandler):
|
||||
"Failed to get destination from event %s", s.event_id
|
||||
)
|
||||
|
||||
new_pdu.destinations = list(destinations)
|
||||
logger.debug(
|
||||
"on_send_join_request: Sending event: %s, signatures: %s",
|
||||
event.event_id,
|
||||
event.signatures,
|
||||
)
|
||||
|
||||
yield self.replication_layer.send_pdu(new_pdu)
|
||||
self.replication_layer.send_pdu(new_pdu, destinations)
|
||||
|
||||
auth_chain = yield self.store.get_auth_chain(event.event_id)
|
||||
state_ids = [e.event_id for e in context.current_state.values()]
|
||||
auth_chain = yield self.store.get_auth_chain(set(
|
||||
[event.event_id] + state_ids
|
||||
))
|
||||
|
||||
defer.returnValue({
|
||||
"state": event.state_events.values(),
|
||||
"state": context.current_state.values(),
|
||||
"auth_chain": auth_chain,
|
||||
})
|
||||
|
||||
@ -509,7 +547,7 @@ class FederationHandler(BaseHandler):
|
||||
"""
|
||||
event = pdu
|
||||
|
||||
event.outlier = True
|
||||
event.internal_metadata.outlier = True
|
||||
|
||||
event.signatures.update(
|
||||
compute_event_signature(
|
||||
@ -519,10 +557,11 @@ class FederationHandler(BaseHandler):
|
||||
)
|
||||
)
|
||||
|
||||
yield self.state_handler.annotate_event_with_state(event)
|
||||
context = yield self.state_handler.compute_event_context(event)
|
||||
|
||||
yield self.store.persist_event(
|
||||
event,
|
||||
context=context,
|
||||
backfilled=False,
|
||||
)
|
||||
|
||||
@ -552,13 +591,13 @@ class FederationHandler(BaseHandler):
|
||||
}
|
||||
|
||||
event = yield self.store.get_event(event_id)
|
||||
if hasattr(event, "state_key"):
|
||||
if event and event.is_state():
|
||||
# Get previous state
|
||||
if hasattr(event, "replaces_state") and event.replaces_state:
|
||||
prev_event = yield self.store.get_event(
|
||||
event.replaces_state
|
||||
)
|
||||
results[(event.type, event.state_key)] = prev_event
|
||||
if "replaces_state" in event.unsigned:
|
||||
prev_id = event.unsigned["replaces_state"]
|
||||
if prev_id != event.event_id:
|
||||
prev_event = yield self.store.get_event(prev_id)
|
||||
results[(event.type, event.state_key)] = prev_event
|
||||
else:
|
||||
del results[(event.type, event.state_key)]
|
||||
|
||||
@ -643,75 +682,88 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _handle_new_event(self, event, state=None, backfilled=False,
|
||||
current_state=None, fetch_missing=True):
|
||||
is_new_state = yield self.state_handler.annotate_event_with_state(
|
||||
event,
|
||||
old_state=state
|
||||
current_state=None, fetch_auth_from=None):
|
||||
|
||||
logger.debug(
|
||||
"_handle_new_event: Before annotate: %s, sigs: %s",
|
||||
event.event_id, event.signatures,
|
||||
)
|
||||
|
||||
if event.old_state_events:
|
||||
known_ids = set(
|
||||
[s.event_id for s in event.old_state_events.values()]
|
||||
)
|
||||
for e_id, _ in event.auth_events:
|
||||
if e_id not in known_ids:
|
||||
e = yield self.store.get_event(
|
||||
e_id,
|
||||
allow_none=True,
|
||||
context = yield self.state_handler.compute_event_context(
|
||||
event, old_state=state
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"_handle_new_event: Before auth fetch: %s, sigs: %s",
|
||||
event.event_id, event.signatures,
|
||||
)
|
||||
|
||||
is_new_state = not event.internal_metadata.is_outlier()
|
||||
|
||||
known_ids = set(
|
||||
[s.event_id for s in context.auth_events.values()]
|
||||
)
|
||||
|
||||
for e_id, _ in event.auth_events:
|
||||
if e_id not in known_ids:
|
||||
e = yield self.store.get_event(e_id, allow_none=True)
|
||||
|
||||
if not e and fetch_auth_from is not None:
|
||||
# Grab the auth_chain over federation if we are missing
|
||||
# auth events.
|
||||
auth_chain = yield self.replication_layer.get_event_auth(
|
||||
fetch_auth_from, event.event_id, event.room_id
|
||||
)
|
||||
|
||||
if not e:
|
||||
# TODO: Do some conflict res to make sure that we're
|
||||
# not the ones who are wrong.
|
||||
logger.info(
|
||||
"Rejecting %s as %s not in %s",
|
||||
event.event_id, e_id, known_ids,
|
||||
)
|
||||
raise AuthError(403, "Auth events are stale")
|
||||
|
||||
auth_events = event.old_state_events
|
||||
else:
|
||||
# We need to get the auth events from somewhere.
|
||||
|
||||
# TODO: Don't just hit the DBs?
|
||||
|
||||
auth_events = {}
|
||||
for e_id, _ in event.auth_events:
|
||||
e = yield self.store.get_event(
|
||||
e_id,
|
||||
allow_none=True,
|
||||
)
|
||||
for auth_event in auth_chain:
|
||||
yield self._handle_new_event(auth_event)
|
||||
e = yield self.store.get_event(e_id, allow_none=True)
|
||||
|
||||
if not e:
|
||||
e = yield self.replication_layer.get_pdu(
|
||||
event.origin, e_id, outlier=True
|
||||
# TODO: Do some conflict res to make sure that we're
|
||||
# not the ones who are wrong.
|
||||
logger.info(
|
||||
"Rejecting %s as %s not in db or %s",
|
||||
event.event_id, e_id, known_ids,
|
||||
)
|
||||
# FIXME: How does raising AuthError work with federation?
|
||||
raise AuthError(403, "Cannot find auth event")
|
||||
|
||||
if e and fetch_missing:
|
||||
try:
|
||||
yield self.on_receive_pdu(event.origin, e, False)
|
||||
except:
|
||||
logger.exception(
|
||||
"Failed to parse auth event %s",
|
||||
e_id,
|
||||
)
|
||||
context.auth_events[(e.type, e.state_key)] = e
|
||||
|
||||
if not e:
|
||||
logger.warn("Can't find auth event %s.", e_id)
|
||||
logger.debug(
|
||||
"_handle_new_event: Before hack: %s, sigs: %s",
|
||||
event.event_id, event.signatures,
|
||||
)
|
||||
|
||||
auth_events[(e.type, e.state_key)] = e
|
||||
if event.type == EventTypes.Member and not event.auth_events:
|
||||
if len(event.prev_events) == 1:
|
||||
c = yield self.store.get_event(event.prev_events[0][0])
|
||||
if c.type == EventTypes.Create:
|
||||
context.auth_events[(c.type, c.state_key)] = c
|
||||
|
||||
if event.type == RoomMemberEvent.TYPE and not event.auth_events:
|
||||
if len(event.prev_events) == 1:
|
||||
c = yield self.store.get_event(event.prev_events[0][0])
|
||||
if c.type == RoomCreateEvent.TYPE:
|
||||
auth_events[(c.type, c.state_key)] = c
|
||||
logger.debug(
|
||||
"_handle_new_event: Before auth check: %s, sigs: %s",
|
||||
event.event_id, event.signatures,
|
||||
)
|
||||
|
||||
self.auth.check(event, auth_events=auth_events)
|
||||
self.auth.check(event, auth_events=context.auth_events)
|
||||
|
||||
logger.debug(
|
||||
"_handle_new_event: Before persist_event: %s, sigs: %s",
|
||||
event.event_id, event.signatures,
|
||||
)
|
||||
|
||||
yield self.store.persist_event(
|
||||
event,
|
||||
context=context,
|
||||
backfilled=backfilled,
|
||||
is_new_state=(is_new_state and not backfilled),
|
||||
current_state=current_state,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"_handle_new_event: After persist_event: %s, sigs: %s",
|
||||
event.event_id, event.signatures,
|
||||
)
|
||||
|
||||
defer.returnValue(context)
|
||||
|
@ -15,10 +15,11 @@
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import RoomError
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
from synapse.events.validator import EventValidator
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
import logging
|
||||
@ -32,7 +33,7 @@ class MessageHandler(BaseHandler):
|
||||
super(MessageHandler, self).__init__(hs)
|
||||
self.hs = hs
|
||||
self.clock = hs.get_clock()
|
||||
self.event_factory = hs.get_event_factory()
|
||||
self.validator = EventValidator()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_message(self, msg_id=None, room_id=None, sender_id=None,
|
||||
@ -63,35 +64,6 @@ class MessageHandler(BaseHandler):
|
||||
|
||||
defer.returnValue(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_message(self, event=None, suppress_auth=False):
|
||||
""" Send a message.
|
||||
|
||||
Args:
|
||||
event : The message event to store.
|
||||
suppress_auth (bool) : True to suppress auth for this message. This
|
||||
is primarily so the home server can inject messages into rooms at
|
||||
will.
|
||||
Raises:
|
||||
SynapseError if something went wrong.
|
||||
"""
|
||||
|
||||
self.ratelimit(event.user_id)
|
||||
# TODO(paul): Why does 'event' not have a 'user' object?
|
||||
user = self.hs.parse_userid(event.user_id)
|
||||
assert user.is_mine, "User must be our own: %s" % (user,)
|
||||
|
||||
snapshot = yield self.store.snapshot_room(event)
|
||||
|
||||
yield self._on_new_room_event(
|
||||
event, snapshot, suppress_auth=suppress_auth
|
||||
)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
self.hs.get_handlers().presence_handler.bump_presence_active_time(
|
||||
user
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_messages(self, user_id=None, room_id=None, pagin_config=None,
|
||||
feedback=False):
|
||||
@ -134,19 +106,53 @@ class MessageHandler(BaseHandler):
|
||||
defer.returnValue(chunk)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def store_room_data(self, event=None):
|
||||
""" Stores data for a room.
|
||||
def create_and_send_event(self, event_dict):
|
||||
""" Given a dict from a client, create and handle a new event.
|
||||
|
||||
Creates an FrozenEvent object, filling out auth_events, prev_events,
|
||||
etc.
|
||||
|
||||
Adds display names to Join membership events.
|
||||
|
||||
Persists and notifies local clients and federation.
|
||||
|
||||
Args:
|
||||
event : The room path event
|
||||
stamp_event (bool) : True to stamp event content with server keys.
|
||||
Raises:
|
||||
SynapseError if something went wrong.
|
||||
event_dict (dict): An entire event
|
||||
"""
|
||||
builder = self.event_builder_factory.new(event_dict)
|
||||
|
||||
snapshot = yield self.store.snapshot_room(event)
|
||||
self.validator.validate_new(builder)
|
||||
|
||||
yield self._on_new_room_event(event, snapshot)
|
||||
self.ratelimit(builder.user_id)
|
||||
# TODO(paul): Why does 'event' not have a 'user' object?
|
||||
user = self.hs.parse_userid(builder.user_id)
|
||||
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
|
||||
|
||||
if builder.type == EventTypes.Member:
|
||||
membership = builder.content.get("membership", None)
|
||||
if membership == Membership.JOIN:
|
||||
joinee = self.hs.parse_userid(builder.state_key)
|
||||
# If event doesn't include a display name, add one.
|
||||
yield self.distributor.fire(
|
||||
"collect_presencelike_data",
|
||||
joinee,
|
||||
builder.content
|
||||
)
|
||||
|
||||
event, context = yield self._create_new_client_event(
|
||||
builder=builder,
|
||||
)
|
||||
|
||||
if event.type == EventTypes.Member:
|
||||
member_handler = self.hs.get_handlers().room_member_handler
|
||||
yield member_handler.change_membership(event, context)
|
||||
else:
|
||||
yield self.handle_new_client_event(
|
||||
event=event,
|
||||
context=context,
|
||||
)
|
||||
|
||||
defer.returnValue(event)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_data(self, user_id=None, room_id=None,
|
||||
@ -180,13 +186,6 @@ class MessageHandler(BaseHandler):
|
||||
defer.returnValue(fb)
|
||||
defer.returnValue(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_feedback(self, event):
|
||||
snapshot = yield self.store.snapshot_room(event)
|
||||
|
||||
# store message in db
|
||||
yield self._on_new_room_event(event, snapshot)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_events(self, user_id, room_id):
|
||||
"""Retrieve all state events for a given room.
|
||||
|
@ -147,7 +147,7 @@ class PresenceHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def is_presence_visible(self, observer_user, observed_user):
|
||||
assert(observed_user.is_mine)
|
||||
assert(self.hs.is_mine(observed_user))
|
||||
|
||||
if observer_user == observed_user:
|
||||
defer.returnValue(True)
|
||||
@ -165,7 +165,7 @@ class PresenceHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state(self, target_user, auth_user, as_event=False):
|
||||
if target_user.is_mine:
|
||||
if self.hs.is_mine(target_user):
|
||||
visible = yield self.is_presence_visible(
|
||||
observer_user=auth_user,
|
||||
observed_user=target_user
|
||||
@ -212,7 +212,7 @@ class PresenceHandler(BaseHandler):
|
||||
# TODO (erikj): Turn this back on. Why did we end up sending EDUs
|
||||
# everywhere?
|
||||
|
||||
if not target_user.is_mine:
|
||||
if not self.hs.is_mine(target_user):
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
if target_user != auth_user:
|
||||
@ -291,7 +291,7 @@ class PresenceHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def user_joined_room(self, user, room_id):
|
||||
if user.is_mine:
|
||||
if self.hs.is_mine(user):
|
||||
statuscache = self._get_or_make_usercache(user)
|
||||
|
||||
# No actual update but we need to bump the serial anyway for the
|
||||
@ -309,7 +309,7 @@ class PresenceHandler(BaseHandler):
|
||||
rm_handler = self.homeserver.get_handlers().room_member_handler
|
||||
curr_users = yield rm_handler.get_room_members(room_id)
|
||||
|
||||
for local_user in [c for c in curr_users if c.is_mine]:
|
||||
for local_user in [c for c in curr_users if self.hs.is_mine(c)]:
|
||||
self.push_update_to_local_and_remote(
|
||||
observed_user=local_user,
|
||||
users_to_push=[user],
|
||||
@ -318,14 +318,14 @@ class PresenceHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_invite(self, observer_user, observed_user):
|
||||
if not observer_user.is_mine:
|
||||
if not self.hs.is_mine(observer_user):
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
yield self.store.add_presence_list_pending(
|
||||
observer_user.localpart, observed_user.to_string()
|
||||
)
|
||||
|
||||
if observed_user.is_mine:
|
||||
if self.hs.is_mine(observed_user):
|
||||
yield self.invite_presence(observed_user, observer_user)
|
||||
else:
|
||||
yield self.federation.send_edu(
|
||||
@ -339,7 +339,7 @@ class PresenceHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _should_accept_invite(self, observed_user, observer_user):
|
||||
if not observed_user.is_mine:
|
||||
if not self.hs.is_mine(observed_user):
|
||||
defer.returnValue(False)
|
||||
|
||||
row = yield self.store.has_presence_state(observed_user.localpart)
|
||||
@ -359,7 +359,7 @@ class PresenceHandler(BaseHandler):
|
||||
observed_user.localpart, observer_user.to_string()
|
||||
)
|
||||
|
||||
if observer_user.is_mine:
|
||||
if self.hs.is_mine(observer_user):
|
||||
if accept:
|
||||
yield self.accept_presence(observed_user, observer_user)
|
||||
else:
|
||||
@ -396,7 +396,7 @@ class PresenceHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def drop(self, observed_user, observer_user):
|
||||
if not observer_user.is_mine:
|
||||
if not self.hs.is_mine(observer_user):
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
yield self.store.del_presence_list(
|
||||
@ -410,7 +410,7 @@ class PresenceHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_presence_list(self, observer_user, accepted=None):
|
||||
if not observer_user.is_mine:
|
||||
if not self.hs.is_mine(observer_user):
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
presence = yield self.store.get_presence_list(
|
||||
@ -465,7 +465,7 @@ class PresenceHandler(BaseHandler):
|
||||
)
|
||||
|
||||
for target_user in target_users:
|
||||
if target_user.is_mine:
|
||||
if self.hs.is_mine(target_user):
|
||||
self._start_polling_local(user, target_user)
|
||||
|
||||
# We want to tell the person that just came online
|
||||
@ -477,7 +477,7 @@ class PresenceHandler(BaseHandler):
|
||||
)
|
||||
|
||||
deferreds = []
|
||||
remote_users = [u for u in target_users if not u.is_mine]
|
||||
remote_users = [u for u in target_users if not self.hs.is_mine(u)]
|
||||
remoteusers_by_domain = partition(remote_users, lambda u: u.domain)
|
||||
# Only poll for people in our get_presence_list
|
||||
for domain in remoteusers_by_domain:
|
||||
@ -520,7 +520,7 @@ class PresenceHandler(BaseHandler):
|
||||
def stop_polling_presence(self, user, target_user=None):
|
||||
logger.debug("Stop polling for presence from %s", user)
|
||||
|
||||
if not target_user or target_user.is_mine:
|
||||
if not target_user or self.hs.is_mine(target_user):
|
||||
self._stop_polling_local(user, target_user=target_user)
|
||||
|
||||
deferreds = []
|
||||
@ -579,7 +579,7 @@ class PresenceHandler(BaseHandler):
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def push_presence(self, user, statuscache):
|
||||
assert(user.is_mine)
|
||||
assert(self.hs.is_mine(user))
|
||||
|
||||
logger.debug("Pushing presence update from %s", user)
|
||||
|
||||
@ -659,10 +659,6 @@ class PresenceHandler(BaseHandler):
|
||||
if room_ids:
|
||||
logger.debug(" | %d interested room IDs %r", len(room_ids), room_ids)
|
||||
|
||||
if not observers and not room_ids:
|
||||
logger.debug(" | no interested observers or room IDs")
|
||||
continue
|
||||
|
||||
state = dict(push)
|
||||
del state["user_id"]
|
||||
|
||||
@ -683,6 +679,10 @@ class PresenceHandler(BaseHandler):
|
||||
self._user_cachemap_latest_serial += 1
|
||||
statuscache.update(state, serial=self._user_cachemap_latest_serial)
|
||||
|
||||
if not observers and not room_ids:
|
||||
logger.debug(" | no interested observers or room IDs")
|
||||
continue
|
||||
|
||||
self.push_update_to_clients(
|
||||
observed_user=user,
|
||||
users_to_push=observers,
|
||||
@ -696,7 +696,7 @@ class PresenceHandler(BaseHandler):
|
||||
for poll in content.get("poll", []):
|
||||
user = self.hs.parse_userid(poll)
|
||||
|
||||
if not user.is_mine:
|
||||
if not self.hs.is_mine(user):
|
||||
continue
|
||||
|
||||
# TODO(paul) permissions checks
|
||||
@ -711,7 +711,7 @@ class PresenceHandler(BaseHandler):
|
||||
for unpoll in content.get("unpoll", []):
|
||||
user = self.hs.parse_userid(unpoll)
|
||||
|
||||
if not user.is_mine:
|
||||
if not self.hs.is_mine(user):
|
||||
continue
|
||||
|
||||
if user in self._remote_sendmap:
|
||||
@ -730,7 +730,7 @@ class PresenceHandler(BaseHandler):
|
||||
|
||||
localusers, remoteusers = partitionbool(
|
||||
users_to_push,
|
||||
lambda u: u.is_mine
|
||||
lambda u: self.hs.is_mine(u)
|
||||
)
|
||||
|
||||
localusers = set(localusers)
|
||||
@ -788,7 +788,7 @@ class PresenceEventSource(object):
|
||||
[u.to_string() for u in observer_user, observed_user])):
|
||||
defer.returnValue(True)
|
||||
|
||||
if observed_user.is_mine:
|
||||
if self.hs.is_mine(observed_user):
|
||||
pushmap = presence._local_pushmap
|
||||
|
||||
defer.returnValue(
|
||||
@ -804,6 +804,7 @@ class PresenceEventSource(object):
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_new_events_for_user(self, user, from_key, limit):
|
||||
from_key = int(from_key)
|
||||
|
||||
@ -816,7 +817,8 @@ class PresenceEventSource(object):
|
||||
# TODO(paul): use a DeferredList ? How to limit concurrency.
|
||||
for observed_user in cachemap.keys():
|
||||
cached = cachemap[observed_user]
|
||||
if not (from_key < cached.serial):
|
||||
|
||||
if cached.serial <= from_key:
|
||||
continue
|
||||
|
||||
if (yield self.is_visible(observer_user, observed_user)):
|
||||
|
@ -51,7 +51,7 @@ class ProfileHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_displayname(self, target_user):
|
||||
if target_user.is_mine:
|
||||
if self.hs.is_mine(target_user):
|
||||
displayname = yield self.store.get_profile_displayname(
|
||||
target_user.localpart
|
||||
)
|
||||
@ -81,7 +81,7 @@ class ProfileHandler(BaseHandler):
|
||||
def set_displayname(self, target_user, auth_user, new_displayname):
|
||||
"""target_user is the user whose displayname is to be changed;
|
||||
auth_user is the user attempting to make this change."""
|
||||
if not target_user.is_mine:
|
||||
if not self.hs.is_mine(target_user):
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
if target_user != auth_user:
|
||||
@ -101,7 +101,7 @@ class ProfileHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_avatar_url(self, target_user):
|
||||
if target_user.is_mine:
|
||||
if self.hs.is_mine(target_user):
|
||||
avatar_url = yield self.store.get_profile_avatar_url(
|
||||
target_user.localpart
|
||||
)
|
||||
@ -130,7 +130,7 @@ class ProfileHandler(BaseHandler):
|
||||
def set_avatar_url(self, target_user, auth_user, new_avatar_url):
|
||||
"""target_user is the user whose avatar_url is to be changed;
|
||||
auth_user is the user attempting to make this change."""
|
||||
if not target_user.is_mine:
|
||||
if not self.hs.is_mine(target_user):
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
if target_user != auth_user:
|
||||
@ -150,7 +150,7 @@ class ProfileHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def collect_presencelike_data(self, user, state):
|
||||
if not user.is_mine:
|
||||
if not self.hs.is_mine(user):
|
||||
defer.returnValue(None)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
@ -170,7 +170,7 @@ class ProfileHandler(BaseHandler):
|
||||
@defer.inlineCallbacks
|
||||
def on_profile_query(self, args):
|
||||
user = self.hs.parse_userid(args["user_id"])
|
||||
if not user.is_mine:
|
||||
if not self.hs.is_mine(user):
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
just_field = args.get("field", None)
|
||||
@ -191,7 +191,7 @@ class ProfileHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _update_join_states(self, user):
|
||||
if not user.is_mine:
|
||||
if not self.hs.is_mine(user):
|
||||
return
|
||||
|
||||
joins = yield self.store.get_rooms_for_user_where_membership_is(
|
||||
@ -200,8 +200,6 @@ class ProfileHandler(BaseHandler):
|
||||
)
|
||||
|
||||
for j in joins:
|
||||
snapshot = yield self.store.snapshot_room(j)
|
||||
|
||||
content = {
|
||||
"membership": j.content["membership"],
|
||||
}
|
||||
@ -210,14 +208,11 @@ class ProfileHandler(BaseHandler):
|
||||
"collect_presencelike_data", user, content
|
||||
)
|
||||
|
||||
new_event = self.event_factory.create_event(
|
||||
etype=j.type,
|
||||
room_id=j.room_id,
|
||||
state_key=j.state_key,
|
||||
content=content,
|
||||
user_id=j.state_key,
|
||||
)
|
||||
|
||||
yield self._on_new_room_event(
|
||||
new_event, snapshot, suppress_auth=True
|
||||
)
|
||||
msg_handler = self.hs.get_handlers().message_handler
|
||||
yield msg_handler.create_and_send_event({
|
||||
"type": j.type,
|
||||
"room_id": j.room_id,
|
||||
"state_key": j.state_key,
|
||||
"content": content,
|
||||
"sender": j.state_key,
|
||||
})
|
||||
|
@ -22,6 +22,7 @@ from synapse.api.errors import (
|
||||
)
|
||||
from ._base import BaseHandler
|
||||
import synapse.util.stringutils as stringutils
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
from synapse.http.client import CaptchaServerHttpClient
|
||||
|
||||
@ -54,12 +55,13 @@ class RegistrationHandler(BaseHandler):
|
||||
Raises:
|
||||
RegistrationError if there was a problem registering.
|
||||
"""
|
||||
yield run_on_reactor()
|
||||
password_hash = None
|
||||
if password:
|
||||
password_hash = bcrypt.hashpw(password, bcrypt.gensalt())
|
||||
|
||||
if localpart:
|
||||
user = UserID(localpart, self.hs.hostname, True)
|
||||
user = UserID(localpart, self.hs.hostname)
|
||||
user_id = user.to_string()
|
||||
|
||||
token = self._generate_token(user_id)
|
||||
@ -78,7 +80,7 @@ class RegistrationHandler(BaseHandler):
|
||||
while not user_id and not token:
|
||||
try:
|
||||
localpart = self._generate_user_id()
|
||||
user = UserID(localpart, self.hs.hostname, True)
|
||||
user = UserID(localpart, self.hs.hostname)
|
||||
user_id = user.to_string()
|
||||
|
||||
token = self._generate_token(user_id)
|
||||
|
@ -17,12 +17,8 @@
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.types import UserID, RoomAlias, RoomID
|
||||
from synapse.api.constants import Membership, JoinRules
|
||||
from synapse.api.constants import EventTypes, Membership, JoinRules
|
||||
from synapse.api.errors import StoreError, SynapseError
|
||||
from synapse.api.events.room import (
|
||||
RoomMemberEvent, RoomCreateEvent, RoomPowerLevelsEvent,
|
||||
RoomTopicEvent, RoomNameEvent, RoomJoinRulesEvent,
|
||||
)
|
||||
from synapse.util import stringutils
|
||||
from synapse.util.async import run_on_reactor
|
||||
from ._base import BaseHandler
|
||||
@ -52,9 +48,9 @@ class RoomCreationHandler(BaseHandler):
|
||||
self.ratelimit(user_id)
|
||||
|
||||
if "room_alias_name" in config:
|
||||
room_alias = RoomAlias.create_local(
|
||||
room_alias = RoomAlias.create(
|
||||
config["room_alias_name"],
|
||||
self.hs
|
||||
self.hs.hostname,
|
||||
)
|
||||
mapping = yield self.store.get_association_from_room_alias(
|
||||
room_alias
|
||||
@ -76,8 +72,8 @@ class RoomCreationHandler(BaseHandler):
|
||||
|
||||
if room_id:
|
||||
# Ensure room_id is the correct type
|
||||
room_id_obj = RoomID.from_string(room_id, self.hs)
|
||||
if not room_id_obj.is_mine:
|
||||
room_id_obj = RoomID.from_string(room_id)
|
||||
if not self.hs.is_mine(room_id_obj):
|
||||
raise SynapseError(400, "Room id must be local")
|
||||
|
||||
yield self.store.store_room(
|
||||
@ -93,7 +89,10 @@ class RoomCreationHandler(BaseHandler):
|
||||
while attempts < 5:
|
||||
try:
|
||||
random_string = stringutils.random_string(18)
|
||||
gen_room_id = RoomID.create_local(random_string, self.hs)
|
||||
gen_room_id = RoomID.create(
|
||||
random_string,
|
||||
self.hs.hostname,
|
||||
)
|
||||
yield self.store.store_room(
|
||||
room_id=gen_room_id.to_string(),
|
||||
room_creator_user_id=user_id,
|
||||
@ -120,59 +119,37 @@ class RoomCreationHandler(BaseHandler):
|
||||
user, room_id, is_public=is_public
|
||||
)
|
||||
|
||||
room_member_handler = self.hs.get_handlers().room_member_handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_event(event):
|
||||
snapshot = yield self.store.snapshot_room(event)
|
||||
|
||||
logger.debug("Event: %s", event)
|
||||
|
||||
if event.type == RoomMemberEvent.TYPE:
|
||||
yield room_member_handler.change_membership(
|
||||
event,
|
||||
do_auth=True
|
||||
)
|
||||
else:
|
||||
yield self._on_new_room_event(
|
||||
event, snapshot, extra_users=[user], suppress_auth=True
|
||||
)
|
||||
msg_handler = self.hs.get_handlers().message_handler
|
||||
|
||||
for event in creation_events:
|
||||
yield handle_event(event)
|
||||
yield msg_handler.create_and_send_event(event)
|
||||
|
||||
if "name" in config:
|
||||
name = config["name"]
|
||||
name_event = self.event_factory.create_event(
|
||||
etype=RoomNameEvent.TYPE,
|
||||
room_id=room_id,
|
||||
user_id=user_id,
|
||||
content={"name": name},
|
||||
)
|
||||
|
||||
yield handle_event(name_event)
|
||||
yield msg_handler.create_and_send_event({
|
||||
"type": EventTypes.Name,
|
||||
"room_id": room_id,
|
||||
"sender": user_id,
|
||||
"content": {"name": name},
|
||||
})
|
||||
|
||||
if "topic" in config:
|
||||
topic = config["topic"]
|
||||
topic_event = self.event_factory.create_event(
|
||||
etype=RoomTopicEvent.TYPE,
|
||||
room_id=room_id,
|
||||
user_id=user_id,
|
||||
content={"topic": topic},
|
||||
)
|
||||
yield msg_handler.create_and_send_event({
|
||||
"type": EventTypes.Topic,
|
||||
"room_id": room_id,
|
||||
"sender": user_id,
|
||||
"content": {"topic": topic},
|
||||
})
|
||||
|
||||
yield handle_event(topic_event)
|
||||
|
||||
content = {"membership": Membership.INVITE}
|
||||
for invitee in invite_list:
|
||||
invite_event = self.event_factory.create_event(
|
||||
etype=RoomMemberEvent.TYPE,
|
||||
state_key=invitee,
|
||||
room_id=room_id,
|
||||
user_id=user_id,
|
||||
content=content
|
||||
)
|
||||
yield handle_event(invite_event)
|
||||
yield msg_handler.create_and_send_event({
|
||||
"type": EventTypes.Member,
|
||||
"state_key": invitee,
|
||||
"room_id": room_id,
|
||||
"sender": user_id,
|
||||
"content": {"membership": Membership.INVITE},
|
||||
})
|
||||
|
||||
result = {"room_id": room_id}
|
||||
|
||||
@ -189,40 +166,44 @@ class RoomCreationHandler(BaseHandler):
|
||||
|
||||
event_keys = {
|
||||
"room_id": room_id,
|
||||
"user_id": creator_id,
|
||||
"sender": creator_id,
|
||||
"state_key": "",
|
||||
}
|
||||
|
||||
def create(etype, **content):
|
||||
return self.event_factory.create_event(
|
||||
etype=etype,
|
||||
content=content,
|
||||
**event_keys
|
||||
)
|
||||
def create(etype, content, **kwargs):
|
||||
e = {
|
||||
"type": etype,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
e.update(event_keys)
|
||||
e.update(kwargs)
|
||||
|
||||
return e
|
||||
|
||||
creation_event = create(
|
||||
etype=RoomCreateEvent.TYPE,
|
||||
creator=creator.to_string(),
|
||||
etype=EventTypes.Create,
|
||||
content={"creator": creator.to_string()},
|
||||
)
|
||||
|
||||
join_event = self.event_factory.create_event(
|
||||
etype=RoomMemberEvent.TYPE,
|
||||
join_event = create(
|
||||
etype=EventTypes.Member,
|
||||
state_key=creator_id,
|
||||
content={
|
||||
"membership": Membership.JOIN,
|
||||
},
|
||||
**event_keys
|
||||
)
|
||||
|
||||
power_levels_event = self.event_factory.create_event(
|
||||
etype=RoomPowerLevelsEvent.TYPE,
|
||||
power_levels_event = create(
|
||||
etype=EventTypes.PowerLevels,
|
||||
content={
|
||||
"users": {
|
||||
creator.to_string(): 100,
|
||||
},
|
||||
"users_default": 0,
|
||||
"events": {
|
||||
RoomNameEvent.TYPE: 100,
|
||||
RoomPowerLevelsEvent.TYPE: 100,
|
||||
EventTypes.Name: 100,
|
||||
EventTypes.PowerLevels: 100,
|
||||
},
|
||||
"events_default": 0,
|
||||
"state_default": 50,
|
||||
@ -230,13 +211,12 @@ class RoomCreationHandler(BaseHandler):
|
||||
"kick": 50,
|
||||
"redact": 50
|
||||
},
|
||||
**event_keys
|
||||
)
|
||||
|
||||
join_rule = JoinRules.PUBLIC if is_public else JoinRules.INVITE
|
||||
join_rules_event = create(
|
||||
etype=RoomJoinRulesEvent.TYPE,
|
||||
join_rule=join_rule,
|
||||
etype=EventTypes.JoinRules,
|
||||
content={"join_rule": join_rule},
|
||||
)
|
||||
|
||||
return [
|
||||
@ -260,6 +240,7 @@ class RoomMemberHandler(BaseHandler):
|
||||
|
||||
self.distributor = hs.get_distributor()
|
||||
self.distributor.declare("user_joined_room")
|
||||
self.distributor.declare("user_left_room")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_members(self, room_id, membership=Membership.JOIN):
|
||||
@ -287,7 +268,7 @@ class RoomMemberHandler(BaseHandler):
|
||||
if ignore_user is not None and member == ignore_user:
|
||||
continue
|
||||
|
||||
if member.is_mine:
|
||||
if self.hs.is_mine(member):
|
||||
if localusers is not None:
|
||||
localusers.add(member)
|
||||
else:
|
||||
@ -348,7 +329,7 @@ class RoomMemberHandler(BaseHandler):
|
||||
defer.returnValue(member)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def change_membership(self, event=None, do_auth=True):
|
||||
def change_membership(self, event, context, do_auth=True):
|
||||
""" Change the membership status of a user in a room.
|
||||
|
||||
Args:
|
||||
@ -358,11 +339,9 @@ class RoomMemberHandler(BaseHandler):
|
||||
"""
|
||||
target_user_id = event.state_key
|
||||
|
||||
snapshot = yield self.store.snapshot_room(event)
|
||||
|
||||
## TODO(markjh): get prev state from snapshot.
|
||||
prev_state = yield self.store.get_room_member(
|
||||
target_user_id, event.room_id
|
||||
prev_state = context.current_state.get(
|
||||
(EventTypes.Member, target_user_id),
|
||||
None
|
||||
)
|
||||
|
||||
room_id = event.room_id
|
||||
@ -371,10 +350,11 @@ class RoomMemberHandler(BaseHandler):
|
||||
# if this HS is not currently in the room, i.e. we have to do the
|
||||
# invite/join dance.
|
||||
if event.membership == Membership.JOIN:
|
||||
yield self._do_join(event, snapshot, do_auth=do_auth)
|
||||
yield self._do_join(event, context, do_auth=do_auth)
|
||||
else:
|
||||
# This is not a JOIN, so we can handle it normally.
|
||||
|
||||
# FIXME: This isn't idempotency.
|
||||
if prev_state and prev_state.membership == event.membership:
|
||||
# double same action, treat this event as a NOOP.
|
||||
defer.returnValue({})
|
||||
@ -383,10 +363,16 @@ class RoomMemberHandler(BaseHandler):
|
||||
yield self._do_local_membership_update(
|
||||
event,
|
||||
membership=event.content["membership"],
|
||||
snapshot=snapshot,
|
||||
context=context,
|
||||
do_auth=do_auth,
|
||||
)
|
||||
|
||||
if prev_state and prev_state.membership == Membership.JOIN:
|
||||
user = self.hs.parse_userid(event.user_id)
|
||||
self.distributor.fire(
|
||||
"user_left_room", user=user, room_id=event.room_id
|
||||
)
|
||||
|
||||
defer.returnValue({"room_id": room_id})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -404,33 +390,32 @@ class RoomMemberHandler(BaseHandler):
|
||||
|
||||
host = hosts[0]
|
||||
|
||||
content.update({"membership": Membership.JOIN})
|
||||
new_event = self.event_factory.create_event(
|
||||
etype=RoomMemberEvent.TYPE,
|
||||
state_key=joinee.to_string(),
|
||||
room_id=room_id,
|
||||
user_id=joinee.to_string(),
|
||||
membership=Membership.JOIN,
|
||||
content=content,
|
||||
# If event doesn't include a display name, add one.
|
||||
yield self.distributor.fire(
|
||||
"collect_presencelike_data", joinee, content
|
||||
)
|
||||
|
||||
snapshot = yield self.store.snapshot_room(new_event)
|
||||
content.update({"membership": Membership.JOIN})
|
||||
builder = self.event_builder_factory.new({
|
||||
"type": EventTypes.Member,
|
||||
"state_key": joinee.to_string(),
|
||||
"room_id": room_id,
|
||||
"sender": joinee.to_string(),
|
||||
"membership": Membership.JOIN,
|
||||
"content": content,
|
||||
})
|
||||
event, context = yield self._create_new_client_event(builder)
|
||||
|
||||
yield self._do_join(new_event, snapshot, room_host=host, do_auth=True)
|
||||
yield self._do_join(event, context, room_host=host, do_auth=True)
|
||||
|
||||
defer.returnValue({"room_id": room_id})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_join(self, event, snapshot, room_host=None, do_auth=True):
|
||||
def _do_join(self, event, context, room_host=None, do_auth=True):
|
||||
joinee = self.hs.parse_userid(event.state_key)
|
||||
# room_id = RoomID.from_string(event.room_id, self.hs)
|
||||
room_id = event.room_id
|
||||
|
||||
# If event doesn't include a display name, add one.
|
||||
yield self.distributor.fire(
|
||||
"collect_presencelike_data", joinee, event.content
|
||||
)
|
||||
|
||||
# XXX: We don't do an auth check if we are doing an invite
|
||||
# join dance for now, since we're kinda implicitly checking
|
||||
# that we are allowed to join when we decide whether or not we
|
||||
@ -452,31 +437,29 @@ class RoomMemberHandler(BaseHandler):
|
||||
)
|
||||
|
||||
if prev_state and prev_state.membership == Membership.INVITE:
|
||||
room = yield self.store.get_room(room_id)
|
||||
inviter = UserID.from_string(
|
||||
prev_state.user_id, self.hs
|
||||
)
|
||||
inviter = UserID.from_string(prev_state.user_id)
|
||||
|
||||
should_do_dance = not inviter.is_mine and not room
|
||||
should_do_dance = not self.hs.is_mine(inviter)
|
||||
room_host = inviter.domain
|
||||
else:
|
||||
should_do_dance = False
|
||||
|
||||
have_joined = False
|
||||
if should_do_dance:
|
||||
handler = self.hs.get_handlers().federation_handler
|
||||
have_joined = yield handler.do_invite_join(
|
||||
room_host, room_id, event.user_id, event.content, snapshot
|
||||
yield handler.do_invite_join(
|
||||
room_host,
|
||||
room_id,
|
||||
event.user_id,
|
||||
event.get_dict()["content"], # FIXME To get a non-frozen dict
|
||||
context
|
||||
)
|
||||
|
||||
# We want to do the _do_update inside the room lock.
|
||||
if not have_joined:
|
||||
else:
|
||||
logger.debug("Doing normal join")
|
||||
|
||||
yield self._do_local_membership_update(
|
||||
event,
|
||||
membership=event.content["membership"],
|
||||
snapshot=snapshot,
|
||||
context=context,
|
||||
do_auth=do_auth,
|
||||
)
|
||||
|
||||
@ -501,10 +484,10 @@ class RoomMemberHandler(BaseHandler):
|
||||
if prev_state and prev_state.membership == Membership.INVITE:
|
||||
room = yield self.store.get_room(room_id)
|
||||
inviter = UserID.from_string(
|
||||
prev_state.sender, self.hs
|
||||
prev_state.sender
|
||||
)
|
||||
|
||||
is_remote_invite_join = not inviter.is_mine and not room
|
||||
is_remote_invite_join = not self.hs.is_mine(inviter) and not room
|
||||
room_host = inviter.domain
|
||||
else:
|
||||
is_remote_invite_join = False
|
||||
@ -526,25 +509,17 @@ class RoomMemberHandler(BaseHandler):
|
||||
defer.returnValue(room_ids)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_local_membership_update(self, event, membership, snapshot,
|
||||
def _do_local_membership_update(self, event, membership, context,
|
||||
do_auth):
|
||||
yield run_on_reactor()
|
||||
|
||||
# If we're inviting someone, then we should also send it to that
|
||||
# HS.
|
||||
target_user_id = event.state_key
|
||||
target_user = self.hs.parse_userid(target_user_id)
|
||||
if membership == Membership.INVITE and not target_user.is_mine:
|
||||
do_invite_host = target_user.domain
|
||||
else:
|
||||
do_invite_host = None
|
||||
target_user = self.hs.parse_userid(event.state_key)
|
||||
|
||||
yield self._on_new_room_event(
|
||||
yield self.handle_new_client_event(
|
||||
event,
|
||||
snapshot,
|
||||
context,
|
||||
extra_users=[target_user],
|
||||
suppress_auth=(not do_auth),
|
||||
do_invite_host=do_invite_host,
|
||||
)
|
||||
|
||||
|
||||
|
@ -43,22 +43,50 @@ class TypingNotificationHandler(BaseHandler):
|
||||
|
||||
self.federation.register_edu_handler("m.typing", self._recv_edu)
|
||||
|
||||
self._member_typing_until = {}
|
||||
hs.get_distributor().observe("user_left_room", self.user_left_room)
|
||||
|
||||
self._member_typing_until = {} # clock time we expect to stop
|
||||
self._member_typing_timer = {} # deferreds to manage theabove
|
||||
|
||||
# map room IDs to serial numbers
|
||||
self._room_serials = {}
|
||||
self._latest_room_serial = 0
|
||||
# map room IDs to sets of users currently typing
|
||||
self._room_typing = {}
|
||||
|
||||
def tearDown(self):
|
||||
"""Cancels all the pending timers.
|
||||
Normally this shouldn't be needed, but it's required from unit tests
|
||||
to avoid a "Reactor was unclean" warning."""
|
||||
for t in self._member_typing_timer.values():
|
||||
self.clock.cancel_call_later(t)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def started_typing(self, target_user, auth_user, room_id, timeout):
|
||||
if not target_user.is_mine:
|
||||
if not self.hs.is_mine(target_user):
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
if target_user != auth_user:
|
||||
raise AuthError(400, "Cannot set another user's typing state")
|
||||
|
||||
yield self.auth.check_joined_room(room_id, target_user.to_string())
|
||||
|
||||
logger.debug(
|
||||
"%s has started typing in %s", target_user.to_string(), room_id
|
||||
)
|
||||
|
||||
until = self.clock.time_msec() + timeout
|
||||
member = RoomMember(room_id=room_id, user=target_user)
|
||||
|
||||
was_present = member in self._member_typing_until
|
||||
|
||||
if member in self._member_typing_timer:
|
||||
self.clock.cancel_call_later(self._member_typing_timer[member])
|
||||
|
||||
self._member_typing_until[member] = until
|
||||
self._member_typing_timer[member] = self.clock.call_later(
|
||||
timeout / 1000, lambda: self._stopped_typing(member)
|
||||
)
|
||||
|
||||
if was_present:
|
||||
# No point sending another notification
|
||||
@ -72,24 +100,45 @@ class TypingNotificationHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def stopped_typing(self, target_user, auth_user, room_id):
|
||||
if not target_user.is_mine:
|
||||
if not self.hs.is_mine(target_user):
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
if target_user != auth_user:
|
||||
raise AuthError(400, "Cannot set another user's typing state")
|
||||
|
||||
yield self.auth.check_joined_room(room_id, target_user.to_string())
|
||||
|
||||
logger.debug(
|
||||
"%s has stopped typing in %s", target_user.to_string(), room_id
|
||||
)
|
||||
|
||||
member = RoomMember(room_id=room_id, user=target_user)
|
||||
|
||||
yield self._stopped_typing(member)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def user_left_room(self, user, room_id):
|
||||
if self.hs.is_mine(user):
|
||||
member = RoomMember(room_id=room_id, user=user)
|
||||
yield self._stopped_typing(member)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _stopped_typing(self, member):
|
||||
if member not in self._member_typing_until:
|
||||
# No point
|
||||
defer.returnValue(None)
|
||||
|
||||
yield self._push_update(
|
||||
room_id=room_id,
|
||||
user=target_user,
|
||||
room_id=member.room_id,
|
||||
user=member.user,
|
||||
typing=False,
|
||||
)
|
||||
|
||||
del self._member_typing_until[member]
|
||||
|
||||
self.clock.cancel_call_later(self._member_typing_timer[member])
|
||||
del self._member_typing_timer[member]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _push_update(self, room_id, user, typing):
|
||||
localusers = set()
|
||||
@ -97,16 +146,14 @@ class TypingNotificationHandler(BaseHandler):
|
||||
|
||||
rm_handler = self.homeserver.get_handlers().room_member_handler
|
||||
yield rm_handler.fetch_room_distributions_into(
|
||||
room_id, localusers=localusers, remotedomains=remotedomains,
|
||||
ignore_user=user
|
||||
room_id, localusers=localusers, remotedomains=remotedomains
|
||||
)
|
||||
|
||||
for u in localusers:
|
||||
self.push_update_to_clients(
|
||||
if localusers:
|
||||
self._push_update_local(
|
||||
room_id=room_id,
|
||||
observer_user=u,
|
||||
observed_user=user,
|
||||
typing=typing,
|
||||
user=user,
|
||||
typing=typing
|
||||
)
|
||||
|
||||
deferreds = []
|
||||
@ -135,29 +182,67 @@ class TypingNotificationHandler(BaseHandler):
|
||||
room_id, localusers=localusers
|
||||
)
|
||||
|
||||
for u in localusers:
|
||||
self.push_update_to_clients(
|
||||
if localusers:
|
||||
self._push_update_local(
|
||||
room_id=room_id,
|
||||
observer_user=u,
|
||||
observed_user=user,
|
||||
user=user,
|
||||
typing=content["typing"]
|
||||
)
|
||||
|
||||
def push_update_to_clients(self, room_id, observer_user, observed_user,
|
||||
typing):
|
||||
# TODO(paul) steal this from presence.py
|
||||
pass
|
||||
def _push_update_local(self, room_id, user, typing):
|
||||
if room_id not in self._room_serials:
|
||||
self._room_serials[room_id] = 0
|
||||
self._room_typing[room_id] = set()
|
||||
|
||||
room_set = self._room_typing[room_id]
|
||||
if typing:
|
||||
room_set.add(user)
|
||||
elif user in room_set:
|
||||
room_set.remove(user)
|
||||
|
||||
self._latest_room_serial += 1
|
||||
self._room_serials[room_id] = self._latest_room_serial
|
||||
|
||||
self.notifier.on_new_user_event(rooms=[room_id])
|
||||
|
||||
|
||||
class TypingNotificationEventSource(object):
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self._handler = None
|
||||
|
||||
def handler(self):
|
||||
# Avoid cyclic dependency in handler setup
|
||||
if not self._handler:
|
||||
self._handler = self.hs.get_handlers().typing_notification_handler
|
||||
return self._handler
|
||||
|
||||
def _make_event_for(self, room_id):
|
||||
typing = self.handler()._room_typing[room_id]
|
||||
return {
|
||||
"type": "m.typing",
|
||||
"room_id": room_id,
|
||||
"content": {
|
||||
"user_ids": [u.to_string() for u in typing],
|
||||
},
|
||||
}
|
||||
|
||||
def get_new_events_for_user(self, user, from_key, limit):
|
||||
return ([], from_key)
|
||||
from_key = int(from_key)
|
||||
handler = self.handler()
|
||||
|
||||
events = []
|
||||
for room_id in handler._room_serials:
|
||||
if handler._room_serials[room_id] <= from_key:
|
||||
continue
|
||||
|
||||
# TODO: check if user is in room
|
||||
events.append(self._make_event_for(room_id))
|
||||
|
||||
return (events, handler._latest_room_serial)
|
||||
|
||||
def get_current_key(self):
|
||||
return 0
|
||||
return self.handler()._latest_room_serial
|
||||
|
||||
def get_pagination_rows(self, user, pagination_config, key):
|
||||
return ([], pagination_config.from_key)
|
||||
|
@ -14,10 +14,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet import defer, reactor, protocol
|
||||
from twisted.internet.error import DNSLookupError
|
||||
from twisted.web.client import readBody, _AgentBase, _URI
|
||||
from twisted.web.http_headers import Headers
|
||||
from twisted.web._newclient import ResponseDone
|
||||
|
||||
from synapse.http.endpoint import matrix_federation_endpoint
|
||||
from synapse.util.async import sleep
|
||||
@ -25,7 +26,7 @@ from synapse.util.logcontext import PreserveLoggingContext
|
||||
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
|
||||
from synapse.api.errors import CodeMessageException, SynapseError
|
||||
from synapse.api.errors import CodeMessageException, SynapseError, Codes
|
||||
|
||||
from syutil.crypto.jsonsign import sign_json
|
||||
|
||||
@ -89,8 +90,8 @@ class MatrixFederationHttpClient(object):
|
||||
("", "", path_bytes, param_bytes, query_bytes, "",)
|
||||
)
|
||||
|
||||
logger.debug("Sending request to %s: %s %s",
|
||||
destination, method, url_bytes)
|
||||
logger.info("Sending request to %s: %s %s",
|
||||
destination, method, url_bytes)
|
||||
|
||||
logger.debug(
|
||||
"Types: %s",
|
||||
@ -101,6 +102,8 @@ class MatrixFederationHttpClient(object):
|
||||
]
|
||||
)
|
||||
|
||||
# XXX: Would be much nicer to retry only at the transaction-layer
|
||||
# (once we have reliable transactions in place)
|
||||
retries_left = 5
|
||||
|
||||
endpoint = self._getEndpoint(reactor, destination)
|
||||
@ -127,11 +130,20 @@ class MatrixFederationHttpClient(object):
|
||||
break
|
||||
except Exception as e:
|
||||
if not retry_on_dns_fail and isinstance(e, DNSLookupError):
|
||||
logger.warn("DNS Lookup failed to %s with %s", destination,
|
||||
e)
|
||||
logger.warn(
|
||||
"DNS Lookup failed to %s with %s",
|
||||
destination,
|
||||
e
|
||||
)
|
||||
raise SynapseError(400, "Domain specified not found.")
|
||||
|
||||
logger.exception("Got error in _create_request")
|
||||
logger.warn(
|
||||
"Sending request failed to %s: %s %s : %s",
|
||||
destination,
|
||||
method,
|
||||
url_bytes,
|
||||
e
|
||||
)
|
||||
_print_ex(e)
|
||||
|
||||
if retries_left:
|
||||
@ -140,15 +152,21 @@ class MatrixFederationHttpClient(object):
|
||||
else:
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
"Received response %d %s for %s: %s %s",
|
||||
response.code,
|
||||
response.phrase,
|
||||
destination,
|
||||
method,
|
||||
url_bytes
|
||||
)
|
||||
|
||||
if 200 <= response.code < 300:
|
||||
# We need to update the transactions table to say it was sent?
|
||||
pass
|
||||
else:
|
||||
# :'(
|
||||
# Update transactions table?
|
||||
logger.error(
|
||||
"Got response %d %s", response.code, response.phrase
|
||||
)
|
||||
raise CodeMessageException(
|
||||
response.code, response.phrase
|
||||
)
|
||||
@ -227,7 +245,7 @@ class MatrixFederationHttpClient(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_json(self, destination, path, args={}, retry_on_dns_fail=True):
|
||||
""" Get's some json from the given host homeserver and path
|
||||
""" GETs some json from the given host homeserver and path
|
||||
|
||||
Args:
|
||||
destination (str): The remote server to send the HTTP request
|
||||
@ -235,9 +253,6 @@ class MatrixFederationHttpClient(object):
|
||||
path (str): The HTTP path.
|
||||
args (dict): A dictionary used to create query strings, defaults to
|
||||
None.
|
||||
**Note**: The value of each key is assumed to be an iterable
|
||||
and *not* a string.
|
||||
|
||||
Returns:
|
||||
Deferred: Succeeds when we get *any* HTTP response.
|
||||
|
||||
@ -272,6 +287,52 @@ class MatrixFederationHttpClient(object):
|
||||
|
||||
defer.returnValue(json.loads(body))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_file(self, destination, path, output_stream, args={},
|
||||
retry_on_dns_fail=True, max_size=None):
|
||||
"""GETs a file from a given homeserver
|
||||
Args:
|
||||
destination (str): The remote server to send the HTTP request to.
|
||||
path (str): The HTTP path to GET.
|
||||
output_stream (file): File to write the response body to.
|
||||
args (dict): Optional dictionary used to create the query string.
|
||||
Returns:
|
||||
A (int,dict) tuple of the file length and a dict of the response
|
||||
headers.
|
||||
"""
|
||||
|
||||
encoded_args = {}
|
||||
for k, vs in args.items():
|
||||
if isinstance(vs, basestring):
|
||||
vs = [vs]
|
||||
encoded_args[k] = [v.encode("UTF-8") for v in vs]
|
||||
|
||||
query_bytes = urllib.urlencode(encoded_args, True)
|
||||
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
|
||||
|
||||
def body_callback(method, url_bytes, headers_dict):
|
||||
self.sign_request(destination, method, url_bytes, headers_dict)
|
||||
return None
|
||||
|
||||
response = yield self._create_request(
|
||||
destination.encode("ascii"),
|
||||
"GET",
|
||||
path.encode("ascii"),
|
||||
query_bytes=query_bytes,
|
||||
body_callback=body_callback,
|
||||
retry_on_dns_fail=retry_on_dns_fail
|
||||
)
|
||||
|
||||
headers = dict(response.headers.getAllRawHeaders())
|
||||
|
||||
try:
|
||||
length = yield _readBodyToFile(response, output_stream, max_size)
|
||||
except:
|
||||
logger.exception("Failed to download body")
|
||||
raise
|
||||
|
||||
defer.returnValue((length, headers))
|
||||
|
||||
def _getEndpoint(self, reactor, destination):
|
||||
return matrix_federation_endpoint(
|
||||
reactor, destination, timeout=10,
|
||||
@ -279,12 +340,44 @@ class MatrixFederationHttpClient(object):
|
||||
)
|
||||
|
||||
|
||||
class _ReadBodyToFileProtocol(protocol.Protocol):
|
||||
def __init__(self, stream, deferred, max_size):
|
||||
self.stream = stream
|
||||
self.deferred = deferred
|
||||
self.length = 0
|
||||
self.max_size = max_size
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.stream.write(data)
|
||||
self.length += len(data)
|
||||
if self.max_size is not None and self.length >= self.max_size:
|
||||
self.deferred.errback(SynapseError(
|
||||
502,
|
||||
"Requested file is too large > %r bytes" % (self.max_size,),
|
||||
Codes.TOO_LARGE,
|
||||
))
|
||||
self.deferred = defer.Deferred()
|
||||
self.transport.loseConnection()
|
||||
|
||||
def connectionLost(self, reason):
|
||||
if reason.check(ResponseDone):
|
||||
self.deferred.callback(self.length)
|
||||
else:
|
||||
self.deferred.errback(reason)
|
||||
|
||||
|
||||
def _readBodyToFile(response, stream, max_size):
|
||||
d = defer.Deferred()
|
||||
response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
|
||||
return d
|
||||
|
||||
|
||||
def _print_ex(e):
|
||||
if hasattr(e, "reasons") and e.reasons:
|
||||
for ex in e.reasons:
|
||||
_print_ex(ex)
|
||||
else:
|
||||
logger.exception(e)
|
||||
logger.warn(e)
|
||||
|
||||
|
||||
class _JsonProducer(object):
|
||||
|
@ -29,6 +29,7 @@ from twisted.web.util import redirectTo
|
||||
|
||||
import collections
|
||||
import logging
|
||||
import urllib
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -122,9 +123,14 @@ class JsonResource(HttpServer, resource.Resource):
|
||||
# We found a match! Trigger callback and then return the
|
||||
# returned response. We pass both the request and any
|
||||
# matched groups from the regex to the callback.
|
||||
|
||||
args = [
|
||||
urllib.unquote(u).decode("UTF-8") for u in m.groups()
|
||||
]
|
||||
|
||||
code, response = yield path_entry.callback(
|
||||
request,
|
||||
*m.groups()
|
||||
*args
|
||||
)
|
||||
|
||||
self._send_response(request, code, response)
|
||||
@ -166,14 +172,10 @@ class JsonResource(HttpServer, resource.Resource):
|
||||
request)
|
||||
return
|
||||
|
||||
if not self._request_user_agent_is_curl(request):
|
||||
json_bytes = encode_canonical_json(response_json_object)
|
||||
else:
|
||||
json_bytes = encode_pretty_printed_json(response_json_object)
|
||||
|
||||
# TODO: Only enable CORS for the requests that need it.
|
||||
respond_with_json_bytes(request, code, json_bytes, send_cors=True,
|
||||
response_code_message=response_code_message)
|
||||
respond_with_json(request, code, response_json_object, send_cors=True,
|
||||
response_code_message=response_code_message,
|
||||
pretty_print=self._request_user_agent_is_curl)
|
||||
|
||||
@staticmethod
|
||||
def _request_user_agent_is_curl(request):
|
||||
@ -202,6 +204,17 @@ class RootRedirect(resource.Resource):
|
||||
return resource.Resource.getChild(self, name, request)
|
||||
|
||||
|
||||
def respond_with_json(request, code, json_object, send_cors=False,
|
||||
response_code_message=None, pretty_print=False):
|
||||
if not pretty_print:
|
||||
json_bytes = encode_pretty_printed_json(json_object)
|
||||
else:
|
||||
json_bytes = encode_canonical_json(json_object)
|
||||
|
||||
return respond_with_json_bytes(request, code, json_bytes, send_cors,
|
||||
response_code_message=response_code_message)
|
||||
|
||||
|
||||
def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
|
||||
response_code_message=None):
|
||||
"""Sends encoded JSON in response to the given request.
|
||||
|
0
synapse/media/__init__.py
Normal file
0
synapse/media/__init__.py
Normal file
0
synapse/media/v0/__init__.py
Normal file
0
synapse/media/v0/__init__.py
Normal file
@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .server import respond_with_json_bytes
|
||||
from synapse.http.server import respond_with_json_bytes
|
||||
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.api.errors import (
|
0
synapse/media/v1/__init__.py
Normal file
0
synapse/media/v1/__init__.py
Normal file
369
synapse/media/v1/base_resource.py
Normal file
369
synapse/media/v1/base_resource.py
Normal file
@ -0,0 +1,369 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .thumbnailer import Thumbnailer
|
||||
|
||||
from synapse.http.server import respond_with_json
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.api.errors import (
|
||||
cs_exception, CodeMessageException, cs_error, Codes, SynapseError
|
||||
)
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.protocols.basic import FileSender
|
||||
|
||||
import os
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseMediaResource(Resource):
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs, filepaths):
|
||||
Resource.__init__(self)
|
||||
self.auth = hs.get_auth()
|
||||
self.client = hs.get_http_client()
|
||||
self.clock = hs.get_clock()
|
||||
self.server_name = hs.hostname
|
||||
self.store = hs.get_datastore()
|
||||
self.max_upload_size = hs.config.max_upload_size
|
||||
self.max_image_pixels = hs.config.max_image_pixels
|
||||
self.filepaths = filepaths
|
||||
self.downloads = {}
|
||||
|
||||
@staticmethod
|
||||
def catch_errors(request_handler):
|
||||
@defer.inlineCallbacks
|
||||
def wrapped_request_handler(self, request):
|
||||
try:
|
||||
yield request_handler(self, request)
|
||||
except CodeMessageException as e:
|
||||
logger.exception(e)
|
||||
respond_with_json(
|
||||
request, e.code, cs_exception(e), send_cors=True
|
||||
)
|
||||
except:
|
||||
logger.exception(
|
||||
"Failed handle request %s.%s on %r",
|
||||
request_handler.__module__,
|
||||
request_handler.__name__,
|
||||
self,
|
||||
)
|
||||
respond_with_json(
|
||||
request,
|
||||
500,
|
||||
{"error": "Internal server error"},
|
||||
send_cors=True
|
||||
)
|
||||
return wrapped_request_handler
|
||||
|
||||
@staticmethod
|
||||
def _parse_media_id(request):
|
||||
try:
|
||||
server_name, media_id = request.postpath
|
||||
return (server_name, media_id)
|
||||
except:
|
||||
raise SynapseError(
|
||||
404,
|
||||
"Invalid media id token %r" % (request.postpath,),
|
||||
Codes.UNKKOWN,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_integer(request, arg_name, default=None):
|
||||
try:
|
||||
if default is None:
|
||||
return int(request.args[arg_name][0])
|
||||
else:
|
||||
return int(request.args.get(arg_name, [default])[0])
|
||||
except:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Missing integer argument %r" % (arg_name,),
|
||||
Codes.UNKNOWN,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_string(request, arg_name, default=None):
|
||||
try:
|
||||
if default is None:
|
||||
return request.args[arg_name][0]
|
||||
else:
|
||||
return request.args.get(arg_name, [default])[0]
|
||||
except:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Missing string argument %r" % (arg_name,),
|
||||
Codes.UNKNOWN,
|
||||
)
|
||||
|
||||
def _respond_404(self, request):
|
||||
respond_with_json(
|
||||
request, 404,
|
||||
cs_error(
|
||||
"Not found %r" % (request.postpath,),
|
||||
code=Codes.NOT_FOUND,
|
||||
),
|
||||
send_cors=True
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _makedirs(filepath):
|
||||
dirname = os.path.dirname(filepath)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
def _get_remote_media(self, server_name, media_id):
|
||||
key = (server_name, media_id)
|
||||
download = self.downloads.get(key)
|
||||
if download is None:
|
||||
download = self._get_remote_media_impl(server_name, media_id)
|
||||
self.downloads[key] = download
|
||||
|
||||
@download.addBoth
|
||||
def callback(media_info):
|
||||
del self.downloads[key]
|
||||
return download
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_remote_media_impl(self, server_name, media_id):
|
||||
media_info = yield self.store.get_cached_remote_media(
|
||||
server_name, media_id
|
||||
)
|
||||
if not media_info:
|
||||
media_info = yield self._download_remote_file(
|
||||
server_name, media_id
|
||||
)
|
||||
defer.returnValue(media_info)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _download_remote_file(self, server_name, media_id):
|
||||
file_id = random_string(24)
|
||||
|
||||
fname = self.filepaths.remote_media_filepath(
|
||||
server_name, file_id
|
||||
)
|
||||
self._makedirs(fname)
|
||||
|
||||
try:
|
||||
with open(fname, "wb") as f:
|
||||
request_path = "/".join((
|
||||
"/_matrix/media/v1/download", server_name, media_id,
|
||||
))
|
||||
length, headers = yield self.client.get_file(
|
||||
server_name, request_path, output_stream=f,
|
||||
max_size=self.max_upload_size,
|
||||
)
|
||||
media_type = headers["Content-Type"][0]
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
yield self.store.store_cached_remote_media(
|
||||
origin=server_name,
|
||||
media_id=media_id,
|
||||
media_type=media_type,
|
||||
time_now_ms=self.clock.time_msec(),
|
||||
upload_name=None,
|
||||
media_length=length,
|
||||
filesystem_id=file_id,
|
||||
)
|
||||
except:
|
||||
os.remove(fname)
|
||||
raise
|
||||
|
||||
media_info = {
|
||||
"media_type": media_type,
|
||||
"media_length": length,
|
||||
"upload_name": None,
|
||||
"created_ts": time_now_ms,
|
||||
"filesystem_id": file_id,
|
||||
}
|
||||
|
||||
yield self._generate_remote_thumbnails(
|
||||
server_name, media_id, media_info
|
||||
)
|
||||
|
||||
defer.returnValue(media_info)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _respond_with_file(self, request, media_type, file_path):
|
||||
logger.debug("Responding with %r", file_path)
|
||||
|
||||
if os.path.isfile(file_path):
|
||||
request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
|
||||
|
||||
# cache for at least a day.
|
||||
# XXX: we might want to turn this off for data we don't want to
|
||||
# recommend caching as it's sensitive or private - or at least
|
||||
# select private. don't bother setting Expires as all our
|
||||
# clients are smart enough to be happy with Cache-Control
|
||||
request.setHeader(
|
||||
b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
|
||||
)
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
yield FileSender().beginFileTransfer(f, request)
|
||||
|
||||
request.finish()
|
||||
else:
|
||||
self._respond_404()
|
||||
|
||||
def _get_thumbnail_requirements(self, media_type):
|
||||
if media_type == "image/jpeg":
|
||||
return (
|
||||
(32, 32, "crop", "image/jpeg"),
|
||||
(96, 96, "crop", "image/jpeg"),
|
||||
(320, 240, "scale", "image/jpeg"),
|
||||
(640, 480, "scale", "image/jpeg"),
|
||||
)
|
||||
elif (media_type == "image/png") or (media_type == "image/gif"):
|
||||
return (
|
||||
(32, 32, "crop", "image/png"),
|
||||
(96, 96, "crop", "image/png"),
|
||||
(320, 240, "scale", "image/png"),
|
||||
(640, 480, "scale", "image/png"),
|
||||
)
|
||||
else:
|
||||
return ()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _generate_local_thumbnails(self, media_id, media_info):
|
||||
media_type = media_info["media_type"]
|
||||
requirements = self._get_thumbnail_requirements(media_type)
|
||||
if not requirements:
|
||||
return
|
||||
|
||||
input_path = self.filepaths.local_media_filepath(media_id)
|
||||
thumbnailer = Thumbnailer(input_path)
|
||||
m_width = thumbnailer.width
|
||||
m_height = thumbnailer.height
|
||||
|
||||
if m_width * m_height >= self.max_image_pixels:
|
||||
logger.info(
|
||||
"Image too large to thumbnail %r x %r > %r",
|
||||
m_width, m_height, self.max_image_pixels
|
||||
)
|
||||
return
|
||||
|
||||
scales = set()
|
||||
crops = set()
|
||||
for r_width, r_height, r_method, r_type in requirements:
|
||||
if r_method == "scale":
|
||||
t_width, t_height = thumbnailer.aspect(r_width, r_height)
|
||||
scales.add((
|
||||
min(m_width, t_width), min(m_height, t_height), r_type,
|
||||
))
|
||||
elif r_method == "crop":
|
||||
crops.add((r_width, r_height, r_type))
|
||||
|
||||
for t_width, t_height, t_type in scales:
|
||||
t_method = "scale"
|
||||
t_path = self.filepaths.local_media_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
self._makedirs(t_path)
|
||||
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
|
||||
yield self.store.store_local_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method, t_len
|
||||
)
|
||||
|
||||
for t_width, t_height, t_type in crops:
|
||||
if (t_width, t_height, t_type) in scales:
|
||||
# If the aspect ratio of the cropped thumbnail matches a purely
|
||||
# scaled one then there is no point in calculating a separate
|
||||
# thumbnail.
|
||||
continue
|
||||
t_method = "crop"
|
||||
t_path = self.filepaths.local_media_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
self._makedirs(t_path)
|
||||
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
|
||||
yield self.store.store_local_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method, t_len
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
"width": m_width,
|
||||
"height": m_height,
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _generate_remote_thumbnails(self, server_name, media_id, media_info):
|
||||
media_type = media_info["media_type"]
|
||||
file_id = media_info["filesystem_id"]
|
||||
requirements = self._get_thumbnail_requirements(media_type)
|
||||
if not requirements:
|
||||
return
|
||||
|
||||
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
|
||||
thumbnailer = Thumbnailer(input_path)
|
||||
m_width = thumbnailer.width
|
||||
m_height = thumbnailer.height
|
||||
|
||||
if m_width * m_height >= self.max_image_pixels:
|
||||
logger.info(
|
||||
"Image too large to thumbnail %r x %r > %r",
|
||||
m_width, m_height, self.max_image_pixels
|
||||
)
|
||||
return
|
||||
|
||||
scales = set()
|
||||
crops = set()
|
||||
for r_width, r_height, r_method, r_type in requirements:
|
||||
if r_method == "scale":
|
||||
t_width, t_height = thumbnailer.aspect(r_width, r_height)
|
||||
scales.add((
|
||||
min(m_width, t_width), min(m_height, t_height), r_type,
|
||||
))
|
||||
elif r_method == "crop":
|
||||
crops.add((r_width, r_height, r_type))
|
||||
|
||||
for t_width, t_height, t_type in scales:
|
||||
t_method = "scale"
|
||||
t_path = self.filepaths.remote_media_thumbnail(
|
||||
server_name, file_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
self._makedirs(t_path)
|
||||
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
|
||||
yield self.store.store_remote_media_thumbnail(
|
||||
server_name, media_id, file_id,
|
||||
t_width, t_height, t_type, t_method, t_len
|
||||
)
|
||||
|
||||
for t_width, t_height, t_type in crops:
|
||||
if (t_width, t_height, t_type) in scales:
|
||||
# If the aspect ratio of the cropped thumbnail matches a purely
|
||||
# scaled one then there is no point in calculating a separate
|
||||
# thumbnail.
|
||||
continue
|
||||
t_method = "crop"
|
||||
t_path = self.filepaths.remote_media_thumbnail(
|
||||
server_name, file_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
self._makedirs(t_path)
|
||||
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
|
||||
yield self.store.store_remote_media_thumbnail(
|
||||
server_name, media_id, file_id,
|
||||
t_width, t_height, t_type, t_method, t_len
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
"width": m_width,
|
||||
"height": m_height,
|
||||
})
|
68
synapse/media/v1/download_resource.py
Normal file
68
synapse/media/v1/download_resource.py
Normal file
@ -0,0 +1,68 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .base_resource import BaseMediaResource
|
||||
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
from twisted.internet import defer
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DownloadResource(BaseMediaResource):
|
||||
def render_GET(self, request):
|
||||
self._async_render_GET(request)
|
||||
return NOT_DONE_YET
|
||||
|
||||
@BaseMediaResource.catch_errors
|
||||
@defer.inlineCallbacks
|
||||
def _async_render_GET(self, request):
|
||||
try:
|
||||
server_name, media_id = request.postpath
|
||||
except:
|
||||
self._respond_404(request)
|
||||
return
|
||||
|
||||
if server_name == self.server_name:
|
||||
yield self._respond_local_file(request, media_id)
|
||||
else:
|
||||
yield self._respond_remote_file(request, server_name, media_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _respond_local_file(self, request, media_id):
|
||||
media_info = yield self.store.get_local_media(media_id)
|
||||
if not media_info:
|
||||
self._respond_404()
|
||||
return
|
||||
|
||||
media_type = media_info["media_type"]
|
||||
file_path = self.filepaths.local_media_filepath(media_id)
|
||||
|
||||
yield self._respond_with_file(request, media_type, file_path)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _respond_remote_file(self, request, server_name, media_id):
|
||||
media_info = yield self._get_remote_media(server_name, media_id)
|
||||
|
||||
media_type = media_info["media_type"]
|
||||
filesystem_id = media_info["filesystem_id"]
|
||||
|
||||
file_path = self.filepaths.remote_media_filepath(
|
||||
server_name, filesystem_id
|
||||
)
|
||||
|
||||
yield self._respond_with_file(request, media_type, file_path)
|
67
synapse/media/v1/filepath.py
Normal file
67
synapse/media/v1/filepath.py
Normal file
@ -0,0 +1,67 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class MediaFilePaths(object):
|
||||
|
||||
def __init__(self, base_path):
|
||||
self.base_path = base_path
|
||||
|
||||
def default_thumbnail(self, default_top_level, default_sub_type, width,
|
||||
height, content_type, method):
|
||||
top_level_type, sub_type = content_type.split("/")
|
||||
file_name = "%i-%i-%s-%s-%s" % (
|
||||
width, height, top_level_type, sub_type, method
|
||||
)
|
||||
return os.path.join(
|
||||
self.base_path, "default_thumbnails", default_top_level,
|
||||
default_sub_type, file_name
|
||||
)
|
||||
|
||||
def local_media_filepath(self, media_id):
|
||||
return os.path.join(
|
||||
self.base_path, "local_content",
|
||||
media_id[0:2], media_id[2:4], media_id[4:]
|
||||
)
|
||||
|
||||
def local_media_thumbnail(self, media_id, width, height, content_type,
|
||||
method):
|
||||
top_level_type, sub_type = content_type.split("/")
|
||||
file_name = "%i-%i-%s-%s-%s" % (
|
||||
width, height, top_level_type, sub_type, method
|
||||
)
|
||||
return os.path.join(
|
||||
self.base_path, "local_thumbnails",
|
||||
media_id[0:2], media_id[2:4], media_id[4:],
|
||||
file_name
|
||||
)
|
||||
|
||||
def remote_media_filepath(self, server_name, file_id):
|
||||
return os.path.join(
|
||||
self.base_path, "remote_content", server_name,
|
||||
file_id[0:2], file_id[2:4], file_id[4:]
|
||||
)
|
||||
|
||||
def remote_media_thumbnail(self, server_name, file_id, width, height,
|
||||
content_type, method):
|
||||
top_level_type, sub_type = content_type.split("/")
|
||||
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
|
||||
return os.path.join(
|
||||
self.base_path, "remote_thumbnail", server_name,
|
||||
file_id[0:2], file_id[2:4], file_id[4:],
|
||||
file_name
|
||||
)
|
77
synapse/media/v1/media_repository.py
Normal file
77
synapse/media/v1/media_repository.py
Normal file
@ -0,0 +1,77 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .upload_resource import UploadResource
|
||||
from .download_resource import DownloadResource
|
||||
from .thumbnail_resource import ThumbnailResource
|
||||
from .filepath import MediaFilePaths
|
||||
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MediaRepositoryResource(Resource):
|
||||
"""File uploading and downloading.
|
||||
|
||||
Uploads are POSTed to a resource which returns a token which is used to GET
|
||||
the download::
|
||||
|
||||
=> POST /_matrix/media/v1/upload HTTP/1.1
|
||||
Content-Type: <media-type>
|
||||
|
||||
<media>
|
||||
|
||||
<= HTTP/1.1 200 OK
|
||||
Content-Type: application/json
|
||||
|
||||
{ "content_uri": "mxc://<server-name>/<media-id>" }
|
||||
|
||||
=> GET /_matrix/media/v1/download/<server-name>/<media-id> HTTP/1.1
|
||||
|
||||
<= HTTP/1.1 200 OK
|
||||
Content-Type: <media-type>
|
||||
Content-Disposition: attachment;filename=<upload-filename>
|
||||
|
||||
<media>
|
||||
|
||||
Clients can get thumbnails by supplying a desired width and height and
|
||||
thumbnailing method::
|
||||
|
||||
=> GET /_matrix/media/v1/thumbnail/<server_name>
|
||||
/<media-id>?width=<w>&height=<h>&method=<m> HTTP/1.1
|
||||
|
||||
<= HTTP/1.1 200 OK
|
||||
Content-Type: image/jpeg or image/png
|
||||
|
||||
<thumbnail>
|
||||
|
||||
The thumbnail methods are "crop" and "scale". "scale" trys to return an
|
||||
image where either the width or the height is smaller than the requested
|
||||
size. The client should then scale and letterbox the image if it needs to
|
||||
fit within a given rectangle. "crop" trys to return an image where the
|
||||
width and height are close to the requested size and the aspect matches
|
||||
the requested size. The client should scale the image if it needs to fit
|
||||
within a given rectangle.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
Resource.__init__(self)
|
||||
filepaths = MediaFilePaths(hs.config.media_store_path)
|
||||
self.putChild("upload", UploadResource(hs, filepaths))
|
||||
self.putChild("download", DownloadResource(hs, filepaths))
|
||||
self.putChild("thumbnail", ThumbnailResource(hs, filepaths))
|
191
synapse/media/v1/thumbnail_resource.py
Normal file
191
synapse/media/v1/thumbnail_resource.py
Normal file
@ -0,0 +1,191 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from .base_resource import BaseMediaResource
|
||||
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
from twisted.internet import defer
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ThumbnailResource(BaseMediaResource):
|
||||
isLeaf = True
|
||||
|
||||
def render_GET(self, request):
|
||||
self._async_render_GET(request)
|
||||
return NOT_DONE_YET
|
||||
|
||||
@BaseMediaResource.catch_errors
|
||||
@defer.inlineCallbacks
|
||||
def _async_render_GET(self, request):
|
||||
server_name, media_id = self._parse_media_id(request)
|
||||
width = self._parse_integer(request, "width")
|
||||
height = self._parse_integer(request, "height")
|
||||
method = self._parse_string(request, "method", "scale")
|
||||
m_type = self._parse_string(request, "type", "image/png")
|
||||
|
||||
if server_name == self.server_name:
|
||||
yield self._respond_local_thumbnail(
|
||||
request, media_id, width, height, method, m_type
|
||||
)
|
||||
else:
|
||||
yield self._respond_remote_thumbnail(
|
||||
request, server_name, media_id,
|
||||
width, height, method, m_type
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _respond_local_thumbnail(self, request, media_id, width, height,
|
||||
method, m_type):
|
||||
media_info = yield self.store.get_local_media(media_id)
|
||||
|
||||
if not media_info:
|
||||
self._respond_404(request)
|
||||
return
|
||||
|
||||
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
|
||||
|
||||
if thumbnail_infos:
|
||||
thumbnail_info = self._select_thumbnail(
|
||||
width, height, method, m_type, thumbnail_infos
|
||||
)
|
||||
t_width = thumbnail_info["thumbnail_width"]
|
||||
t_height = thumbnail_info["thumbnail_height"]
|
||||
t_type = thumbnail_info["thumbnail_type"]
|
||||
t_method = thumbnail_info["thumbnail_method"]
|
||||
|
||||
file_path = self.filepaths.local_media_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method,
|
||||
)
|
||||
yield self._respond_with_file(request, t_type, file_path)
|
||||
|
||||
else:
|
||||
yield self._respond_default_thumbnail(
|
||||
request, media_info, width, height, method, m_type,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _respond_remote_thumbnail(self, request, server_name, media_id, width,
|
||||
height, method, m_type):
|
||||
# TODO: Don't download the whole remote file
|
||||
# We should proxy the thumbnail from the remote server instead.
|
||||
media_info = yield self._get_remote_media(server_name, media_id)
|
||||
|
||||
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
|
||||
server_name, media_id,
|
||||
)
|
||||
|
||||
if thumbnail_infos:
|
||||
thumbnail_info = self._select_thumbnail(
|
||||
width, height, method, m_type, thumbnail_infos
|
||||
)
|
||||
t_width = thumbnail_info["thumbnail_width"]
|
||||
t_height = thumbnail_info["thumbnail_height"]
|
||||
t_type = thumbnail_info["thumbnail_type"]
|
||||
t_method = thumbnail_info["thumbnail_method"]
|
||||
file_id = thumbnail_info["filesystem_id"]
|
||||
|
||||
file_path = self.filepaths.remote_media_thumbnail(
|
||||
server_name, file_id, t_width, t_height, t_type, t_method,
|
||||
)
|
||||
yield self._respond_with_file(request, t_type, file_path)
|
||||
else:
|
||||
yield self._respond_default_thumbnail(
|
||||
request, media_info, width, height, method, m_type,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _respond_default_thumbnail(self, request, media_info, width, height,
|
||||
method, m_type):
|
||||
media_type = media_info["media_type"]
|
||||
top_level_type = media_type.split("/")[0]
|
||||
sub_type = media_type.split("/")[-1].split(";")[0]
|
||||
thumbnail_infos = yield self.store.get_default_thumbnails(
|
||||
top_level_type, sub_type,
|
||||
)
|
||||
if not thumbnail_infos:
|
||||
thumbnail_infos = yield self.store.get_default_thumbnails(
|
||||
top_level_type, "_default",
|
||||
)
|
||||
if not thumbnail_infos:
|
||||
thumbnail_infos = yield self.store.get_default_thumbnails(
|
||||
"_default", "_default",
|
||||
)
|
||||
if not thumbnail_infos:
|
||||
self._respond_404(request)
|
||||
return
|
||||
|
||||
thumbnail_info = self._select_thumbnail(
|
||||
width, height, "crop", m_type, thumbnail_infos
|
||||
)
|
||||
|
||||
t_width = thumbnail_info["thumbnail_width"]
|
||||
t_height = thumbnail_info["thumbnail_height"]
|
||||
t_type = thumbnail_info["thumbnail_type"]
|
||||
t_method = thumbnail_info["thumbnail_method"]
|
||||
|
||||
file_path = self.filepaths.default_thumbnail(
|
||||
top_level_type, sub_type, t_width, t_height, t_type, t_method,
|
||||
)
|
||||
yield self.respond_with_file(request, t_type, file_path)
|
||||
|
||||
def _select_thumbnail(self, desired_width, desired_height, desired_method,
|
||||
desired_type, thumbnail_infos):
|
||||
d_w = desired_width
|
||||
d_h = desired_height
|
||||
|
||||
if desired_method.lower() == "crop":
|
||||
info_list = []
|
||||
for info in thumbnail_infos:
|
||||
t_w = info["thumbnail_width"]
|
||||
t_h = info["thumbnail_height"]
|
||||
t_method = info["thumbnail_method"]
|
||||
if t_method == "scale" or t_method == "crop":
|
||||
aspect_quality = abs(d_w * t_h - d_h * t_w)
|
||||
size_quality = abs((d_w - t_w) * (d_h - t_h))
|
||||
type_quality = desired_type != info["thumbnail_type"]
|
||||
length_quality = info["thumbnail_length"]
|
||||
info_list.append((
|
||||
aspect_quality, size_quality, type_quality,
|
||||
length_quality, info
|
||||
))
|
||||
if info_list:
|
||||
return min(info_list)[-1]
|
||||
else:
|
||||
info_list = []
|
||||
info_list2 = []
|
||||
for info in thumbnail_infos:
|
||||
t_w = info["thumbnail_width"]
|
||||
t_h = info["thumbnail_height"]
|
||||
t_method = info["thumbnail_method"]
|
||||
size_quality = abs((d_w - t_w) * (d_h - t_h))
|
||||
type_quality = desired_type != info["thumbnail_type"]
|
||||
length_quality = info["thumbnail_length"]
|
||||
if t_method == "scale" and (t_w >= d_w or t_h >= d_h):
|
||||
info_list.append((
|
||||
size_quality, type_quality, length_quality, info
|
||||
))
|
||||
elif t_method == "scale":
|
||||
info_list2.append((
|
||||
size_quality, type_quality, length_quality, info
|
||||
))
|
||||
if info_list:
|
||||
return min(info_list)[-1]
|
||||
else:
|
||||
return min(info_list2)[-1]
|
89
synapse/media/v1/thumbnailer.py
Normal file
89
synapse/media/v1/thumbnailer.py
Normal file
@ -0,0 +1,89 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import PIL.Image as Image
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
class Thumbnailer(object):
|
||||
|
||||
FORMATS = {
|
||||
"image/jpeg": "JPEG",
|
||||
"image/png": "PNG",
|
||||
}
|
||||
|
||||
def __init__(self, input_path):
|
||||
self.image = Image.open(input_path)
|
||||
self.width, self.height = self.image.size
|
||||
|
||||
def aspect(self, max_width, max_height):
|
||||
"""Calculate the largest size that preserves aspect ratio which
|
||||
fits within the given rectangle::
|
||||
|
||||
(w_in / h_in) = (w_out / h_out)
|
||||
w_out = min(w_max, h_max * (w_in / h_in))
|
||||
h_out = min(h_max, w_max * (h_in / w_in))
|
||||
|
||||
Args:
|
||||
max_width: The largest possible width.
|
||||
max_height: The larget possible height.
|
||||
"""
|
||||
|
||||
if max_width * self.height < max_height * self.width:
|
||||
return (max_width, (max_width * self.height) // self.width)
|
||||
else:
|
||||
return ((max_height * self.width) // self.height, max_height)
|
||||
|
||||
def scale(self, output_path, width, height, output_type):
|
||||
"""Rescales the image to the given dimensions"""
|
||||
scaled = self.image.resize((width, height), Image.BILINEAR)
|
||||
return self.save_image(scaled, output_type, output_path)
|
||||
|
||||
def crop(self, output_path, width, height, output_type):
|
||||
"""Rescales and crops the image to the given dimensions preserving
|
||||
aspect::
|
||||
(w_in / h_in) = (w_scaled / h_scaled)
|
||||
w_scaled = max(w_out, h_out * (w_in / h_in))
|
||||
h_scaled = max(h_out, w_out * (h_in / w_in))
|
||||
|
||||
Args:
|
||||
max_width: The largest possible width.
|
||||
max_height: The larget possible height.
|
||||
"""
|
||||
if width * self.height > height * self.width:
|
||||
scaled_height = (width * self.height) // self.width
|
||||
scaled_image = self.image.resize(
|
||||
(width, scaled_height), Image.BILINEAR
|
||||
)
|
||||
crop_top = (scaled_height - height) // 2
|
||||
crop_bottom = height + crop_top
|
||||
cropped = scaled_image.crop((0, crop_top, width, crop_bottom))
|
||||
else:
|
||||
scaled_width = (height * self.width) // self.height
|
||||
scaled_image = self.image.resize(
|
||||
(scaled_width, height), Image.BILINEAR
|
||||
)
|
||||
crop_left = (scaled_width - width) // 2
|
||||
crop_right = width + crop_left
|
||||
cropped = scaled_image.crop((crop_left, 0, crop_right, height))
|
||||
return self.save_image(cropped, output_type, output_path)
|
||||
|
||||
def save_image(self, output_image, output_type, output_path):
|
||||
output_bytes_io = BytesIO()
|
||||
output_image.save(output_bytes_io, self.FORMATS[output_type])
|
||||
output_bytes = output_bytes_io.getvalue()
|
||||
with open(output_path, "wb") as output_file:
|
||||
output_file.write(output_bytes)
|
||||
return len(output_bytes)
|
113
synapse/media/v1/upload_resource.py
Normal file
113
synapse/media/v1/upload_resource.py
Normal file
@ -0,0 +1,113 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.http.server import respond_with_json
|
||||
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.api.errors import (
|
||||
cs_exception, SynapseError, CodeMessageException
|
||||
)
|
||||
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
from twisted.internet import defer
|
||||
|
||||
from .base_resource import BaseMediaResource
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UploadResource(BaseMediaResource):
|
||||
def render_POST(self, request):
|
||||
self._async_render_POST(request)
|
||||
return NOT_DONE_YET
|
||||
|
||||
def render_OPTIONS(self, request):
|
||||
respond_with_json(request, 200, {}, send_cors=True)
|
||||
return NOT_DONE_YET
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _async_render_POST(self, request):
|
||||
try:
|
||||
auth_user = yield self.auth.get_user_by_req(request)
|
||||
# TODO: The checks here are a bit late. The content will have
|
||||
# already been uploaded to a tmp file at this point
|
||||
content_length = request.getHeader("Content-Length")
|
||||
if content_length is None:
|
||||
raise SynapseError(
|
||||
msg="Request must specify a Content-Length", code=400
|
||||
)
|
||||
if int(content_length) > self.max_upload_size:
|
||||
raise SynapseError(
|
||||
msg="Upload request body is too large",
|
||||
code=413,
|
||||
)
|
||||
|
||||
headers = request.requestHeaders
|
||||
|
||||
if headers.hasHeader("Content-Type"):
|
||||
media_type = headers.getRawHeaders("Content-Type")[0]
|
||||
else:
|
||||
raise SynapseError(
|
||||
msg="Upload request missing 'Content-Type'",
|
||||
code=400,
|
||||
)
|
||||
|
||||
#if headers.hasHeader("Content-Disposition"):
|
||||
# disposition = headers.getRawHeaders("Content-Disposition")[0]
|
||||
# TODO(markjh): parse content-dispostion
|
||||
|
||||
media_id = random_string(24)
|
||||
|
||||
fname = self.filepaths.local_media_filepath(media_id)
|
||||
self._makedirs(fname)
|
||||
|
||||
# This shouldn't block for very long because the content will have
|
||||
# already been uploaded at this point.
|
||||
with open(fname, "wb") as f:
|
||||
f.write(request.content.read())
|
||||
|
||||
yield self.store.store_local_media(
|
||||
media_id=media_id,
|
||||
media_type=media_type,
|
||||
time_now_ms=self.clock.time_msec(),
|
||||
upload_name=None,
|
||||
media_length=content_length,
|
||||
user_id=auth_user,
|
||||
)
|
||||
media_info = {
|
||||
"media_type": media_type,
|
||||
"media_length": content_length,
|
||||
}
|
||||
|
||||
yield self._generate_local_thumbnails(media_id, media_info)
|
||||
|
||||
content_uri = "mxc://%s/%s" % (self.server_name, media_id)
|
||||
|
||||
respond_with_json(
|
||||
request, 200, {"content_uri": content_uri}, send_cors=True
|
||||
)
|
||||
except CodeMessageException as e:
|
||||
logger.exception(e)
|
||||
respond_with_json(request, e.code, cs_exception(e), send_cors=True)
|
||||
except:
|
||||
logger.exception("Failed to store file")
|
||||
respond_with_json(
|
||||
request,
|
||||
500,
|
||||
{"error": "Internal server error"},
|
||||
send_cors=True
|
||||
)
|
@ -146,7 +146,11 @@ class Notifier(object):
|
||||
Will wake up all listeners for the given users and rooms.
|
||||
"""
|
||||
yield run_on_reactor()
|
||||
|
||||
# TODO(paul): This is horrible, having to manually list every event
|
||||
# source here individually
|
||||
presence_source = self.event_sources.sources["presence"]
|
||||
typing_source = self.event_sources.sources["typing"]
|
||||
|
||||
listeners = set()
|
||||
|
||||
@ -158,19 +162,33 @@ class Notifier(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def notify(listener):
|
||||
events, end_key = yield presence_source.get_new_events_for_user(
|
||||
listener.user,
|
||||
listener.from_token.presence_key,
|
||||
listener.limit,
|
||||
presence_events, presence_end_key = (
|
||||
yield presence_source.get_new_events_for_user(
|
||||
listener.user,
|
||||
listener.from_token.presence_key,
|
||||
listener.limit,
|
||||
)
|
||||
)
|
||||
typing_events, typing_end_key = (
|
||||
yield typing_source.get_new_events_for_user(
|
||||
listener.user,
|
||||
listener.from_token.typing_key,
|
||||
listener.limit,
|
||||
)
|
||||
)
|
||||
|
||||
if events:
|
||||
if presence_events or typing_events:
|
||||
end_token = listener.from_token.copy_and_replace(
|
||||
"presence_key", end_key
|
||||
"presence_key", presence_end_key
|
||||
).copy_and_replace(
|
||||
"typing_key", typing_end_key
|
||||
)
|
||||
|
||||
listener.notify(
|
||||
self, events, listener.from_token, end_token
|
||||
self,
|
||||
presence_events + typing_events,
|
||||
listener.from_token,
|
||||
end_token
|
||||
)
|
||||
|
||||
def eb(failure):
|
||||
|
@ -28,7 +28,7 @@ class RestServletFactory(object):
|
||||
speaking, they serve as wrappers around events and the handlers that
|
||||
process them.
|
||||
|
||||
See synapse.api.events for information on synapse events.
|
||||
See synapse.events for information on synapse events.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
|
@ -35,7 +35,7 @@ class WhoisRestServlet(RestServlet):
|
||||
if not is_admin and target_user != auth_user:
|
||||
raise AuthError(403, "You are not a server admin")
|
||||
|
||||
if not target_user.is_mine:
|
||||
if not self.hs.is_mine(target_user):
|
||||
raise SynapseError(400, "Can only whois a local user")
|
||||
|
||||
ret = yield self.handlers.admin_handler.get_whois(target_user)
|
||||
|
@ -63,12 +63,10 @@ class RestServlet(object):
|
||||
self.hs = hs
|
||||
|
||||
self.handlers = hs.get_handlers()
|
||||
self.event_factory = hs.get_event_factory()
|
||||
self.builder_factory = hs.get_event_builder_factory()
|
||||
self.auth = hs.get_auth()
|
||||
self.txns = HttpTransactionStore()
|
||||
|
||||
self.validator = hs.get_event_validator()
|
||||
|
||||
def register(self, http_server):
|
||||
""" Register this servlet with the given HTTP server. """
|
||||
if hasattr(self, "PATTERN"):
|
||||
|
@ -21,7 +21,6 @@ from base import RestServlet, client_path_pattern
|
||||
|
||||
import json
|
||||
import logging
|
||||
import urllib
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -36,9 +35,7 @@ class ClientDirectoryServer(RestServlet):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_alias):
|
||||
room_alias = self.hs.parse_roomalias(
|
||||
urllib.unquote(room_alias).decode("utf-8")
|
||||
)
|
||||
room_alias = self.hs.parse_roomalias(room_alias)
|
||||
|
||||
dir_handler = self.handlers.directory_handler
|
||||
res = yield dir_handler.get_association(room_alias)
|
||||
@ -56,9 +53,7 @@ class ClientDirectoryServer(RestServlet):
|
||||
|
||||
logger.debug("Got content: %s", content)
|
||||
|
||||
room_alias = self.hs.parse_roomalias(
|
||||
urllib.unquote(room_alias).decode("utf-8")
|
||||
)
|
||||
room_alias = self.hs.parse_roomalias(room_alias)
|
||||
|
||||
logger.debug("Got room name: %s", room_alias.to_string())
|
||||
|
||||
@ -97,9 +92,7 @@ class ClientDirectoryServer(RestServlet):
|
||||
|
||||
dir_handler = self.handlers.directory_handler
|
||||
|
||||
room_alias = self.hs.parse_roomalias(
|
||||
urllib.unquote(room_alias).decode("utf-8")
|
||||
)
|
||||
room_alias = self.hs.parse_roomalias(room_alias)
|
||||
|
||||
yield dir_handler.delete_association(
|
||||
user.to_string(), room_alias
|
||||
|
@ -47,8 +47,8 @@ class LoginRestServlet(RestServlet):
|
||||
@defer.inlineCallbacks
|
||||
def do_password_login(self, login_submission):
|
||||
if not login_submission["user"].startswith('@'):
|
||||
login_submission["user"] = UserID.create_local(
|
||||
login_submission["user"], self.hs).to_string()
|
||||
login_submission["user"] = UserID.create(
|
||||
login_submission["user"], self.hs.hostname).to_string()
|
||||
|
||||
handler = self.handlers.login_handler
|
||||
token = yield handler.login(
|
||||
|
@ -22,7 +22,6 @@ from base import RestServlet, client_path_pattern
|
||||
|
||||
import json
|
||||
import logging
|
||||
import urllib
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -33,7 +32,6 @@ class PresenceStatusRestServlet(RestServlet):
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
auth_user = yield self.auth.get_user_by_req(request)
|
||||
user_id = urllib.unquote(user_id)
|
||||
user = self.hs.parse_userid(user_id)
|
||||
|
||||
state = yield self.handlers.presence_handler.get_state(
|
||||
@ -44,7 +42,6 @@ class PresenceStatusRestServlet(RestServlet):
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, user_id):
|
||||
auth_user = yield self.auth.get_user_by_req(request)
|
||||
user_id = urllib.unquote(user_id)
|
||||
user = self.hs.parse_userid(user_id)
|
||||
|
||||
state = {}
|
||||
@ -80,10 +77,9 @@ class PresenceListRestServlet(RestServlet):
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
auth_user = yield self.auth.get_user_by_req(request)
|
||||
user_id = urllib.unquote(user_id)
|
||||
user = self.hs.parse_userid(user_id)
|
||||
|
||||
if not user.is_mine:
|
||||
if not self.hs.is_mine(user):
|
||||
raise SynapseError(400, "User not hosted on this Home Server")
|
||||
|
||||
if auth_user != user:
|
||||
@ -101,10 +97,9 @@ class PresenceListRestServlet(RestServlet):
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, user_id):
|
||||
auth_user = yield self.auth.get_user_by_req(request)
|
||||
user_id = urllib.unquote(user_id)
|
||||
user = self.hs.parse_userid(user_id)
|
||||
|
||||
if not user.is_mine:
|
||||
if not self.hs.is_mine(user):
|
||||
raise SynapseError(400, "User not hosted on this Home Server")
|
||||
|
||||
if auth_user != user:
|
||||
|
@ -19,7 +19,6 @@ from twisted.internet import defer
|
||||
from base import RestServlet, client_path_pattern
|
||||
|
||||
import json
|
||||
import urllib
|
||||
|
||||
|
||||
class ProfileDisplaynameRestServlet(RestServlet):
|
||||
@ -27,7 +26,6 @@ class ProfileDisplaynameRestServlet(RestServlet):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
user_id = urllib.unquote(user_id)
|
||||
user = self.hs.parse_userid(user_id)
|
||||
|
||||
displayname = yield self.handlers.profile_handler.get_displayname(
|
||||
@ -39,7 +37,6 @@ class ProfileDisplaynameRestServlet(RestServlet):
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, user_id):
|
||||
auth_user = yield self.auth.get_user_by_req(request)
|
||||
user_id = urllib.unquote(user_id)
|
||||
user = self.hs.parse_userid(user_id)
|
||||
|
||||
try:
|
||||
@ -62,7 +59,6 @@ class ProfileAvatarURLRestServlet(RestServlet):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
user_id = urllib.unquote(user_id)
|
||||
user = self.hs.parse_userid(user_id)
|
||||
|
||||
avatar_url = yield self.handlers.profile_handler.get_avatar_url(
|
||||
@ -74,7 +70,6 @@ class ProfileAvatarURLRestServlet(RestServlet):
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, user_id):
|
||||
auth_user = yield self.auth.get_user_by_req(request)
|
||||
user_id = urllib.unquote(user_id)
|
||||
user = self.hs.parse_userid(user_id)
|
||||
|
||||
try:
|
||||
@ -97,7 +92,6 @@ class ProfileRestServlet(RestServlet):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
user_id = urllib.unquote(user_id)
|
||||
user = self.hs.parse_userid(user_id)
|
||||
|
||||
displayname = yield self.handlers.profile_handler.get_displayname(
|
||||
|
@ -21,6 +21,8 @@ from synapse.api.constants import LoginType
|
||||
from base import RestServlet, client_path_pattern
|
||||
import synapse.util.stringutils as stringutils
|
||||
|
||||
from synapse.util.async import run_on_reactor
|
||||
|
||||
from hashlib import sha1
|
||||
import hmac
|
||||
import json
|
||||
@ -233,7 +235,7 @@ class RegisterRestServlet(RestServlet):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_password(self, request, register_json, session):
|
||||
yield
|
||||
yield run_on_reactor()
|
||||
if (self.hs.config.enable_registration_captcha and
|
||||
not session[LoginType.RECAPTCHA]):
|
||||
# captcha should've been done by this stage!
|
||||
|
@ -19,8 +19,7 @@ from twisted.internet import defer
|
||||
from base import RestServlet, client_path_pattern
|
||||
from synapse.api.errors import SynapseError, Codes
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.api.events.room import RoomMemberEvent, RoomRedactionEvent
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
|
||||
import json
|
||||
import logging
|
||||
@ -129,9 +128,9 @@ class RoomStateEventRestServlet(RestServlet):
|
||||
msg_handler = self.handlers.message_handler
|
||||
data = yield msg_handler.get_room_data(
|
||||
user_id=user.to_string(),
|
||||
room_id=urllib.unquote(room_id),
|
||||
event_type=urllib.unquote(event_type),
|
||||
state_key=urllib.unquote(state_key),
|
||||
room_id=room_id,
|
||||
event_type=event_type,
|
||||
state_key=state_key,
|
||||
)
|
||||
|
||||
if not data:
|
||||
@ -143,32 +142,23 @@ class RoomStateEventRestServlet(RestServlet):
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, room_id, event_type, state_key):
|
||||
user = yield self.auth.get_user_by_req(request)
|
||||
event_type = urllib.unquote(event_type)
|
||||
|
||||
content = _parse_json(request)
|
||||
|
||||
event = self.event_factory.create_event(
|
||||
etype=event_type, # already urldecoded
|
||||
content=content,
|
||||
room_id=urllib.unquote(room_id),
|
||||
user_id=user.to_string(),
|
||||
state_key=urllib.unquote(state_key)
|
||||
)
|
||||
event_dict = {
|
||||
"type": event_type,
|
||||
"content": content,
|
||||
"room_id": room_id,
|
||||
"sender": user.to_string(),
|
||||
}
|
||||
|
||||
self.validator.validate(event)
|
||||
if state_key is not None:
|
||||
event_dict["state_key"] = state_key
|
||||
|
||||
if event_type == RoomMemberEvent.TYPE:
|
||||
# membership events are special
|
||||
handler = self.handlers.room_member_handler
|
||||
yield handler.change_membership(event)
|
||||
defer.returnValue((200, {}))
|
||||
else:
|
||||
# store random bits of state
|
||||
msg_handler = self.handlers.message_handler
|
||||
yield msg_handler.store_room_data(
|
||||
event=event
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
msg_handler = self.handlers.message_handler
|
||||
yield msg_handler.create_and_send_event(event_dict)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
# TODO: Needs unit testing for generic events + feedback
|
||||
@ -184,17 +174,15 @@ class RoomSendEventRestServlet(RestServlet):
|
||||
user = yield self.auth.get_user_by_req(request)
|
||||
content = _parse_json(request)
|
||||
|
||||
event = self.event_factory.create_event(
|
||||
etype=urllib.unquote(event_type),
|
||||
room_id=urllib.unquote(room_id),
|
||||
user_id=user.to_string(),
|
||||
content=content
|
||||
)
|
||||
|
||||
self.validator.validate(event)
|
||||
|
||||
msg_handler = self.handlers.message_handler
|
||||
yield msg_handler.send_message(event)
|
||||
event = yield msg_handler.create_and_send_event(
|
||||
{
|
||||
"type": event_type,
|
||||
"content": content,
|
||||
"room_id": room_id,
|
||||
"sender": user.to_string(),
|
||||
}
|
||||
)
|
||||
|
||||
defer.returnValue((200, {"event_id": event.event_id}))
|
||||
|
||||
@ -235,14 +223,10 @@ class JoinRoomAliasServlet(RestServlet):
|
||||
identifier = None
|
||||
is_room_alias = False
|
||||
try:
|
||||
identifier = self.hs.parse_roomalias(
|
||||
urllib.unquote(room_identifier)
|
||||
)
|
||||
identifier = self.hs.parse_roomalias(room_identifier)
|
||||
is_room_alias = True
|
||||
except SynapseError:
|
||||
identifier = self.hs.parse_roomid(
|
||||
urllib.unquote(room_identifier)
|
||||
)
|
||||
identifier = self.hs.parse_roomid(room_identifier)
|
||||
|
||||
# TODO: Support for specifying the home server to join with?
|
||||
|
||||
@ -251,18 +235,17 @@ class JoinRoomAliasServlet(RestServlet):
|
||||
ret_dict = yield handler.join_room_alias(user, identifier)
|
||||
defer.returnValue((200, ret_dict))
|
||||
else: # room id
|
||||
event = self.event_factory.create_event(
|
||||
etype=RoomMemberEvent.TYPE,
|
||||
content={"membership": Membership.JOIN},
|
||||
room_id=urllib.unquote(identifier.to_string()),
|
||||
user_id=user.to_string(),
|
||||
state_key=user.to_string()
|
||||
msg_handler = self.handlers.message_handler
|
||||
yield msg_handler.create_and_send_event(
|
||||
{
|
||||
"type": EventTypes.Member,
|
||||
"content": {"membership": Membership.JOIN},
|
||||
"room_id": identifier.to_string(),
|
||||
"sender": user.to_string(),
|
||||
"state_key": user.to_string(),
|
||||
}
|
||||
)
|
||||
|
||||
self.validator.validate(event)
|
||||
|
||||
handler = self.handlers.room_member_handler
|
||||
yield handler.change_membership(event)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -301,7 +284,7 @@ class RoomMemberListRestServlet(RestServlet):
|
||||
user = yield self.auth.get_user_by_req(request)
|
||||
handler = self.handlers.room_member_handler
|
||||
members = yield handler.get_room_members_as_pagination_chunk(
|
||||
room_id=urllib.unquote(room_id),
|
||||
room_id=room_id,
|
||||
user_id=user.to_string())
|
||||
|
||||
for event in members["chunk"]:
|
||||
@ -327,13 +310,13 @@ class RoomMessageListRestServlet(RestServlet):
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
user = yield self.auth.get_user_by_req(request)
|
||||
pagination_config = PaginationConfig.from_request(request,
|
||||
default_limit=10,
|
||||
pagination_config = PaginationConfig.from_request(
|
||||
request, default_limit=10,
|
||||
)
|
||||
with_feedback = "feedback" in request.args
|
||||
handler = self.handlers.message_handler
|
||||
msgs = yield handler.get_messages(
|
||||
room_id=urllib.unquote(room_id),
|
||||
room_id=room_id,
|
||||
user_id=user.to_string(),
|
||||
pagin_config=pagination_config,
|
||||
feedback=with_feedback)
|
||||
@ -351,7 +334,7 @@ class RoomStateRestServlet(RestServlet):
|
||||
handler = self.handlers.message_handler
|
||||
# Get all the current state for this room
|
||||
events = yield handler.get_state_events(
|
||||
room_id=urllib.unquote(room_id),
|
||||
room_id=room_id,
|
||||
user_id=user.to_string(),
|
||||
)
|
||||
defer.returnValue((200, events))
|
||||
@ -366,7 +349,7 @@ class RoomInitialSyncRestServlet(RestServlet):
|
||||
user = yield self.auth.get_user_by_req(request)
|
||||
pagination_config = PaginationConfig.from_request(request)
|
||||
content = yield self.handlers.message_handler.room_initial_sync(
|
||||
room_id=urllib.unquote(room_id),
|
||||
room_id=room_id,
|
||||
user_id=user.to_string(),
|
||||
pagin_config=pagination_config,
|
||||
)
|
||||
@ -378,8 +361,10 @@ class RoomTriggerBackfill(RestServlet):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
remote_server = urllib.unquote(request.args["remote"][0])
|
||||
room_id = urllib.unquote(room_id)
|
||||
remote_server = urllib.unquote(
|
||||
request.args["remote"][0]
|
||||
).decode("UTF-8")
|
||||
|
||||
limit = int(request.args["limit"][0])
|
||||
|
||||
handler = self.handlers.federation_handler
|
||||
@ -414,18 +399,17 @@ class RoomMembershipRestServlet(RestServlet):
|
||||
if membership_action == "kick":
|
||||
membership_action = "leave"
|
||||
|
||||
event = self.event_factory.create_event(
|
||||
etype=RoomMemberEvent.TYPE,
|
||||
content={"membership": unicode(membership_action)},
|
||||
room_id=urllib.unquote(room_id),
|
||||
user_id=user.to_string(),
|
||||
state_key=state_key
|
||||
msg_handler = self.handlers.message_handler
|
||||
yield msg_handler.create_and_send_event(
|
||||
{
|
||||
"type": EventTypes.Member,
|
||||
"content": {"membership": unicode(membership_action)},
|
||||
"room_id": room_id,
|
||||
"sender": user.to_string(),
|
||||
"state_key": state_key,
|
||||
}
|
||||
)
|
||||
|
||||
self.validator.validate(event)
|
||||
|
||||
handler = self.handlers.room_member_handler
|
||||
yield handler.change_membership(event)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -453,18 +437,16 @@ class RoomRedactEventRestServlet(RestServlet):
|
||||
user = yield self.auth.get_user_by_req(request)
|
||||
content = _parse_json(request)
|
||||
|
||||
event = self.event_factory.create_event(
|
||||
etype=RoomRedactionEvent.TYPE,
|
||||
room_id=urllib.unquote(room_id),
|
||||
user_id=user.to_string(),
|
||||
content=content,
|
||||
redacts=urllib.unquote(event_id),
|
||||
)
|
||||
|
||||
self.validator.validate(event)
|
||||
|
||||
msg_handler = self.handlers.message_handler
|
||||
yield msg_handler.send_message(event)
|
||||
event = yield msg_handler.create_and_send_event(
|
||||
{
|
||||
"type": EventTypes.Redaction,
|
||||
"content": content,
|
||||
"room_id": room_id,
|
||||
"sender": user.to_string(),
|
||||
"redacts": event_id,
|
||||
}
|
||||
)
|
||||
|
||||
defer.returnValue((200, {"event_id": event.event_id}))
|
||||
|
||||
@ -483,6 +465,39 @@ class RoomRedactEventRestServlet(RestServlet):
|
||||
defer.returnValue(response)
|
||||
|
||||
|
||||
class RoomTypingRestServlet(RestServlet):
|
||||
PATTERN = client_path_pattern(
|
||||
"/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$"
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, room_id, user_id):
|
||||
auth_user = yield self.auth.get_user_by_req(request)
|
||||
|
||||
room_id = urllib.unquote(room_id)
|
||||
target_user = self.hs.parse_userid(urllib.unquote(user_id))
|
||||
|
||||
content = _parse_json(request)
|
||||
|
||||
typing_handler = self.handlers.typing_notification_handler
|
||||
|
||||
if content["typing"]:
|
||||
yield typing_handler.started_typing(
|
||||
target_user=target_user,
|
||||
auth_user=auth_user,
|
||||
room_id=room_id,
|
||||
timeout=content.get("timeout", 30000),
|
||||
)
|
||||
else:
|
||||
yield typing_handler.stopped_typing(
|
||||
target_user=target_user,
|
||||
auth_user=auth_user,
|
||||
room_id=room_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
def _parse_json(request):
|
||||
try:
|
||||
content = json.loads(request.content.read())
|
||||
@ -538,3 +553,4 @@ def register_servlets(hs, http_server):
|
||||
RoomStateRestServlet(hs).register(http_server)
|
||||
RoomInitialSyncRestServlet(hs).register(http_server)
|
||||
RoomRedactEventRestServlet(hs).register(http_server)
|
||||
RoomTypingRestServlet(hs).register(http_server)
|
||||
|
@ -20,6 +20,7 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# FIXME: elsewhere we use FooStore to indicate something in the storage layer...
|
||||
class HttpTransactionStore(object):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -20,9 +20,7 @@
|
||||
|
||||
# Imports required for the default HomeServer() implementation
|
||||
from synapse.federation import initialize_http_replication
|
||||
from synapse.api.events import serialize_event
|
||||
from synapse.api.events.factory import EventFactory
|
||||
from synapse.api.events.validator import EventValidator
|
||||
from synapse.events.utils import serialize_event
|
||||
from synapse.notifier import Notifier
|
||||
from synapse.api.auth import Auth
|
||||
from synapse.handlers import Handlers
|
||||
@ -36,6 +34,7 @@ from synapse.util.lockutils import LockManager
|
||||
from synapse.streams.events import EventSources
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.crypto.keyring import Keyring
|
||||
from synapse.events.builder import EventBuilderFactory
|
||||
|
||||
|
||||
class BaseHomeServer(object):
|
||||
@ -65,7 +64,6 @@ class BaseHomeServer(object):
|
||||
'persistence_service',
|
||||
'replication_layer',
|
||||
'datastore',
|
||||
'event_factory',
|
||||
'handlers',
|
||||
'auth',
|
||||
'rest_servlet_factory',
|
||||
@ -78,10 +76,11 @@ class BaseHomeServer(object):
|
||||
'resource_for_web_client',
|
||||
'resource_for_content_repo',
|
||||
'resource_for_server_key',
|
||||
'resource_for_media_repository',
|
||||
'event_sources',
|
||||
'ratelimiter',
|
||||
'keyring',
|
||||
'event_validator',
|
||||
'event_builder_factory',
|
||||
]
|
||||
|
||||
def __init__(self, hostname, **kwargs):
|
||||
@ -133,22 +132,22 @@ class BaseHomeServer(object):
|
||||
def parse_userid(self, s):
|
||||
"""Parse the string given by 's' as a User ID and return a UserID
|
||||
object."""
|
||||
return UserID.from_string(s, hs=self)
|
||||
return UserID.from_string(s)
|
||||
|
||||
def parse_roomalias(self, s):
|
||||
"""Parse the string given by 's' as a Room Alias and return a RoomAlias
|
||||
object."""
|
||||
return RoomAlias.from_string(s, hs=self)
|
||||
return RoomAlias.from_string(s)
|
||||
|
||||
def parse_roomid(self, s):
|
||||
"""Parse the string given by 's' as a Room ID and return a RoomID
|
||||
object."""
|
||||
return RoomID.from_string(s, hs=self)
|
||||
return RoomID.from_string(s)
|
||||
|
||||
def parse_eventid(self, s):
|
||||
"""Parse the string given by 's' as a Event ID and return a EventID
|
||||
object."""
|
||||
return EventID.from_string(s, hs=self)
|
||||
return EventID.from_string(s)
|
||||
|
||||
def serialize_event(self, e):
|
||||
return serialize_event(self, e)
|
||||
@ -165,6 +164,9 @@ class BaseHomeServer(object):
|
||||
|
||||
return ip_addr
|
||||
|
||||
def is_mine(self, domain_specific_string):
|
||||
return domain_specific_string.domain == self.hostname
|
||||
|
||||
# Build magic accessors for every dependency
|
||||
for depname in BaseHomeServer.DEPENDENCIES:
|
||||
BaseHomeServer._make_dependency_method(depname)
|
||||
@ -192,9 +194,6 @@ class HomeServer(BaseHomeServer):
|
||||
def build_datastore(self):
|
||||
return DataStore(self)
|
||||
|
||||
def build_event_factory(self):
|
||||
return EventFactory(self)
|
||||
|
||||
def build_handlers(self):
|
||||
return Handlers(self)
|
||||
|
||||
@ -225,8 +224,11 @@ class HomeServer(BaseHomeServer):
|
||||
def build_keyring(self):
|
||||
return Keyring(self)
|
||||
|
||||
def build_event_validator(self):
|
||||
return EventValidator(self)
|
||||
def build_event_builder_factory(self):
|
||||
return EventBuilderFactory(
|
||||
clock=self.get_clock(),
|
||||
hostname=self.hostname,
|
||||
)
|
||||
|
||||
def register_servlets(self):
|
||||
""" Register all servlets associated with this HomeServer.
|
||||
|
175
synapse/state.py
175
synapse/state.py
@ -18,11 +18,11 @@ from twisted.internet import defer
|
||||
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.api.events.room import RoomPowerLevelsEvent
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.events.snapshot import EventContext
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import hashlib
|
||||
|
||||
@ -43,71 +43,6 @@ class StateHandler(object):
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def annotate_event_with_state(self, event, old_state=None):
|
||||
""" Annotates the event with the current state events as of that event.
|
||||
|
||||
This method adds three new attributes to the event:
|
||||
* `state_events`: The state up to and including the event. Encoded
|
||||
as a dict mapping tuple (type, state_key) -> event.
|
||||
* `old_state_events`: The state up to, but excluding, the event.
|
||||
Encoded similarly as `state_events`.
|
||||
* `state_group`: If there is an existing state group that can be
|
||||
used, then return that. Otherwise return `None`. See state
|
||||
storage for more information.
|
||||
|
||||
If the argument `old_state` is given (in the form of a list of
|
||||
events), then they are used as a the values for `old_state_events` and
|
||||
the value for `state_events` is generated from it. `state_group` is
|
||||
set to None.
|
||||
|
||||
This needs to be called before persisting the event.
|
||||
"""
|
||||
yield run_on_reactor()
|
||||
|
||||
if old_state:
|
||||
event.state_group = None
|
||||
event.old_state_events = {
|
||||
(s.type, s.state_key): s for s in old_state
|
||||
}
|
||||
event.state_events = event.old_state_events
|
||||
|
||||
if hasattr(event, "state_key"):
|
||||
event.state_events[(event.type, event.state_key)] = event
|
||||
|
||||
defer.returnValue(False)
|
||||
return
|
||||
|
||||
if hasattr(event, "outlier") and event.outlier:
|
||||
event.state_group = None
|
||||
event.old_state_events = None
|
||||
event.state_events = None
|
||||
defer.returnValue(False)
|
||||
return
|
||||
|
||||
ids = [e for e, _ in event.prev_events]
|
||||
|
||||
ret = yield self.resolve_state_groups(ids)
|
||||
state_group, new_state = ret
|
||||
|
||||
event.old_state_events = copy.deepcopy(new_state)
|
||||
|
||||
if hasattr(event, "state_key"):
|
||||
key = (event.type, event.state_key)
|
||||
if key in new_state:
|
||||
event.replaces_state = new_state[key].event_id
|
||||
new_state[key] = event
|
||||
elif state_group:
|
||||
event.state_group = state_group
|
||||
event.state_events = new_state
|
||||
defer.returnValue(False)
|
||||
|
||||
event.state_group = None
|
||||
event.state_events = new_state
|
||||
|
||||
defer.returnValue(hasattr(event, "state_key"))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_current_state(self, room_id, event_type=None, state_key=""):
|
||||
""" Returns the current state for the room as a list. This is done by
|
||||
@ -135,9 +70,92 @@ class StateHandler(object):
|
||||
|
||||
defer.returnValue(res[1].values())
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def compute_event_context(self, event, old_state=None):
|
||||
""" Fills out the context with the `current state` of the graph. The
|
||||
`current state` here is defined to be the state of the event graph
|
||||
just before the event - i.e. it never includes `event`
|
||||
|
||||
If `event` has `auth_events` then this will also fill out the
|
||||
`auth_events` field on `context` from the `current_state`.
|
||||
|
||||
Args:
|
||||
event (EventBase)
|
||||
Returns:
|
||||
an EventContext
|
||||
"""
|
||||
context = EventContext()
|
||||
|
||||
yield run_on_reactor()
|
||||
|
||||
if old_state:
|
||||
context.current_state = {
|
||||
(s.type, s.state_key): s for s in old_state
|
||||
}
|
||||
context.state_group = None
|
||||
|
||||
if hasattr(event, "auth_events") and event.auth_events:
|
||||
auth_ids = zip(*event.auth_events)[0]
|
||||
context.auth_events = {
|
||||
k: v
|
||||
for k, v in context.current_state.items()
|
||||
if v.event_id in auth_ids
|
||||
}
|
||||
else:
|
||||
context.auth_events = {}
|
||||
|
||||
if event.is_state():
|
||||
key = (event.type, event.state_key)
|
||||
if key in context.current_state:
|
||||
replaces = context.current_state[key]
|
||||
if replaces.event_id != event.event_id: # Paranoia check
|
||||
event.unsigned["replaces_state"] = replaces.event_id
|
||||
|
||||
context.prev_state_events = []
|
||||
defer.returnValue(context)
|
||||
|
||||
if event.is_state():
|
||||
ret = yield self.resolve_state_groups(
|
||||
[e for e, _ in event.prev_events],
|
||||
event_type=event.type,
|
||||
state_key=event.state_key,
|
||||
)
|
||||
else:
|
||||
ret = yield self.resolve_state_groups(
|
||||
[e for e, _ in event.prev_events],
|
||||
)
|
||||
|
||||
group, curr_state, prev_state = ret
|
||||
|
||||
context.current_state = curr_state
|
||||
context.state_group = group if not event.is_state() else None
|
||||
|
||||
prev_state = yield self.store.add_event_hashes(
|
||||
prev_state
|
||||
)
|
||||
|
||||
if event.is_state():
|
||||
key = (event.type, event.state_key)
|
||||
if key in context.current_state:
|
||||
replaces = context.current_state[key]
|
||||
event.unsigned["replaces_state"] = replaces.event_id
|
||||
|
||||
if hasattr(event, "auth_events") and event.auth_events:
|
||||
auth_ids = zip(*event.auth_events)[0]
|
||||
context.auth_events = {
|
||||
k: v
|
||||
for k, v in context.current_state.items()
|
||||
if v.event_id in auth_ids
|
||||
}
|
||||
else:
|
||||
context.auth_events = {}
|
||||
|
||||
context.prev_state_events = prev_state
|
||||
defer.returnValue(context)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def resolve_state_groups(self, event_ids):
|
||||
def resolve_state_groups(self, event_ids, event_type=None, state_key=""):
|
||||
""" Given a list of event_ids this method fetches the state at each
|
||||
event, resolves conflicts between them and returns them.
|
||||
|
||||
@ -156,7 +174,14 @@ class StateHandler(object):
|
||||
(e.type, e.state_key): e
|
||||
for e in state_list
|
||||
}
|
||||
defer.returnValue((name, state))
|
||||
prev_state = state.get((event_type, state_key), None)
|
||||
if prev_state:
|
||||
prev_state = prev_state.event_id
|
||||
prev_states = [prev_state]
|
||||
else:
|
||||
prev_states = []
|
||||
|
||||
defer.returnValue((name, state, prev_states))
|
||||
|
||||
state = {}
|
||||
for group, g_state in state_groups.items():
|
||||
@ -177,6 +202,14 @@ class StateHandler(object):
|
||||
if len(v.values()) > 1
|
||||
}
|
||||
|
||||
if event_type:
|
||||
prev_states_events = conflicted_state.get(
|
||||
(event_type, state_key), []
|
||||
)
|
||||
prev_states = [s.event_id for s in prev_states_events]
|
||||
else:
|
||||
prev_states = []
|
||||
|
||||
try:
|
||||
new_state = {}
|
||||
new_state.update(unconflicted_state)
|
||||
@ -186,11 +219,11 @@ class StateHandler(object):
|
||||
logger.exception("Failed to resolve state")
|
||||
raise
|
||||
|
||||
defer.returnValue((None, new_state))
|
||||
defer.returnValue((None, new_state, prev_states))
|
||||
|
||||
def _get_power_level_from_event_state(self, event, user_id):
|
||||
if hasattr(event, "old_state_events") and event.old_state_events:
|
||||
key = (RoomPowerLevelsEvent.TYPE, "", )
|
||||
key = (EventTypes.PowerLevels, "", )
|
||||
power_level_event = event.old_state_events.get(key)
|
||||
level = None
|
||||
if power_level_event:
|
||||
|
@ -15,12 +15,8 @@
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.events.room import (
|
||||
RoomMemberEvent, RoomTopicEvent, FeedbackEvent, RoomNameEvent,
|
||||
RoomRedactionEvent,
|
||||
)
|
||||
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.api.constants import EventTypes
|
||||
|
||||
from .directory import DirectoryStore
|
||||
from .feedback import FeedbackStore
|
||||
@ -33,11 +29,13 @@ from .stream import StreamStore
|
||||
from .transactions import TransactionStore
|
||||
from .keys import KeyStore
|
||||
from .event_federation import EventFederationStore
|
||||
from .media_repository import MediaRepositoryStore
|
||||
|
||||
from .state import StateStore
|
||||
from .signatures import SignatureStore
|
||||
|
||||
from syutil.base64util import decode_base64
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
|
||||
from synapse.crypto.event_signing import compute_event_reference_hash
|
||||
|
||||
@ -62,12 +60,13 @@ SCHEMAS = [
|
||||
"state",
|
||||
"event_edges",
|
||||
"event_signatures",
|
||||
"media_repository",
|
||||
]
|
||||
|
||||
|
||||
# Remember to update this number every time an incompatible change is made to
|
||||
# database schema files, so the users will be informed on server restarts.
|
||||
SCHEMA_VERSION = 8
|
||||
SCHEMA_VERSION = 10
|
||||
|
||||
|
||||
class _RollbackButIsFineException(Exception):
|
||||
@ -81,11 +80,12 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
RegistrationStore, StreamStore, ProfileStore, FeedbackStore,
|
||||
PresenceStore, TransactionStore,
|
||||
DirectoryStore, KeyStore, StateStore, SignatureStore,
|
||||
EventFederationStore, ):
|
||||
EventFederationStore,
|
||||
MediaRepositoryStore,
|
||||
):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(DataStore, self).__init__(hs)
|
||||
self.event_factory = hs.get_event_factory()
|
||||
self.hs = hs
|
||||
|
||||
self.min_token_deferred = self._get_min_token()
|
||||
@ -93,8 +93,8 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def persist_event(self, event, backfilled=False, is_new_state=True,
|
||||
current_state=None):
|
||||
def persist_event(self, event, context, backfilled=False,
|
||||
is_new_state=True, current_state=None):
|
||||
stream_ordering = None
|
||||
if backfilled:
|
||||
if not self.min_token_deferred.called:
|
||||
@ -107,6 +107,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
"persist_event",
|
||||
self._persist_event_txn,
|
||||
event=event,
|
||||
context=context,
|
||||
backfilled=backfilled,
|
||||
stream_ordering=stream_ordering,
|
||||
is_new_state=is_new_state,
|
||||
@ -117,50 +118,64 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_event(self, event_id, allow_none=False):
|
||||
events_dict = yield self._simple_select_one(
|
||||
"events",
|
||||
{"event_id": event_id},
|
||||
[
|
||||
"event_id",
|
||||
"type",
|
||||
"room_id",
|
||||
"content",
|
||||
"unrecognized_keys",
|
||||
"depth",
|
||||
],
|
||||
allow_none=allow_none,
|
||||
)
|
||||
events = yield self._get_events([event_id])
|
||||
|
||||
if not events_dict:
|
||||
defer.returnValue(None)
|
||||
if not events:
|
||||
if allow_none:
|
||||
defer.returnValue(None)
|
||||
else:
|
||||
raise RuntimeError("Could not find event %s" % (event_id,))
|
||||
|
||||
event = yield self._parse_events([events_dict])
|
||||
defer.returnValue(event[0])
|
||||
defer.returnValue(events[0])
|
||||
|
||||
@log_function
|
||||
def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None,
|
||||
is_new_state=True, current_state=None):
|
||||
if event.type == RoomMemberEvent.TYPE:
|
||||
def _persist_event_txn(self, txn, event, context, backfilled,
|
||||
stream_ordering=None, is_new_state=True,
|
||||
current_state=None):
|
||||
if event.type == EventTypes.Member:
|
||||
self._store_room_member_txn(txn, event)
|
||||
elif event.type == FeedbackEvent.TYPE:
|
||||
elif event.type == EventTypes.Feedback:
|
||||
self._store_feedback_txn(txn, event)
|
||||
elif event.type == RoomNameEvent.TYPE:
|
||||
elif event.type == EventTypes.Name:
|
||||
self._store_room_name_txn(txn, event)
|
||||
elif event.type == RoomTopicEvent.TYPE:
|
||||
elif event.type == EventTypes.Topic:
|
||||
self._store_room_topic_txn(txn, event)
|
||||
elif event.type == RoomRedactionEvent.TYPE:
|
||||
elif event.type == EventTypes.Redaction:
|
||||
self._store_redaction(txn, event)
|
||||
|
||||
outlier = False
|
||||
if hasattr(event, "outlier"):
|
||||
outlier = event.outlier
|
||||
outlier = event.internal_metadata.is_outlier()
|
||||
|
||||
event_dict = {
|
||||
k: v
|
||||
for k, v in event.get_dict().items()
|
||||
if k not in [
|
||||
"redacted",
|
||||
"redacted_because",
|
||||
]
|
||||
}
|
||||
|
||||
metadata_json = encode_canonical_json(
|
||||
event.internal_metadata.get_dict()
|
||||
)
|
||||
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="event_json",
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"internal_metadata": metadata_json.decode("UTF-8"),
|
||||
"json": encode_canonical_json(event_dict).decode("UTF-8"),
|
||||
},
|
||||
or_replace=True,
|
||||
)
|
||||
|
||||
vals = {
|
||||
"topological_ordering": event.depth,
|
||||
"event_id": event.event_id,
|
||||
"type": event.type,
|
||||
"room_id": event.room_id,
|
||||
"content": json.dumps(event.content),
|
||||
"content": json.dumps(event.get_dict()["content"]),
|
||||
"processed": True,
|
||||
"outlier": outlier,
|
||||
"depth": event.depth,
|
||||
@ -171,7 +186,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
|
||||
unrec = {
|
||||
k: v
|
||||
for k, v in event.get_full_dict().items()
|
||||
for k, v in event.get_dict().items()
|
||||
if k not in vals.keys() and k not in [
|
||||
"redacted",
|
||||
"redacted_because",
|
||||
@ -206,7 +221,8 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
room_id=event.room_id,
|
||||
)
|
||||
|
||||
self._store_state_groups_txn(txn, event)
|
||||
if not outlier:
|
||||
self._store_state_groups_txn(txn, event, context)
|
||||
|
||||
if current_state:
|
||||
txn.execute(
|
||||
@ -300,16 +316,6 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
txn, event.event_id, hash_alg, hash_bytes,
|
||||
)
|
||||
|
||||
if hasattr(event, "signatures"):
|
||||
logger.debug("sigs: %s", event.signatures)
|
||||
for name, sigs in event.signatures.items():
|
||||
for key_id, signature_base64 in sigs.items():
|
||||
signature_bytes = decode_base64(signature_base64)
|
||||
self._store_event_signature_txn(
|
||||
txn, event.event_id, name, key_id,
|
||||
signature_bytes,
|
||||
)
|
||||
|
||||
for prev_event_id, prev_hashes in event.prev_events:
|
||||
for alg, hash_base64 in prev_hashes.items():
|
||||
hash_bytes = decode_base64(hash_base64)
|
||||
@ -411,86 +417,6 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
],
|
||||
)
|
||||
|
||||
def snapshot_room(self, event):
|
||||
"""Snapshot the room for an update by a user
|
||||
Args:
|
||||
room_id (synapse.types.RoomId): The room to snapshot.
|
||||
user_id (synapse.types.UserId): The user to snapshot the room for.
|
||||
state_type (str): Optional state type to snapshot.
|
||||
state_key (str): Optional state key to snapshot.
|
||||
Returns:
|
||||
synapse.storage.Snapshot: A snapshot of the state of the room.
|
||||
"""
|
||||
def _snapshot(txn):
|
||||
prev_events = self._get_latest_events_in_room(
|
||||
txn,
|
||||
event.room_id
|
||||
)
|
||||
|
||||
prev_state = None
|
||||
state_key = None
|
||||
if hasattr(event, "state_key"):
|
||||
state_key = event.state_key
|
||||
prev_state = self._get_latest_state_in_room(
|
||||
txn,
|
||||
event.room_id,
|
||||
type=event.type,
|
||||
state_key=state_key,
|
||||
)
|
||||
|
||||
return Snapshot(
|
||||
store=self,
|
||||
room_id=event.room_id,
|
||||
user_id=event.user_id,
|
||||
prev_events=prev_events,
|
||||
prev_state=prev_state,
|
||||
state_type=event.type,
|
||||
state_key=state_key,
|
||||
)
|
||||
|
||||
return self.runInteraction("snapshot_room", _snapshot)
|
||||
|
||||
|
||||
class Snapshot(object):
|
||||
"""Snapshot of the state of a room
|
||||
Args:
|
||||
store (DataStore): The datastore.
|
||||
room_id (RoomId): The room of the snapshot.
|
||||
user_id (UserId): The user this snapshot is for.
|
||||
prev_events (list): The list of event ids this snapshot is after.
|
||||
membership_state (RoomMemberEvent): The current state of the user in
|
||||
the room.
|
||||
state_type (str, optional): State type captured by the snapshot
|
||||
state_key (str, optional): State key captured by the snapshot
|
||||
prev_state_pdu (PduEntry, optional): pdu id of
|
||||
the previous value of the state type and key in the room.
|
||||
"""
|
||||
|
||||
def __init__(self, store, room_id, user_id, prev_events,
|
||||
prev_state, state_type=None, state_key=None):
|
||||
self.store = store
|
||||
self.room_id = room_id
|
||||
self.user_id = user_id
|
||||
self.prev_events = prev_events
|
||||
self.prev_state = prev_state
|
||||
self.state_type = state_type
|
||||
self.state_key = state_key
|
||||
|
||||
def fill_out_prev_events(self, event):
|
||||
if not hasattr(event, "prev_events"):
|
||||
event.prev_events = [
|
||||
(event_id, hashes)
|
||||
for event_id, hashes, _ in self.prev_events
|
||||
]
|
||||
|
||||
if self.prev_events:
|
||||
event.depth = max([int(v) for _, _, v in self.prev_events]) + 1
|
||||
else:
|
||||
event.depth = 0
|
||||
|
||||
if not hasattr(event, "prev_state") and self.prev_state is not None:
|
||||
event.prev_state = self.prev_state
|
||||
|
||||
|
||||
def schema_path(schema):
|
||||
""" Get a filesystem path for the named database schema
|
||||
@ -518,6 +444,14 @@ def read_schema(schema):
|
||||
return schema_file.read()
|
||||
|
||||
|
||||
class PrepareDatabaseException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UpgradeDatabaseException(PrepareDatabaseException):
|
||||
pass
|
||||
|
||||
|
||||
def prepare_database(db_conn):
|
||||
""" Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we
|
||||
don't have to worry about overwriting existing content.
|
||||
@ -542,6 +476,10 @@ def prepare_database(db_conn):
|
||||
|
||||
# Run every version since after the current version.
|
||||
for v in range(user_version + 1, SCHEMA_VERSION + 1):
|
||||
if v == 10:
|
||||
raise UpgradeDatabaseException(
|
||||
"No delta for version 10"
|
||||
)
|
||||
sql_script = read_schema("delta/v%d" % (v))
|
||||
c.executescript(sql_script)
|
||||
|
||||
|
@ -15,15 +15,14 @@
|
||||
import logging
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.api.events.utils import prune_event
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.events.utils import prune_event
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
|
||||
from syutil.base64util import encode_base64
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
@ -84,7 +83,6 @@ class SQLBaseStore(object):
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self._db_pool = hs.get_db_pool()
|
||||
self.event_factory = hs.get_event_factory()
|
||||
self._clock = hs.get_clock()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -436,42 +434,69 @@ class SQLBaseStore(object):
|
||||
|
||||
return self.runInteraction("_simple_max_id", func)
|
||||
|
||||
def _parse_event_from_row(self, row_dict):
|
||||
d = copy.deepcopy({k: v for k, v in row_dict.items()})
|
||||
|
||||
d.pop("stream_ordering", None)
|
||||
d.pop("topological_ordering", None)
|
||||
d.pop("processed", None)
|
||||
d["origin_server_ts"] = d.pop("ts", 0)
|
||||
replaces_state = d.pop("prev_state", None)
|
||||
|
||||
if replaces_state:
|
||||
d["replaces_state"] = replaces_state
|
||||
|
||||
d.update(json.loads(row_dict["unrecognized_keys"]))
|
||||
d["content"] = json.loads(d["content"])
|
||||
del d["unrecognized_keys"]
|
||||
|
||||
if "age_ts" not in d:
|
||||
# For compatibility
|
||||
d["age_ts"] = d.get("origin_server_ts", 0)
|
||||
|
||||
return self.event_factory.create_event(
|
||||
etype=d["type"],
|
||||
**d
|
||||
def _get_events(self, event_ids):
|
||||
return self.runInteraction(
|
||||
"_get_events", self._get_events_txn, event_ids
|
||||
)
|
||||
|
||||
def _get_events_txn(self, txn, event_ids):
|
||||
# FIXME (erikj): This should be batched?
|
||||
|
||||
sql = "SELECT * FROM events WHERE event_id = ? ORDER BY rowid asc"
|
||||
|
||||
event_rows = []
|
||||
events = []
|
||||
for e_id in event_ids:
|
||||
c = txn.execute(sql, (e_id,))
|
||||
event_rows.extend(self.cursor_to_dict(c))
|
||||
ev = self._get_event_txn(txn, e_id)
|
||||
|
||||
return self._parse_events_txn(txn, event_rows)
|
||||
if ev:
|
||||
events.append(ev)
|
||||
|
||||
return events
|
||||
|
||||
def _get_event_txn(self, txn, event_id, check_redacted=True,
|
||||
get_prev_content=True):
|
||||
sql = (
|
||||
"SELECT internal_metadata, json, r.event_id FROM event_json as e "
|
||||
"LEFT JOIN redactions as r ON e.event_id = r.redacts "
|
||||
"WHERE e.event_id = ? "
|
||||
"LIMIT 1 "
|
||||
)
|
||||
|
||||
txn.execute(sql, (event_id,))
|
||||
|
||||
res = txn.fetchone()
|
||||
|
||||
if not res:
|
||||
return None
|
||||
|
||||
internal_metadata, js, redacted = res
|
||||
|
||||
d = json.loads(js)
|
||||
internal_metadata = json.loads(internal_metadata)
|
||||
|
||||
ev = FrozenEvent(d, internal_metadata_dict=internal_metadata)
|
||||
|
||||
if check_redacted and redacted:
|
||||
ev = prune_event(ev)
|
||||
|
||||
ev.unsigned["redacted_by"] = redacted
|
||||
# Get the redaction event.
|
||||
|
||||
because = self._get_event_txn(
|
||||
txn,
|
||||
redacted,
|
||||
check_redacted=False
|
||||
)
|
||||
|
||||
if because:
|
||||
ev.unsigned["redacted_because"] = because
|
||||
|
||||
if get_prev_content and "replaces_state" in ev.unsigned:
|
||||
prev = self._get_event_txn(
|
||||
txn,
|
||||
ev.unsigned["replaces_state"],
|
||||
get_prev_content=False,
|
||||
)
|
||||
if prev:
|
||||
ev.unsigned["prev_content"] = prev.get_dict()["content"]
|
||||
|
||||
return ev
|
||||
|
||||
def _parse_events(self, rows):
|
||||
return self.runInteraction(
|
||||
@ -479,80 +504,9 @@ class SQLBaseStore(object):
|
||||
)
|
||||
|
||||
def _parse_events_txn(self, txn, rows):
|
||||
events = [self._parse_event_from_row(r) for r in rows]
|
||||
event_ids = [r["event_id"] for r in rows]
|
||||
|
||||
select_event_sql = (
|
||||
"SELECT * FROM events WHERE event_id = ? ORDER BY rowid asc"
|
||||
)
|
||||
|
||||
for i, ev in enumerate(events):
|
||||
signatures = self._get_event_signatures_txn(
|
||||
txn, ev.event_id,
|
||||
)
|
||||
|
||||
ev.signatures = {
|
||||
n: {
|
||||
k: encode_base64(v) for k, v in s.items()
|
||||
}
|
||||
for n, s in signatures.items()
|
||||
}
|
||||
|
||||
hashes = self._get_event_content_hashes_txn(
|
||||
txn, ev.event_id,
|
||||
)
|
||||
|
||||
ev.hashes = {
|
||||
k: encode_base64(v) for k, v in hashes.items()
|
||||
}
|
||||
|
||||
prevs = self._get_prev_events_and_state(txn, ev.event_id)
|
||||
|
||||
ev.prev_events = [
|
||||
(e_id, h)
|
||||
for e_id, h, is_state in prevs
|
||||
if is_state == 0
|
||||
]
|
||||
|
||||
ev.auth_events = self._get_auth_events(txn, ev.event_id)
|
||||
|
||||
if hasattr(ev, "state_key"):
|
||||
ev.prev_state = [
|
||||
(e_id, h)
|
||||
for e_id, h, is_state in prevs
|
||||
if is_state == 1
|
||||
]
|
||||
|
||||
if hasattr(ev, "replaces_state"):
|
||||
# Load previous state_content.
|
||||
# FIXME (erikj): Handle multiple prev_states.
|
||||
cursor = txn.execute(
|
||||
select_event_sql,
|
||||
(ev.replaces_state,)
|
||||
)
|
||||
prevs = self.cursor_to_dict(cursor)
|
||||
if prevs:
|
||||
prev = self._parse_event_from_row(prevs[0])
|
||||
ev.prev_content = prev.content
|
||||
|
||||
if not hasattr(ev, "redacted"):
|
||||
logger.debug("Doesn't have redacted key: %s", ev)
|
||||
ev.redacted = self._has_been_redacted_txn(txn, ev)
|
||||
|
||||
if ev.redacted:
|
||||
# Get the redaction event.
|
||||
select_event_sql = "SELECT * FROM events WHERE event_id = ?"
|
||||
txn.execute(select_event_sql, (ev.redacted,))
|
||||
|
||||
del_evs = self._parse_events_txn(
|
||||
txn, self.cursor_to_dict(txn)
|
||||
)
|
||||
|
||||
if del_evs:
|
||||
ev = prune_event(ev)
|
||||
events[i] = ev
|
||||
ev.redacted_because = del_evs[0]
|
||||
|
||||
return events
|
||||
return self._get_events_txn(txn, event_ids)
|
||||
|
||||
def _has_been_redacted_txn(self, txn, event):
|
||||
sql = "SELECT event_id FROM redactions WHERE redacts = ?"
|
||||
@ -650,7 +604,7 @@ class JoinHelper(object):
|
||||
to dump the results into.
|
||||
|
||||
Attributes:
|
||||
taples (list): List of `Table` classes
|
||||
tables (list): List of `Table` classes
|
||||
EntryType (type)
|
||||
"""
|
||||
|
||||
|
@ -32,39 +32,33 @@ class EventFederationStore(SQLBaseStore):
|
||||
and backfilling from another server respectively.
|
||||
"""
|
||||
|
||||
def get_auth_chain(self, event_id):
|
||||
def get_auth_chain(self, event_ids):
|
||||
return self.runInteraction(
|
||||
"get_auth_chain",
|
||||
self._get_auth_chain_txn,
|
||||
event_id
|
||||
event_ids
|
||||
)
|
||||
|
||||
def _get_auth_chain_txn(self, txn, event_id):
|
||||
results = self._get_auth_chain_ids_txn(txn, event_id)
|
||||
def _get_auth_chain_txn(self, txn, event_ids):
|
||||
results = self._get_auth_chain_ids_txn(txn, event_ids)
|
||||
|
||||
sql = "SELECT * FROM events WHERE event_id = ?"
|
||||
rows = []
|
||||
for ev_id in results:
|
||||
c = txn.execute(sql, (ev_id,))
|
||||
rows.extend(self.cursor_to_dict(c))
|
||||
return self._get_events_txn(txn, results)
|
||||
|
||||
return self._parse_events_txn(txn, rows)
|
||||
|
||||
def get_auth_chain_ids(self, event_id):
|
||||
def get_auth_chain_ids(self, event_ids):
|
||||
return self.runInteraction(
|
||||
"get_auth_chain_ids",
|
||||
self._get_auth_chain_ids_txn,
|
||||
event_id
|
||||
event_ids
|
||||
)
|
||||
|
||||
def _get_auth_chain_ids_txn(self, txn, event_id):
|
||||
def _get_auth_chain_ids_txn(self, txn, event_ids):
|
||||
results = set()
|
||||
|
||||
base_sql = (
|
||||
"SELECT auth_id FROM event_auth WHERE %s"
|
||||
)
|
||||
|
||||
front = set([event_id])
|
||||
front = set(event_ids)
|
||||
while front:
|
||||
sql = base_sql % (
|
||||
" OR ".join(["event_id=?"] * len(front)),
|
||||
@ -177,14 +171,15 @@ class EventFederationStore(SQLBaseStore):
|
||||
retcols=["prev_event_id", "is_state"],
|
||||
)
|
||||
|
||||
hashes = self._get_prev_event_hashes_txn(txn, event_id)
|
||||
|
||||
results = []
|
||||
for d in res:
|
||||
hashes = self._get_event_reference_hashes_txn(
|
||||
txn,
|
||||
d["prev_event_id"]
|
||||
)
|
||||
edge_hash = self._get_event_reference_hashes_txn(txn, d["prev_event_id"])
|
||||
edge_hash.update(hashes.get(d["prev_event_id"], {}))
|
||||
prev_hashes = {
|
||||
k: encode_base64(v) for k, v in hashes.items()
|
||||
k: encode_base64(v)
|
||||
for k, v in edge_hash.items()
|
||||
if k == "sha256"
|
||||
}
|
||||
results.append((d["prev_event_id"], prev_hashes, d["is_state"]))
|
||||
|
129
synapse/storage/media_repository.py
Normal file
129
synapse/storage/media_repository.py
Normal file
@ -0,0 +1,129 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from _base import SQLBaseStore
|
||||
|
||||
|
||||
class MediaRepositoryStore(SQLBaseStore):
|
||||
"""Persistence for attachments and avatars"""
|
||||
|
||||
def get_default_thumbnails(self, top_level_type, sub_type):
|
||||
return []
|
||||
|
||||
def get_local_media(self, media_id):
|
||||
"""Get the metadata for a local piece of media
|
||||
Returns:
|
||||
None if the meia_id doesn't exist.
|
||||
"""
|
||||
return self._simple_select_one(
|
||||
"local_media_repository",
|
||||
{"media_id": media_id},
|
||||
("media_type", "media_length", "upload_name", "created_ts"),
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
def store_local_media(self, media_id, media_type, time_now_ms, upload_name,
|
||||
media_length, user_id):
|
||||
return self._simple_insert(
|
||||
"local_media_repository",
|
||||
{
|
||||
"media_id": media_id,
|
||||
"media_type": media_type,
|
||||
"created_ts": time_now_ms,
|
||||
"upload_name": upload_name,
|
||||
"media_length": media_length,
|
||||
"user_id": user_id.to_string(),
|
||||
}
|
||||
)
|
||||
|
||||
def get_local_media_thumbnails(self, media_id):
|
||||
return self._simple_select_list(
|
||||
"local_media_repository_thumbnails",
|
||||
{"media_id": media_id},
|
||||
(
|
||||
"thumbnail_width", "thumbnail_height", "thumbnail_method",
|
||||
"thumbnail_type", "thumbnail_length",
|
||||
)
|
||||
)
|
||||
|
||||
def store_local_thumbnail(self, media_id, thumbnail_width,
|
||||
thumbnail_height, thumbnail_type,
|
||||
thumbnail_method, thumbnail_length):
|
||||
return self._simple_insert(
|
||||
"local_media_repository_thumbnails",
|
||||
{
|
||||
"media_id": media_id,
|
||||
"thumbnail_width": thumbnail_width,
|
||||
"thumbnail_height": thumbnail_height,
|
||||
"thumbnail_method": thumbnail_method,
|
||||
"thumbnail_type": thumbnail_type,
|
||||
"thumbnail_length": thumbnail_length,
|
||||
}
|
||||
)
|
||||
|
||||
def get_cached_remote_media(self, origin, media_id):
|
||||
return self._simple_select_one(
|
||||
"remote_media_cache",
|
||||
{"media_origin": origin, "media_id": media_id},
|
||||
(
|
||||
"media_type", "media_length", "upload_name", "created_ts",
|
||||
"filesystem_id",
|
||||
),
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
def store_cached_remote_media(self, origin, media_id, media_type,
|
||||
media_length, time_now_ms, upload_name,
|
||||
filesystem_id):
|
||||
return self._simple_insert(
|
||||
"remote_media_cache",
|
||||
{
|
||||
"media_origin": origin,
|
||||
"media_id": media_id,
|
||||
"media_type": media_type,
|
||||
"media_length": media_length,
|
||||
"created_ts": time_now_ms,
|
||||
"upload_name": upload_name,
|
||||
"filesystem_id": filesystem_id,
|
||||
}
|
||||
)
|
||||
|
||||
def get_remote_media_thumbnails(self, origin, media_id):
|
||||
return self._simple_select_list(
|
||||
"remote_media_cache_thumbnails",
|
||||
{"media_origin": origin, "media_id": media_id},
|
||||
(
|
||||
"thumbnail_width", "thumbnail_height", "thumbnail_method",
|
||||
"thumbnail_type", "thumbnail_length", "filesystem_id",
|
||||
)
|
||||
)
|
||||
|
||||
def store_remote_media_thumbnail(self, origin, media_id, filesystem_id,
|
||||
thumbnail_width, thumbnail_height,
|
||||
thumbnail_type, thumbnail_method,
|
||||
thumbnail_length):
|
||||
return self._simple_insert(
|
||||
"remote_media_cache_thumbnails",
|
||||
{
|
||||
"media_origin": origin,
|
||||
"media_id": media_id,
|
||||
"thumbnail_width": thumbnail_width,
|
||||
"thumbnail_height": thumbnail_height,
|
||||
"thumbnail_method": thumbnail_method,
|
||||
"thumbnail_type": thumbnail_type,
|
||||
"thumbnail_length": thumbnail_length,
|
||||
"filesystem_id": filesystem_id,
|
||||
}
|
||||
)
|
79
synapse/storage/schema/delta/v9.sql
Normal file
79
synapse/storage/schema/delta/v9.sql
Normal file
@ -0,0 +1,79 @@
|
||||
/* Copyright 2014 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- To track destination health
|
||||
CREATE TABLE IF NOT EXISTS destinations(
|
||||
destination TEXT PRIMARY KEY,
|
||||
retry_last_ts INTEGER,
|
||||
retry_interval INTEGER
|
||||
);
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS local_media_repository (
|
||||
media_id TEXT, -- The id used to refer to the media.
|
||||
media_type TEXT, -- The MIME-type of the media.
|
||||
media_length INTEGER, -- Length of the media in bytes.
|
||||
created_ts INTEGER, -- When the content was uploaded in ms.
|
||||
upload_name TEXT, -- The name the media was uploaded with.
|
||||
user_id TEXT, -- The user who uploaded the file.
|
||||
CONSTRAINT uniqueness UNIQUE (media_id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS local_media_repository_thumbnails (
|
||||
media_id TEXT, -- The id used to refer to the media.
|
||||
thumbnail_width INTEGER, -- The width of the thumbnail in pixels.
|
||||
thumbnail_height INTEGER, -- The height of the thumbnail in pixels.
|
||||
thumbnail_type TEXT, -- The MIME-type of the thumbnail.
|
||||
thumbnail_method TEXT, -- The method used to make the thumbnail.
|
||||
thumbnail_length INTEGER, -- The length of the thumbnail in bytes.
|
||||
CONSTRAINT uniqueness UNIQUE (
|
||||
media_id, thumbnail_width, thumbnail_height, thumbnail_type
|
||||
)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS local_media_repository_thumbnails_media_id
|
||||
ON local_media_repository_thumbnails (media_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS remote_media_cache (
|
||||
media_origin TEXT, -- The remote HS the media came from.
|
||||
media_id TEXT, -- The id used to refer to the media on that server.
|
||||
media_type TEXT, -- The MIME-type of the media.
|
||||
created_ts INTEGER, -- When the content was uploaded in ms.
|
||||
upload_name TEXT, -- The name the media was uploaded with.
|
||||
media_length INTEGER, -- Length of the media in bytes.
|
||||
filesystem_id TEXT, -- The name used to store the media on disk.
|
||||
CONSTRAINT uniqueness UNIQUE (media_origin, media_id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails (
|
||||
media_origin TEXT, -- The remote HS the media came from.
|
||||
media_id TEXT, -- The id used to refer to the media.
|
||||
thumbnail_width INTEGER, -- The width of the thumbnail in pixels.
|
||||
thumbnail_height INTEGER, -- The height of the thumbnail in pixels.
|
||||
thumbnail_method TEXT, -- The method used to make the thumbnail
|
||||
thumbnail_type TEXT, -- The MIME-type of the thumbnail.
|
||||
thumbnail_length INTEGER, -- The length of the thumbnail in bytes.
|
||||
filesystem_id TEXT, -- The name used to store the media on disk.
|
||||
CONSTRAINT uniqueness UNIQUE (
|
||||
media_origin, media_id, thumbnail_width, thumbnail_height,
|
||||
thumbnail_type, thumbnail_type
|
||||
)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS remote_media_cache_thumbnails_media_id
|
||||
ON local_media_repository_thumbnails (media_id);
|
||||
|
||||
|
||||
PRAGMA user_version = 9;
|
@ -32,6 +32,19 @@ CREATE INDEX IF NOT EXISTS events_stream_ordering ON events (stream_ordering);
|
||||
CREATE INDEX IF NOT EXISTS events_topological_ordering ON events (topological_ordering);
|
||||
CREATE INDEX IF NOT EXISTS events_room_id ON events (room_id);
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS event_json(
|
||||
event_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
internal_metadata NOT NULL,
|
||||
json BLOB NOT NULL,
|
||||
CONSTRAINT ev_j_uniq UNIQUE (event_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS event_json_id ON event_json(event_id);
|
||||
CREATE INDEX IF NOT EXISTS event_json_room_id ON event_json(room_id);
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS state_events(
|
||||
event_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
|
68
synapse/storage/schema/media_repository.sql
Normal file
68
synapse/storage/schema/media_repository.sql
Normal file
@ -0,0 +1,68 @@
|
||||
/* Copyright 2014 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE TABLE IF NOT EXISTS local_media_repository (
|
||||
media_id TEXT, -- The id used to refer to the media.
|
||||
media_type TEXT, -- The MIME-type of the media.
|
||||
media_length INTEGER, -- Length of the media in bytes.
|
||||
created_ts INTEGER, -- When the content was uploaded in ms.
|
||||
upload_name TEXT, -- The name the media was uploaded with.
|
||||
user_id TEXT, -- The user who uploaded the file.
|
||||
CONSTRAINT uniqueness UNIQUE (media_id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS local_media_repository_thumbnails (
|
||||
media_id TEXT, -- The id used to refer to the media.
|
||||
thumbnail_width INTEGER, -- The width of the thumbnail in pixels.
|
||||
thumbnail_height INTEGER, -- The height of the thumbnail in pixels.
|
||||
thumbnail_type TEXT, -- The MIME-type of the thumbnail.
|
||||
thumbnail_method TEXT, -- The method used to make the thumbnail.
|
||||
thumbnail_length INTEGER, -- The length of the thumbnail in bytes.
|
||||
CONSTRAINT uniqueness UNIQUE (
|
||||
media_id, thumbnail_width, thumbnail_height, thumbnail_type
|
||||
)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS local_media_repository_thumbnails_media_id
|
||||
ON local_media_repository_thumbnails (media_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS remote_media_cache (
|
||||
media_origin TEXT, -- The remote HS the media came from.
|
||||
media_id TEXT, -- The id used to refer to the media on that server.
|
||||
media_type TEXT, -- The MIME-type of the media.
|
||||
created_ts INTEGER, -- When the content was uploaded in ms.
|
||||
upload_name TEXT, -- The name the media was uploaded with.
|
||||
media_length INTEGER, -- Length of the media in bytes.
|
||||
filesystem_id TEXT, -- The name used to store the media on disk.
|
||||
CONSTRAINT uniqueness UNIQUE (media_origin, media_id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails (
|
||||
media_origin TEXT, -- The remote HS the media came from.
|
||||
media_id TEXT, -- The id used to refer to the media.
|
||||
thumbnail_width INTEGER, -- The width of the thumbnail in pixels.
|
||||
thumbnail_height INTEGER, -- The height of the thumbnail in pixels.
|
||||
thumbnail_method TEXT, -- The method used to make the thumbnail
|
||||
thumbnail_type TEXT, -- The MIME-type of the thumbnail.
|
||||
thumbnail_length INTEGER, -- The length of the thumbnail in bytes.
|
||||
filesystem_id TEXT, -- The name used to store the media on disk.
|
||||
CONSTRAINT uniqueness UNIQUE (
|
||||
media_origin, media_id, thumbnail_width, thumbnail_height,
|
||||
thumbnail_type, thumbnail_type
|
||||
)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS remote_media_cache_thumbnails_media_id
|
||||
ON local_media_repository_thumbnails (media_id);
|
@ -29,7 +29,8 @@ CREATE TABLE IF NOT EXISTS state_groups_state(
|
||||
|
||||
CREATE TABLE IF NOT EXISTS event_to_state_groups(
|
||||
event_id TEXT NOT NULL,
|
||||
state_group INTEGER NOT NULL
|
||||
state_group INTEGER NOT NULL,
|
||||
CONSTRAINT event_to_state_groups_uniq UNIQUE (event_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS state_groups_id ON state_groups(id);
|
||||
|
@ -59,3 +59,9 @@ CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_tx ON transaction_id_to_pdu(tra
|
||||
CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_dest ON transaction_id_to_pdu(destination);
|
||||
CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_index ON transaction_id_to_pdu(transaction_id, destination);
|
||||
|
||||
-- To track destination health
|
||||
CREATE TABLE IF NOT EXISTS destinations(
|
||||
destination TEXT PRIMARY KEY,
|
||||
retry_last_ts INTEGER,
|
||||
retry_interval INTEGER
|
||||
);
|
||||
|
@ -13,8 +13,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from _base import SQLBaseStore
|
||||
|
||||
from syutil.base64util import encode_base64
|
||||
|
||||
|
||||
class SignatureStore(SQLBaseStore):
|
||||
"""Persistence for event signatures and hashes"""
|
||||
@ -67,6 +71,21 @@ class SignatureStore(SQLBaseStore):
|
||||
f
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_event_hashes(self, event_ids):
|
||||
hashes = yield self.get_event_reference_hashes(
|
||||
event_ids
|
||||
)
|
||||
hashes = [
|
||||
{
|
||||
k: encode_base64(v) for k, v in h.items()
|
||||
if k == "sha256"
|
||||
}
|
||||
for h in hashes
|
||||
]
|
||||
|
||||
defer.returnValue(zip(event_ids, hashes))
|
||||
|
||||
def _get_event_reference_hashes_txn(self, txn, event_id):
|
||||
"""Get all the hashes for a given PDU.
|
||||
Args:
|
||||
|
@ -86,11 +86,16 @@ class StateStore(SQLBaseStore):
|
||||
self._store_state_groups_txn, event
|
||||
)
|
||||
|
||||
def _store_state_groups_txn(self, txn, event):
|
||||
if event.state_events is None:
|
||||
def _store_state_groups_txn(self, txn, event, context):
|
||||
if context.current_state is None:
|
||||
return
|
||||
|
||||
state_group = event.state_group
|
||||
state_events = context.current_state
|
||||
|
||||
if event.is_state():
|
||||
state_events[(event.type, event.state_key)] = event
|
||||
|
||||
state_group = context.state_group
|
||||
if not state_group:
|
||||
state_group = self._simple_insert_txn(
|
||||
txn,
|
||||
@ -102,7 +107,7 @@ class StateStore(SQLBaseStore):
|
||||
or_ignore=True,
|
||||
)
|
||||
|
||||
for state in event.state_events.values():
|
||||
for state in state_events.values():
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="state_groups_state",
|
||||
|
@ -17,6 +17,8 @@ from ._base import SQLBaseStore, Table
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -26,6 +28,10 @@ class TransactionStore(SQLBaseStore):
|
||||
"""A collection of queries for handling PDUs.
|
||||
"""
|
||||
|
||||
# a write-through cache of DestinationsTable.EntryType indexed by
|
||||
# destination string
|
||||
destination_retry_cache = {}
|
||||
|
||||
def get_received_txn_response(self, transaction_id, origin):
|
||||
"""For an incoming transaction from a given origin, check if we have
|
||||
already responded to it. If so, return the response code and response
|
||||
@ -114,7 +120,7 @@ class TransactionStore(SQLBaseStore):
|
||||
def _prep_send_transaction(self, txn, transaction_id, destination,
|
||||
origin_server_ts):
|
||||
|
||||
# First we find out what the prev_txs should be.
|
||||
# First we find out what the prev_txns should be.
|
||||
# Since we know that we are only sending one transaction at a time,
|
||||
# we can simply take the last one.
|
||||
query = "%s ORDER BY id DESC LIMIT 1" % (
|
||||
@ -205,6 +211,92 @@ class TransactionStore(SQLBaseStore):
|
||||
|
||||
return ReceivedTransactionsTable.decode_results(txn.fetchall())
|
||||
|
||||
def get_destination_retry_timings(self, destination):
|
||||
"""Gets the current retry timings (if any) for a given destination.
|
||||
|
||||
Args:
|
||||
destination (str)
|
||||
|
||||
Returns:
|
||||
None if not retrying
|
||||
Otherwise a DestinationsTable.EntryType for the retry scheme
|
||||
"""
|
||||
if destination in self.destination_retry_cache:
|
||||
return defer.succeed(self.destination_retry_cache[destination])
|
||||
|
||||
return self.runInteraction(
|
||||
"get_destination_retry_timings",
|
||||
self._get_destination_retry_timings, destination)
|
||||
|
||||
def _get_destination_retry_timings(cls, txn, destination):
|
||||
query = DestinationsTable.select_statement("destination = ?")
|
||||
txn.execute(query, (destination,))
|
||||
result = txn.fetchall()
|
||||
if result:
|
||||
result = DestinationsTable.decode_single_result(result)
|
||||
if result.retry_last_ts > 0:
|
||||
return result
|
||||
else:
|
||||
return None
|
||||
|
||||
def set_destination_retry_timings(self, destination,
|
||||
retry_last_ts, retry_interval):
|
||||
"""Sets the current retry timings for a given destination.
|
||||
Both timings should be zero if retrying is no longer occuring.
|
||||
|
||||
Args:
|
||||
destination (str)
|
||||
retry_last_ts (int) - time of last retry attempt in unix epoch ms
|
||||
retry_interval (int) - how long until next retry in ms
|
||||
"""
|
||||
|
||||
self.destination_retry_cache[destination] = (
|
||||
DestinationsTable.EntryType(
|
||||
destination,
|
||||
retry_last_ts,
|
||||
retry_interval
|
||||
)
|
||||
)
|
||||
|
||||
# XXX: we could chose to not bother persisting this if our cache thinks
|
||||
# this is a NOOP
|
||||
return self.runInteraction(
|
||||
"set_destination_retry_timings",
|
||||
self._set_destination_retry_timings,
|
||||
destination,
|
||||
retry_last_ts,
|
||||
retry_interval,
|
||||
)
|
||||
|
||||
def _set_destination_retry_timings(cls, txn, destination,
|
||||
retry_last_ts, retry_interval):
|
||||
|
||||
query = (
|
||||
"INSERT OR REPLACE INTO %s "
|
||||
"(destination, retry_last_ts, retry_interval) "
|
||||
"VALUES (?, ?, ?) "
|
||||
) % DestinationsTable.table_name
|
||||
|
||||
txn.execute(query, (destination, retry_last_ts, retry_interval))
|
||||
|
||||
def get_destinations_needing_retry(self):
|
||||
"""Get all destinations which are due a retry for sending a transaction.
|
||||
|
||||
Returns:
|
||||
list: A list of `DestinationsTable.EntryType`
|
||||
"""
|
||||
|
||||
return self.runInteraction(
|
||||
"get_destinations_needing_retry",
|
||||
self._get_destinations_needing_retry
|
||||
)
|
||||
|
||||
def _get_destinations_needing_retry(cls, txn):
|
||||
where = "retry_last_ts > 0 and retry_next_ts < now()"
|
||||
query = DestinationsTable.select_statement(where)
|
||||
txn.execute(query)
|
||||
return DestinationsTable.decode_results(txn.fetchall())
|
||||
|
||||
|
||||
class ReceivedTransactionsTable(Table):
|
||||
table_name = "received_transactions"
|
||||
@ -247,3 +339,15 @@ class TransactionsToPduTable(Table):
|
||||
]
|
||||
|
||||
EntryType = namedtuple("TransactionsToPduEntry", fields)
|
||||
|
||||
|
||||
class DestinationsTable(Table):
|
||||
table_name = "destinations"
|
||||
|
||||
fields = [
|
||||
"destination",
|
||||
"retry_last_ts",
|
||||
"retry_interval",
|
||||
]
|
||||
|
||||
EntryType = namedtuple("DestinationsEntry", fields)
|
||||
|
@ -19,7 +19,7 @@ from collections import namedtuple
|
||||
|
||||
|
||||
class DomainSpecificString(
|
||||
namedtuple("DomainSpecificString", ("localpart", "domain", "is_mine"))
|
||||
namedtuple("DomainSpecificString", ("localpart", "domain"))
|
||||
):
|
||||
"""Common base class among ID/name strings that have a local part and a
|
||||
domain name, prefixed with a sigil.
|
||||
@ -28,15 +28,13 @@ class DomainSpecificString(
|
||||
|
||||
'localpart' : The local part of the name (without the leading sigil)
|
||||
'domain' : The domain part of the name
|
||||
'is_mine' : Boolean indicating if the domain name is recognised by the
|
||||
HomeServer as being its own
|
||||
"""
|
||||
|
||||
# Deny iteration because it will bite you if you try to create a singleton
|
||||
# set by:
|
||||
# users = set(user)
|
||||
def __iter__(self):
|
||||
raise ValueError("Attempted to iterate a %s" % (type(self).__name__))
|
||||
raise ValueError("Attempted to iterate a %s" % (type(self).__name__,))
|
||||
|
||||
# Because this class is a namedtuple of strings and booleans, it is deeply
|
||||
# immutable.
|
||||
@ -47,7 +45,7 @@ class DomainSpecificString(
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, s, hs):
|
||||
def from_string(cls, s):
|
||||
"""Parse the string given by 's' into a structure object."""
|
||||
if s[0] != cls.SIGIL:
|
||||
raise SynapseError(400, "Expected %s string to start with '%s'" % (
|
||||
@ -66,22 +64,15 @@ class DomainSpecificString(
|
||||
|
||||
# This code will need changing if we want to support multiple domain
|
||||
# names on one HS
|
||||
is_mine = domain == hs.hostname
|
||||
return cls(localpart=parts[0], domain=domain, is_mine=is_mine)
|
||||
return cls(localpart=parts[0], domain=domain)
|
||||
|
||||
def to_string(self):
|
||||
"""Return a string encoding the fields of the structure object."""
|
||||
return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain)
|
||||
|
||||
@classmethod
|
||||
def create_local(cls, localpart, hs):
|
||||
"""Create a structure on the local domain"""
|
||||
return cls(localpart=localpart, domain=hs.hostname, is_mine=True)
|
||||
|
||||
@classmethod
|
||||
def create(cls, localpart, domain, hs):
|
||||
is_mine = domain == hs.hostname
|
||||
return cls(localpart=localpart, domain=domain, is_mine=is_mine)
|
||||
def create(cls, localpart, domain,):
|
||||
return cls(localpart=localpart, domain=domain)
|
||||
|
||||
|
||||
class UserID(DomainSpecificString):
|
||||
|
@ -115,10 +115,10 @@ class Signal(object):
|
||||
failure.value,
|
||||
failure.getTracebackObject()))
|
||||
if not self.suppress_failures:
|
||||
raise failure
|
||||
failure.raiseException()
|
||||
deferreds.append(d.addErrback(eb))
|
||||
|
||||
result = yield defer.DeferredList(
|
||||
deferreds, fireOnOneErrback=not self.suppress_failures
|
||||
)
|
||||
defer.returnValue(result)
|
||||
results = []
|
||||
for deferred in deferreds:
|
||||
result = yield deferred
|
||||
results.append(result)
|
||||
defer.returnValue(results)
|
||||
|
46
synapse/util/frozenutils.py
Normal file
46
synapse/util/frozenutils.py
Normal file
@ -0,0 +1,46 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from frozendict import frozendict
|
||||
|
||||
|
||||
def freeze(o):
|
||||
if isinstance(o, dict) or isinstance(o, frozendict):
|
||||
return frozendict({k: freeze(v) for k, v in o.items()})
|
||||
|
||||
if isinstance(o, basestring):
|
||||
return o
|
||||
|
||||
try:
|
||||
return tuple([freeze(i) for i in o])
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
return o
|
||||
|
||||
|
||||
def unfreeze(o):
|
||||
if isinstance(o, frozendict) or isinstance(o, dict):
|
||||
return dict({k: unfreeze(v) for k, v in o.items()})
|
||||
|
||||
if isinstance(o, basestring):
|
||||
return o
|
||||
|
||||
try:
|
||||
return [unfreeze(i) for i in o]
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
return o
|
@ -1,217 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.api.events import SynapseEvent
|
||||
from synapse.api.events.validator import EventValidator
|
||||
from synapse.api.errors import SynapseError
|
||||
|
||||
from tests import unittest
|
||||
|
||||
|
||||
class SynapseTemplateCheckTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.validator = EventValidator(None)
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
def test_top_level_keys(self):
|
||||
template = {
|
||||
"person": {},
|
||||
"friends": ["string"]
|
||||
}
|
||||
|
||||
content = {
|
||||
"person": {"name": "bob"},
|
||||
"friends": ["jill", "mike"]
|
||||
}
|
||||
|
||||
event = MockSynapseEvent(template)
|
||||
event.content = content
|
||||
self.assertTrue(self.validator.validate(event))
|
||||
|
||||
content = {
|
||||
"person": {"name": "bob"},
|
||||
"friends": ["jill"],
|
||||
"enemies": ["mike"]
|
||||
}
|
||||
event.content = content
|
||||
self.assertTrue(self.validator.validate(event))
|
||||
|
||||
content = {
|
||||
"person": {"name": "bob"},
|
||||
# missing friends
|
||||
"enemies": ["mike", "jill"]
|
||||
}
|
||||
event.content = content
|
||||
self.assertRaises(
|
||||
SynapseError,
|
||||
self.validator.validate,
|
||||
event
|
||||
)
|
||||
|
||||
def test_lists(self):
|
||||
template = {
|
||||
"person": {},
|
||||
"friends": [{"name":"string"}]
|
||||
}
|
||||
|
||||
content = {
|
||||
"person": {"name": "bob"},
|
||||
"friends": ["jill", "mike"] # should be in objects
|
||||
}
|
||||
|
||||
event = MockSynapseEvent(template)
|
||||
event.content = content
|
||||
self.assertRaises(
|
||||
SynapseError,
|
||||
self.validator.validate,
|
||||
event
|
||||
)
|
||||
|
||||
content = {
|
||||
"person": {"name": "bob"},
|
||||
"friends": [{"name": "jill"}, {"name": "mike"}]
|
||||
}
|
||||
event.content = content
|
||||
self.assertTrue(self.validator.validate(event))
|
||||
|
||||
def test_nested_lists(self):
|
||||
template = {
|
||||
"results": {
|
||||
"families": [
|
||||
{
|
||||
"name": "string",
|
||||
"members": [
|
||||
{}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
content = {
|
||||
"results": {
|
||||
"families": [
|
||||
{
|
||||
"name": "Smith",
|
||||
"members": [
|
||||
"Alice", "Bob" # wrong types
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
event = MockSynapseEvent(template)
|
||||
event.content = content
|
||||
self.assertRaises(
|
||||
SynapseError,
|
||||
self.validator.validate,
|
||||
event
|
||||
)
|
||||
|
||||
content = {
|
||||
"results": {
|
||||
"families": [
|
||||
{
|
||||
"name": "Smith",
|
||||
"members": [
|
||||
{"name": "Alice"}, {"name": "Bob"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
event.content = content
|
||||
self.assertTrue(self.validator.validate(event))
|
||||
|
||||
def test_nested_keys(self):
|
||||
template = {
|
||||
"person": {
|
||||
"attributes": {
|
||||
"hair": "string",
|
||||
"eye": "string"
|
||||
},
|
||||
"age": 0,
|
||||
"fav_books": ["string"]
|
||||
}
|
||||
}
|
||||
event = MockSynapseEvent(template)
|
||||
|
||||
content = {
|
||||
"person": {
|
||||
"attributes": {
|
||||
"hair": "brown",
|
||||
"eye": "green",
|
||||
"skin": "purple"
|
||||
},
|
||||
"age": 33,
|
||||
"fav_books": ["lotr", "hobbit"],
|
||||
"fav_music": ["abba", "beatles"]
|
||||
}
|
||||
}
|
||||
|
||||
event.content = content
|
||||
self.assertTrue(self.validator.validate(event))
|
||||
|
||||
content = {
|
||||
"person": {
|
||||
"attributes": {
|
||||
"hair": "brown"
|
||||
# missing eye
|
||||
},
|
||||
"age": 33,
|
||||
"fav_books": ["lotr", "hobbit"],
|
||||
"fav_music": ["abba", "beatles"]
|
||||
}
|
||||
}
|
||||
|
||||
event.content = content
|
||||
self.assertRaises(
|
||||
SynapseError,
|
||||
self.validator.validate,
|
||||
event
|
||||
)
|
||||
|
||||
content = {
|
||||
"person": {
|
||||
"attributes": {
|
||||
"hair": "brown",
|
||||
"eye": "green",
|
||||
"skin": "purple"
|
||||
},
|
||||
"age": 33,
|
||||
"fav_books": "nothing", # should be a list
|
||||
}
|
||||
}
|
||||
|
||||
event.content = content
|
||||
self.assertRaises(
|
||||
SynapseError,
|
||||
self.validator.validate,
|
||||
event
|
||||
)
|
||||
|
||||
|
||||
class MockSynapseEvent(SynapseEvent):
|
||||
|
||||
def __init__(self, template):
|
||||
self.template = template
|
||||
|
||||
def get_content_template(self):
|
||||
return self.template
|
||||
|
@ -23,24 +23,20 @@ from ..utils import MockHttpResource, MockClock, MockKey
|
||||
|
||||
from synapse.server import HomeServer
|
||||
from synapse.federation import initialize_http_replication
|
||||
from synapse.api.events import SynapseEvent
|
||||
from synapse.events import FrozenEvent
|
||||
|
||||
from synapse.storage.transactions import DestinationsTable
|
||||
|
||||
|
||||
def make_pdu(prev_pdus=[], **kwargs):
|
||||
"""Provide some default fields for making a PduTuple."""
|
||||
pdu_fields = {
|
||||
"is_state": False,
|
||||
"unrecognized_keys": [],
|
||||
"outlier": False,
|
||||
"have_processed": True,
|
||||
"state_key": None,
|
||||
"power_level": None,
|
||||
"prev_state_id": None,
|
||||
"prev_state_origin": None,
|
||||
"prev_events": prev_pdus,
|
||||
}
|
||||
pdu_fields.update(kwargs)
|
||||
|
||||
return SynapseEvent(prev_pdus=prev_pdus, **pdu_fields)
|
||||
return FrozenEvent(pdu_fields)
|
||||
|
||||
|
||||
class FederationTestCase(unittest.TestCase):
|
||||
@ -55,10 +51,16 @@ class FederationTestCase(unittest.TestCase):
|
||||
"delivered_txn",
|
||||
"get_received_txn_response",
|
||||
"set_received_txn_response",
|
||||
"get_destination_retry_timings",
|
||||
"get_auth_chain",
|
||||
])
|
||||
self.mock_persistence.get_received_txn_response.return_value = (
|
||||
defer.succeed(None)
|
||||
)
|
||||
self.mock_persistence.get_destination_retry_timings.return_value = (
|
||||
defer.succeed(DestinationsTable.EntryType("", 0, 0))
|
||||
)
|
||||
self.mock_persistence.get_auth_chain.return_value = []
|
||||
self.mock_config = Mock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
self.clock = MockClock()
|
||||
@ -171,7 +173,7 @@ class FederationTestCase(unittest.TestCase):
|
||||
(200, "OK")
|
||||
)
|
||||
|
||||
pdu = SynapseEvent(
|
||||
pdu = make_pdu(
|
||||
event_id="abc123def456",
|
||||
origin="red",
|
||||
user_id="@a:red",
|
||||
@ -180,10 +182,9 @@ class FederationTestCase(unittest.TestCase):
|
||||
origin_server_ts=123456789001,
|
||||
depth=1,
|
||||
content={"text": "Here is the message"},
|
||||
destinations=["remote"],
|
||||
)
|
||||
|
||||
yield self.federation.send_pdu(pdu)
|
||||
yield self.federation.send_pdu(pdu, ["remote"])
|
||||
|
||||
self.mock_http_client.put_json.assert_called_with(
|
||||
"remote",
|
||||
|
@ -16,11 +16,8 @@
|
||||
from twisted.internet import defer
|
||||
from tests import unittest
|
||||
|
||||
from synapse.api.events.room import (
|
||||
MessageEvent,
|
||||
)
|
||||
|
||||
from synapse.api.events import SynapseEvent
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.handlers.federation import FederationHandler
|
||||
from synapse.server import HomeServer
|
||||
|
||||
@ -37,7 +34,7 @@ class FederationTestCase(unittest.TestCase):
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
|
||||
self.state_handler = NonCallableMock(spec_set=[
|
||||
"annotate_event_with_state",
|
||||
"compute_event_context",
|
||||
])
|
||||
|
||||
self.auth = NonCallableMock(spec_set=[
|
||||
@ -53,6 +50,8 @@ class FederationTestCase(unittest.TestCase):
|
||||
"persist_event",
|
||||
"store_room",
|
||||
"get_room",
|
||||
"get_destination_retry_timings",
|
||||
"set_destination_retry_timings",
|
||||
]),
|
||||
resource_for_federation=NonCallableMock(),
|
||||
http_client=NonCallableMock(spec_set=[]),
|
||||
@ -76,43 +75,47 @@ class FederationTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_msg(self):
|
||||
pdu = SynapseEvent(
|
||||
type=MessageEvent.TYPE,
|
||||
room_id="foo",
|
||||
content={"msgtype": u"fooo"},
|
||||
origin_server_ts=0,
|
||||
event_id="$a:b",
|
||||
user_id="@a:b",
|
||||
origin="b",
|
||||
auth_events=[],
|
||||
hashes={"sha256":"AcLrgtUIqqwaGoHhrEvYG1YLDIsVPYJdSRGhkp3jJp8"},
|
||||
)
|
||||
pdu = FrozenEvent({
|
||||
"type": EventTypes.Message,
|
||||
"room_id": "foo",
|
||||
"content": {"msgtype": u"fooo"},
|
||||
"origin_server_ts": 0,
|
||||
"event_id": "$a:b",
|
||||
"user_id":"@a:b",
|
||||
"origin": "b",
|
||||
"auth_events": [],
|
||||
"hashes": {"sha256":"AcLrgtUIqqwaGoHhrEvYG1YLDIsVPYJdSRGhkp3jJp8"},
|
||||
})
|
||||
|
||||
self.datastore.persist_event.return_value = defer.succeed(None)
|
||||
self.datastore.get_room.return_value = defer.succeed(True)
|
||||
self.auth.check_host_in_room.return_value = defer.succeed(True)
|
||||
|
||||
def annotate(ev, old_state=None):
|
||||
ev.old_state_events = []
|
||||
return defer.succeed(False)
|
||||
self.state_handler.annotate_event_with_state.side_effect = annotate
|
||||
context = Mock()
|
||||
context.current_state = {}
|
||||
context.auth_events = {}
|
||||
return defer.succeed(context)
|
||||
self.state_handler.compute_event_context.side_effect = annotate
|
||||
|
||||
yield self.handlers.federation_handler.on_receive_pdu(
|
||||
"fo", pdu, False
|
||||
)
|
||||
|
||||
self.datastore.persist_event.assert_called_once_with(
|
||||
ANY, is_new_state=False, backfilled=False, current_state=None
|
||||
ANY,
|
||||
is_new_state=True,
|
||||
backfilled=False,
|
||||
current_state=None,
|
||||
context=ANY,
|
||||
)
|
||||
|
||||
self.state_handler.annotate_event_with_state.assert_called_once_with(
|
||||
ANY,
|
||||
old_state=None,
|
||||
self.state_handler.compute_event_context.assert_called_once_with(
|
||||
ANY, old_state=None,
|
||||
)
|
||||
|
||||
self.auth.check.assert_called_once_with(ANY, auth_events={})
|
||||
|
||||
self.notifier.on_new_room_event.assert_called_once_with(
|
||||
ANY,
|
||||
extra_users=[]
|
||||
ANY, extra_users=[]
|
||||
)
|
||||
|
@ -30,7 +30,7 @@ from synapse.api.constants import PresenceState
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.handlers.presence import PresenceHandler, UserPresenceCache
|
||||
from synapse.streams.config import SourcePaginationConfig
|
||||
|
||||
from synapse.storage.transactions import DestinationsTable
|
||||
|
||||
OFFLINE = PresenceState.OFFLINE
|
||||
UNAVAILABLE = PresenceState.UNAVAILABLE
|
||||
@ -528,6 +528,7 @@ class PresencePushTestCase(unittest.TestCase):
|
||||
"delivered_txn",
|
||||
"get_received_txn_response",
|
||||
"set_received_txn_response",
|
||||
"get_destination_retry_timings",
|
||||
]),
|
||||
handlers=None,
|
||||
resource_for_client=Mock(),
|
||||
@ -539,6 +540,9 @@ class PresencePushTestCase(unittest.TestCase):
|
||||
hs.handlers = JustPresenceHandlers(hs)
|
||||
|
||||
self.datastore = hs.get_datastore()
|
||||
self.datastore.get_destination_retry_timings.return_value = (
|
||||
defer.succeed(DestinationsTable.EntryType("", 0, 0))
|
||||
)
|
||||
|
||||
def get_received_txn_response(*args):
|
||||
return defer.succeed(None)
|
||||
@ -1037,6 +1041,7 @@ class PresencePollingTestCase(unittest.TestCase):
|
||||
"delivered_txn",
|
||||
"get_received_txn_response",
|
||||
"set_received_txn_response",
|
||||
"get_destination_retry_timings",
|
||||
]),
|
||||
handlers=None,
|
||||
resource_for_client=Mock(),
|
||||
@ -1048,6 +1053,9 @@ class PresencePollingTestCase(unittest.TestCase):
|
||||
hs.handlers = JustPresenceHandlers(hs)
|
||||
|
||||
self.datastore = hs.get_datastore()
|
||||
self.datastore.get_destination_retry_timings.return_value = (
|
||||
defer.succeed(DestinationsTable.EntryType("", 0, 0))
|
||||
)
|
||||
|
||||
def get_received_txn_response(*args):
|
||||
return defer.succeed(None)
|
||||
|
@ -17,10 +17,7 @@
|
||||
from twisted.internet import defer
|
||||
from tests import unittest
|
||||
|
||||
from synapse.api.events.room import (
|
||||
RoomMemberEvent,
|
||||
)
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.handlers.room import RoomMemberHandler, RoomCreationHandler
|
||||
from synapse.handlers.profile import ProfileHandler
|
||||
from synapse.server import HomeServer
|
||||
@ -47,7 +44,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
||||
"get_room_member",
|
||||
"get_room",
|
||||
"store_room",
|
||||
"snapshot_room",
|
||||
"get_latest_events_in_room",
|
||||
]),
|
||||
resource_for_federation=NonCallableMock(),
|
||||
http_client=NonCallableMock(spec_set=[]),
|
||||
@ -63,7 +60,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
||||
"check_host_in_room",
|
||||
]),
|
||||
state_handler=NonCallableMock(spec_set=[
|
||||
"annotate_event_with_state",
|
||||
"compute_event_context",
|
||||
"get_current_state",
|
||||
]),
|
||||
config=self.mock_config,
|
||||
@ -91,9 +88,6 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
||||
self.handlers.profile_handler = ProfileHandler(self.hs)
|
||||
self.room_member_handler = self.handlers.room_member_handler
|
||||
|
||||
self.snapshot = Mock()
|
||||
self.datastore.snapshot_room.return_value = self.snapshot
|
||||
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
self.ratelimiter.send_message.return_value = (True, 0)
|
||||
|
||||
@ -104,50 +98,70 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
||||
target_user_id = "@red:blue"
|
||||
content = {"membership": Membership.INVITE}
|
||||
|
||||
event = self.hs.get_event_factory().create_event(
|
||||
etype=RoomMemberEvent.TYPE,
|
||||
user_id=user_id,
|
||||
state_key=target_user_id,
|
||||
room_id=room_id,
|
||||
membership=Membership.INVITE,
|
||||
content=content,
|
||||
builder = self.hs.get_event_builder_factory().new({
|
||||
"type": EventTypes.Member,
|
||||
"sender": user_id,
|
||||
"state_key": target_user_id,
|
||||
"room_id": room_id,
|
||||
"content": content,
|
||||
})
|
||||
|
||||
self.datastore.get_latest_events_in_room.return_value = (
|
||||
defer.succeed([])
|
||||
)
|
||||
|
||||
self.auth.check_host_in_room.return_value = defer.succeed(True)
|
||||
def annotate(_):
|
||||
ctx = Mock()
|
||||
ctx.current_state = {
|
||||
(EventTypes.Member, "@alice:green"): self._create_member(
|
||||
user_id="@alice:green",
|
||||
room_id=room_id,
|
||||
),
|
||||
(EventTypes.Member, "@bob:red"): self._create_member(
|
||||
user_id="@bob:red",
|
||||
room_id=room_id,
|
||||
),
|
||||
}
|
||||
ctx.prev_state_events = []
|
||||
|
||||
store_id = "store_id_fooo"
|
||||
self.datastore.persist_event.return_value = defer.succeed(store_id)
|
||||
return defer.succeed(ctx)
|
||||
|
||||
self.datastore.get_room_member.return_value = defer.succeed(None)
|
||||
self.state_handler.compute_event_context.side_effect = annotate
|
||||
|
||||
event.old_state_events = {
|
||||
(RoomMemberEvent.TYPE, "@alice:green"): self._create_member(
|
||||
user_id="@alice:green",
|
||||
room_id=room_id,
|
||||
),
|
||||
(RoomMemberEvent.TYPE, "@bob:red"): self._create_member(
|
||||
user_id="@bob:red",
|
||||
room_id=room_id,
|
||||
),
|
||||
}
|
||||
def add_auth(_, ctx):
|
||||
ctx.auth_events = ctx.current_state[
|
||||
(EventTypes.Member, "@bob:red")
|
||||
]
|
||||
|
||||
event.state_events = event.old_state_events
|
||||
event.state_events[(RoomMemberEvent.TYPE, target_user_id)] = event
|
||||
return defer.succeed(True)
|
||||
self.auth.add_auth_events.side_effect = add_auth
|
||||
|
||||
# Actual invocation
|
||||
yield self.room_member_handler.change_membership(event)
|
||||
def send_invite(domain, event):
|
||||
return defer.succeed(event)
|
||||
|
||||
self.federation.handle_new_event.assert_called_once_with(
|
||||
event, self.snapshot,
|
||||
self.federation.send_invite.side_effect = send_invite
|
||||
|
||||
room_handler = self.room_member_handler
|
||||
event, context = yield room_handler._create_new_client_event(
|
||||
builder
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
set(["red", "green"]),
|
||||
set(event.destinations)
|
||||
yield room_handler.change_membership(event, context)
|
||||
|
||||
self.state_handler.compute_event_context.assert_called_once_with(
|
||||
builder
|
||||
)
|
||||
|
||||
self.auth.add_auth_events.assert_called_once_with(
|
||||
builder, context
|
||||
)
|
||||
|
||||
self.federation.send_invite.assert_called_once_with(
|
||||
"blue", event,
|
||||
)
|
||||
|
||||
self.datastore.persist_event.assert_called_once_with(
|
||||
event
|
||||
event, context=context,
|
||||
)
|
||||
self.notifier.on_new_room_event.assert_called_once_with(
|
||||
event, extra_users=[self.hs.parse_userid(target_user_id)]
|
||||
@ -162,57 +176,58 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
||||
user_id = "@bob:red"
|
||||
user = self.hs.parse_userid(user_id)
|
||||
|
||||
event = self._create_member(
|
||||
user_id=user_id,
|
||||
room_id=room_id,
|
||||
)
|
||||
|
||||
self.auth.check_host_in_room.return_value = defer.succeed(True)
|
||||
|
||||
store_id = "store_id_fooo"
|
||||
self.datastore.persist_event.return_value = defer.succeed(store_id)
|
||||
self.datastore.get_room.return_value = defer.succeed(1) # Not None.
|
||||
|
||||
prev_state = NonCallableMock()
|
||||
prev_state.membership = Membership.INVITE
|
||||
prev_state.sender = "@foo:red"
|
||||
self.datastore.get_room_member.return_value = defer.succeed(prev_state)
|
||||
|
||||
join_signal_observer = Mock()
|
||||
self.distributor.observe("user_joined_room", join_signal_observer)
|
||||
|
||||
event.state_events = {
|
||||
(RoomMemberEvent.TYPE, "@alice:green"): self._create_member(
|
||||
user_id="@alice:green",
|
||||
room_id=room_id,
|
||||
),
|
||||
(RoomMemberEvent.TYPE, user_id): event,
|
||||
}
|
||||
builder = self.hs.get_event_builder_factory().new({
|
||||
"type": EventTypes.Member,
|
||||
"sender": user_id,
|
||||
"state_key": user_id,
|
||||
"room_id": room_id,
|
||||
"content": {"membership": Membership.JOIN},
|
||||
})
|
||||
|
||||
event.old_state_events = {
|
||||
(RoomMemberEvent.TYPE, "@alice:green"): self._create_member(
|
||||
user_id="@alice:green",
|
||||
room_id=room_id,
|
||||
),
|
||||
}
|
||||
|
||||
event.state_events = event.old_state_events
|
||||
event.state_events[(RoomMemberEvent.TYPE, user_id)] = event
|
||||
|
||||
# Actual invocation
|
||||
yield self.room_member_handler.change_membership(event)
|
||||
|
||||
self.federation.handle_new_event.assert_called_once_with(
|
||||
event, self.snapshot
|
||||
self.datastore.get_latest_events_in_room.return_value = (
|
||||
defer.succeed([])
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
set(["red", "green"]),
|
||||
set(event.destinations)
|
||||
def annotate(_):
|
||||
ctx = Mock()
|
||||
ctx.current_state = {
|
||||
(EventTypes.Member, "@bob:red"): self._create_member(
|
||||
user_id="@bob:red",
|
||||
room_id=room_id,
|
||||
membership=Membership.INVITE
|
||||
),
|
||||
}
|
||||
ctx.prev_state_events = []
|
||||
|
||||
return defer.succeed(ctx)
|
||||
|
||||
self.state_handler.compute_event_context.side_effect = annotate
|
||||
|
||||
def add_auth(_, ctx):
|
||||
ctx.auth_events = ctx.current_state[
|
||||
(EventTypes.Member, "@bob:red")
|
||||
]
|
||||
|
||||
return defer.succeed(True)
|
||||
self.auth.add_auth_events.side_effect = add_auth
|
||||
|
||||
room_handler = self.room_member_handler
|
||||
event, context = yield room_handler._create_new_client_event(
|
||||
builder
|
||||
)
|
||||
|
||||
# Actual invocation
|
||||
yield room_handler.change_membership(event, context)
|
||||
|
||||
self.federation.handle_new_event.assert_called_once_with(
|
||||
event, None, destinations=set()
|
||||
)
|
||||
|
||||
self.datastore.persist_event.assert_called_once_with(
|
||||
event
|
||||
event, context=context
|
||||
)
|
||||
self.notifier.on_new_room_event.assert_called_once_with(
|
||||
event, extra_users=[user]
|
||||
@ -222,14 +237,82 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
||||
user=user, room_id=room_id
|
||||
)
|
||||
|
||||
def _create_member(self, user_id, room_id):
|
||||
return self.hs.get_event_factory().create_event(
|
||||
etype=RoomMemberEvent.TYPE,
|
||||
user_id=user_id,
|
||||
state_key=user_id,
|
||||
room_id=room_id,
|
||||
membership=Membership.JOIN,
|
||||
content={"membership": Membership.JOIN},
|
||||
def _create_member(self, user_id, room_id, membership=Membership.JOIN):
|
||||
builder = self.hs.get_event_builder_factory().new({
|
||||
"type": EventTypes.Member,
|
||||
"sender": user_id,
|
||||
"state_key": user_id,
|
||||
"room_id": room_id,
|
||||
"content": {"membership": membership},
|
||||
})
|
||||
|
||||
return builder.build()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_simple_leave(self):
|
||||
room_id = "!foo:red"
|
||||
user_id = "@bob:red"
|
||||
user = self.hs.parse_userid(user_id)
|
||||
|
||||
builder = self.hs.get_event_builder_factory().new({
|
||||
"type": EventTypes.Member,
|
||||
"sender": user_id,
|
||||
"state_key": user_id,
|
||||
"room_id": room_id,
|
||||
"content": {"membership": Membership.LEAVE},
|
||||
})
|
||||
|
||||
self.datastore.get_latest_events_in_room.return_value = (
|
||||
defer.succeed([])
|
||||
)
|
||||
|
||||
def annotate(_):
|
||||
ctx = Mock()
|
||||
ctx.current_state = {
|
||||
(EventTypes.Member, "@bob:red"): self._create_member(
|
||||
user_id="@bob:red",
|
||||
room_id=room_id,
|
||||
membership=Membership.JOIN
|
||||
),
|
||||
}
|
||||
ctx.prev_state_events = []
|
||||
|
||||
return defer.succeed(ctx)
|
||||
|
||||
self.state_handler.compute_event_context.side_effect = annotate
|
||||
|
||||
def add_auth(_, ctx):
|
||||
ctx.auth_events = ctx.current_state[
|
||||
(EventTypes.Member, "@bob:red")
|
||||
]
|
||||
|
||||
return defer.succeed(True)
|
||||
self.auth.add_auth_events.side_effect = add_auth
|
||||
|
||||
room_handler = self.room_member_handler
|
||||
event, context = yield room_handler._create_new_client_event(
|
||||
builder
|
||||
)
|
||||
|
||||
leave_signal_observer = Mock()
|
||||
self.distributor.observe("user_left_room", leave_signal_observer)
|
||||
|
||||
# Actual invocation
|
||||
yield room_handler.change_membership(event, context)
|
||||
|
||||
self.federation.handle_new_event.assert_called_once_with(
|
||||
event, None, destinations=set(['red'])
|
||||
)
|
||||
|
||||
self.datastore.persist_event.assert_called_once_with(
|
||||
event, context=context
|
||||
)
|
||||
self.notifier.on_new_room_event.assert_called_once_with(
|
||||
event, extra_users=[user]
|
||||
)
|
||||
|
||||
leave_signal_observer.assert_called_with(
|
||||
user=user, room_id=room_id
|
||||
)
|
||||
|
||||
|
||||
@ -254,13 +337,9 @@ class RoomCreationTest(unittest.TestCase):
|
||||
notifier=NonCallableMock(spec_set=["on_new_room_event"]),
|
||||
handlers=NonCallableMock(spec_set=[
|
||||
"room_creation_handler",
|
||||
"room_member_handler",
|
||||
"federation_handler",
|
||||
"message_handler",
|
||||
]),
|
||||
auth=NonCallableMock(spec_set=["check", "add_auth_events"]),
|
||||
state_handler=NonCallableMock(spec_set=[
|
||||
"annotate_event_with_state",
|
||||
]),
|
||||
ratelimiter=NonCallableMock(spec_set=[
|
||||
"send_message",
|
||||
]),
|
||||
@ -271,30 +350,12 @@ class RoomCreationTest(unittest.TestCase):
|
||||
"handle_new_event",
|
||||
])
|
||||
|
||||
self.datastore = hs.get_datastore()
|
||||
self.handlers = hs.get_handlers()
|
||||
self.notifier = hs.get_notifier()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
self.hs = hs
|
||||
|
||||
self.handlers.federation_handler = self.federation
|
||||
|
||||
self.handlers.room_creation_handler = RoomCreationHandler(self.hs)
|
||||
self.handlers.room_creation_handler = RoomCreationHandler(hs)
|
||||
self.room_creation_handler = self.handlers.room_creation_handler
|
||||
|
||||
self.handlers.room_member_handler = NonCallableMock(spec_set=[
|
||||
"change_membership"
|
||||
])
|
||||
self.room_member_handler = self.handlers.room_member_handler
|
||||
|
||||
def annotate(event):
|
||||
event.state_events = {}
|
||||
return defer.succeed(None)
|
||||
self.state_handler.annotate_event_with_state.side_effect = annotate
|
||||
|
||||
def hosts(room):
|
||||
return defer.succeed([])
|
||||
self.datastore.get_joined_hosts_for_room.side_effect = hosts
|
||||
self.message_handler = self.handlers.message_handler
|
||||
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
self.ratelimiter.send_message.return_value = (True, 0)
|
||||
@ -311,14 +372,37 @@ class RoomCreationTest(unittest.TestCase):
|
||||
config=config,
|
||||
)
|
||||
|
||||
self.assertTrue(self.room_member_handler.change_membership.called)
|
||||
join_event = self.room_member_handler.change_membership.call_args[0][0]
|
||||
self.assertTrue(self.message_handler.create_and_send_event.called)
|
||||
|
||||
self.assertEquals(RoomMemberEvent.TYPE, join_event.type)
|
||||
self.assertEquals(room_id, join_event.room_id)
|
||||
self.assertEquals(user_id, join_event.user_id)
|
||||
self.assertEquals(user_id, join_event.state_key)
|
||||
event_dicts = [
|
||||
e[0][0]
|
||||
for e in self.message_handler.create_and_send_event.call_args_list
|
||||
]
|
||||
|
||||
self.assertTrue(self.state_handler.annotate_event_with_state.called)
|
||||
self.assertTrue(len(event_dicts) > 3)
|
||||
|
||||
self.assertTrue(self.federation.handle_new_event.called)
|
||||
self.assertDictContainsSubset(
|
||||
{
|
||||
"type": EventTypes.Create,
|
||||
"sender": user_id,
|
||||
"room_id": room_id,
|
||||
},
|
||||
event_dicts[0]
|
||||
)
|
||||
|
||||
self.assertEqual(user_id, event_dicts[0]["content"]["creator"])
|
||||
|
||||
self.assertDictContainsSubset(
|
||||
{
|
||||
"type": EventTypes.Member,
|
||||
"sender": user_id,
|
||||
"room_id": room_id,
|
||||
"state_key": user_id,
|
||||
},
|
||||
event_dicts[1]
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
Membership.JOIN,
|
||||
event_dicts[1]["content"]["membership"]
|
||||
)
|
||||
|
@ -22,9 +22,12 @@ import json
|
||||
|
||||
from ..utils import MockHttpResource, MockClock, DeferredMockCallable, MockKey
|
||||
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.server import HomeServer
|
||||
from synapse.handlers.typing import TypingNotificationHandler
|
||||
|
||||
from synapse.storage.transactions import DestinationsTable
|
||||
|
||||
|
||||
def _expect_edu(destination, edu_type, content, origin="test"):
|
||||
return {
|
||||
@ -63,7 +66,13 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||
self.mock_config = Mock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
|
||||
mock_notifier = Mock(spec=["on_new_user_event"])
|
||||
self.on_new_user_event = mock_notifier.on_new_user_event
|
||||
|
||||
self.auth = Mock(spec=[])
|
||||
|
||||
hs = HomeServer("test",
|
||||
auth=self.auth,
|
||||
clock=self.clock,
|
||||
db_pool=None,
|
||||
datastore=Mock(spec=[
|
||||
@ -72,8 +81,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||
"delivered_txn",
|
||||
"get_received_txn_response",
|
||||
"set_received_txn_response",
|
||||
"get_destination_retry_timings",
|
||||
]),
|
||||
handlers=None,
|
||||
notifier=mock_notifier,
|
||||
resource_for_client=Mock(),
|
||||
resource_for_federation=self.mock_federation_resource,
|
||||
http_client=self.mock_http_client,
|
||||
@ -82,13 +93,14 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||
)
|
||||
hs.handlers = JustTypingNotificationHandlers(hs)
|
||||
|
||||
self.mock_update_client = Mock()
|
||||
self.mock_update_client.return_value = defer.succeed(None)
|
||||
|
||||
self.handler = hs.get_handlers().typing_notification_handler
|
||||
self.handler.push_update_to_clients = self.mock_update_client
|
||||
|
||||
self.event_source = hs.get_event_sources().sources["typing"]
|
||||
|
||||
self.datastore = hs.get_datastore()
|
||||
self.datastore.get_destination_retry_timings.return_value = (
|
||||
defer.succeed(DestinationsTable.EntryType("", 0, 0))
|
||||
)
|
||||
|
||||
def get_received_txn_response(*args):
|
||||
return defer.succeed(None)
|
||||
@ -125,7 +137,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||
if ignore_user is not None and member == ignore_user:
|
||||
continue
|
||||
|
||||
if member.is_mine:
|
||||
if hs.is_mine(member):
|
||||
if localusers is not None:
|
||||
localusers.add(member)
|
||||
else:
|
||||
@ -134,6 +146,12 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||
self.room_member_handler.fetch_room_distributions_into = (
|
||||
fetch_room_distributions_into)
|
||||
|
||||
def check_joined_room(room_id, user_id):
|
||||
if user_id not in [u.to_string() for u in self.room_members]:
|
||||
raise AuthError(401, "User is not in the room")
|
||||
|
||||
self.auth.check_joined_room = check_joined_room
|
||||
|
||||
# Some local users to test with
|
||||
self.u_apple = hs.parse_userid("@apple:test")
|
||||
self.u_banana = hs.parse_userid("@banana:test")
|
||||
@ -145,6 +163,8 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||
def test_started_typing_local(self):
|
||||
self.room_members = [self.u_apple, self.u_banana]
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 0)
|
||||
|
||||
yield self.handler.started_typing(
|
||||
target_user=self.u_apple,
|
||||
auth_user=self.u_apple,
|
||||
@ -152,13 +172,22 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||
timeout=20000,
|
||||
)
|
||||
|
||||
self.mock_update_client.assert_has_calls([
|
||||
call(observer_user=self.u_banana,
|
||||
observed_user=self.u_apple,
|
||||
room_id=self.room_id,
|
||||
typing=True),
|
||||
self.on_new_user_event.assert_has_calls([
|
||||
call(rooms=[self.room_id]),
|
||||
])
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||
self.assertEquals(
|
||||
self.event_source.get_new_events_for_user(self.u_apple, 0, None)[0],
|
||||
[
|
||||
{"type": "m.typing",
|
||||
"room_id": self.room_id,
|
||||
"content": {
|
||||
"user_ids": [self.u_apple.to_string()],
|
||||
}},
|
||||
]
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_started_typing_remote_send(self):
|
||||
self.room_members = [self.u_apple, self.u_onion]
|
||||
@ -192,6 +221,8 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||
def test_started_typing_remote_recv(self):
|
||||
self.room_members = [self.u_apple, self.u_onion]
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 0)
|
||||
|
||||
yield self.mock_federation_resource.trigger("PUT",
|
||||
"/_matrix/federation/v1/send/1000000/",
|
||||
_make_edu_json("farm", "m.typing",
|
||||
@ -203,13 +234,22 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||
)
|
||||
)
|
||||
|
||||
self.mock_update_client.assert_has_calls([
|
||||
call(observer_user=self.u_apple,
|
||||
observed_user=self.u_onion,
|
||||
room_id=self.room_id,
|
||||
typing=True),
|
||||
self.on_new_user_event.assert_has_calls([
|
||||
call(rooms=[self.room_id]),
|
||||
])
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||
self.assertEquals(
|
||||
self.event_source.get_new_events_for_user(self.u_apple, 0, None)[0],
|
||||
[
|
||||
{"type": "m.typing",
|
||||
"room_id": self.room_id,
|
||||
"content": {
|
||||
"user_ids": [self.u_onion.to_string()],
|
||||
}},
|
||||
]
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_stopped_typing(self):
|
||||
self.room_members = [self.u_apple, self.u_banana, self.u_onion]
|
||||
@ -232,9 +272,14 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||
|
||||
# Gut-wrenching
|
||||
from synapse.handlers.typing import RoomMember
|
||||
self.handler._member_typing_until[
|
||||
RoomMember(self.room_id, self.u_apple)
|
||||
] = 1002000
|
||||
member = RoomMember(self.room_id, self.u_apple)
|
||||
self.handler._member_typing_until[member] = 1002000
|
||||
self.handler._member_typing_timer[member] = (
|
||||
self.clock.call_later(1002, lambda: 0)
|
||||
)
|
||||
self.handler._room_typing[self.room_id] = set((self.u_apple,))
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 0)
|
||||
|
||||
yield self.handler.stopped_typing(
|
||||
target_user=self.u_apple,
|
||||
@ -242,11 +287,68 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||
room_id=self.room_id,
|
||||
)
|
||||
|
||||
self.mock_update_client.assert_has_calls([
|
||||
call(observer_user=self.u_banana,
|
||||
observed_user=self.u_apple,
|
||||
room_id=self.room_id,
|
||||
typing=False),
|
||||
self.on_new_user_event.assert_has_calls([
|
||||
call(rooms=[self.room_id]),
|
||||
])
|
||||
|
||||
yield put_json.await_calls()
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||
self.assertEquals(
|
||||
self.event_source.get_new_events_for_user(self.u_apple, 0, None)[0],
|
||||
[
|
||||
{"type": "m.typing",
|
||||
"room_id": self.room_id,
|
||||
"content": {
|
||||
"user_ids": [],
|
||||
}},
|
||||
]
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_typing_timeout(self):
|
||||
self.room_members = [self.u_apple, self.u_banana]
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 0)
|
||||
|
||||
yield self.handler.started_typing(
|
||||
target_user=self.u_apple,
|
||||
auth_user=self.u_apple,
|
||||
room_id=self.room_id,
|
||||
timeout=10000,
|
||||
)
|
||||
|
||||
self.on_new_user_event.assert_has_calls([
|
||||
call(rooms=[self.room_id]),
|
||||
])
|
||||
self.on_new_user_event.reset_mock()
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||
self.assertEquals(
|
||||
self.event_source.get_new_events_for_user(self.u_apple, 0, None)[0],
|
||||
[
|
||||
{"type": "m.typing",
|
||||
"room_id": self.room_id,
|
||||
"content": {
|
||||
"user_ids": [self.u_apple.to_string()],
|
||||
}},
|
||||
]
|
||||
)
|
||||
|
||||
self.clock.advance_time(11)
|
||||
|
||||
self.on_new_user_event.assert_has_calls([
|
||||
call(rooms=[self.room_id]),
|
||||
])
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 2)
|
||||
self.assertEquals(
|
||||
self.event_source.get_new_events_for_user(self.u_apple, 1, None)[0],
|
||||
[
|
||||
{"type": "m.typing",
|
||||
"room_id": self.room_id,
|
||||
"content": {
|
||||
"user_ids": [],
|
||||
}},
|
||||
]
|
||||
)
|
||||
|
@ -113,9 +113,6 @@ class EventStreamPermissionsTestCase(RestTestCase):
|
||||
def setUp(self):
|
||||
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
|
||||
|
||||
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
|
||||
persistence_service.get_latest_pdus_in_context.return_value = []
|
||||
|
||||
self.mock_config = NonCallableMock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
|
||||
@ -127,7 +124,6 @@ class EventStreamPermissionsTestCase(RestTestCase):
|
||||
db_pool=db_pool,
|
||||
http_client=None,
|
||||
replication_layer=Mock(),
|
||||
persistence_service=persistence_service,
|
||||
clock=Mock(spec=[
|
||||
"call_later",
|
||||
"cancel_call_later",
|
||||
|
@ -503,7 +503,7 @@ class RoomsMemberListTestCase(RestTestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_member_list_mixed_memberships(self):
|
||||
room_creator = "@some_other_guy:blue"
|
||||
room_creator = "@some_other_guy:red"
|
||||
room_id = yield self.create_room_as(room_creator)
|
||||
room_path = "/rooms/%s/members" % room_id
|
||||
yield self.invite(room=room_id, src=room_creator,
|
||||
@ -1066,7 +1066,3 @@ class RoomInitialSyncTestCase(RestTestCase):
|
||||
}
|
||||
self.assertTrue(self.user_id in presence_by_user)
|
||||
self.assertEquals("m.presence", presence_by_user[self.user_id]["type"])
|
||||
|
||||
# (code, response) = yield self.mock_resource.trigger("GET", path, None)
|
||||
# self.assertEquals(200, code, msg=str(response))
|
||||
# self.assert_dict(json.loads(content), response)
|
||||
|
115
tests/rest/test_typing.py
Normal file
115
tests/rest/test_typing.py
Normal file
@ -0,0 +1,115 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests REST events for /rooms paths."""
|
||||
|
||||
# twisted imports
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.rest.room
|
||||
from synapse.server import HomeServer
|
||||
|
||||
from ..utils import MockHttpResource, SQLiteMemoryDbPool, MockKey
|
||||
from .utils import RestTestCase
|
||||
|
||||
from mock import Mock, NonCallableMock
|
||||
|
||||
|
||||
PATH_PREFIX = "/_matrix/client/api/v1"
|
||||
|
||||
|
||||
class RoomTypingTestCase(RestTestCase):
|
||||
""" Tests /rooms/$room_id/typing/$user_id REST API. """
|
||||
user_id = "@sid:red"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
|
||||
self.auth_user_id = self.user_id
|
||||
|
||||
self.mock_config = NonCallableMock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
|
||||
db_pool = SQLiteMemoryDbPool()
|
||||
yield db_pool.prepare()
|
||||
|
||||
hs = HomeServer(
|
||||
"red",
|
||||
db_pool=db_pool,
|
||||
http_client=None,
|
||||
replication_layer=Mock(),
|
||||
ratelimiter=NonCallableMock(spec_set=[
|
||||
"send_message",
|
||||
]),
|
||||
config=self.mock_config,
|
||||
)
|
||||
self.hs = hs
|
||||
|
||||
self.event_source = hs.get_event_sources().sources["typing"]
|
||||
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
self.ratelimiter.send_message.return_value = (True, 0)
|
||||
|
||||
hs.get_handlers().federation_handler = Mock()
|
||||
|
||||
def _get_user_by_token(token=None):
|
||||
return {
|
||||
"user": hs.parse_userid(self.auth_user_id),
|
||||
"admin": False,
|
||||
"device_id": None,
|
||||
}
|
||||
|
||||
hs.get_auth().get_user_by_token = _get_user_by_token
|
||||
|
||||
def _insert_client_ip(*args, **kwargs):
|
||||
return defer.succeed(None)
|
||||
hs.get_datastore().insert_client_ip = _insert_client_ip
|
||||
|
||||
synapse.rest.room.register_servlets(hs, self.mock_resource)
|
||||
|
||||
self.room_id = yield self.create_room_as(self.user_id)
|
||||
# Need another user to make notifications actually work
|
||||
yield self.join(self.room_id, user="@jim:red")
|
||||
|
||||
def tearDown(self):
|
||||
self.hs.get_handlers().typing_notification_handler.tearDown()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_set_typing(self):
|
||||
(code, _) = yield self.mock_resource.trigger("PUT",
|
||||
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
|
||||
'{"typing": true, "timeout": 30000}'
|
||||
)
|
||||
self.assertEquals(200, code)
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||
self.assertEquals(
|
||||
self.event_source.get_new_events_for_user(self.user_id, 0, None)[0],
|
||||
[
|
||||
{"type": "m.typing",
|
||||
"room_id": self.room_id,
|
||||
"content": {
|
||||
"user_ids": [self.user_id],
|
||||
}},
|
||||
]
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_set_not_typing(self):
|
||||
(code, _) = yield self.mock_resource.trigger("PUT",
|
||||
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
|
||||
'{"typing": false}'
|
||||
)
|
||||
self.assertEquals(200, code)
|
@ -18,12 +18,11 @@ from tests import unittest
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.server import HomeServer
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.events.room import (
|
||||
RoomMemberEvent, MessageEvent, RoomRedactionEvent,
|
||||
)
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
|
||||
from tests.utils import SQLiteMemoryDbPool
|
||||
from tests.utils import SQLiteMemoryDbPool, MockKey
|
||||
|
||||
from mock import Mock
|
||||
|
||||
|
||||
class RedactionTestCase(unittest.TestCase):
|
||||
@ -33,13 +32,21 @@ class RedactionTestCase(unittest.TestCase):
|
||||
db_pool = SQLiteMemoryDbPool()
|
||||
yield db_pool.prepare()
|
||||
|
||||
self.mock_config = Mock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
|
||||
hs = HomeServer(
|
||||
"test",
|
||||
db_pool=db_pool,
|
||||
config=self.mock_config,
|
||||
resource_for_federation=Mock(),
|
||||
http_client=None,
|
||||
)
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
self.event_factory = hs.get_event_factory()
|
||||
self.event_builder_factory = hs.get_event_builder_factory()
|
||||
self.handlers = hs.get_handlers()
|
||||
self.message_handler = self.handlers.message_handler
|
||||
|
||||
self.u_alice = hs.parse_userid("@alice:test")
|
||||
self.u_bob = hs.parse_userid("@bob:test")
|
||||
@ -49,35 +56,23 @@ class RedactionTestCase(unittest.TestCase):
|
||||
self.depth = 1
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def inject_room_member(self, room, user, membership, prev_state=None,
|
||||
def inject_room_member(self, room, user, membership, replaces_state=None,
|
||||
extra_content={}):
|
||||
self.depth += 1
|
||||
content = {"membership": membership}
|
||||
content.update(extra_content)
|
||||
builder = self.event_builder_factory.new({
|
||||
"type": EventTypes.Member,
|
||||
"sender": user.to_string(),
|
||||
"state_key": user.to_string(),
|
||||
"room_id": room.to_string(),
|
||||
"content": content,
|
||||
})
|
||||
|
||||
event = self.event_factory.create_event(
|
||||
etype=RoomMemberEvent.TYPE,
|
||||
user_id=user.to_string(),
|
||||
state_key=user.to_string(),
|
||||
room_id=room.to_string(),
|
||||
membership=membership,
|
||||
content={"membership": membership},
|
||||
depth=self.depth,
|
||||
prev_events=[],
|
||||
event, context = yield self.message_handler._create_new_client_event(
|
||||
builder
|
||||
)
|
||||
|
||||
event.content.update(extra_content)
|
||||
|
||||
if prev_state:
|
||||
event.prev_state = prev_state
|
||||
|
||||
event.state_events = None
|
||||
event.hashes = {}
|
||||
event.prev_state = []
|
||||
event.auth_events = []
|
||||
|
||||
# Have to create a join event using the eventfactory
|
||||
yield self.store.persist_event(
|
||||
event
|
||||
)
|
||||
yield self.store.persist_event(event, context)
|
||||
|
||||
defer.returnValue(event)
|
||||
|
||||
@ -85,46 +80,38 @@ class RedactionTestCase(unittest.TestCase):
|
||||
def inject_message(self, room, user, body):
|
||||
self.depth += 1
|
||||
|
||||
event = self.event_factory.create_event(
|
||||
etype=MessageEvent.TYPE,
|
||||
user_id=user.to_string(),
|
||||
room_id=room.to_string(),
|
||||
content={"body": body, "msgtype": u"message"},
|
||||
depth=self.depth,
|
||||
prev_events=[],
|
||||
builder = self.event_builder_factory.new({
|
||||
"type": EventTypes.Message,
|
||||
"sender": user.to_string(),
|
||||
"state_key": user.to_string(),
|
||||
"room_id": room.to_string(),
|
||||
"content": {"body": body, "msgtype": u"message"},
|
||||
})
|
||||
|
||||
event, context = yield self.message_handler._create_new_client_event(
|
||||
builder
|
||||
)
|
||||
|
||||
event.state_events = None
|
||||
event.hashes = {}
|
||||
event.auth_events = []
|
||||
|
||||
yield self.store.persist_event(
|
||||
event
|
||||
)
|
||||
yield self.store.persist_event(event, context)
|
||||
|
||||
defer.returnValue(event)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def inject_redaction(self, room, event_id, user, reason):
|
||||
event = self.event_factory.create_event(
|
||||
etype=RoomRedactionEvent.TYPE,
|
||||
user_id=user.to_string(),
|
||||
room_id=room.to_string(),
|
||||
content={"reason": reason},
|
||||
depth=self.depth,
|
||||
redacts=event_id,
|
||||
prev_events=[],
|
||||
builder = self.event_builder_factory.new({
|
||||
"type": EventTypes.Redaction,
|
||||
"sender": user.to_string(),
|
||||
"state_key": user.to_string(),
|
||||
"room_id": room.to_string(),
|
||||
"content": {"reason": reason},
|
||||
"redacts": event_id,
|
||||
})
|
||||
|
||||
event, context = yield self.message_handler._create_new_client_event(
|
||||
builder
|
||||
)
|
||||
|
||||
event.state_events = None
|
||||
event.hashes = {}
|
||||
event.auth_events = []
|
||||
|
||||
yield self.store.persist_event(
|
||||
event
|
||||
)
|
||||
|
||||
defer.returnValue(event)
|
||||
yield self.store.persist_event(event, context)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_redact(self):
|
||||
@ -152,14 +139,14 @@ class RedactionTestCase(unittest.TestCase):
|
||||
|
||||
self.assertObjectHasAttributes(
|
||||
{
|
||||
"type": MessageEvent.TYPE,
|
||||
"type": EventTypes.Message,
|
||||
"user_id": self.u_alice.to_string(),
|
||||
"content": {"body": "t", "msgtype": "message"},
|
||||
},
|
||||
event,
|
||||
)
|
||||
|
||||
self.assertFalse(hasattr(event, "redacted_because"))
|
||||
self.assertFalse("redacted_because" in event.unsigned)
|
||||
|
||||
# Redact event
|
||||
reason = "Because I said so"
|
||||
@ -180,24 +167,26 @@ class RedactionTestCase(unittest.TestCase):
|
||||
|
||||
event = results[0]
|
||||
|
||||
self.assertEqual(msg_event.event_id, event.event_id)
|
||||
|
||||
self.assertTrue("redacted_because" in event.unsigned)
|
||||
|
||||
self.assertObjectHasAttributes(
|
||||
{
|
||||
"type": MessageEvent.TYPE,
|
||||
"type": EventTypes.Message,
|
||||
"user_id": self.u_alice.to_string(),
|
||||
"content": {},
|
||||
},
|
||||
event,
|
||||
)
|
||||
|
||||
self.assertTrue(hasattr(event, "redacted_because"))
|
||||
|
||||
self.assertObjectHasAttributes(
|
||||
{
|
||||
"type": RoomRedactionEvent.TYPE,
|
||||
"type": EventTypes.Redaction,
|
||||
"user_id": self.u_alice.to_string(),
|
||||
"content": {"reason": reason},
|
||||
},
|
||||
event.redacted_because,
|
||||
event.unsigned["redacted_because"],
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -229,7 +218,7 @@ class RedactionTestCase(unittest.TestCase):
|
||||
|
||||
self.assertObjectHasAttributes(
|
||||
{
|
||||
"type": RoomMemberEvent.TYPE,
|
||||
"type": EventTypes.Member,
|
||||
"user_id": self.u_bob.to_string(),
|
||||
"content": {"membership": Membership.JOIN, "blue": "red"},
|
||||
},
|
||||
@ -257,22 +246,22 @@ class RedactionTestCase(unittest.TestCase):
|
||||
|
||||
event = results[0]
|
||||
|
||||
self.assertTrue("redacted_because" in event.unsigned)
|
||||
|
||||
self.assertObjectHasAttributes(
|
||||
{
|
||||
"type": RoomMemberEvent.TYPE,
|
||||
"type": EventTypes.Member,
|
||||
"user_id": self.u_bob.to_string(),
|
||||
"content": {"membership": Membership.JOIN},
|
||||
},
|
||||
event,
|
||||
)
|
||||
|
||||
self.assertTrue(hasattr(event, "redacted_because"))
|
||||
|
||||
self.assertObjectHasAttributes(
|
||||
{
|
||||
"type": RoomRedactionEvent.TYPE,
|
||||
"type": EventTypes.Redaction,
|
||||
"user_id": self.u_alice.to_string(),
|
||||
"content": {"reason": reason},
|
||||
},
|
||||
event.redacted_because,
|
||||
event.unsigned["redacted_because"],
|
||||
)
|
||||
|
@ -18,9 +18,7 @@ from tests import unittest
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.server import HomeServer
|
||||
from synapse.api.events.room import (
|
||||
RoomNameEvent, RoomTopicEvent
|
||||
)
|
||||
from synapse.api.constants import EventTypes
|
||||
|
||||
from tests.utils import SQLiteMemoryDbPool
|
||||
|
||||
@ -131,7 +129,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
|
||||
name = u"A-Room-Name"
|
||||
|
||||
yield self.inject_room_event(
|
||||
etype=RoomNameEvent.TYPE,
|
||||
etype=EventTypes.Name,
|
||||
name=name,
|
||||
content={"name": name},
|
||||
depth=1,
|
||||
@ -154,7 +152,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
|
||||
topic = u"A place for things"
|
||||
|
||||
yield self.inject_room_event(
|
||||
etype=RoomTopicEvent.TYPE,
|
||||
etype=EventTypes.Topic,
|
||||
topic=topic,
|
||||
content={"topic": topic},
|
||||
depth=1,
|
||||
|
@ -18,10 +18,11 @@ from tests import unittest
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.server import HomeServer
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.events.room import RoomMemberEvent
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
|
||||
from tests.utils import SQLiteMemoryDbPool
|
||||
from tests.utils import SQLiteMemoryDbPool, MockKey
|
||||
|
||||
from mock import Mock
|
||||
|
||||
|
||||
class RoomMemberStoreTestCase(unittest.TestCase):
|
||||
@ -31,14 +32,22 @@ class RoomMemberStoreTestCase(unittest.TestCase):
|
||||
db_pool = SQLiteMemoryDbPool()
|
||||
yield db_pool.prepare()
|
||||
|
||||
hs = HomeServer("test",
|
||||
db_pool=db_pool,
|
||||
)
|
||||
self.mock_config = Mock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
|
||||
hs = HomeServer(
|
||||
"test",
|
||||
db_pool=db_pool,
|
||||
config=self.mock_config,
|
||||
resource_for_federation=Mock(),
|
||||
http_client=None,
|
||||
)
|
||||
# We can't test the RoomMemberStore on its own without the other event
|
||||
# storage logic
|
||||
self.store = hs.get_datastore()
|
||||
self.event_factory = hs.get_event_factory()
|
||||
self.event_builder_factory = hs.get_event_builder_factory()
|
||||
self.handlers = hs.get_handlers()
|
||||
self.message_handler = self.handlers.message_handler
|
||||
|
||||
self.u_alice = hs.parse_userid("@alice:test")
|
||||
self.u_bob = hs.parse_userid("@bob:test")
|
||||
@ -49,27 +58,22 @@ class RoomMemberStoreTestCase(unittest.TestCase):
|
||||
self.room = hs.parse_roomid("!abc123:test")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def inject_room_member(self, room, user, membership):
|
||||
# Have to create a join event using the eventfactory
|
||||
event = self.event_factory.create_event(
|
||||
etype=RoomMemberEvent.TYPE,
|
||||
user_id=user.to_string(),
|
||||
state_key=user.to_string(),
|
||||
room_id=room.to_string(),
|
||||
membership=membership,
|
||||
content={"membership": membership},
|
||||
depth=1,
|
||||
prev_events=[],
|
||||
def inject_room_member(self, room, user, membership, replaces_state=None):
|
||||
builder = self.event_builder_factory.new({
|
||||
"type": EventTypes.Member,
|
||||
"sender": user.to_string(),
|
||||
"state_key": user.to_string(),
|
||||
"room_id": room.to_string(),
|
||||
"content": {"membership": membership},
|
||||
})
|
||||
|
||||
event, context = yield self.message_handler._create_new_client_event(
|
||||
builder
|
||||
)
|
||||
|
||||
event.state_events = None
|
||||
event.hashes = {}
|
||||
event.prev_state = {}
|
||||
event.auth_events = {}
|
||||
yield self.store.persist_event(event, context)
|
||||
|
||||
yield self.store.persist_event(
|
||||
event
|
||||
)
|
||||
defer.returnValue(event)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_one_member(self):
|
||||
|
@ -18,10 +18,11 @@ from tests import unittest
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.server import HomeServer
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.events.room import RoomMemberEvent, MessageEvent
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
|
||||
from tests.utils import SQLiteMemoryDbPool
|
||||
from tests.utils import SQLiteMemoryDbPool, MockKey
|
||||
|
||||
from mock import Mock
|
||||
|
||||
|
||||
class StreamStoreTestCase(unittest.TestCase):
|
||||
@ -31,13 +32,21 @@ class StreamStoreTestCase(unittest.TestCase):
|
||||
db_pool = SQLiteMemoryDbPool()
|
||||
yield db_pool.prepare()
|
||||
|
||||
self.mock_config = Mock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
|
||||
hs = HomeServer(
|
||||
"test",
|
||||
db_pool=db_pool,
|
||||
config=self.mock_config,
|
||||
resource_for_federation=Mock(),
|
||||
http_client=None,
|
||||
)
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
self.event_factory = hs.get_event_factory()
|
||||
self.event_builder_factory = hs.get_event_builder_factory()
|
||||
self.handlers = hs.get_handlers()
|
||||
self.message_handler = self.handlers.message_handler
|
||||
|
||||
self.u_alice = hs.parse_userid("@alice:test")
|
||||
self.u_bob = hs.parse_userid("@bob:test")
|
||||
@ -48,33 +57,22 @@ class StreamStoreTestCase(unittest.TestCase):
|
||||
self.depth = 1
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def inject_room_member(self, room, user, membership, replaces_state=None):
|
||||
def inject_room_member(self, room, user, membership):
|
||||
self.depth += 1
|
||||
|
||||
event = self.event_factory.create_event(
|
||||
etype=RoomMemberEvent.TYPE,
|
||||
user_id=user.to_string(),
|
||||
state_key=user.to_string(),
|
||||
room_id=room.to_string(),
|
||||
membership=membership,
|
||||
content={"membership": membership},
|
||||
depth=self.depth,
|
||||
prev_events=[],
|
||||
builder = self.event_builder_factory.new({
|
||||
"type": EventTypes.Member,
|
||||
"sender": user.to_string(),
|
||||
"state_key": user.to_string(),
|
||||
"room_id": room.to_string(),
|
||||
"content": {"membership": membership},
|
||||
})
|
||||
|
||||
event, context = yield self.message_handler._create_new_client_event(
|
||||
builder
|
||||
)
|
||||
|
||||
event.state_events = None
|
||||
event.hashes = {}
|
||||
event.prev_state = []
|
||||
event.auth_events = []
|
||||
|
||||
if replaces_state:
|
||||
event.prev_state = [(replaces_state, "hash")]
|
||||
event.replaces_state = replaces_state
|
||||
|
||||
# Have to create a join event using the eventfactory
|
||||
yield self.store.persist_event(
|
||||
event
|
||||
)
|
||||
yield self.store.persist_event(event, context)
|
||||
|
||||
defer.returnValue(event)
|
||||
|
||||
@ -82,23 +80,19 @@ class StreamStoreTestCase(unittest.TestCase):
|
||||
def inject_message(self, room, user, body):
|
||||
self.depth += 1
|
||||
|
||||
event = self.event_factory.create_event(
|
||||
etype=MessageEvent.TYPE,
|
||||
user_id=user.to_string(),
|
||||
room_id=room.to_string(),
|
||||
content={"body": body, "msgtype": u"message"},
|
||||
depth=self.depth,
|
||||
prev_events=[],
|
||||
builder = self.event_builder_factory.new({
|
||||
"type": EventTypes.Message,
|
||||
"sender": user.to_string(),
|
||||
"state_key": user.to_string(),
|
||||
"room_id": room.to_string(),
|
||||
"content": {"body": body, "msgtype": u"message"},
|
||||
})
|
||||
|
||||
event, context = yield self.message_handler._create_new_client_event(
|
||||
builder
|
||||
)
|
||||
|
||||
event.state_events = None
|
||||
event.hashes = {}
|
||||
event.auth_events = []
|
||||
|
||||
# Have to create a join event using the eventfactory
|
||||
yield self.store.persist_event(
|
||||
event
|
||||
)
|
||||
yield self.store.persist_event(event, context)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_event_stream_get_other(self):
|
||||
@ -130,7 +124,7 @@ class StreamStoreTestCase(unittest.TestCase):
|
||||
|
||||
self.assertObjectHasAttributes(
|
||||
{
|
||||
"type": MessageEvent.TYPE,
|
||||
"type": EventTypes.Message,
|
||||
"user_id": self.u_alice.to_string(),
|
||||
"content": {"body": "test", "msgtype": "message"},
|
||||
},
|
||||
@ -167,7 +161,7 @@ class StreamStoreTestCase(unittest.TestCase):
|
||||
|
||||
self.assertObjectHasAttributes(
|
||||
{
|
||||
"type": MessageEvent.TYPE,
|
||||
"type": EventTypes.Message,
|
||||
"user_id": self.u_alice.to_string(),
|
||||
"content": {"body": "test", "msgtype": "message"},
|
||||
},
|
||||
@ -220,7 +214,6 @@ class StreamStoreTestCase(unittest.TestCase):
|
||||
|
||||
event2 = yield self.inject_room_member(
|
||||
self.room1, self.u_alice, Membership.JOIN,
|
||||
replaces_state=event1.event_id,
|
||||
)
|
||||
|
||||
end = yield self.store.get_room_events_max_id()
|
||||
@ -238,6 +231,6 @@ class StreamStoreTestCase(unittest.TestCase):
|
||||
event = results[0]
|
||||
|
||||
self.assertTrue(
|
||||
hasattr(event, "prev_content"),
|
||||
"prev_content" in event.unsigned,
|
||||
msg="No prev_content key"
|
||||
)
|
||||
|
@ -13,12 +13,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from tests import unittest
|
||||
from . import unittest
|
||||
from twisted.internet import defer
|
||||
|
||||
from mock import Mock, patch
|
||||
|
||||
from synapse.util.distributor import Distributor
|
||||
from synapse.util.async import run_on_reactor
|
||||
|
||||
|
||||
class DistributorTestCase(unittest.TestCase):
|
||||
@ -26,6 +27,7 @@ class DistributorTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.dist = Distributor()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_signal_dispatch(self):
|
||||
self.dist.declare("alert")
|
||||
|
||||
@ -33,10 +35,11 @@ class DistributorTestCase(unittest.TestCase):
|
||||
self.dist.observe("alert", observer)
|
||||
|
||||
d = self.dist.fire("alert", 1, 2, 3)
|
||||
|
||||
yield d
|
||||
self.assertTrue(d.called)
|
||||
observer.assert_called_with(1, 2, 3)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_signal_dispatch_deferred(self):
|
||||
self.dist.declare("whine")
|
||||
|
||||
@ -50,8 +53,10 @@ class DistributorTestCase(unittest.TestCase):
|
||||
self.assertFalse(d_outer.called)
|
||||
|
||||
d_inner.callback(None)
|
||||
yield d_outer
|
||||
self.assertTrue(d_outer.called)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_signal_catch(self):
|
||||
self.dist.declare("alarm")
|
||||
|
||||
@ -65,6 +70,7 @@ class DistributorTestCase(unittest.TestCase):
|
||||
spec=["warning"]
|
||||
) as mock_logger:
|
||||
d = self.dist.fire("alarm", "Go")
|
||||
yield d
|
||||
self.assertTrue(d.called)
|
||||
|
||||
observers[0].assert_called_once("Go")
|
||||
@ -81,23 +87,28 @@ class DistributorTestCase(unittest.TestCase):
|
||||
|
||||
self.dist.declare("whail")
|
||||
|
||||
observer = Mock()
|
||||
observer.return_value = defer.fail(
|
||||
Exception("Oopsie")
|
||||
)
|
||||
class MyException(Exception):
|
||||
pass
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def observer():
|
||||
yield run_on_reactor()
|
||||
raise MyException("Oopsie")
|
||||
|
||||
self.dist.observe("whail", observer)
|
||||
|
||||
d = self.dist.fire("whail")
|
||||
|
||||
yield self.assertFailure(d, Exception)
|
||||
yield self.assertFailure(d, MyException)
|
||||
self.dist.suppress_failures = True
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_signal_prereg(self):
|
||||
observer = Mock()
|
||||
self.dist.observe("flare", observer)
|
||||
|
||||
self.dist.declare("flare")
|
||||
self.dist.fire("flare", 4, 5)
|
||||
yield self.dist.fire("flare", 4, 5)
|
||||
|
||||
observer.assert_called_with(4, 5)
|
||||
|
||||
|
@ -26,6 +26,7 @@ class StateTestCase(unittest.TestCase):
|
||||
self.store = Mock(
|
||||
spec_set=[
|
||||
"get_state_groups",
|
||||
"add_event_hashes",
|
||||
]
|
||||
)
|
||||
hs = Mock(spec=["get_datastore"])
|
||||
@ -44,17 +45,20 @@ class StateTestCase(unittest.TestCase):
|
||||
self.create_event(type="test2", state_key=""),
|
||||
]
|
||||
|
||||
yield self.state.annotate_event_with_state(event, old_state=old_state)
|
||||
context = yield self.state.compute_event_context(
|
||||
event, old_state=old_state
|
||||
)
|
||||
|
||||
for k, v in event.old_state_events.items():
|
||||
for k, v in context.current_state.items():
|
||||
type, state_key = k
|
||||
self.assertEqual(type, v.type)
|
||||
self.assertEqual(state_key, v.state_key)
|
||||
|
||||
self.assertEqual(set(old_state), set(event.old_state_events.values()))
|
||||
self.assertDictEqual(event.old_state_events, event.state_events)
|
||||
self.assertEqual(
|
||||
set(old_state), set(context.current_state.values())
|
||||
)
|
||||
|
||||
self.assertIsNone(event.state_group)
|
||||
self.assertIsNone(context.state_group)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_annotate_with_old_state(self):
|
||||
@ -66,21 +70,21 @@ class StateTestCase(unittest.TestCase):
|
||||
self.create_event(type="test2", state_key=""),
|
||||
]
|
||||
|
||||
yield self.state.annotate_event_with_state(event, old_state=old_state)
|
||||
context = yield self.state.compute_event_context(
|
||||
event, old_state=old_state
|
||||
)
|
||||
|
||||
for k, v in event.old_state_events.items():
|
||||
for k, v in context.current_state.items():
|
||||
type, state_key = k
|
||||
self.assertEqual(type, v.type)
|
||||
self.assertEqual(state_key, v.state_key)
|
||||
|
||||
self.assertEqual(
|
||||
set(old_state + [event]),
|
||||
set(event.old_state_events.values())
|
||||
set(old_state),
|
||||
set(context.current_state.values())
|
||||
)
|
||||
|
||||
self.assertDictEqual(event.old_state_events, event.state_events)
|
||||
|
||||
self.assertIsNone(event.state_group)
|
||||
self.assertIsNone(context.state_group)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_trivial_annotate_message(self):
|
||||
@ -99,30 +103,19 @@ class StateTestCase(unittest.TestCase):
|
||||
group_name: old_state,
|
||||
}
|
||||
|
||||
yield self.state.annotate_event_with_state(event)
|
||||
context = yield self.state.compute_event_context(event)
|
||||
|
||||
for k, v in event.old_state_events.items():
|
||||
for k, v in context.current_state.items():
|
||||
type, state_key = k
|
||||
self.assertEqual(type, v.type)
|
||||
self.assertEqual(state_key, v.state_key)
|
||||
|
||||
self.assertEqual(
|
||||
set([e.event_id for e in old_state]),
|
||||
set([e.event_id for e in event.old_state_events.values()])
|
||||
set([e.event_id for e in context.current_state.values()])
|
||||
)
|
||||
|
||||
self.assertDictEqual(
|
||||
{
|
||||
k: v.event_id
|
||||
for k, v in event.old_state_events.items()
|
||||
},
|
||||
{
|
||||
k: v.event_id
|
||||
for k, v in event.state_events.items()
|
||||
}
|
||||
)
|
||||
|
||||
self.assertEqual(group_name, event.state_group)
|
||||
self.assertEqual(group_name, context.state_group)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_trivial_annotate_state(self):
|
||||
@ -141,38 +134,19 @@ class StateTestCase(unittest.TestCase):
|
||||
group_name: old_state,
|
||||
}
|
||||
|
||||
yield self.state.annotate_event_with_state(event)
|
||||
context = yield self.state.compute_event_context(event)
|
||||
|
||||
for k, v in event.old_state_events.items():
|
||||
for k, v in context.current_state.items():
|
||||
type, state_key = k
|
||||
self.assertEqual(type, v.type)
|
||||
self.assertEqual(state_key, v.state_key)
|
||||
|
||||
self.assertEqual(
|
||||
set([e.event_id for e in old_state]),
|
||||
set([e.event_id for e in event.old_state_events.values()])
|
||||
set([e.event_id for e in context.current_state.values()])
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
set([e.event_id for e in old_state] + [event.event_id]),
|
||||
set([e.event_id for e in event.state_events.values()])
|
||||
)
|
||||
|
||||
new_state = {
|
||||
k: v.event_id
|
||||
for k, v in event.state_events.items()
|
||||
}
|
||||
old_state = {
|
||||
k: v.event_id
|
||||
for k, v in event.old_state_events.items()
|
||||
}
|
||||
old_state[(event.type, event.state_key)] = event.event_id
|
||||
self.assertDictEqual(
|
||||
old_state,
|
||||
new_state
|
||||
)
|
||||
|
||||
self.assertIsNone(event.state_group)
|
||||
self.assertIsNone(context.state_group)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_resolve_message_conflict(self):
|
||||
@ -199,16 +173,11 @@ class StateTestCase(unittest.TestCase):
|
||||
group_name_2: old_state_2,
|
||||
}
|
||||
|
||||
yield self.state.annotate_event_with_state(event)
|
||||
context = yield self.state.compute_event_context(event)
|
||||
|
||||
self.assertEqual(len(event.old_state_events), 5)
|
||||
self.assertEqual(len(context.current_state), 5)
|
||||
|
||||
self.assertEqual(
|
||||
set([e.event_id for e in event.state_events.values()]),
|
||||
set([e.event_id for e in event.old_state_events.values()])
|
||||
)
|
||||
|
||||
self.assertIsNone(event.state_group)
|
||||
self.assertIsNone(context.state_group)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_resolve_state_conflict(self):
|
||||
@ -235,19 +204,11 @@ class StateTestCase(unittest.TestCase):
|
||||
group_name_2: old_state_2,
|
||||
}
|
||||
|
||||
yield self.state.annotate_event_with_state(event)
|
||||
context = yield self.state.compute_event_context(event)
|
||||
|
||||
self.assertEqual(len(event.old_state_events), 5)
|
||||
self.assertEqual(len(context.current_state), 5)
|
||||
|
||||
expected_new = event.old_state_events
|
||||
expected_new[(event.type, event.state_key)] = event
|
||||
|
||||
self.assertEqual(
|
||||
set([e.event_id for e in expected_new.values()]),
|
||||
set([e.event_id for e in event.state_events.values()]),
|
||||
)
|
||||
|
||||
self.assertIsNone(event.state_group)
|
||||
self.assertIsNone(context.state_group)
|
||||
|
||||
def create_event(self, name=None, type=None, state_key=None):
|
||||
self.event_id += 1
|
||||
@ -266,6 +227,9 @@ class StateTestCase(unittest.TestCase):
|
||||
event.state_key = state_key
|
||||
event.event_id = event_id
|
||||
|
||||
event.is_state = lambda: (state_key is not None)
|
||||
event.unsigned = {}
|
||||
|
||||
event.user_id = "@user_id:example.com"
|
||||
event.room_id = "!room_id:example.com"
|
||||
|
||||
|
70
tests/test_test_utils.py
Normal file
70
tests/test_test_utils.py
Normal file
@ -0,0 +1,70 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from tests import unittest
|
||||
|
||||
from tests.utils import MockClock
|
||||
|
||||
class MockClockTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.clock = MockClock()
|
||||
|
||||
def test_advance_time(self):
|
||||
start_time = self.clock.time()
|
||||
|
||||
self.clock.advance_time(20)
|
||||
|
||||
self.assertEquals(20, self.clock.time() - start_time)
|
||||
|
||||
def test_later(self):
|
||||
invoked = [0, 0]
|
||||
|
||||
def _cb0():
|
||||
invoked[0] = 1
|
||||
self.clock.call_later(10, _cb0)
|
||||
|
||||
def _cb1():
|
||||
invoked[1] = 1
|
||||
self.clock.call_later(20, _cb1)
|
||||
|
||||
self.assertFalse(invoked[0])
|
||||
|
||||
self.clock.advance_time(15)
|
||||
|
||||
self.assertTrue(invoked[0])
|
||||
self.assertFalse(invoked[1])
|
||||
|
||||
self.clock.advance_time(5)
|
||||
|
||||
self.assertTrue(invoked[1])
|
||||
|
||||
def test_cancel_later(self):
|
||||
invoked = [0, 0]
|
||||
|
||||
def _cb0():
|
||||
invoked[0] = 1
|
||||
t0 = self.clock.call_later(10, _cb0)
|
||||
|
||||
def _cb1():
|
||||
invoked[1] = 1
|
||||
t1 = self.clock.call_later(20, _cb1)
|
||||
|
||||
self.clock.cancel_call_later(t0)
|
||||
|
||||
self.clock.advance_time(30)
|
||||
|
||||
self.assertFalse(invoked[0])
|
||||
self.assertTrue(invoked[1])
|
@ -23,21 +23,21 @@ mock_homeserver = BaseHomeServer(hostname="my.domain")
|
||||
class UserIDTestCase(unittest.TestCase):
|
||||
|
||||
def test_parse(self):
|
||||
user = UserID.from_string("@1234abcd:my.domain", hs=mock_homeserver)
|
||||
user = UserID.from_string("@1234abcd:my.domain")
|
||||
|
||||
self.assertEquals("1234abcd", user.localpart)
|
||||
self.assertEquals("my.domain", user.domain)
|
||||
self.assertEquals(True, user.is_mine)
|
||||
self.assertEquals(True, mock_homeserver.is_mine(user))
|
||||
|
||||
def test_build(self):
|
||||
user = UserID("5678efgh", "my.domain", True)
|
||||
user = UserID("5678efgh", "my.domain")
|
||||
|
||||
self.assertEquals(user.to_string(), "@5678efgh:my.domain")
|
||||
|
||||
def test_compare(self):
|
||||
userA = UserID.from_string("@userA:my.domain", hs=mock_homeserver)
|
||||
userAagain = UserID.from_string("@userA:my.domain", hs=mock_homeserver)
|
||||
userB = UserID.from_string("@userB:my.domain", hs=mock_homeserver)
|
||||
userA = UserID.from_string("@userA:my.domain")
|
||||
userAagain = UserID.from_string("@userA:my.domain")
|
||||
userB = UserID.from_string("@userB:my.domain")
|
||||
|
||||
self.assertTrue(userA == userAagain)
|
||||
self.assertTrue(userA != userB)
|
||||
@ -52,14 +52,14 @@ class UserIDTestCase(unittest.TestCase):
|
||||
class RoomAliasTestCase(unittest.TestCase):
|
||||
|
||||
def test_parse(self):
|
||||
room = RoomAlias.from_string("#channel:my.domain", hs=mock_homeserver)
|
||||
room = RoomAlias.from_string("#channel:my.domain")
|
||||
|
||||
self.assertEquals("channel", room.localpart)
|
||||
self.assertEquals("my.domain", room.domain)
|
||||
self.assertEquals(True, room.is_mine)
|
||||
self.assertEquals(True, mock_homeserver.is_mine(room))
|
||||
|
||||
def test_build(self):
|
||||
room = RoomAlias("channel", "my.domain", True)
|
||||
room = RoomAlias("channel", "my.domain")
|
||||
|
||||
self.assertEquals(room.to_string(), "#channel:my.domain")
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user