Merge pull request #25 from matrix-org/events_refactor

Event refactor
This commit is contained in:
Mark Haines 2014-12-16 13:53:43 +00:00
commit 2af40cfa14
60 changed files with 2129 additions and 1827 deletions

138
graph/graph2.py Normal file
View File

@ -0,0 +1,138 @@
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sqlite3
import pydot
import cgi
import json
import datetime
import argparse
from synapse.events import FrozenEvent
def make_graph(db_name, room_id, file_prefix):
conn = sqlite3.connect(db_name)
c = conn.execute(
"SELECT json FROM event_json where room_id = ?",
(room_id,)
)
events = [FrozenEvent(json.loads(e[0])) for e in c.fetchall()]
events.sort(key=lambda e: e.depth)
node_map = {}
state_groups = {}
graph = pydot.Dot(graph_name="Test")
for event in events:
c = conn.execute(
"SELECT state_group FROM event_to_state_groups "
"WHERE event_id = ?",
(event.event_id,)
)
res = c.fetchone()
state_group = res[0] if res else None
if state_group is not None:
state_groups.setdefault(state_group, []).append(event.event_id)
t = datetime.datetime.fromtimestamp(
float(event.origin_server_ts) / 1000
).strftime('%Y-%m-%d %H:%M:%S,%f')
content = json.dumps(event.get_dict()["content"])
label = (
"<"
"<b>%(name)s </b><br/>"
"Type: <b>%(type)s </b><br/>"
"State key: <b>%(state_key)s </b><br/>"
"Content: <b>%(content)s </b><br/>"
"Time: <b>%(time)s </b><br/>"
"Depth: <b>%(depth)s </b><br/>"
"State group: %(state_group)s<br/>"
">"
) % {
"name": event.event_id,
"type": event.type,
"state_key": event.get("state_key", None),
"content": cgi.escape(content, quote=True),
"time": t,
"depth": event.depth,
"state_group": state_group,
}
node = pydot.Node(
name=event.event_id,
label=label,
)
node_map[event.event_id] = node
graph.add_node(node)
for event in events:
for prev_id, _ in event.prev_events:
try:
end_node = node_map[prev_id]
except:
end_node = pydot.Node(
name=prev_id,
label="<<b>%s</b>>" % (prev_id,),
)
node_map[prev_id] = end_node
graph.add_node(end_node)
edge = pydot.Edge(node_map[event.event_id], end_node)
graph.add_edge(edge)
for group, event_ids in state_groups.items():
if len(event_ids) <= 1:
continue
cluster = pydot.Cluster(
str(group),
label="<State Group: %s>" % (str(group),)
)
for event_id in event_ids:
cluster.add_node(node_map[event_id])
graph.add_subgraph(cluster)
graph.write('%s.dot' % file_prefix, format='raw', prog='dot')
graph.write_svg("%s.svg" % file_prefix, prog='dot')
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate a PDU graph for a given room by talking "
"to the given homeserver to get the list of PDUs. \n"
"Requires pydot."
)
parser.add_argument(
"-p", "--prefix", dest="prefix",
help="String to prefix output files with"
)
parser.add_argument('db')
parser.add_argument('room')
args = parser.parse_args()
make_graph(args.db, args.room, args.prefix)

View File

@ -18,6 +18,9 @@ class dictobj(dict):
def get_full_dict(self):
return dict(self)
def get_pdu_json(self):
return dict(self)
def main():
parser = argparse.ArgumentParser()

View File

@ -0,0 +1,296 @@
from synapse.storage._base import SQLBaseStore
from synapse.storage.signatures import SignatureStore
from synapse.storage.event_federation import EventFederationStore
from syutil.base64util import encode_base64, decode_base64
from synapse.crypto.event_signing import compute_event_signature
from synapse.events.builder import EventBuilder
from synapse.events.utils import prune_event
from synapse.crypto.event_signing import check_event_content_hash
from syutil.crypto.jsonsign import (
verify_signed_json, SignatureVerifyException,
)
from syutil.crypto.signing_key import decode_verify_key_bytes
from syutil.jsonutil import encode_canonical_json
import argparse
import dns.resolver
import hashlib
import json
import sqlite3
import syutil
import urllib2
delta_sql = """
CREATE TABLE IF NOT EXISTS event_json(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
internal_metadata NOT NULL,
json BLOB NOT NULL,
CONSTRAINT ev_j_uniq UNIQUE (event_id)
);
CREATE INDEX IF NOT EXISTS event_json_id ON event_json(event_id);
CREATE INDEX IF NOT EXISTS event_json_room_id ON event_json(room_id);
"""
class Store(object):
_get_event_signatures_txn = SignatureStore.__dict__["_get_event_signatures_txn"]
_get_event_content_hashes_txn = SignatureStore.__dict__["_get_event_content_hashes_txn"]
_get_event_reference_hashes_txn = SignatureStore.__dict__["_get_event_reference_hashes_txn"]
_get_prev_event_hashes_txn = SignatureStore.__dict__["_get_prev_event_hashes_txn"]
_get_prev_events_and_state = EventFederationStore.__dict__["_get_prev_events_and_state"]
_get_auth_events = EventFederationStore.__dict__["_get_auth_events"]
cursor_to_dict = SQLBaseStore.__dict__["cursor_to_dict"]
_simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
_simple_select_list_txn = SQLBaseStore.__dict__["_simple_select_list_txn"]
_simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"]
def _generate_event_json(self, txn, rows):
events = []
for row in rows:
d = dict(row)
d.pop("stream_ordering", None)
d.pop("topological_ordering", None)
d.pop("processed", None)
if "origin_server_ts" not in d:
d["origin_server_ts"] = d.pop("ts", 0)
else:
d.pop("ts", 0)
d.pop("prev_state", None)
d.update(json.loads(d.pop("unrecognized_keys")))
d["sender"] = d.pop("user_id")
d["content"] = json.loads(d["content"])
if "age_ts" not in d:
# For compatibility
d["age_ts"] = d.get("origin_server_ts", 0)
d.setdefault("unsigned", {})["age_ts"] = d.pop("age_ts")
outlier = d.pop("outlier", False)
# d.pop("membership", None)
d.pop("state_hash", None)
d.pop("replaces_state", None)
b = EventBuilder(d)
b.internal_metadata.outlier = outlier
events.append(b)
for i, ev in enumerate(events):
signatures = self._get_event_signatures_txn(
txn, ev.event_id,
)
ev.signatures = {
n: {
k: encode_base64(v) for k, v in s.items()
}
for n, s in signatures.items()
}
hashes = self._get_event_content_hashes_txn(
txn, ev.event_id,
)
ev.hashes = {
k: encode_base64(v) for k, v in hashes.items()
}
prevs = self._get_prev_events_and_state(txn, ev.event_id)
ev.prev_events = [
(e_id, h)
for e_id, h, is_state in prevs
if is_state == 0
]
# ev.auth_events = self._get_auth_events(txn, ev.event_id)
hashes = dict(ev.auth_events)
for e_id, hash in ev.prev_events:
if e_id in hashes and not hash:
hash.update(hashes[e_id])
#
# if hasattr(ev, "state_key"):
# ev.prev_state = [
# (e_id, h)
# for e_id, h, is_state in prevs
# if is_state == 1
# ]
return [e.build() for e in events]
store = Store()
def get_key(server_name):
print "Getting keys for: %s" % (server_name,)
targets = []
if ":" in server_name:
target, port = server_name.split(":")
targets.append((target, int(port)))
return
try:
answers = dns.resolver.query("_matrix._tcp." + server_name, "SRV")
for srv in answers:
targets.append((srv.target, srv.port))
except dns.resolver.NXDOMAIN:
targets.append((server_name, 8448))
except:
print "Failed to lookup keys for %s" % (server_name,)
return {}
for target, port in targets:
url = "https://%s:%i/_matrix/key/v1" % (target, port)
try:
keys = json.load(urllib2.urlopen(url, timeout=2))
verify_keys = {}
for key_id, key_base64 in keys["verify_keys"].items():
verify_key = decode_verify_key_bytes(
key_id, decode_base64(key_base64)
)
verify_signed_json(keys, server_name, verify_key)
verify_keys[key_id] = verify_key
print "Got keys for: %s" % (server_name,)
return verify_keys
except urllib2.URLError:
pass
print "Failed to get keys for %s" % (server_name,)
return {}
def reinsert_events(cursor, server_name, signing_key):
cursor.executescript(delta_sql)
cursor.execute(
"SELECT * FROM events ORDER BY rowid ASC"
)
rows = store.cursor_to_dict(cursor)
events = store._generate_event_json(cursor, rows)
print "Got events from DB."
algorithms = {
"sha256": hashlib.sha256,
}
key_id = "%s:%s" % (signing_key.alg, signing_key.version)
verify_key = signing_key.verify_key
verify_key.alg = signing_key.alg
verify_key.version = signing_key.version
server_keys = {
server_name: {
key_id: verify_key
}
}
for event in events:
for alg_name in event.hashes:
if check_event_content_hash(event, algorithms[alg_name]):
pass
else:
pass
print "FAIL content hash %s %s" % (alg_name, event.event_id, )
have_own_correctly_signed = False
for host, sigs in event.signatures.items():
pruned = prune_event(event)
for key_id in sigs:
if host not in server_keys:
server_keys[host] = get_key(host)
if key_id in server_keys[host]:
try:
verify_signed_json(
pruned.get_pdu_json(),
host,
server_keys[host][key_id]
)
if host == server_name:
have_own_correctly_signed = True
except SignatureVerifyException:
print "FAIL signature check %s %s" % (
key_id, event.event_id
)
# TODO: Re sign with our own server key
if not have_own_correctly_signed:
sigs = compute_event_signature(event, server_name, signing_key)
event.signatures.update(sigs)
pruned = prune_event(event)
for key_id in event.signatures[server_name]:
verify_signed_json(
pruned.get_pdu_json(),
server_name,
server_keys[server_name][key_id]
)
event_json = encode_canonical_json(
event.get_dict()
).decode("UTF-8")
metadata_json = encode_canonical_json(
event.internal_metadata.get_dict()
).decode("UTF-8")
store._simple_insert_txn(
cursor,
table="event_json",
values={
"event_id": event.event_id,
"room_id": event.room_id,
"internal_metadata": metadata_json,
"json": event_json,
},
or_replace=True,
)
def main(database, server_name, signing_key):
conn = sqlite3.connect(database)
cursor = conn.cursor()
reinsert_events(cursor, server_name, signing_key)
conn.commit()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("database")
parser.add_argument("server_name")
parser.add_argument(
"signing_key", type=argparse.FileType('r'),
)
args = parser.parse_args()
signing_key = syutil.crypto.signing_key.read_signing_keys(
args.signing_key
)
main(args.database, args.server_name, signing_key[0])

View File

@ -41,6 +41,7 @@ setup(
"pynacl",
"daemonize",
"py-bcrypt",
"frozendict>=0.4",
"pillow",
],
dependency_links=[

View File

@ -17,14 +17,10 @@
from twisted.internet import defer
from synapse.api.constants import Membership, JoinRules
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, StoreError, Codes, SynapseError
from synapse.api.events.room import (
RoomMemberEvent, RoomPowerLevelsEvent, RoomRedactionEvent,
RoomJoinRulesEvent, RoomCreateEvent, RoomAliasesEvent,
)
from synapse.util.logutils import log_function
from syutil.base64util import encode_base64
from synapse.util.async import run_on_reactor
import logging
@ -53,15 +49,17 @@ class Auth(object):
logger.warn("Trusting event: %s", event.event_id)
return True
if event.type == RoomCreateEvent.TYPE:
if event.type == EventTypes.Create:
# FIXME
return True
# FIXME: Temp hack
if event.type == RoomAliasesEvent.TYPE:
if event.type == EventTypes.Aliases:
return True
if event.type == RoomMemberEvent.TYPE:
logger.debug("Auth events: %s", auth_events)
if event.type == EventTypes.Member:
allowed = self.is_membership_change_allowed(
event, auth_events
)
@ -74,10 +72,10 @@ class Auth(object):
self.check_event_sender_in_room(event, auth_events)
self._can_send_event(event, auth_events)
if event.type == RoomPowerLevelsEvent.TYPE:
if event.type == EventTypes.PowerLevels:
self._check_power_levels(event, auth_events)
if event.type == RoomRedactionEvent.TYPE:
if event.type == EventTypes.Redaction:
self._check_redaction(event, auth_events)
logger.debug("Allowing! %s", event)
@ -93,7 +91,7 @@ class Auth(object):
def check_joined_room(self, room_id, user_id):
member = yield self.state.get_current_state(
room_id=room_id,
event_type=RoomMemberEvent.TYPE,
event_type=EventTypes.Member,
state_key=user_id
)
self._check_joined_room(member, user_id, room_id)
@ -104,7 +102,7 @@ class Auth(object):
curr_state = yield self.state.get_current_state(room_id)
for event in curr_state:
if event.type == RoomMemberEvent.TYPE:
if event.type == EventTypes.Member:
try:
if self.hs.parse_userid(event.state_key).domain != host:
continue
@ -118,7 +116,7 @@ class Auth(object):
defer.returnValue(False)
def check_event_sender_in_room(self, event, auth_events):
key = (RoomMemberEvent.TYPE, event.user_id, )
key = (EventTypes.Member, event.user_id, )
member_event = auth_events.get(key)
return self._check_joined_room(
@ -140,7 +138,7 @@ class Auth(object):
# Check if this is the room creator joining:
if len(event.prev_events) == 1 and Membership.JOIN == membership:
# Get room creation event:
key = (RoomCreateEvent.TYPE, "", )
key = (EventTypes.Create, "", )
create = auth_events.get(key)
if create and event.prev_events[0][0] == create.event_id:
if create.content["creator"] == event.state_key:
@ -149,19 +147,19 @@ class Auth(object):
target_user_id = event.state_key
# get info about the caller
key = (RoomMemberEvent.TYPE, event.user_id, )
key = (EventTypes.Member, event.user_id, )
caller = auth_events.get(key)
caller_in_room = caller and caller.membership == Membership.JOIN
caller_invited = caller and caller.membership == Membership.INVITE
# get info about the target
key = (RoomMemberEvent.TYPE, target_user_id, )
key = (EventTypes.Member, target_user_id, )
target = auth_events.get(key)
target_in_room = target and target.membership == Membership.JOIN
key = (RoomJoinRulesEvent.TYPE, "", )
key = (EventTypes.JoinRules, "", )
join_rule_event = auth_events.get(key)
if join_rule_event:
join_rule = join_rule_event.content.get(
@ -256,7 +254,7 @@ class Auth(object):
return True
def _get_power_level_from_event_state(self, event, user_id, auth_events):
key = (RoomPowerLevelsEvent.TYPE, "", )
key = (EventTypes.PowerLevels, "", )
power_level_event = auth_events.get(key)
level = None
if power_level_event:
@ -264,7 +262,7 @@ class Auth(object):
if not level:
level = power_level_event.content.get("users_default", 0)
else:
key = (RoomCreateEvent.TYPE, "", )
key = (EventTypes.Create, "", )
create_event = auth_events.get(key)
if (create_event is not None and
create_event.content["creator"] == user_id):
@ -273,7 +271,7 @@ class Auth(object):
return level
def _get_ops_level_from_event_state(self, event, auth_events):
key = (RoomPowerLevelsEvent.TYPE, "", )
key = (EventTypes.PowerLevels, "", )
power_level_event = auth_events.get(key)
if power_level_event:
@ -351,29 +349,31 @@ class Auth(object):
return self.store.is_server_admin(user)
@defer.inlineCallbacks
def add_auth_events(self, event):
if event.type == RoomCreateEvent.TYPE:
event.auth_events = []
def add_auth_events(self, builder, context):
yield run_on_reactor()
if builder.type == EventTypes.Create:
builder.auth_events = []
return
auth_events = []
auth_ids = []
key = (RoomPowerLevelsEvent.TYPE, "", )
power_level_event = event.old_state_events.get(key)
key = (EventTypes.PowerLevels, "", )
power_level_event = context.current_state.get(key)
if power_level_event:
auth_events.append(power_level_event.event_id)
auth_ids.append(power_level_event.event_id)
key = (RoomJoinRulesEvent.TYPE, "", )
join_rule_event = event.old_state_events.get(key)
key = (EventTypes.JoinRules, "", )
join_rule_event = context.current_state.get(key)
key = (RoomMemberEvent.TYPE, event.user_id, )
member_event = event.old_state_events.get(key)
key = (EventTypes.Member, builder.user_id, )
member_event = context.current_state.get(key)
key = (RoomCreateEvent.TYPE, "", )
create_event = event.old_state_events.get(key)
key = (EventTypes.Create, "", )
create_event = context.current_state.get(key)
if create_event:
auth_events.append(create_event.event_id)
auth_ids.append(create_event.event_id)
if join_rule_event:
join_rule = join_rule_event.content.get("join_rule")
@ -381,33 +381,37 @@ class Auth(object):
else:
is_public = False
if event.type == RoomMemberEvent.TYPE:
e_type = event.content["membership"]
if builder.type == EventTypes.Member:
e_type = builder.content["membership"]
if e_type in [Membership.JOIN, Membership.INVITE]:
if join_rule_event:
auth_events.append(join_rule_event.event_id)
auth_ids.append(join_rule_event.event_id)
if e_type == Membership.JOIN:
if member_event and not is_public:
auth_events.append(member_event.event_id)
auth_ids.append(member_event.event_id)
else:
if member_event:
auth_ids.append(member_event.event_id)
elif member_event:
if member_event.content["membership"] == Membership.JOIN:
auth_events.append(member_event.event_id)
auth_ids.append(member_event.event_id)
hashes = yield self.store.get_event_reference_hashes(
auth_events
auth_events_entries = yield self.store.add_event_hashes(
auth_ids
)
hashes = [
{
k: encode_base64(v) for k, v in h.items()
if k == "sha256"
}
for h in hashes
]
event.auth_events = zip(auth_events, hashes)
builder.auth_events = auth_events_entries
context.auth_events = {
k: v
for k, v in context.current_state.items()
if v.event_id in auth_ids
}
@log_function
def _can_send_event(self, event, auth_events):
key = (RoomPowerLevelsEvent.TYPE, "", )
key = (EventTypes.PowerLevels, "", )
send_level_event = auth_events.get(key)
send_level = None
if send_level_event:

View File

@ -59,3 +59,18 @@ class LoginType(object):
EMAIL_URL = u"m.login.email.url"
EMAIL_IDENTITY = u"m.login.email.identity"
RECAPTCHA = u"m.login.recaptcha"
class EventTypes(object):
Member = "m.room.member"
Create = "m.room.create"
JoinRules = "m.room.join_rules"
PowerLevels = "m.room.power_levels"
Aliases = "m.room.aliases"
Redaction = "m.room.redaction"
Feedback = "m.room.message.feedback"
# These are used for validation
Message = "m.room.message"
Topic = "m.room.topic"
Name = "m.room.name"

View File

@ -1,148 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util.jsonobject import JsonEncodedObject
def serialize_event(hs, e):
# FIXME(erikj): To handle the case of presence events and the like
if not isinstance(e, SynapseEvent):
return e
# Should this strip out None's?
d = {k: v for k, v in e.get_dict().items()}
if "age_ts" in d:
d["age"] = int(hs.get_clock().time_msec()) - d["age_ts"]
del d["age_ts"]
return d
class SynapseEvent(JsonEncodedObject):
"""Base class for Synapse events. These are JSON objects which must abide
by a certain well-defined structure.
"""
# Attributes that are currently assumed by the federation side:
# Mandatory:
# - event_id
# - room_id
# - type
# - is_state
#
# Optional:
# - state_key (mandatory when is_state is True)
# - prev_events (these can be filled out by the federation layer itself.)
# - prev_state
valid_keys = [
"event_id",
"type",
"room_id",
"user_id", # sender/initiator
"content", # HTTP body, JSON
"state_key",
"age_ts",
"prev_content",
"replaces_state",
"redacted_because",
"origin_server_ts",
]
internal_keys = [
"is_state",
"depth",
"destinations",
"origin",
"outlier",
"redacted",
"prev_events",
"hashes",
"signatures",
"prev_state",
"auth_events",
"state_hash",
]
required_keys = [
"event_id",
"room_id",
"content",
]
outlier = False
def __init__(self, raises=True, **kwargs):
super(SynapseEvent, self).__init__(**kwargs)
# if "content" in kwargs:
# self.check_json(self.content, raises=raises)
def get_content_template(self):
""" Retrieve the JSON template for this event as a dict.
The template must be a dict representing the JSON to match. Only
required keys should be present. The values of the keys in the template
are checked via type() to the values of the same keys in the actual
event JSON.
NB: If loading content via json.loads, you MUST define strings as
unicode.
For example:
Content:
{
"name": u"bob",
"age": 18,
"friends": [u"mike", u"jill"]
}
Template:
{
"name": u"string",
"age": 0,
"friends": [u"string"]
}
The values "string" and 0 could be anything, so long as the types
are the same as the content.
"""
raise NotImplementedError("get_content_template not implemented.")
def get_pdu_json(self, time_now=None):
pdu_json = self.get_full_dict()
pdu_json.pop("destinations", None)
pdu_json.pop("outlier", None)
pdu_json.pop("replaces_state", None)
pdu_json.pop("redacted", None)
pdu_json.pop("prev_content", None)
state_hash = pdu_json.pop("state_hash", None)
if state_hash is not None:
pdu_json.setdefault("unsigned", {})["state_hash"] = state_hash
content = pdu_json.get("content", {})
content.pop("prev", None)
if time_now is not None and "age_ts" in pdu_json:
age = time_now - pdu_json["age_ts"]
pdu_json.setdefault("unsigned", {})["age"] = int(age)
del pdu_json["age_ts"]
user_id = pdu_json.pop("user_id")
pdu_json["sender"] = user_id
return pdu_json
class SynapseStateEvent(SynapseEvent):
def __init__(self, **kwargs):
if "state_key" not in kwargs:
kwargs["state_key"] = ""
super(SynapseStateEvent, self).__init__(**kwargs)

View File

@ -1,90 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.events.room import (
RoomTopicEvent, MessageEvent, RoomMemberEvent, FeedbackEvent,
InviteJoinEvent, RoomConfigEvent, RoomNameEvent, GenericEvent,
RoomPowerLevelsEvent, RoomJoinRulesEvent,
RoomCreateEvent,
RoomRedactionEvent,
)
from synapse.types import EventID
from synapse.util.stringutils import random_string
class EventFactory(object):
_event_classes = [
RoomTopicEvent,
RoomNameEvent,
MessageEvent,
RoomMemberEvent,
FeedbackEvent,
InviteJoinEvent,
RoomConfigEvent,
RoomPowerLevelsEvent,
RoomJoinRulesEvent,
RoomCreateEvent,
RoomRedactionEvent,
]
def __init__(self, hs):
self._event_list = {} # dict of TYPE to event class
for event_class in EventFactory._event_classes:
self._event_list[event_class.TYPE] = event_class
self.clock = hs.get_clock()
self.hs = hs
self.event_id_count = 0
def create_event_id(self):
i = str(self.event_id_count)
self.event_id_count += 1
local_part = str(int(self.clock.time())) + i + random_string(5)
e_id = EventID.create_local(local_part, self.hs)
return e_id.to_string()
def create_event(self, etype=None, **kwargs):
kwargs["type"] = etype
if "event_id" not in kwargs:
kwargs["event_id"] = self.create_event_id()
kwargs["origin"] = self.hs.hostname
else:
ev_id = self.hs.parse_eventid(kwargs["event_id"])
kwargs["origin"] = ev_id.domain
if "origin_server_ts" not in kwargs:
kwargs["origin_server_ts"] = int(self.clock.time_msec())
# The "age" key is a delta timestamp that should be converted into an
# absolute timestamp the minute we see it.
if "age" in kwargs:
kwargs["age_ts"] = int(self.clock.time_msec()) - int(kwargs["age"])
del kwargs["age"]
elif "age_ts" not in kwargs:
kwargs["age_ts"] = int(self.clock.time_msec())
if etype in self._event_list:
handler = self._event_list[etype]
else:
handler = GenericEvent
return handler(**kwargs)

View File

@ -1,170 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.constants import Feedback, Membership
from synapse.api.errors import SynapseError
from . import SynapseEvent, SynapseStateEvent
class GenericEvent(SynapseEvent):
def get_content_template(self):
return {}
class RoomTopicEvent(SynapseEvent):
TYPE = "m.room.topic"
internal_keys = SynapseEvent.internal_keys + [
"topic",
]
def __init__(self, **kwargs):
kwargs["state_key"] = ""
if "topic" in kwargs["content"]:
kwargs["topic"] = kwargs["content"]["topic"]
super(RoomTopicEvent, self).__init__(**kwargs)
def get_content_template(self):
return {"topic": u"string"}
class RoomNameEvent(SynapseEvent):
TYPE = "m.room.name"
internal_keys = SynapseEvent.internal_keys + [
"name",
]
def __init__(self, **kwargs):
kwargs["state_key"] = ""
if "name" in kwargs["content"]:
kwargs["name"] = kwargs["content"]["name"]
super(RoomNameEvent, self).__init__(**kwargs)
def get_content_template(self):
return {"name": u"string"}
class RoomMemberEvent(SynapseEvent):
TYPE = "m.room.member"
valid_keys = SynapseEvent.valid_keys + [
# target is the state_key
"membership", # action
]
def __init__(self, **kwargs):
if "membership" not in kwargs:
kwargs["membership"] = kwargs.get("content", {}).get("membership")
if not kwargs["membership"] in Membership.LIST:
raise SynapseError(400, "Bad membership value.")
super(RoomMemberEvent, self).__init__(**kwargs)
def get_content_template(self):
return {"membership": u"string"}
class MessageEvent(SynapseEvent):
TYPE = "m.room.message"
valid_keys = SynapseEvent.valid_keys + [
"msg_id", # unique per room + user combo
]
def __init__(self, **kwargs):
super(MessageEvent, self).__init__(**kwargs)
def get_content_template(self):
return {"msgtype": u"string"}
class FeedbackEvent(SynapseEvent):
TYPE = "m.room.message.feedback"
valid_keys = SynapseEvent.valid_keys
def __init__(self, **kwargs):
super(FeedbackEvent, self).__init__(**kwargs)
if not kwargs["content"]["type"] in Feedback.LIST:
raise SynapseError(400, "Bad feedback value.")
def get_content_template(self):
return {
"type": u"string",
"target_event_id": u"string"
}
class InviteJoinEvent(SynapseEvent):
TYPE = "m.room.invite_join"
valid_keys = SynapseEvent.valid_keys + [
# target_user_id is the state_key
"target_host",
]
def __init__(self, **kwargs):
super(InviteJoinEvent, self).__init__(**kwargs)
def get_content_template(self):
return {}
class RoomConfigEvent(SynapseEvent):
TYPE = "m.room.config"
def __init__(self, **kwargs):
kwargs["state_key"] = ""
super(RoomConfigEvent, self).__init__(**kwargs)
def get_content_template(self):
return {}
class RoomCreateEvent(SynapseStateEvent):
TYPE = "m.room.create"
def get_content_template(self):
return {}
class RoomJoinRulesEvent(SynapseStateEvent):
TYPE = "m.room.join_rules"
def get_content_template(self):
return {}
class RoomPowerLevelsEvent(SynapseStateEvent):
TYPE = "m.room.power_levels"
def get_content_template(self):
return {}
class RoomAliasesEvent(SynapseStateEvent):
TYPE = "m.room.aliases"
def get_content_template(self):
return {}
class RoomRedactionEvent(SynapseEvent):
TYPE = "m.room.redaction"
valid_keys = SynapseEvent.valid_keys + ["redacts"]
def get_content_template(self):
return {}

View File

@ -1,87 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.errors import SynapseError, Codes
class EventValidator(object):
def __init__(self, hs):
pass
def validate(self, event):
"""Checks the given JSON content abides by the rules of the template.
Args:
content : A JSON object to check.
raises: True to raise a SynapseError if the check fails.
Returns:
True if the content passes the template. Returns False if the check
fails and raises=False.
Raises:
SynapseError if the check fails and raises=True.
"""
# recursively call to inspect each layer
err_msg = self._check_json_template(
event.content,
event.get_content_template()
)
if err_msg:
raise SynapseError(400, err_msg, Codes.BAD_JSON)
else:
return True
def _check_json_template(self, content, template):
"""Check content and template matches.
If the template is a dict, each key in the dict will be validated with
the content, else it will just compare the types of content and
template. This basic type check is required because this function will
be recursively called and could be called with just strs or ints.
Args:
content: The content to validate.
template: The validation template.
Returns:
str: An error message if the validation fails, else None.
"""
if type(content) != type(template):
return "Mismatched types: %s" % template
if type(template) == dict:
for key in template:
if key not in content:
return "Missing %s key" % key
if type(content[key]) != type(template[key]):
return "Key %s is of the wrong type (got %s, want %s)" % (
key, type(content[key]), type(template[key]))
if type(content[key]) == dict:
# we must go deeper
msg = self._check_json_template(
content[key],
template[key]
)
if msg:
return msg
elif type(content[key]) == list:
# make sure each item type in content matches the template
for entry in content[key]:
msg = self._check_json_template(
entry,
template[key][0]
)
if msg:
return msg

View File

@ -15,7 +15,7 @@
# limitations under the License.
from synapse.api.events.utils import prune_event
from synapse.events.utils import prune_event
from syutil.jsonutil import encode_canonical_json
from syutil.base64util import encode_base64, decode_base64
from syutil.crypto.jsonsign import sign_json
@ -29,17 +29,17 @@ logger = logging.getLogger(__name__)
def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
"""Check whether the hash for this PDU matches the contents"""
computed_hash = _compute_content_hash(event, hash_algorithm)
logger.debug("Expecting hash: %s", encode_base64(computed_hash.digest()))
if computed_hash.name not in event.hashes:
name, expected_hash = compute_content_hash(event, hash_algorithm)
logger.debug("Expecting hash: %s", encode_base64(expected_hash))
if name not in event.hashes:
raise SynapseError(
400,
"Algorithm %s not in hashes %s" % (
computed_hash.name, list(event.hashes),
name, list(event.hashes),
),
Codes.UNAUTHORIZED,
)
message_hash_base64 = event.hashes[computed_hash.name]
message_hash_base64 = event.hashes[name]
try:
message_hash_bytes = decode_base64(message_hash_base64)
except:
@ -48,10 +48,10 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
"Invalid base64: %s" % (message_hash_base64,),
Codes.UNAUTHORIZED,
)
return message_hash_bytes == computed_hash.digest()
return message_hash_bytes == expected_hash
def _compute_content_hash(event, hash_algorithm):
def compute_content_hash(event, hash_algorithm):
event_json = event.get_pdu_json()
event_json.pop("age_ts", None)
event_json.pop("unsigned", None)
@ -59,8 +59,11 @@ def _compute_content_hash(event, hash_algorithm):
event_json.pop("hashes", None)
event_json.pop("outlier", None)
event_json.pop("destinations", None)
event_json_bytes = encode_canonical_json(event_json)
return hash_algorithm(event_json_bytes)
hashed = hash_algorithm(event_json_bytes)
return (hashed.name, hashed.digest())
def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256):
@ -79,27 +82,28 @@ def compute_event_signature(event, signature_name, signing_key):
redact_json = tmp_event.get_pdu_json()
redact_json.pop("age_ts", None)
redact_json.pop("unsigned", None)
logger.debug("Signing event: %s", redact_json)
logger.debug("Signing event: %s", encode_canonical_json(redact_json))
redact_json = sign_json(redact_json, signature_name, signing_key)
logger.debug("Signed event: %s", encode_canonical_json(redact_json))
return redact_json["signatures"]
def add_hashes_and_signatures(event, signature_name, signing_key,
hash_algorithm=hashlib.sha256):
if hasattr(event, "old_state_events"):
state_json_bytes = encode_canonical_json(
[e.event_id for e in event.old_state_events.values()]
)
hashed = hash_algorithm(state_json_bytes)
event.state_hash = {
hashed.name: encode_base64(hashed.digest())
}
# if hasattr(event, "old_state_events"):
# state_json_bytes = encode_canonical_json(
# [e.event_id for e in event.old_state_events.values()]
# )
# hashed = hash_algorithm(state_json_bytes)
# event.state_hash = {
# hashed.name: encode_base64(hashed.digest())
# }
hashed = _compute_content_hash(event, hash_algorithm=hash_algorithm)
name, digest = compute_content_hash(event, hash_algorithm=hash_algorithm)
if not hasattr(event, "hashes"):
event.hashes = {}
event.hashes[hashed.name] = encode_base64(hashed.digest())
event.hashes[name] = encode_base64(digest)
event.signatures = compute_event_signature(
event,

149
synapse/events/__init__.py Normal file
View File

@ -0,0 +1,149 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util.frozenutils import freeze, unfreeze
import copy
class _EventInternalMetadata(object):
def __init__(self, internal_metadata_dict):
self.__dict__ = copy.deepcopy(internal_metadata_dict)
def get_dict(self):
return dict(self.__dict__)
def is_outlier(self):
return hasattr(self, "outlier") and self.outlier
def _event_dict_property(key):
def getter(self):
return self._event_dict[key]
def setter(self, v):
self._event_dict[key] = v
def delete(self):
del self._event_dict[key]
return property(
getter,
setter,
delete,
)
class EventBase(object):
def __init__(self, event_dict, signatures={}, unsigned={},
internal_metadata_dict={}):
self.signatures = copy.deepcopy(signatures)
self.unsigned = copy.deepcopy(unsigned)
self._event_dict = copy.deepcopy(event_dict)
self.internal_metadata = _EventInternalMetadata(
internal_metadata_dict
)
auth_events = _event_dict_property("auth_events")
depth = _event_dict_property("depth")
content = _event_dict_property("content")
event_id = _event_dict_property("event_id")
hashes = _event_dict_property("hashes")
origin = _event_dict_property("origin")
origin_server_ts = _event_dict_property("origin_server_ts")
prev_events = _event_dict_property("prev_events")
prev_state = _event_dict_property("prev_state")
redacts = _event_dict_property("redacts")
room_id = _event_dict_property("room_id")
sender = _event_dict_property("sender")
state_key = _event_dict_property("state_key")
type = _event_dict_property("type")
user_id = _event_dict_property("sender")
@property
def membership(self):
return self.content["membership"]
def is_state(self):
return hasattr(self, "state_key")
def get_dict(self):
d = dict(self._event_dict)
d.update({
"signatures": self.signatures,
"unsigned": self.unsigned,
})
return d
def get(self, key, default):
return self._event_dict.get(key, default)
def get_internal_metadata_dict(self):
return self.internal_metadata.get_dict()
def get_pdu_json(self, time_now=None):
pdu_json = self.get_dict()
if time_now is not None and "age_ts" in pdu_json["unsigned"]:
age = time_now - pdu_json["unsigned"]["age_ts"]
pdu_json.setdefault("unsigned", {})["age"] = int(age)
del pdu_json["unsigned"]["age_ts"]
return pdu_json
def __set__(self, instance, value):
raise AttributeError("Unrecognized attribute %s" % (instance,))
class FrozenEvent(EventBase):
def __init__(self, event_dict, internal_metadata_dict={}):
event_dict = copy.deepcopy(event_dict)
signatures = copy.deepcopy(event_dict.pop("signatures", {}))
unsigned = copy.deepcopy(event_dict.pop("unsigned", {}))
frozen_dict = freeze(event_dict)
super(FrozenEvent, self).__init__(
frozen_dict,
signatures=signatures,
unsigned=unsigned,
internal_metadata_dict=internal_metadata_dict,
)
@staticmethod
def from_event(event):
e = FrozenEvent(
event.get_pdu_json()
)
e.internal_metadata = event.internal_metadata
return e
def get_dict(self):
# We need to unfreeze what we return
return unfreeze(super(FrozenEvent, self).get_dict())
def __str__(self):
return self.__repr__()
def __repr__(self):
return "<FrozenEvent event_id='%s', type='%s', state_key='%s'>" % (
self.event_id, self.type, self.get("state_key", None),
)

77
synapse/events/builder.py Normal file
View File

@ -0,0 +1,77 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import EventBase, FrozenEvent
from synapse.types import EventID
from synapse.util.stringutils import random_string
import copy
class EventBuilder(EventBase):
def __init__(self, key_values={}):
signatures = copy.deepcopy(key_values.pop("signatures", {}))
unsigned = copy.deepcopy(key_values.pop("unsigned", {}))
super(EventBuilder, self).__init__(
key_values,
signatures=signatures,
unsigned=unsigned
)
def update_event_key(self, key, value):
self._event_dict[key] = value
def update_event_keys(self, other_dict):
self._event_dict.update(other_dict)
def build(self):
return FrozenEvent.from_event(self)
class EventBuilderFactory(object):
def __init__(self, clock, hostname):
self.clock = clock
self.hostname = hostname
self.event_id_count = 0
def create_event_id(self):
i = str(self.event_id_count)
self.event_id_count += 1
local_part = str(int(self.clock.time())) + i + random_string(5)
e_id = EventID.create(local_part, self.hostname)
return e_id.to_string()
def new(self, key_values={}):
key_values["event_id"] = self.create_event_id()
time_now = int(self.clock.time_msec())
key_values.setdefault("origin", self.hostname)
key_values.setdefault("origin_server_ts", time_now)
key_values.setdefault("unsigned", {})
age = key_values["unsigned"].pop("age", 0)
key_values["unsigned"].setdefault("age_ts", time_now - age)
key_values["signatures"] = {}
return EventBuilder(key_values=key_values,)

View File

@ -13,3 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
class EventContext(object):
def __init__(self, current_state=None, auth_events=None):
self.current_state = current_state
self.auth_events = auth_events
self.state_group = None

View File

@ -13,10 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .room import (
RoomMemberEvent, RoomJoinRulesEvent, RoomPowerLevelsEvent,
RoomAliasesEvent, RoomCreateEvent,
)
from synapse.api.constants import EventTypes
from . import EventBase
def prune_event(event):
@ -31,7 +29,7 @@ def prune_event(event):
allowed_keys = [
"event_id",
"user_id",
"sender",
"room_id",
"hashes",
"signatures",
@ -44,6 +42,7 @@ def prune_event(event):
"auth_events",
"origin",
"origin_server_ts",
"membership",
]
new_content = {}
@ -53,13 +52,13 @@ def prune_event(event):
if field in event.content:
new_content[field] = event.content[field]
if event_type == RoomMemberEvent.TYPE:
if event_type == EventTypes.Member:
add_fields("membership")
elif event_type == RoomCreateEvent.TYPE:
elif event_type == EventTypes.Create:
add_fields("creator")
elif event_type == RoomJoinRulesEvent.TYPE:
elif event_type == EventTypes.JoinRules:
add_fields("join_rule")
elif event_type == RoomPowerLevelsEvent.TYPE:
elif event_type == EventTypes.PowerLevels:
add_fields(
"users",
"users_default",
@ -71,15 +70,61 @@ def prune_event(event):
"kick",
"redact",
)
elif event_type == RoomAliasesEvent.TYPE:
elif event_type == EventTypes.Aliases:
add_fields("aliases")
allowed_fields = {
k: v
for k, v in event.get_full_dict().items()
for k, v in event.get_dict().items()
if k in allowed_keys
}
allowed_fields["content"] = new_content
return type(event)(**allowed_fields)
allowed_fields["unsigned"] = {}
if "age_ts" in event.unsigned:
allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"]
return type(event)(allowed_fields)
def serialize_event(hs, e):
# FIXME(erikj): To handle the case of presence events and the like
if not isinstance(e, EventBase):
return e
# Should this strip out None's?
d = {k: v for k, v in e.get_dict().items()}
if "age_ts" in d["unsigned"]:
now = int(hs.get_clock().time_msec())
d["unsigned"]["age"] = now - d["unsigned"]["age_ts"]
del d["unsigned"]["age_ts"]
d["user_id"] = d.pop("sender", None)
if "redacted_because" in e.unsigned:
d["redacted_because"] = serialize_event(
hs, e.unsigned["redacted_because"]
)
del d["unsigned"]["redacted_because"]
if "redacted_by" in e.unsigned:
d["redacted_by"] = e.unsigned["redacted_by"]
del d["unsigned"]["redacted_by"]
if "replaces_state" in e.unsigned:
d["replaces_state"] = e.unsigned["replaces_state"]
del d["unsigned"]["replaces_state"]
if "prev_content" in e.unsigned:
d["prev_content"] = e.unsigned["prev_content"]
del d["unsigned"]["prev_content"]
del d["auth_events"]
del d["prev_events"]
del d["hashes"]
del d["signatures"]
return d

View File

@ -0,0 +1,92 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.types import EventID, RoomID, UserID
from synapse.api.errors import SynapseError
from synapse.api.constants import EventTypes, Membership
class EventValidator(object):
def validate(self, event):
EventID.from_string(event.event_id)
RoomID.from_string(event.room_id)
required = [
# "auth_events",
"content",
# "hashes",
"origin",
# "prev_events",
"sender",
"type",
]
for k in required:
if not hasattr(event, k):
raise SynapseError(400, "Event does not have key %s" % (k,))
# Check that the following keys have string values
strings = [
"origin",
"sender",
"type",
]
if hasattr(event, "state_key"):
strings.append("state_key")
for s in strings:
if not isinstance(getattr(event, s), basestring):
raise SynapseError(400, "Not '%s' a string type" % (s,))
if event.type == EventTypes.Member:
if "membership" not in event.content:
raise SynapseError(400, "Content has not membership key")
if event.content["membership"] not in Membership.LIST:
raise SynapseError(400, "Invalid membership key")
# Check that the following keys have dictionary values
# TODO
# Check that the following keys have the correct format for DAGs
# TODO
def validate_new(self, event):
self.validate(event)
UserID.from_string(event.sender)
if event.type == EventTypes.Message:
strings = [
"body",
"msgtype",
]
self._ensure_strings(event.content, strings)
elif event.type == EventTypes.Topic:
self._ensure_strings(event.content, ["topic"])
elif event.type == EventTypes.Name:
self._ensure_strings(event.content, ["name"])
def _ensure_strings(self, d, keys):
for s in keys:
if s not in d:
raise SynapseError(400, "'%s' not in content" % (s,))
if not isinstance(d[s], basestring):
raise SynapseError(400, "Not '%s' a string type" % (s,))

View File

@ -25,6 +25,7 @@ from .persistence import TransactionActions
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.events import FrozenEvent
import logging
@ -73,7 +74,7 @@ class ReplicationLayer(object):
self._clock = hs.get_clock()
self.event_factory = hs.get_event_factory()
self.event_builder_factory = hs.get_event_builder_factory()
def set_handler(self, handler):
"""Sets the handler that the replication layer will use to communicate
@ -112,7 +113,7 @@ class ReplicationLayer(object):
self.query_handlers[query_type] = handler
@log_function
def send_pdu(self, pdu):
def send_pdu(self, pdu, destinations):
"""Informs the replication layer about a new PDU generated within the
home server that should be transmitted to others.
@ -131,7 +132,7 @@ class ReplicationLayer(object):
logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
# TODO, add errback, etc.
self._transaction_queue.enqueue_pdu(pdu, order)
self._transaction_queue.enqueue_pdu(pdu, destinations, order)
logger.debug(
"[%s] transaction_layer.enqueue_pdu... done",
@ -438,7 +439,9 @@ class ReplicationLayer(object):
@defer.inlineCallbacks
def on_send_join_request(self, origin, content):
logger.debug("on_send_join_request: content: %s", content)
pdu = self.event_from_pdu_json(content)
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
res_pdus = yield self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
defer.returnValue((200, {
@ -557,7 +560,13 @@ class ReplicationLayer(object):
origin, pdu.event_id, do_auth=False
)
if existing and (not existing.outlier or pdu.outlier):
already_seen = (
existing and (
not existing.internal_metadata.outlier
or pdu.internal_metadata.outlier
)
)
if already_seen:
logger.debug("Already seen pdu %s", pdu.event_id)
defer.returnValue({})
return
@ -595,7 +604,7 @@ class ReplicationLayer(object):
# )
# Get missing pdus if necessary.
if not pdu.outlier:
if not pdu.internal_metadata.outlier:
# We only backfill backwards to the min depth.
min_depth = yield self.handler.get_min_depth_for_context(
pdu.room_id
@ -658,19 +667,14 @@ class ReplicationLayer(object):
return "<ReplicationLayer(%s)>" % self.server_name
def event_from_pdu_json(self, pdu_json, outlier=False):
#TODO: Check we have all the PDU keys here
pdu_json.setdefault("hashes", {})
pdu_json.setdefault("signatures", {})
sender = pdu_json.pop("sender", None)
if sender is not None:
pdu_json["user_id"] = sender
state_hash = pdu_json.get("unsigned", {}).pop("state_hash", None)
if state_hash is not None:
pdu_json["state_hash"] = state_hash
return self.event_factory.create_event(
pdu_json["type"], outlier=outlier, **pdu_json
event = FrozenEvent(
pdu_json
)
event.internal_metadata.outlier = outlier
return event
class _TransactionQueue(object):
"""This class makes sure we only have one transaction in flight at
@ -706,15 +710,13 @@ class _TransactionQueue(object):
@defer.inlineCallbacks
@log_function
def enqueue_pdu(self, pdu, order):
def enqueue_pdu(self, pdu, destinations, order):
# We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus
# table and we'll get back to it later.
destinations = set([
d for d in pdu.destinations
if d != self.server_name
])
destinations = set(destinations)
destinations.discard(self.server_name)
logger.debug("Sending to: %s", str(destinations))

View File

@ -15,11 +15,12 @@
from twisted.internet import defer
from synapse.api.errors import LimitExceededError
from synapse.api.errors import LimitExceededError, SynapseError
from synapse.util.async import run_on_reactor
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.api.events.room import RoomMemberEvent
from synapse.api.constants import Membership
from synapse.api.constants import Membership, EventTypes
from synapse.events.snapshot import EventContext
import logging
@ -31,10 +32,8 @@ class BaseHandler(object):
def __init__(self, hs):
self.store = hs.get_datastore()
self.event_factory = hs.get_event_factory()
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
self.room_lock = hs.get_room_lock_manager()
self.state_handler = hs.get_state_handler()
self.distributor = hs.get_distributor()
self.ratelimiter = hs.get_ratelimiter()
@ -44,6 +43,8 @@ class BaseHandler(object):
self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname
self.event_builder_factory = hs.get_event_builder_factory()
def ratelimit(self, user_id):
time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.send_message(
@ -57,62 +58,100 @@ class BaseHandler(object):
)
@defer.inlineCallbacks
def _on_new_room_event(self, event, snapshot, extra_destinations=[],
extra_users=[], suppress_auth=False,
do_invite_host=None):
def _create_new_client_event(self, builder):
yield run_on_reactor()
snapshot.fill_out_prev_events(event)
context = EventContext()
yield self.state_handler.annotate_event_with_state(event)
yield self.auth.add_auth_events(event)
logger.debug("Signing event...")
add_hashes_and_signatures(
event, self.server_name, self.signing_key
latest_ret = yield self.store.get_latest_events_in_room(
builder.room_id,
)
logger.debug("Signed event.")
if latest_ret:
depth = max([d for _, _, d in latest_ret]) + 1
else:
depth = 1
prev_events = [(e, h) for e, h, _ in latest_ret]
builder.prev_events = prev_events
builder.depth = depth
state_handler = self.state_handler
ret = yield state_handler.annotate_context_with_state(
builder,
context,
)
prev_state = ret
if builder.is_state():
builder.prev_state = prev_state
yield self.auth.add_auth_events(builder, context)
add_hashes_and_signatures(
builder, self.server_name, self.signing_key
)
event = builder.build()
logger.debug(
"Created event %s with auth_events: %s, current state: %s",
event.event_id, context.auth_events, context.current_state,
)
defer.returnValue(
(event, context,)
)
@defer.inlineCallbacks
def handle_new_client_event(self, event, context, extra_destinations=[],
extra_users=[], suppress_auth=False):
yield run_on_reactor()
# We now need to go and hit out to wherever we need to hit out to.
if not suppress_auth:
logger.debug("Authing...")
self.auth.check(event, auth_events=event.old_state_events)
logger.debug("Authed")
else:
logger.debug("Suppressed auth.")
self.auth.check(event, auth_events=context.auth_events)
if do_invite_host:
federation_handler = self.hs.get_handlers().federation_handler
invite_event = yield federation_handler.send_invite(
do_invite_host,
event
)
yield self.store.persist_event(event, context=context)
# FIXME: We need to check if the remote changed anything else
event.signatures = invite_event.signatures
federation_handler = self.hs.get_handlers().federation_handler
yield self.store.persist_event(event)
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.INVITE:
invitee = self.hs.parse_userid(event.state_key)
if not self.hs.is_mine(invitee):
# TODO: Can we add signature from remote server in a nicer
# way? If we have been invited by a remote server, we need
# to get them to sign the event.
returned_invite = yield federation_handler.send_invite(
invitee.domain,
event,
)
# TODO: Make sure the signatures actually are correct.
event.signatures.update(
returned_invite.signatures
)
destinations = set(extra_destinations)
# Send a PDU to all hosts who have joined the room.
for k, s in event.state_events.items():
for k, s in context.current_state.items():
try:
if k[0] == RoomMemberEvent.TYPE:
if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.JOIN:
destinations.add(
self.hs.parse_userid(s.state_key).domain
)
except:
except SynapseError:
logger.warn(
"Failed to get destination from event %s", s.event_id
)
event.destinations = list(destinations)
yield self.notifier.on_new_room_event(event, extra_users=extra_users)
federation_handler = self.hs.get_handlers().federation_handler
yield federation_handler.handle_new_event(event, snapshot)
yield federation_handler.handle_new_event(
event,
None,
destinations=destinations,
)

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.errors import SynapseError, Codes, CodeMessageException
from synapse.api.events.room import RoomAliasesEvent
from synapse.api.constants import EventTypes
import logging
@ -40,7 +40,7 @@ class DirectoryHandler(BaseHandler):
# TODO(erikj): Do auth.
if not room_alias.is_mine:
if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local")
# TODO(erikj): Change this.
@ -64,7 +64,7 @@ class DirectoryHandler(BaseHandler):
def delete_association(self, user_id, room_alias):
# TODO Check if server admin
if not room_alias.is_mine:
if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local")
room_id = yield self.store.delete_room_alias(room_alias)
@ -75,7 +75,7 @@ class DirectoryHandler(BaseHandler):
@defer.inlineCallbacks
def get_association(self, room_alias):
room_id = None
if room_alias.is_mine:
if self.hs.is_mine(room_alias):
result = yield self.store.get_association_from_room_alias(
room_alias
)
@ -123,7 +123,7 @@ class DirectoryHandler(BaseHandler):
@defer.inlineCallbacks
def on_directory_query(self, args):
room_alias = self.hs.parse_roomalias(args["room_alias"])
if not room_alias.is_mine:
if not self.hs.is_mine(room_alias):
raise SynapseError(
400, "Room Alias is not hosted on this Home Server"
)
@ -148,16 +148,12 @@ class DirectoryHandler(BaseHandler):
def send_room_alias_update_event(self, user_id, room_id):
aliases = yield self.store.get_aliases_for_room(room_id)
event = self.event_factory.create_event(
etype=RoomAliasesEvent.TYPE,
state_key=self.hs.hostname,
room_id=room_id,
user_id=user_id,
content={"aliases": aliases},
)
msg_handler = self.hs.get_handlers().message_handler
yield msg_handler.create_and_send_event({
"type": EventTypes.Aliases,
"state_key": self.hs.hostname,
"room_id": room_id,
"sender": user_id,
"content": {"aliases": aliases},
})
snapshot = yield self.store.snapshot_room(event)
yield self._on_new_room_event(
event, snapshot, extra_users=[user_id], suppress_auth=True
)

View File

@ -17,12 +17,12 @@
from ._base import BaseHandler
from synapse.api.events.utils import prune_event
from synapse.events.snapshot import EventContext
from synapse.events.utils import prune_event
from synapse.api.errors import (
AuthError, FederationError, SynapseError, StoreError,
)
from synapse.api.events.room import RoomMemberEvent, RoomCreateEvent
from synapse.api.constants import Membership
from synapse.api.constants import EventTypes, Membership
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
from synapse.crypto.event_signing import (
@ -76,7 +76,7 @@ class FederationHandler(BaseHandler):
@log_function
@defer.inlineCallbacks
def handle_new_event(self, event, snapshot):
def handle_new_event(self, event, snapshot, destinations):
""" Takes in an event from the client to server side, that has already
been authed and handled by the state module, and sends it to any
remote home servers that may be interested.
@ -92,12 +92,7 @@ class FederationHandler(BaseHandler):
yield run_on_reactor()
pdu = event
if not hasattr(pdu, "destinations") or not pdu.destinations:
pdu.destinations = []
yield self.replication_layer.send_pdu(pdu)
yield self.replication_layer.send_pdu(event, destinations)
@log_function
@defer.inlineCallbacks
@ -140,7 +135,7 @@ class FederationHandler(BaseHandler):
if not check_event_content_hash(event):
logger.warn(
"Event content has been tampered, redacting %s, %s",
event.event_id, encode_canonical_json(event.get_full_dict())
event.event_id, encode_canonical_json(event.get_dict())
)
event = redacted_event
@ -153,7 +148,7 @@ class FederationHandler(BaseHandler):
event.room_id,
self.server_name
)
if not is_in_room and not event.outlier:
if not is_in_room and not event.internal_metadata.outlier:
logger.debug("Got event for room we're not in.")
replication_layer = self.replication_layer
@ -164,7 +159,7 @@ class FederationHandler(BaseHandler):
)
for e in auth_chain:
e.outlier = True
e.internal_metadata.outlier = True
try:
yield self._handle_new_event(e, fetch_missing=False)
except:
@ -184,7 +179,7 @@ class FederationHandler(BaseHandler):
if state:
for e in state:
e.outlier = True
e.internal_metadata.outlier = True
try:
yield self._handle_new_event(e)
except:
@ -229,7 +224,7 @@ class FederationHandler(BaseHandler):
if not backfilled:
extra_users = []
if event.type == RoomMemberEvent.TYPE:
if event.type == EventTypes.Member:
target_user_id = event.state_key
target_user = self.hs.parse_userid(target_user_id)
extra_users.append(target_user)
@ -238,7 +233,7 @@ class FederationHandler(BaseHandler):
event, extra_users=extra_users
)
if event.type == RoomMemberEvent.TYPE:
if event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
user = self.hs.parse_userid(event.state_key)
yield self.distributor.fire(
@ -265,11 +260,18 @@ class FederationHandler(BaseHandler):
event = pdu
# FIXME (erikj): Not sure this actually works :/
yield self.state_handler.annotate_event_with_state(event)
context = EventContext()
yield self.state_handler.annotate_context_with_state(event, context)
events.append(event)
events.append(
(event, context)
)
yield self.store.persist_event(event, backfilled=True)
yield self.store.persist_event(
event,
context=context,
backfilled=True
)
defer.returnValue(events)
@ -286,8 +288,6 @@ class FederationHandler(BaseHandler):
pdu=event
)
defer.returnValue(pdu)
@defer.inlineCallbacks
@ -332,42 +332,55 @@ class FederationHandler(BaseHandler):
event = pdu
# We should assert some things.
assert(event.type == RoomMemberEvent.TYPE)
# FIXME: Do this in a nicer way
assert(event.type == EventTypes.Member)
assert(event.user_id == joinee)
assert(event.state_key == joinee)
assert(event.room_id == room_id)
event.outlier = False
event.internal_metadata.outlier = False
self.room_queues[room_id] = []
builder = self.event_builder_factory.new(
event.get_pdu_json()
)
handled_events = set()
try:
event.event_id = self.event_factory.create_event_id()
event.origin = self.hs.hostname
event.content = content
builder.event_id = self.event_builder_factory.create_event_id()
builder.origin = self.hs.hostname
builder.content = content
if not hasattr(event, "signatures"):
event.signatures = {}
builder.signatures = {}
add_hashes_and_signatures(
event,
builder,
self.hs.hostname,
self.hs.config.signing_key[0],
)
new_event = builder.build()
ret = yield self.replication_layer.send_join(
target_host,
event
new_event
)
state = ret["state"]
auth_chain = ret["auth_chain"]
auth_chain.sort(key=lambda e: e.depth)
handled_events.update([s.event_id for s in state])
handled_events.update([a.event_id for a in auth_chain])
handled_events.add(new_event.event_id)
logger.debug("do_invite_join auth_chain: %s", auth_chain)
logger.debug("do_invite_join state: %s", state)
logger.debug("do_invite_join event: %s", event)
logger.debug("do_invite_join event: %s", new_event)
try:
yield self.store.store_room(
@ -380,7 +393,7 @@ class FederationHandler(BaseHandler):
pass
for e in auth_chain:
e.outlier = True
e.internal_metadata.outlier = True
try:
yield self._handle_new_event(e, fetch_missing=False)
except:
@ -391,7 +404,7 @@ class FederationHandler(BaseHandler):
for e in state:
# FIXME: Auth these.
e.outlier = True
e.internal_metadata.outlier = True
try:
yield self._handle_new_event(
e,
@ -404,13 +417,13 @@ class FederationHandler(BaseHandler):
)
yield self._handle_new_event(
event,
new_event,
state=state,
current_state=state,
)
yield self.notifier.on_new_room_event(
event, extra_users=[joinee]
new_event, extra_users=[joinee]
)
logger.debug("Finished joining %s to %s", joinee, room_id)
@ -419,6 +432,9 @@ class FederationHandler(BaseHandler):
del self.room_queues[room_id]
for p, origin in room_queue:
if p.event_id in handled_events:
continue
try:
self.on_receive_pdu(origin, p, backfilled=False)
except:
@ -428,25 +444,24 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
def on_make_join_request(self, context, user_id):
def on_make_join_request(self, room_id, user_id):
""" We've received a /make_join/ request, so we create a partial
join event for the room and return that. We don *not* persist or
process it until the other server has signed it and sent it back.
"""
event = self.event_factory.create_event(
etype=RoomMemberEvent.TYPE,
content={"membership": Membership.JOIN},
room_id=context,
user_id=user_id,
state_key=user_id,
builder = self.event_builder_factory.new({
"type": EventTypes.Member,
"content": {"membership": Membership.JOIN},
"room_id": room_id,
"sender": user_id,
"state_key": user_id,
})
event, context = yield self._create_new_client_event(
builder=builder,
)
snapshot = yield self.store.snapshot_room(event)
snapshot.fill_out_prev_events(event)
yield self.state_handler.annotate_event_with_state(event)
yield self.auth.add_auth_events(event)
self.auth.check(event, auth_events=event.old_state_events)
self.auth.check(event, auth_events=context.auth_events)
pdu = event
@ -460,12 +475,24 @@ class FederationHandler(BaseHandler):
"""
event = pdu
event.outlier = False
logger.debug(
"on_send_join_request: Got event: %s, signatures: %s",
event.event_id,
event.signatures,
)
yield self._handle_new_event(event)
event.internal_metadata.outlier = False
context = yield self._handle_new_event(event)
logger.debug(
"on_send_join_request: After _handle_new_event: %s, sigs: %s",
event.event_id,
event.signatures,
)
extra_users = []
if event.type == RoomMemberEvent.TYPE:
if event.type == EventTypes.Member:
target_user_id = event.state_key
target_user = self.hs.parse_userid(target_user_id)
extra_users.append(target_user)
@ -474,7 +501,7 @@ class FederationHandler(BaseHandler):
event, extra_users=extra_users
)
if event.type == RoomMemberEvent.TYPE:
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.JOIN:
user = self.hs.parse_userid(event.state_key)
yield self.distributor.fire(
@ -485,9 +512,9 @@ class FederationHandler(BaseHandler):
destinations = set()
for k, s in event.state_events.items():
for k, s in context.current_state.items():
try:
if k[0] == RoomMemberEvent.TYPE:
if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.JOIN:
destinations.add(
self.hs.parse_userid(s.state_key).domain
@ -497,14 +524,18 @@ class FederationHandler(BaseHandler):
"Failed to get destination from event %s", s.event_id
)
new_pdu.destinations = list(destinations)
logger.debug(
"on_send_join_request: Sending event: %s, signatures: %s",
event.event_id,
event.signatures,
)
yield self.replication_layer.send_pdu(new_pdu)
yield self.replication_layer.send_pdu(new_pdu, destinations)
auth_chain = yield self.store.get_auth_chain(event.event_id)
defer.returnValue({
"state": event.state_events.values(),
"state": context.current_state.values(),
"auth_chain": auth_chain,
})
@ -516,7 +547,9 @@ class FederationHandler(BaseHandler):
"""
event = pdu
event.outlier = True
context = EventContext()
event.internal_metadata.outlier = True
event.signatures.update(
compute_event_signature(
@ -526,10 +559,11 @@ class FederationHandler(BaseHandler):
)
)
yield self.state_handler.annotate_event_with_state(event)
yield self.state_handler.annotate_context_with_state(event, context)
yield self.store.persist_event(
event,
context=context,
backfilled=False,
)
@ -559,13 +593,13 @@ class FederationHandler(BaseHandler):
}
event = yield self.store.get_event(event_id)
if hasattr(event, "state_key"):
if event and event.is_state():
# Get previous state
if hasattr(event, "replaces_state") and event.replaces_state:
prev_event = yield self.store.get_event(
event.replaces_state
)
results[(event.type, event.state_key)] = prev_event
if "replaces_state" in event.unsigned:
prev_id = event.unsigned["replaces_state"]
if prev_id != event.event_id:
prev_event = yield self.store.get_event(prev_id)
results[(event.type, event.state_key)] = prev_event
else:
del results[(event.type, event.state_key)]
@ -651,74 +685,81 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def _handle_new_event(self, event, state=None, backfilled=False,
current_state=None, fetch_missing=True):
is_new_state = yield self.state_handler.annotate_event_with_state(
context = EventContext()
logger.debug(
"_handle_new_event: Before annotate: %s, sigs: %s",
event.event_id, event.signatures,
)
yield self.state_handler.annotate_context_with_state(
event,
context,
old_state=state
)
if event.old_state_events:
known_ids = set(
[s.event_id for s in event.old_state_events.values()]
)
for e_id, _ in event.auth_events:
if e_id not in known_ids:
e = yield self.store.get_event(
e_id,
allow_none=True,
)
logger.debug(
"_handle_new_event: Before auth fetch: %s, sigs: %s",
event.event_id, event.signatures,
)
if not e:
# TODO: Do some conflict res to make sure that we're
# not the ones who are wrong.
logger.info(
"Rejecting %s as %s not in %s",
event.event_id, e_id, known_ids,
)
raise AuthError(403, "Auth events are stale")
is_new_state = not event.internal_metadata.is_outlier()
auth_events = event.old_state_events
else:
# We need to get the auth events from somewhere.
# TODO: Don't just hit the DBs?
auth_events = {}
for e_id, _ in event.auth_events:
known_ids = set(
[s.event_id for s in context.auth_events.values()]
)
for e_id, _ in event.auth_events:
if e_id not in known_ids:
e = yield self.store.get_event(
e_id,
allow_none=True,
e_id, allow_none=True,
)
if not e:
e = yield self.replication_layer.get_pdu(
event.origin, e_id, outlier=True
# TODO: Do some conflict res to make sure that we're
# not the ones who are wrong.
logger.info(
"Rejecting %s as %s not in db or %s",
event.event_id, e_id, known_ids,
)
# FIXME: How does raising AuthError work with federation?
raise AuthError(403, "Auth events are stale")
if e and fetch_missing:
try:
yield self.on_receive_pdu(event.origin, e, False)
except:
logger.exception(
"Failed to parse auth event %s",
e_id,
)
context.auth_events[(e.type, e.state_key)] = e
if not e:
logger.warn("Can't find auth event %s.", e_id)
logger.debug(
"_handle_new_event: Before hack: %s, sigs: %s",
event.event_id, event.signatures,
)
auth_events[(e.type, e.state_key)] = e
if event.type == EventTypes.Member and not event.auth_events:
if len(event.prev_events) == 1:
c = yield self.store.get_event(event.prev_events[0][0])
if c.type == EventTypes.Create:
context.auth_events[(c.type, c.state_key)] = c
if event.type == RoomMemberEvent.TYPE and not event.auth_events:
if len(event.prev_events) == 1:
c = yield self.store.get_event(event.prev_events[0][0])
if c.type == RoomCreateEvent.TYPE:
auth_events[(c.type, c.state_key)] = c
logger.debug(
"_handle_new_event: Before auth check: %s, sigs: %s",
event.event_id, event.signatures,
)
self.auth.check(event, auth_events=auth_events)
self.auth.check(event, auth_events=context.auth_events)
logger.debug(
"_handle_new_event: Before persist_event: %s, sigs: %s",
event.event_id, event.signatures,
)
yield self.store.persist_event(
event,
context=context,
backfilled=backfilled,
is_new_state=(is_new_state and not backfilled),
current_state=current_state,
)
logger.debug(
"_handle_new_event: After persist_event: %s, sigs: %s",
event.event_id, event.signatures,
)
defer.returnValue(context)

View File

@ -15,10 +15,13 @@
from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import RoomError
from synapse.streams.config import PaginationConfig
from synapse.util.logcontext import PreserveLoggingContext
from synapse.events.validator import EventValidator
from ._base import BaseHandler
import logging
@ -32,7 +35,7 @@ class MessageHandler(BaseHandler):
super(MessageHandler, self).__init__(hs)
self.hs = hs
self.clock = hs.get_clock()
self.event_factory = hs.get_event_factory()
self.validator = EventValidator()
@defer.inlineCallbacks
def get_message(self, msg_id=None, room_id=None, sender_id=None,
@ -79,7 +82,7 @@ class MessageHandler(BaseHandler):
self.ratelimit(event.user_id)
# TODO(paul): Why does 'event' not have a 'user' object?
user = self.hs.parse_userid(event.user_id)
assert user.is_mine, "User must be our own: %s" % (user,)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
snapshot = yield self.store.snapshot_room(event)
@ -134,19 +137,48 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk)
@defer.inlineCallbacks
def store_room_data(self, event=None):
""" Stores data for a room.
def create_and_send_event(self, event_dict):
""" Given a dict from a client, create and handle a new event.
Creates an FrozenEvent object, filling out auth_events, prev_events,
etc.
Adds display names to Join membership events.
Persists and notifies local clients and federation.
Args:
event : The room path event
stamp_event (bool) : True to stamp event content with server keys.
Raises:
SynapseError if something went wrong.
event_dict (dict): An entire event
"""
builder = self.event_builder_factory.new(event_dict)
snapshot = yield self.store.snapshot_room(event)
self.validator.validate_new(builder)
yield self._on_new_room_event(event, snapshot)
if builder.type == EventTypes.Member:
membership = builder.content.get("membership", None)
if membership == Membership.JOIN:
joinee = self.hs.parse_userid(builder.state_key)
# If event doesn't include a display name, add one.
yield self.distributor.fire(
"collect_presencelike_data",
joinee,
builder.content
)
event, context = yield self._create_new_client_event(
builder=builder,
)
if event.type == EventTypes.Member:
member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.change_membership(event, context)
else:
yield self.handle_new_client_event(
event=event,
context=context,
)
defer.returnValue(event)
@defer.inlineCallbacks
def get_room_data(self, user_id=None, room_id=None,

View File

@ -147,7 +147,7 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks
def is_presence_visible(self, observer_user, observed_user):
assert(observed_user.is_mine)
assert(self.hs.is_mine(observed_user))
if observer_user == observed_user:
defer.returnValue(True)
@ -165,7 +165,7 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks
def get_state(self, target_user, auth_user, as_event=False):
if target_user.is_mine:
if self.hs.is_mine(target_user):
visible = yield self.is_presence_visible(
observer_user=auth_user,
observed_user=target_user
@ -212,7 +212,7 @@ class PresenceHandler(BaseHandler):
# TODO (erikj): Turn this back on. Why did we end up sending EDUs
# everywhere?
if not target_user.is_mine:
if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != auth_user:
@ -291,7 +291,7 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks
def user_joined_room(self, user, room_id):
if user.is_mine:
if self.hs.is_mine(user):
statuscache = self._get_or_make_usercache(user)
# No actual update but we need to bump the serial anyway for the
@ -309,7 +309,7 @@ class PresenceHandler(BaseHandler):
rm_handler = self.homeserver.get_handlers().room_member_handler
curr_users = yield rm_handler.get_room_members(room_id)
for local_user in [c for c in curr_users if c.is_mine]:
for local_user in [c for c in curr_users if self.hs.is_mine(c)]:
self.push_update_to_local_and_remote(
observed_user=local_user,
users_to_push=[user],
@ -318,14 +318,14 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks
def send_invite(self, observer_user, observed_user):
if not observer_user.is_mine:
if not self.hs.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server")
yield self.store.add_presence_list_pending(
observer_user.localpart, observed_user.to_string()
)
if observed_user.is_mine:
if self.hs.is_mine(observed_user):
yield self.invite_presence(observed_user, observer_user)
else:
yield self.federation.send_edu(
@ -339,7 +339,7 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks
def _should_accept_invite(self, observed_user, observer_user):
if not observed_user.is_mine:
if not self.hs.is_mine(observed_user):
defer.returnValue(False)
row = yield self.store.has_presence_state(observed_user.localpart)
@ -359,7 +359,7 @@ class PresenceHandler(BaseHandler):
observed_user.localpart, observer_user.to_string()
)
if observer_user.is_mine:
if self.hs.is_mine(observer_user):
if accept:
yield self.accept_presence(observed_user, observer_user)
else:
@ -396,7 +396,7 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks
def drop(self, observed_user, observer_user):
if not observer_user.is_mine:
if not self.hs.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server")
yield self.store.del_presence_list(
@ -410,7 +410,7 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks
def get_presence_list(self, observer_user, accepted=None):
if not observer_user.is_mine:
if not self.hs.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server")
presence = yield self.store.get_presence_list(
@ -465,7 +465,7 @@ class PresenceHandler(BaseHandler):
)
for target_user in target_users:
if target_user.is_mine:
if self.hs.is_mine(target_user):
self._start_polling_local(user, target_user)
# We want to tell the person that just came online
@ -477,7 +477,7 @@ class PresenceHandler(BaseHandler):
)
deferreds = []
remote_users = [u for u in target_users if not u.is_mine]
remote_users = [u for u in target_users if not self.hs.is_mine(u)]
remoteusers_by_domain = partition(remote_users, lambda u: u.domain)
# Only poll for people in our get_presence_list
for domain in remoteusers_by_domain:
@ -520,7 +520,7 @@ class PresenceHandler(BaseHandler):
def stop_polling_presence(self, user, target_user=None):
logger.debug("Stop polling for presence from %s", user)
if not target_user or target_user.is_mine:
if not target_user or self.hs.is_mine(target_user):
self._stop_polling_local(user, target_user=target_user)
deferreds = []
@ -579,7 +579,7 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
def push_presence(self, user, statuscache):
assert(user.is_mine)
assert(self.hs.is_mine(user))
logger.debug("Pushing presence update from %s", user)
@ -696,7 +696,7 @@ class PresenceHandler(BaseHandler):
for poll in content.get("poll", []):
user = self.hs.parse_userid(poll)
if not user.is_mine:
if not self.hs.is_mine(user):
continue
# TODO(paul) permissions checks
@ -711,7 +711,7 @@ class PresenceHandler(BaseHandler):
for unpoll in content.get("unpoll", []):
user = self.hs.parse_userid(unpoll)
if not user.is_mine:
if not self.hs.is_mine(user):
continue
if user in self._remote_sendmap:
@ -730,7 +730,7 @@ class PresenceHandler(BaseHandler):
localusers, remoteusers = partitionbool(
users_to_push,
lambda u: u.is_mine
lambda u: self.hs.is_mine(u)
)
localusers = set(localusers)
@ -788,7 +788,7 @@ class PresenceEventSource(object):
[u.to_string() for u in observer_user, observed_user])):
defer.returnValue(True)
if observed_user.is_mine:
if self.hs.is_mine(observed_user):
pushmap = presence._local_pushmap
defer.returnValue(

View File

@ -51,7 +51,7 @@ class ProfileHandler(BaseHandler):
@defer.inlineCallbacks
def get_displayname(self, target_user):
if target_user.is_mine:
if self.hs.is_mine(target_user):
displayname = yield self.store.get_profile_displayname(
target_user.localpart
)
@ -81,7 +81,7 @@ class ProfileHandler(BaseHandler):
def set_displayname(self, target_user, auth_user, new_displayname):
"""target_user is the user whose displayname is to be changed;
auth_user is the user attempting to make this change."""
if not target_user.is_mine:
if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != auth_user:
@ -101,7 +101,7 @@ class ProfileHandler(BaseHandler):
@defer.inlineCallbacks
def get_avatar_url(self, target_user):
if target_user.is_mine:
if self.hs.is_mine(target_user):
avatar_url = yield self.store.get_profile_avatar_url(
target_user.localpart
)
@ -130,7 +130,7 @@ class ProfileHandler(BaseHandler):
def set_avatar_url(self, target_user, auth_user, new_avatar_url):
"""target_user is the user whose avatar_url is to be changed;
auth_user is the user attempting to make this change."""
if not target_user.is_mine:
if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != auth_user:
@ -150,7 +150,7 @@ class ProfileHandler(BaseHandler):
@defer.inlineCallbacks
def collect_presencelike_data(self, user, state):
if not user.is_mine:
if not self.hs.is_mine(user):
defer.returnValue(None)
with PreserveLoggingContext():
@ -170,7 +170,7 @@ class ProfileHandler(BaseHandler):
@defer.inlineCallbacks
def on_profile_query(self, args):
user = self.hs.parse_userid(args["user_id"])
if not user.is_mine:
if not self.hs.is_mine(user):
raise SynapseError(400, "User is not hosted on this Home Server")
just_field = args.get("field", None)
@ -191,7 +191,7 @@ class ProfileHandler(BaseHandler):
@defer.inlineCallbacks
def _update_join_states(self, user):
if not user.is_mine:
if not self.hs.is_mine(user):
return
joins = yield self.store.get_rooms_for_user_where_membership_is(
@ -200,8 +200,6 @@ class ProfileHandler(BaseHandler):
)
for j in joins:
snapshot = yield self.store.snapshot_room(j)
content = {
"membership": j.content["membership"],
}
@ -210,14 +208,11 @@ class ProfileHandler(BaseHandler):
"collect_presencelike_data", user, content
)
new_event = self.event_factory.create_event(
etype=j.type,
room_id=j.room_id,
state_key=j.state_key,
content=content,
user_id=j.state_key,
)
yield self._on_new_room_event(
new_event, snapshot, suppress_auth=True
)
msg_handler = self.hs.get_handlers().message_handler
yield msg_handler.create_and_send_event({
"type": j.type,
"room_id": j.room_id,
"state_key": j.state_key,
"content": content,
"sender": j.state_key,
})

View File

@ -22,6 +22,7 @@ from synapse.api.errors import (
)
from ._base import BaseHandler
import synapse.util.stringutils as stringutils
from synapse.util.async import run_on_reactor
from synapse.http.client import SimpleHttpClient
from synapse.http.client import CaptchaServerHttpClient
@ -54,12 +55,13 @@ class RegistrationHandler(BaseHandler):
Raises:
RegistrationError if there was a problem registering.
"""
yield run_on_reactor()
password_hash = None
if password:
password_hash = bcrypt.hashpw(password, bcrypt.gensalt())
if localpart:
user = UserID(localpart, self.hs.hostname, True)
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
token = self._generate_token(user_id)
@ -78,7 +80,7 @@ class RegistrationHandler(BaseHandler):
while not user_id and not token:
try:
localpart = self._generate_user_id()
user = UserID(localpart, self.hs.hostname, True)
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
token = self._generate_token(user_id)

View File

@ -17,12 +17,8 @@
from twisted.internet import defer
from synapse.types import UserID, RoomAlias, RoomID
from synapse.api.constants import Membership, JoinRules
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import StoreError, SynapseError
from synapse.api.events.room import (
RoomMemberEvent, RoomCreateEvent, RoomPowerLevelsEvent,
RoomTopicEvent, RoomNameEvent, RoomJoinRulesEvent,
)
from synapse.util import stringutils
from synapse.util.async import run_on_reactor
from ._base import BaseHandler
@ -52,9 +48,9 @@ class RoomCreationHandler(BaseHandler):
self.ratelimit(user_id)
if "room_alias_name" in config:
room_alias = RoomAlias.create_local(
room_alias = RoomAlias.create(
config["room_alias_name"],
self.hs
self.hs.hostname,
)
mapping = yield self.store.get_association_from_room_alias(
room_alias
@ -76,8 +72,8 @@ class RoomCreationHandler(BaseHandler):
if room_id:
# Ensure room_id is the correct type
room_id_obj = RoomID.from_string(room_id, self.hs)
if not room_id_obj.is_mine:
room_id_obj = RoomID.from_string(room_id)
if not self.hs.is_mine(room_id_obj):
raise SynapseError(400, "Room id must be local")
yield self.store.store_room(
@ -93,7 +89,10 @@ class RoomCreationHandler(BaseHandler):
while attempts < 5:
try:
random_string = stringutils.random_string(18)
gen_room_id = RoomID.create_local(random_string, self.hs)
gen_room_id = RoomID.create(
random_string,
self.hs.hostname,
)
yield self.store.store_room(
room_id=gen_room_id.to_string(),
room_creator_user_id=user_id,
@ -120,59 +119,37 @@ class RoomCreationHandler(BaseHandler):
user, room_id, is_public=is_public
)
room_member_handler = self.hs.get_handlers().room_member_handler
@defer.inlineCallbacks
def handle_event(event):
snapshot = yield self.store.snapshot_room(event)
logger.debug("Event: %s", event)
if event.type == RoomMemberEvent.TYPE:
yield room_member_handler.change_membership(
event,
do_auth=True
)
else:
yield self._on_new_room_event(
event, snapshot, extra_users=[user], suppress_auth=True
)
msg_handler = self.hs.get_handlers().message_handler
for event in creation_events:
yield handle_event(event)
yield msg_handler.create_and_send_event(event)
if "name" in config:
name = config["name"]
name_event = self.event_factory.create_event(
etype=RoomNameEvent.TYPE,
room_id=room_id,
user_id=user_id,
content={"name": name},
)
yield handle_event(name_event)
yield msg_handler.create_and_send_event({
"type": EventTypes.Name,
"room_id": room_id,
"sender": user_id,
"content": {"name": name},
})
if "topic" in config:
topic = config["topic"]
topic_event = self.event_factory.create_event(
etype=RoomTopicEvent.TYPE,
room_id=room_id,
user_id=user_id,
content={"topic": topic},
)
yield msg_handler.create_and_send_event({
"type": EventTypes.Topic,
"room_id": room_id,
"sender": user_id,
"content": {"topic": topic},
})
yield handle_event(topic_event)
content = {"membership": Membership.INVITE}
for invitee in invite_list:
invite_event = self.event_factory.create_event(
etype=RoomMemberEvent.TYPE,
state_key=invitee,
room_id=room_id,
user_id=user_id,
content=content
)
yield handle_event(invite_event)
yield msg_handler.create_and_send_event({
"type": EventTypes.Member,
"state_key": invitee,
"room_id": room_id,
"user_id": user_id,
"content": {"membership": Membership.INVITE},
})
result = {"room_id": room_id}
@ -189,40 +166,44 @@ class RoomCreationHandler(BaseHandler):
event_keys = {
"room_id": room_id,
"user_id": creator_id,
"sender": creator_id,
"state_key": "",
}
def create(etype, **content):
return self.event_factory.create_event(
etype=etype,
content=content,
**event_keys
)
def create(etype, content, **kwargs):
e = {
"type": etype,
"content": content,
}
e.update(event_keys)
e.update(kwargs)
return e
creation_event = create(
etype=RoomCreateEvent.TYPE,
creator=creator.to_string(),
etype=EventTypes.Create,
content={"creator": creator.to_string()},
)
join_event = self.event_factory.create_event(
etype=RoomMemberEvent.TYPE,
join_event = create(
etype=EventTypes.Member,
state_key=creator_id,
content={
"membership": Membership.JOIN,
},
**event_keys
)
power_levels_event = self.event_factory.create_event(
etype=RoomPowerLevelsEvent.TYPE,
power_levels_event = create(
etype=EventTypes.PowerLevels,
content={
"users": {
creator.to_string(): 100,
},
"users_default": 0,
"events": {
RoomNameEvent.TYPE: 100,
RoomPowerLevelsEvent.TYPE: 100,
EventTypes.Name: 100,
EventTypes.PowerLevels: 100,
},
"events_default": 0,
"state_default": 50,
@ -230,13 +211,12 @@ class RoomCreationHandler(BaseHandler):
"kick": 50,
"redact": 50
},
**event_keys
)
join_rule = JoinRules.PUBLIC if is_public else JoinRules.INVITE
join_rules_event = create(
etype=RoomJoinRulesEvent.TYPE,
join_rule=join_rule,
etype=EventTypes.JoinRules,
content={"join_rule": join_rule},
)
return [
@ -288,7 +268,7 @@ class RoomMemberHandler(BaseHandler):
if ignore_user is not None and member == ignore_user:
continue
if member.is_mine:
if self.hs.is_mine(member):
if localusers is not None:
localusers.add(member)
else:
@ -349,7 +329,7 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue(member)
@defer.inlineCallbacks
def change_membership(self, event=None, do_auth=True):
def change_membership(self, event, context, do_auth=True):
""" Change the membership status of a user in a room.
Args:
@ -359,11 +339,9 @@ class RoomMemberHandler(BaseHandler):
"""
target_user_id = event.state_key
snapshot = yield self.store.snapshot_room(event)
## TODO(markjh): get prev state from snapshot.
prev_state = yield self.store.get_room_member(
target_user_id, event.room_id
prev_state = context.current_state.get(
(EventTypes.Member, target_user_id),
None
)
room_id = event.room_id
@ -372,10 +350,11 @@ class RoomMemberHandler(BaseHandler):
# if this HS is not currently in the room, i.e. we have to do the
# invite/join dance.
if event.membership == Membership.JOIN:
yield self._do_join(event, snapshot, do_auth=do_auth)
yield self._do_join(event, context, do_auth=do_auth)
else:
# This is not a JOIN, so we can handle it normally.
# FIXME: This isn't idempotency.
if prev_state and prev_state.membership == event.membership:
# double same action, treat this event as a NOOP.
defer.returnValue({})
@ -384,7 +363,7 @@ class RoomMemberHandler(BaseHandler):
yield self._do_local_membership_update(
event,
membership=event.content["membership"],
snapshot=snapshot,
context=context,
do_auth=do_auth,
)
@ -412,32 +391,26 @@ class RoomMemberHandler(BaseHandler):
host = hosts[0]
content.update({"membership": Membership.JOIN})
new_event = self.event_factory.create_event(
etype=RoomMemberEvent.TYPE,
state_key=joinee.to_string(),
room_id=room_id,
user_id=joinee.to_string(),
membership=Membership.JOIN,
content=content,
)
builder = self.event_builder_factory.new({
"type": EventTypes.Member,
"state_key": joinee.to_string(),
"room_id": room_id,
"sender": joinee.to_string(),
"membership": Membership.JOIN,
"content": content,
})
event, context = yield self._create_new_client_event(builder)
snapshot = yield self.store.snapshot_room(new_event)
yield self._do_join(new_event, snapshot, room_host=host, do_auth=True)
yield self._do_join(event, context, room_host=host, do_auth=True)
defer.returnValue({"room_id": room_id})
@defer.inlineCallbacks
def _do_join(self, event, snapshot, room_host=None, do_auth=True):
def _do_join(self, event, context, room_host=None, do_auth=True):
joinee = self.hs.parse_userid(event.state_key)
# room_id = RoomID.from_string(event.room_id, self.hs)
room_id = event.room_id
# If event doesn't include a display name, add one.
yield self.distributor.fire(
"collect_presencelike_data", joinee, event.content
)
# XXX: We don't do an auth check if we are doing an invite
# join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we
@ -459,31 +432,29 @@ class RoomMemberHandler(BaseHandler):
)
if prev_state and prev_state.membership == Membership.INVITE:
room = yield self.store.get_room(room_id)
inviter = UserID.from_string(
prev_state.user_id, self.hs
)
inviter = UserID.from_string(prev_state.user_id)
should_do_dance = not inviter.is_mine and not room
should_do_dance = not self.hs.is_mine(inviter)
room_host = inviter.domain
else:
should_do_dance = False
have_joined = False
if should_do_dance:
handler = self.hs.get_handlers().federation_handler
have_joined = yield handler.do_invite_join(
room_host, room_id, event.user_id, event.content, snapshot
yield handler.do_invite_join(
room_host,
room_id,
event.user_id,
event.get_dict()["content"], # FIXME To get a non-frozen dict
context
)
# We want to do the _do_update inside the room lock.
if not have_joined:
else:
logger.debug("Doing normal join")
yield self._do_local_membership_update(
event,
membership=event.content["membership"],
snapshot=snapshot,
context=context,
do_auth=do_auth,
)
@ -508,10 +479,10 @@ class RoomMemberHandler(BaseHandler):
if prev_state and prev_state.membership == Membership.INVITE:
room = yield self.store.get_room(room_id)
inviter = UserID.from_string(
prev_state.sender, self.hs
prev_state.sender
)
is_remote_invite_join = not inviter.is_mine and not room
is_remote_invite_join = not self.hs.is_mine(inviter) and not room
room_host = inviter.domain
else:
is_remote_invite_join = False
@ -533,25 +504,17 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue(room_ids)
@defer.inlineCallbacks
def _do_local_membership_update(self, event, membership, snapshot,
def _do_local_membership_update(self, event, membership, context,
do_auth):
yield run_on_reactor()
# If we're inviting someone, then we should also send it to that
# HS.
target_user_id = event.state_key
target_user = self.hs.parse_userid(target_user_id)
if membership == Membership.INVITE and not target_user.is_mine:
do_invite_host = target_user.domain
else:
do_invite_host = None
target_user = self.hs.parse_userid(event.state_key)
yield self._on_new_room_event(
yield self.handle_new_client_event(
event,
snapshot,
context,
extra_users=[target_user],
suppress_auth=(not do_auth),
do_invite_host=do_invite_host,
)

View File

@ -63,7 +63,7 @@ class TypingNotificationHandler(BaseHandler):
@defer.inlineCallbacks
def started_typing(self, target_user, auth_user, room_id, timeout):
if not target_user.is_mine:
if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != auth_user:
@ -100,7 +100,7 @@ class TypingNotificationHandler(BaseHandler):
@defer.inlineCallbacks
def stopped_typing(self, target_user, auth_user, room_id):
if not target_user.is_mine:
if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != auth_user:
@ -118,7 +118,7 @@ class TypingNotificationHandler(BaseHandler):
@defer.inlineCallbacks
def user_left_room(self, user, room_id):
if user.is_mine:
if self.hs.is_mine(user):
member = RoomMember(room_id=room_id, user=user)
yield self._stopped_typing(member)

View File

@ -29,6 +29,7 @@ from twisted.web.util import redirectTo
import collections
import logging
import urllib
logger = logging.getLogger(__name__)
@ -122,9 +123,14 @@ class JsonResource(HttpServer, resource.Resource):
# We found a match! Trigger callback and then return the
# returned response. We pass both the request and any
# matched groups from the regex to the callback.
args = [
urllib.unquote(u).decode("UTF-8") for u in m.groups()
]
code, response = yield path_entry.callback(
request,
*m.groups()
*args
)
self._send_response(request, code, response)

View File

@ -28,7 +28,7 @@ class RestServletFactory(object):
speaking, they serve as wrappers around events and the handlers that
process them.
See synapse.api.events for information on synapse events.
See synapse.events for information on synapse events.
"""
def __init__(self, hs):

View File

@ -35,7 +35,7 @@ class WhoisRestServlet(RestServlet):
if not is_admin and target_user != auth_user:
raise AuthError(403, "You are not a server admin")
if not target_user.is_mine:
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only whois a local user")
ret = yield self.handlers.admin_handler.get_whois(target_user)

View File

@ -63,12 +63,10 @@ class RestServlet(object):
self.hs = hs
self.handlers = hs.get_handlers()
self.event_factory = hs.get_event_factory()
self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_auth()
self.txns = HttpTransactionStore()
self.validator = hs.get_event_validator()
def register(self, http_server):
""" Register this servlet with the given HTTP server. """
if hasattr(self, "PATTERN"):

View File

@ -21,7 +21,6 @@ from base import RestServlet, client_path_pattern
import json
import logging
import urllib
logger = logging.getLogger(__name__)
@ -36,9 +35,7 @@ class ClientDirectoryServer(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_alias):
room_alias = self.hs.parse_roomalias(
urllib.unquote(room_alias).decode("utf-8")
)
room_alias = self.hs.parse_roomalias(room_alias)
dir_handler = self.handlers.directory_handler
res = yield dir_handler.get_association(room_alias)
@ -56,9 +53,7 @@ class ClientDirectoryServer(RestServlet):
logger.debug("Got content: %s", content)
room_alias = self.hs.parse_roomalias(
urllib.unquote(room_alias).decode("utf-8")
)
room_alias = self.hs.parse_roomalias(room_alias)
logger.debug("Got room name: %s", room_alias.to_string())
@ -97,9 +92,7 @@ class ClientDirectoryServer(RestServlet):
dir_handler = self.handlers.directory_handler
room_alias = self.hs.parse_roomalias(
urllib.unquote(room_alias).decode("utf-8")
)
room_alias = self.hs.parse_roomalias(room_alias)
yield dir_handler.delete_association(
user.to_string(), room_alias

View File

@ -47,8 +47,8 @@ class LoginRestServlet(RestServlet):
@defer.inlineCallbacks
def do_password_login(self, login_submission):
if not login_submission["user"].startswith('@'):
login_submission["user"] = UserID.create_local(
login_submission["user"], self.hs).to_string()
login_submission["user"] = UserID.create(
login_submission["user"], self.hs.hostname).to_string()
handler = self.handlers.login_handler
token = yield handler.login(

View File

@ -22,7 +22,6 @@ from base import RestServlet, client_path_pattern
import json
import logging
import urllib
logger = logging.getLogger(__name__)
@ -33,7 +32,6 @@ class PresenceStatusRestServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request)
user_id = urllib.unquote(user_id)
user = self.hs.parse_userid(user_id)
state = yield self.handlers.presence_handler.get_state(
@ -44,7 +42,6 @@ class PresenceStatusRestServlet(RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request)
user_id = urllib.unquote(user_id)
user = self.hs.parse_userid(user_id)
state = {}
@ -80,10 +77,9 @@ class PresenceListRestServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request)
user_id = urllib.unquote(user_id)
user = self.hs.parse_userid(user_id)
if not user.is_mine:
if not self.hs.is_mine(user):
raise SynapseError(400, "User not hosted on this Home Server")
if auth_user != user:
@ -101,10 +97,9 @@ class PresenceListRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request)
user_id = urllib.unquote(user_id)
user = self.hs.parse_userid(user_id)
if not user.is_mine:
if not self.hs.is_mine(user):
raise SynapseError(400, "User not hosted on this Home Server")
if auth_user != user:

View File

@ -19,7 +19,6 @@ from twisted.internet import defer
from base import RestServlet, client_path_pattern
import json
import urllib
class ProfileDisplaynameRestServlet(RestServlet):
@ -27,7 +26,6 @@ class ProfileDisplaynameRestServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user_id = urllib.unquote(user_id)
user = self.hs.parse_userid(user_id)
displayname = yield self.handlers.profile_handler.get_displayname(
@ -39,7 +37,6 @@ class ProfileDisplaynameRestServlet(RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request)
user_id = urllib.unquote(user_id)
user = self.hs.parse_userid(user_id)
try:
@ -62,7 +59,6 @@ class ProfileAvatarURLRestServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user_id = urllib.unquote(user_id)
user = self.hs.parse_userid(user_id)
avatar_url = yield self.handlers.profile_handler.get_avatar_url(
@ -74,7 +70,6 @@ class ProfileAvatarURLRestServlet(RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request)
user_id = urllib.unquote(user_id)
user = self.hs.parse_userid(user_id)
try:
@ -97,7 +92,6 @@ class ProfileRestServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user_id = urllib.unquote(user_id)
user = self.hs.parse_userid(user_id)
displayname = yield self.handlers.profile_handler.get_displayname(

View File

@ -21,6 +21,8 @@ from synapse.api.constants import LoginType
from base import RestServlet, client_path_pattern
import synapse.util.stringutils as stringutils
from synapse.util.async import run_on_reactor
from hashlib import sha1
import hmac
import json
@ -233,7 +235,7 @@ class RegisterRestServlet(RestServlet):
@defer.inlineCallbacks
def _do_password(self, request, register_json, session):
yield
yield run_on_reactor()
if (self.hs.config.enable_registration_captcha and
not session[LoginType.RECAPTCHA]):
# captcha should've been done by this stage!

View File

@ -19,8 +19,7 @@ from twisted.internet import defer
from base import RestServlet, client_path_pattern
from synapse.api.errors import SynapseError, Codes
from synapse.streams.config import PaginationConfig
from synapse.api.events.room import RoomMemberEvent, RoomRedactionEvent
from synapse.api.constants import Membership
from synapse.api.constants import EventTypes, Membership
import json
import logging
@ -129,9 +128,9 @@ class RoomStateEventRestServlet(RestServlet):
msg_handler = self.handlers.message_handler
data = yield msg_handler.get_room_data(
user_id=user.to_string(),
room_id=urllib.unquote(room_id),
event_type=urllib.unquote(event_type),
state_key=urllib.unquote(state_key),
room_id=room_id,
event_type=event_type,
state_key=state_key,
)
if not data:
@ -143,32 +142,23 @@ class RoomStateEventRestServlet(RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, state_key):
user = yield self.auth.get_user_by_req(request)
event_type = urllib.unquote(event_type)
content = _parse_json(request)
event = self.event_factory.create_event(
etype=event_type, # already urldecoded
content=content,
room_id=urllib.unquote(room_id),
user_id=user.to_string(),
state_key=urllib.unquote(state_key)
)
event_dict = {
"type": event_type,
"content": content,
"room_id": room_id,
"sender": user.to_string(),
}
self.validator.validate(event)
if state_key is not None:
event_dict["state_key"] = state_key
if event_type == RoomMemberEvent.TYPE:
# membership events are special
handler = self.handlers.room_member_handler
yield handler.change_membership(event)
defer.returnValue((200, {}))
else:
# store random bits of state
msg_handler = self.handlers.message_handler
yield msg_handler.store_room_data(
event=event
)
defer.returnValue((200, {}))
msg_handler = self.handlers.message_handler
yield msg_handler.create_and_send_event(event_dict)
defer.returnValue((200, {}))
# TODO: Needs unit testing for generic events + feedback
@ -184,17 +174,15 @@ class RoomSendEventRestServlet(RestServlet):
user = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
event = self.event_factory.create_event(
etype=urllib.unquote(event_type),
room_id=urllib.unquote(room_id),
user_id=user.to_string(),
content=content
)
self.validator.validate(event)
msg_handler = self.handlers.message_handler
yield msg_handler.send_message(event)
event = yield msg_handler.create_and_send_event(
{
"type": event_type,
"content": content,
"room_id": room_id,
"sender": user.to_string(),
}
)
defer.returnValue((200, {"event_id": event.event_id}))
@ -235,14 +223,10 @@ class JoinRoomAliasServlet(RestServlet):
identifier = None
is_room_alias = False
try:
identifier = self.hs.parse_roomalias(
urllib.unquote(room_identifier)
)
identifier = self.hs.parse_roomalias(room_identifier)
is_room_alias = True
except SynapseError:
identifier = self.hs.parse_roomid(
urllib.unquote(room_identifier)
)
identifier = self.hs.parse_roomid(room_identifier)
# TODO: Support for specifying the home server to join with?
@ -251,18 +235,17 @@ class JoinRoomAliasServlet(RestServlet):
ret_dict = yield handler.join_room_alias(user, identifier)
defer.returnValue((200, ret_dict))
else: # room id
event = self.event_factory.create_event(
etype=RoomMemberEvent.TYPE,
content={"membership": Membership.JOIN},
room_id=urllib.unquote(identifier.to_string()),
user_id=user.to_string(),
state_key=user.to_string()
msg_handler = self.handlers.message_handler
yield msg_handler.create_and_send_event(
{
"type": EventTypes.Member,
"content": {"membership": Membership.JOIN},
"room_id": identifier.to_string(),
"sender": user.to_string(),
"state_key": user.to_string(),
}
)
self.validator.validate(event)
handler = self.handlers.room_member_handler
yield handler.change_membership(event)
defer.returnValue((200, {}))
@defer.inlineCallbacks
@ -301,7 +284,7 @@ class RoomMemberListRestServlet(RestServlet):
user = yield self.auth.get_user_by_req(request)
handler = self.handlers.room_member_handler
members = yield handler.get_room_members_as_pagination_chunk(
room_id=urllib.unquote(room_id),
room_id=room_id,
user_id=user.to_string())
for event in members["chunk"]:
@ -333,7 +316,7 @@ class RoomMessageListRestServlet(RestServlet):
with_feedback = "feedback" in request.args
handler = self.handlers.message_handler
msgs = yield handler.get_messages(
room_id=urllib.unquote(room_id),
room_id=room_id,
user_id=user.to_string(),
pagin_config=pagination_config,
feedback=with_feedback)
@ -351,7 +334,7 @@ class RoomStateRestServlet(RestServlet):
handler = self.handlers.message_handler
# Get all the current state for this room
events = yield handler.get_state_events(
room_id=urllib.unquote(room_id),
room_id=room_id,
user_id=user.to_string(),
)
defer.returnValue((200, events))
@ -366,7 +349,7 @@ class RoomInitialSyncRestServlet(RestServlet):
user = yield self.auth.get_user_by_req(request)
pagination_config = PaginationConfig.from_request(request)
content = yield self.handlers.message_handler.room_initial_sync(
room_id=urllib.unquote(room_id),
room_id=room_id,
user_id=user.to_string(),
pagin_config=pagination_config,
)
@ -378,8 +361,10 @@ class RoomTriggerBackfill(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id):
remote_server = urllib.unquote(request.args["remote"][0])
room_id = urllib.unquote(room_id)
remote_server = urllib.unquote(
request.args["remote"][0]
).decode("UTF-8")
limit = int(request.args["limit"][0])
handler = self.handlers.federation_handler
@ -414,18 +399,17 @@ class RoomMembershipRestServlet(RestServlet):
if membership_action == "kick":
membership_action = "leave"
event = self.event_factory.create_event(
etype=RoomMemberEvent.TYPE,
content={"membership": unicode(membership_action)},
room_id=urllib.unquote(room_id),
user_id=user.to_string(),
state_key=state_key
msg_handler = self.handlers.message_handler
yield msg_handler.create_and_send_event(
{
"type": EventTypes.Member,
"content": {"membership": unicode(membership_action)},
"room_id": room_id,
"sender": user.to_string(),
"state_key": state_key,
}
)
self.validator.validate(event)
handler = self.handlers.room_member_handler
yield handler.change_membership(event)
defer.returnValue((200, {}))
@defer.inlineCallbacks
@ -453,18 +437,16 @@ class RoomRedactEventRestServlet(RestServlet):
user = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
event = self.event_factory.create_event(
etype=RoomRedactionEvent.TYPE,
room_id=urllib.unquote(room_id),
user_id=user.to_string(),
content=content,
redacts=urllib.unquote(event_id),
)
self.validator.validate(event)
msg_handler = self.handlers.message_handler
yield msg_handler.send_message(event)
event = yield msg_handler.create_and_send_event(
{
"type": EventTypes.Redaction,
"content": content,
"room_id": room_id,
"sender": user.to_string(),
"redacts": event_id,
}
)
defer.returnValue((200, {"event_id": event.event_id}))

View File

@ -20,9 +20,7 @@
# Imports required for the default HomeServer() implementation
from synapse.federation import initialize_http_replication
from synapse.api.events import serialize_event
from synapse.api.events.factory import EventFactory
from synapse.api.events.validator import EventValidator
from synapse.events.utils import serialize_event
from synapse.notifier import Notifier
from synapse.api.auth import Auth
from synapse.handlers import Handlers
@ -36,6 +34,7 @@ from synapse.util.lockutils import LockManager
from synapse.streams.events import EventSources
from synapse.api.ratelimiting import Ratelimiter
from synapse.crypto.keyring import Keyring
from synapse.events.builder import EventBuilderFactory
class BaseHomeServer(object):
@ -65,7 +64,6 @@ class BaseHomeServer(object):
'persistence_service',
'replication_layer',
'datastore',
'event_factory',
'handlers',
'auth',
'rest_servlet_factory',
@ -82,7 +80,7 @@ class BaseHomeServer(object):
'event_sources',
'ratelimiter',
'keyring',
'event_validator',
'event_builder_factory',
]
def __init__(self, hostname, **kwargs):
@ -134,22 +132,22 @@ class BaseHomeServer(object):
def parse_userid(self, s):
"""Parse the string given by 's' as a User ID and return a UserID
object."""
return UserID.from_string(s, hs=self)
return UserID.from_string(s)
def parse_roomalias(self, s):
"""Parse the string given by 's' as a Room Alias and return a RoomAlias
object."""
return RoomAlias.from_string(s, hs=self)
return RoomAlias.from_string(s)
def parse_roomid(self, s):
"""Parse the string given by 's' as a Room ID and return a RoomID
object."""
return RoomID.from_string(s, hs=self)
return RoomID.from_string(s)
def parse_eventid(self, s):
"""Parse the string given by 's' as a Event ID and return a EventID
object."""
return EventID.from_string(s, hs=self)
return EventID.from_string(s)
def serialize_event(self, e):
return serialize_event(self, e)
@ -166,6 +164,9 @@ class BaseHomeServer(object):
return ip_addr
def is_mine(self, domain_specific_string):
return domain_specific_string.domain == self.hostname
# Build magic accessors for every dependency
for depname in BaseHomeServer.DEPENDENCIES:
BaseHomeServer._make_dependency_method(depname)
@ -193,9 +194,6 @@ class HomeServer(BaseHomeServer):
def build_datastore(self):
return DataStore(self)
def build_event_factory(self):
return EventFactory(self)
def build_handlers(self):
return Handlers(self)
@ -226,8 +224,11 @@ class HomeServer(BaseHomeServer):
def build_keyring(self):
return Keyring(self)
def build_event_validator(self):
return EventValidator(self)
def build_event_builder_factory(self):
return EventBuilderFactory(
clock=self.get_clock(),
hostname=self.hostname,
)
def register_servlets(self):
""" Register all servlets associated with this HomeServer.

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
from synapse.api.events.room import RoomPowerLevelsEvent
from synapse.api.constants import EventTypes
from collections import namedtuple
@ -89,7 +89,7 @@ class StateHandler(object):
ids = [e for e, _ in event.prev_events]
ret = yield self.resolve_state_groups(ids)
state_group, new_state = ret
state_group, new_state, _ = ret
event.old_state_events = copy.deepcopy(new_state)
@ -135,9 +135,87 @@ class StateHandler(object):
defer.returnValue(res[1].values())
@defer.inlineCallbacks
def annotate_context_with_state(self, event, context, old_state=None):
""" Fills out the context with the `current state` of the graph. The
`current state` here is defined to be the state of the event graph
just before the event - i.e. it never includes `event`
If `event` has `auth_events` then this will also fill out the
`auth_events` field on `context` from the `current_state`.
Args:
event (EventBase)
context (EventContext)
"""
yield run_on_reactor()
if old_state:
context.current_state = {
(s.type, s.state_key): s for s in old_state
}
context.state_group = None
if hasattr(event, "auth_events") and event.auth_events:
auth_ids = zip(*event.auth_events)[0]
context.auth_events = {
k: v
for k, v in context.current_state.items()
if v.event_id in auth_ids
}
else:
context.auth_events = {}
if event.is_state():
key = (event.type, event.state_key)
if key in context.current_state:
replaces = context.current_state[key]
if replaces.event_id != event.event_id: # Paranoia check
event.unsigned["replaces_state"] = replaces.event_id
defer.returnValue([])
if event.is_state():
ret = yield self.resolve_state_groups(
[e for e, _ in event.prev_events],
event_type=event.type,
state_key=event.state_key,
)
else:
ret = yield self.resolve_state_groups(
[e for e, _ in event.prev_events],
)
group, curr_state, prev_state = ret
context.current_state = curr_state
context.state_group = group if not event.is_state() else None
prev_state = yield self.store.add_event_hashes(
prev_state
)
if event.is_state():
key = (event.type, event.state_key)
if key in context.current_state:
replaces = context.current_state[key]
event.unsigned["replaces_state"] = replaces.event_id
if hasattr(event, "auth_events") and event.auth_events:
auth_ids = zip(*event.auth_events)[0]
context.auth_events = {
k: v
for k, v in context.current_state.items()
if v.event_id in auth_ids
}
else:
context.auth_events = {}
defer.returnValue(prev_state)
@defer.inlineCallbacks
@log_function
def resolve_state_groups(self, event_ids):
def resolve_state_groups(self, event_ids, event_type=None, state_key=""):
""" Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
@ -156,7 +234,14 @@ class StateHandler(object):
(e.type, e.state_key): e
for e in state_list
}
defer.returnValue((name, state))
prev_state = state.get((event_type, state_key), None)
if prev_state:
prev_state = prev_state.event_id
prev_states = [prev_state]
else:
prev_states = []
defer.returnValue((name, state, prev_states))
state = {}
for group, g_state in state_groups.items():
@ -177,6 +262,13 @@ class StateHandler(object):
if len(v.values()) > 1
}
if event_type:
prev_states = conflicted_state.get(
(event_type, state_key), {}
).keys()
else:
prev_states = []
try:
new_state = {}
new_state.update(unconflicted_state)
@ -186,11 +278,11 @@ class StateHandler(object):
logger.exception("Failed to resolve state")
raise
defer.returnValue((None, new_state))
defer.returnValue((None, new_state, prev_states))
def _get_power_level_from_event_state(self, event, user_id):
if hasattr(event, "old_state_events") and event.old_state_events:
key = (RoomPowerLevelsEvent.TYPE, "", )
key = (EventTypes.PowerLevels, "", )
power_level_event = event.old_state_events.get(key)
level = None
if power_level_event:

View File

@ -15,12 +15,8 @@
from twisted.internet import defer
from synapse.api.events.room import (
RoomMemberEvent, RoomTopicEvent, FeedbackEvent, RoomNameEvent,
RoomRedactionEvent,
)
from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes
from .directory import DirectoryStore
from .feedback import FeedbackStore
@ -39,6 +35,7 @@ from .state import StateStore
from .signatures import SignatureStore
from syutil.base64util import decode_base64
from syutil.jsonutil import encode_canonical_json
from synapse.crypto.event_signing import compute_event_reference_hash
@ -89,7 +86,6 @@ class DataStore(RoomMemberStore, RoomStore,
def __init__(self, hs):
super(DataStore, self).__init__(hs)
self.event_factory = hs.get_event_factory()
self.hs = hs
self.min_token_deferred = self._get_min_token()
@ -97,8 +93,8 @@ class DataStore(RoomMemberStore, RoomStore,
@defer.inlineCallbacks
@log_function
def persist_event(self, event, backfilled=False, is_new_state=True,
current_state=None):
def persist_event(self, event, context, backfilled=False,
is_new_state=True, current_state=None):
stream_ordering = None
if backfilled:
if not self.min_token_deferred.called:
@ -111,6 +107,7 @@ class DataStore(RoomMemberStore, RoomStore,
"persist_event",
self._persist_event_txn,
event=event,
context=context,
backfilled=backfilled,
stream_ordering=stream_ordering,
is_new_state=is_new_state,
@ -121,50 +118,66 @@ class DataStore(RoomMemberStore, RoomStore,
@defer.inlineCallbacks
def get_event(self, event_id, allow_none=False):
events_dict = yield self._simple_select_one(
"events",
{"event_id": event_id},
[
"event_id",
"type",
"room_id",
"content",
"unrecognized_keys",
"depth",
],
allow_none=allow_none,
)
events = yield self._get_events([event_id])
if not events_dict:
defer.returnValue(None)
if not events:
if allow_none:
defer.returnValue(None)
else:
raise RuntimeError("Could not find event %s" % (event_id,))
event = yield self._parse_events([events_dict])
defer.returnValue(event[0])
defer.returnValue(events[0])
@log_function
def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None,
is_new_state=True, current_state=None):
if event.type == RoomMemberEvent.TYPE:
def _persist_event_txn(self, txn, event, context, backfilled,
stream_ordering=None, is_new_state=True,
current_state=None):
if event.type == EventTypes.Member:
self._store_room_member_txn(txn, event)
elif event.type == FeedbackEvent.TYPE:
elif event.type == EventTypes.Feedback:
self._store_feedback_txn(txn, event)
elif event.type == RoomNameEvent.TYPE:
elif event.type == EventTypes.Name:
self._store_room_name_txn(txn, event)
elif event.type == RoomTopicEvent.TYPE:
elif event.type == EventTypes.Topic:
self._store_room_topic_txn(txn, event)
elif event.type == RoomRedactionEvent.TYPE:
elif event.type == EventTypes.Redaction:
self._store_redaction(txn, event)
outlier = False
if hasattr(event, "outlier"):
outlier = event.outlier
if hasattr(event.internal_metadata, "outlier"):
outlier = event.internal_metadata.outlier
event_dict = {
k: v
for k, v in event.get_dict().items()
if k not in [
"redacted",
"redacted_because",
]
}
metadata_json = encode_canonical_json(
event.internal_metadata.get_dict()
)
self._simple_insert_txn(
txn,
table="event_json",
values={
"event_id": event.event_id,
"room_id": event.room_id,
"internal_metadata": metadata_json.decode("UTF-8"),
"json": encode_canonical_json(event_dict).decode("UTF-8"),
},
or_replace=True,
)
vals = {
"topological_ordering": event.depth,
"event_id": event.event_id,
"type": event.type,
"room_id": event.room_id,
"content": json.dumps(event.content),
"content": json.dumps(event.get_dict()["content"]),
"processed": True,
"outlier": outlier,
"depth": event.depth,
@ -175,7 +188,7 @@ class DataStore(RoomMemberStore, RoomStore,
unrec = {
k: v
for k, v in event.get_full_dict().items()
for k, v in event.get_dict().items()
if k not in vals.keys() and k not in [
"redacted",
"redacted_because",
@ -210,7 +223,8 @@ class DataStore(RoomMemberStore, RoomStore,
room_id=event.room_id,
)
self._store_state_groups_txn(txn, event)
if not outlier:
self._store_state_groups_txn(txn, event, context)
if current_state:
txn.execute(
@ -304,16 +318,6 @@ class DataStore(RoomMemberStore, RoomStore,
txn, event.event_id, hash_alg, hash_bytes,
)
if hasattr(event, "signatures"):
logger.debug("sigs: %s", event.signatures)
for name, sigs in event.signatures.items():
for key_id, signature_base64 in sigs.items():
signature_bytes = decode_base64(signature_base64)
self._store_event_signature_txn(
txn, event.event_id, name, key_id,
signature_bytes,
)
for prev_event_id, prev_hashes in event.prev_events:
for alg, hash_base64 in prev_hashes.items():
hash_bytes = decode_base64(hash_base64)

View File

@ -15,15 +15,14 @@
import logging
from synapse.api.errors import StoreError
from synapse.api.events.utils import prune_event
from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
from syutil.base64util import encode_base64
from twisted.internet import defer
import collections
import copy
import json
import sys
import time
@ -84,7 +83,6 @@ class SQLBaseStore(object):
def __init__(self, hs):
self.hs = hs
self._db_pool = hs.get_db_pool()
self.event_factory = hs.get_event_factory()
self._clock = hs.get_clock()
@defer.inlineCallbacks
@ -436,42 +434,67 @@ class SQLBaseStore(object):
return self.runInteraction("_simple_max_id", func)
def _parse_event_from_row(self, row_dict):
d = copy.deepcopy({k: v for k, v in row_dict.items()})
d.pop("stream_ordering", None)
d.pop("topological_ordering", None)
d.pop("processed", None)
d["origin_server_ts"] = d.pop("ts", 0)
replaces_state = d.pop("prev_state", None)
if replaces_state:
d["replaces_state"] = replaces_state
d.update(json.loads(row_dict["unrecognized_keys"]))
d["content"] = json.loads(d["content"])
del d["unrecognized_keys"]
if "age_ts" not in d:
# For compatibility
d["age_ts"] = d.get("origin_server_ts", 0)
return self.event_factory.create_event(
etype=d["type"],
**d
def _get_events(self, event_ids):
return self.runInteraction(
"_get_events", self._get_events_txn, event_ids
)
def _get_events_txn(self, txn, event_ids):
# FIXME (erikj): This should be batched?
sql = "SELECT * FROM events WHERE event_id = ? ORDER BY rowid asc"
event_rows = []
events = []
for e_id in event_ids:
c = txn.execute(sql, (e_id,))
event_rows.extend(self.cursor_to_dict(c))
ev = self._get_event_txn(txn, e_id)
return self._parse_events_txn(txn, event_rows)
if ev:
events.append(ev)
return events
def _get_event_txn(self, txn, event_id, check_redacted=True,
get_prev_content=True):
sql = (
"SELECT internal_metadata, json, r.event_id FROM event_json as e "
"LEFT JOIN redactions as r ON e.event_id = r.redacts "
"WHERE e.event_id = ? "
"LIMIT 1 "
)
txn.execute(sql, (event_id,))
res = txn.fetchone()
if not res:
return None
internal_metadata, js, redacted = res
d = json.loads(js)
internal_metadata = json.loads(internal_metadata)
ev = FrozenEvent(d, internal_metadata_dict=internal_metadata)
if check_redacted and redacted:
ev = prune_event(ev)
ev.unsigned["redacted_by"] = redacted
# Get the redaction event.
because = self._get_event_txn(
txn,
redacted,
check_redacted=False
)
if because:
ev.unsigned["redacted_because"] = because
if get_prev_content and "replaces_state" in ev.unsigned:
ev.unsigned["prev_content"] = self._get_event_txn(
txn,
ev.unsigned["replaces_state"],
get_prev_content=False,
).get_dict()["content"]
return ev
def _parse_events(self, rows):
return self.runInteraction(
@ -479,80 +502,9 @@ class SQLBaseStore(object):
)
def _parse_events_txn(self, txn, rows):
events = [self._parse_event_from_row(r) for r in rows]
event_ids = [r["event_id"] for r in rows]
select_event_sql = (
"SELECT * FROM events WHERE event_id = ? ORDER BY rowid asc"
)
for i, ev in enumerate(events):
signatures = self._get_event_signatures_txn(
txn, ev.event_id,
)
ev.signatures = {
n: {
k: encode_base64(v) for k, v in s.items()
}
for n, s in signatures.items()
}
hashes = self._get_event_content_hashes_txn(
txn, ev.event_id,
)
ev.hashes = {
k: encode_base64(v) for k, v in hashes.items()
}
prevs = self._get_prev_events_and_state(txn, ev.event_id)
ev.prev_events = [
(e_id, h)
for e_id, h, is_state in prevs
if is_state == 0
]
ev.auth_events = self._get_auth_events(txn, ev.event_id)
if hasattr(ev, "state_key"):
ev.prev_state = [
(e_id, h)
for e_id, h, is_state in prevs
if is_state == 1
]
if hasattr(ev, "replaces_state"):
# Load previous state_content.
# FIXME (erikj): Handle multiple prev_states.
cursor = txn.execute(
select_event_sql,
(ev.replaces_state,)
)
prevs = self.cursor_to_dict(cursor)
if prevs:
prev = self._parse_event_from_row(prevs[0])
ev.prev_content = prev.content
if not hasattr(ev, "redacted"):
logger.debug("Doesn't have redacted key: %s", ev)
ev.redacted = self._has_been_redacted_txn(txn, ev)
if ev.redacted:
# Get the redaction event.
select_event_sql = "SELECT * FROM events WHERE event_id = ?"
txn.execute(select_event_sql, (ev.redacted,))
del_evs = self._parse_events_txn(
txn, self.cursor_to_dict(txn)
)
if del_evs:
ev = prune_event(ev)
events[i] = ev
ev.redacted_because = del_evs[0]
return events
return self._get_events_txn(txn, event_ids)
def _has_been_redacted_txn(self, txn, event):
sql = "SELECT event_id FROM redactions WHERE redacts = ?"

View File

@ -177,14 +177,15 @@ class EventFederationStore(SQLBaseStore):
retcols=["prev_event_id", "is_state"],
)
hashes = self._get_prev_event_hashes_txn(txn, event_id)
results = []
for d in res:
hashes = self._get_event_reference_hashes_txn(
txn,
d["prev_event_id"]
)
edge_hash = self._get_event_reference_hashes_txn(txn, d["prev_event_id"])
edge_hash.update(hashes.get(d["prev_event_id"], {}))
prev_hashes = {
k: encode_base64(v) for k, v in hashes.items()
k: encode_base64(v)
for k, v in edge_hash.items()
if k == "sha256"
}
results.append((d["prev_event_id"], prev_hashes, d["is_state"]))

View File

@ -32,6 +32,19 @@ CREATE INDEX IF NOT EXISTS events_stream_ordering ON events (stream_ordering);
CREATE INDEX IF NOT EXISTS events_topological_ordering ON events (topological_ordering);
CREATE INDEX IF NOT EXISTS events_room_id ON events (room_id);
CREATE TABLE IF NOT EXISTS event_json(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
internal_metadata NOT NULL,
json BLOB NOT NULL,
CONSTRAINT ev_j_uniq UNIQUE (event_id)
);
CREATE INDEX IF NOT EXISTS event_json_id ON event_json(event_id);
CREATE INDEX IF NOT EXISTS event_json_room_id ON event_json(room_id);
CREATE TABLE IF NOT EXISTS state_events(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,

View File

@ -29,7 +29,8 @@ CREATE TABLE IF NOT EXISTS state_groups_state(
CREATE TABLE IF NOT EXISTS event_to_state_groups(
event_id TEXT NOT NULL,
state_group INTEGER NOT NULL
state_group INTEGER NOT NULL,
CONSTRAINT event_to_state_groups_uniq UNIQUE (event_id)
);
CREATE INDEX IF NOT EXISTS state_groups_id ON state_groups(id);

View File

@ -13,8 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from _base import SQLBaseStore
from syutil.base64util import encode_base64
class SignatureStore(SQLBaseStore):
"""Persistence for event signatures and hashes"""
@ -67,6 +71,21 @@ class SignatureStore(SQLBaseStore):
f
)
@defer.inlineCallbacks
def add_event_hashes(self, event_ids):
hashes = yield self.get_event_reference_hashes(
event_ids
)
hashes = [
{
k: encode_base64(v) for k, v in h.items()
if k == "sha256"
}
for h in hashes
]
defer.returnValue(zip(event_ids, hashes))
def _get_event_reference_hashes_txn(self, txn, event_id):
"""Get all the hashes for a given PDU.
Args:

View File

@ -86,11 +86,16 @@ class StateStore(SQLBaseStore):
self._store_state_groups_txn, event
)
def _store_state_groups_txn(self, txn, event):
if event.state_events is None:
def _store_state_groups_txn(self, txn, event, context):
if context.current_state is None:
return
state_group = event.state_group
state_events = context.current_state
if event.is_state():
state_events[(event.type, event.state_key)] = event
state_group = context.state_group
if not state_group:
state_group = self._simple_insert_txn(
txn,
@ -102,7 +107,7 @@ class StateStore(SQLBaseStore):
or_ignore=True,
)
for state in event.state_events.values():
for state in state_events.values():
self._simple_insert_txn(
txn,
table="state_groups_state",

View File

@ -19,7 +19,7 @@ from collections import namedtuple
class DomainSpecificString(
namedtuple("DomainSpecificString", ("localpart", "domain", "is_mine"))
namedtuple("DomainSpecificString", ("localpart", "domain"))
):
"""Common base class among ID/name strings that have a local part and a
domain name, prefixed with a sigil.
@ -28,15 +28,13 @@ class DomainSpecificString(
'localpart' : The local part of the name (without the leading sigil)
'domain' : The domain part of the name
'is_mine' : Boolean indicating if the domain name is recognised by the
HomeServer as being its own
"""
# Deny iteration because it will bite you if you try to create a singleton
# set by:
# users = set(user)
def __iter__(self):
raise ValueError("Attempted to iterate a %s" % (type(self).__name__))
raise ValueError("Attempted to iterate a %s" % (type(self).__name__,))
# Because this class is a namedtuple of strings and booleans, it is deeply
# immutable.
@ -47,7 +45,7 @@ class DomainSpecificString(
return self
@classmethod
def from_string(cls, s, hs):
def from_string(cls, s):
"""Parse the string given by 's' into a structure object."""
if s[0] != cls.SIGIL:
raise SynapseError(400, "Expected %s string to start with '%s'" % (
@ -66,22 +64,15 @@ class DomainSpecificString(
# This code will need changing if we want to support multiple domain
# names on one HS
is_mine = domain == hs.hostname
return cls(localpart=parts[0], domain=domain, is_mine=is_mine)
return cls(localpart=parts[0], domain=domain)
def to_string(self):
"""Return a string encoding the fields of the structure object."""
return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain)
@classmethod
def create_local(cls, localpart, hs):
"""Create a structure on the local domain"""
return cls(localpart=localpart, domain=hs.hostname, is_mine=True)
@classmethod
def create(cls, localpart, domain, hs):
is_mine = domain == hs.hostname
return cls(localpart=localpart, domain=domain, is_mine=is_mine)
def create(cls, localpart, domain,):
return cls(localpart=localpart, domain=domain)
class UserID(DomainSpecificString):

View File

@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from frozendict import frozendict
def freeze(o):
if isinstance(o, dict) or isinstance(o, frozendict):
return frozendict({k: freeze(v) for k, v in o.items()})
if isinstance(o, basestring):
return o
try:
return tuple([freeze(i) for i in o])
except TypeError:
pass
return o
def unfreeze(o):
if isinstance(o, frozendict) or isinstance(o, dict):
return dict({k: unfreeze(v) for k, v in o.items()})
if isinstance(o, basestring):
return o
try:
return [unfreeze(i) for i in o]
except TypeError:
pass
return o

View File

@ -1,217 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.events import SynapseEvent
from synapse.api.events.validator import EventValidator
from synapse.api.errors import SynapseError
from tests import unittest
class SynapseTemplateCheckTestCase(unittest.TestCase):
def setUp(self):
self.validator = EventValidator(None)
def tearDown(self):
pass
def test_top_level_keys(self):
template = {
"person": {},
"friends": ["string"]
}
content = {
"person": {"name": "bob"},
"friends": ["jill", "mike"]
}
event = MockSynapseEvent(template)
event.content = content
self.assertTrue(self.validator.validate(event))
content = {
"person": {"name": "bob"},
"friends": ["jill"],
"enemies": ["mike"]
}
event.content = content
self.assertTrue(self.validator.validate(event))
content = {
"person": {"name": "bob"},
# missing friends
"enemies": ["mike", "jill"]
}
event.content = content
self.assertRaises(
SynapseError,
self.validator.validate,
event
)
def test_lists(self):
template = {
"person": {},
"friends": [{"name":"string"}]
}
content = {
"person": {"name": "bob"},
"friends": ["jill", "mike"] # should be in objects
}
event = MockSynapseEvent(template)
event.content = content
self.assertRaises(
SynapseError,
self.validator.validate,
event
)
content = {
"person": {"name": "bob"},
"friends": [{"name": "jill"}, {"name": "mike"}]
}
event.content = content
self.assertTrue(self.validator.validate(event))
def test_nested_lists(self):
template = {
"results": {
"families": [
{
"name": "string",
"members": [
{}
]
}
]
}
}
content = {
"results": {
"families": [
{
"name": "Smith",
"members": [
"Alice", "Bob" # wrong types
]
}
]
}
}
event = MockSynapseEvent(template)
event.content = content
self.assertRaises(
SynapseError,
self.validator.validate,
event
)
content = {
"results": {
"families": [
{
"name": "Smith",
"members": [
{"name": "Alice"}, {"name": "Bob"}
]
}
]
}
}
event.content = content
self.assertTrue(self.validator.validate(event))
def test_nested_keys(self):
template = {
"person": {
"attributes": {
"hair": "string",
"eye": "string"
},
"age": 0,
"fav_books": ["string"]
}
}
event = MockSynapseEvent(template)
content = {
"person": {
"attributes": {
"hair": "brown",
"eye": "green",
"skin": "purple"
},
"age": 33,
"fav_books": ["lotr", "hobbit"],
"fav_music": ["abba", "beatles"]
}
}
event.content = content
self.assertTrue(self.validator.validate(event))
content = {
"person": {
"attributes": {
"hair": "brown"
# missing eye
},
"age": 33,
"fav_books": ["lotr", "hobbit"],
"fav_music": ["abba", "beatles"]
}
}
event.content = content
self.assertRaises(
SynapseError,
self.validator.validate,
event
)
content = {
"person": {
"attributes": {
"hair": "brown",
"eye": "green",
"skin": "purple"
},
"age": 33,
"fav_books": "nothing", # should be a list
}
}
event.content = content
self.assertRaises(
SynapseError,
self.validator.validate,
event
)
class MockSynapseEvent(SynapseEvent):
def __init__(self, template):
self.template = template
def get_content_template(self):
return self.template

View File

@ -23,25 +23,20 @@ from ..utils import MockHttpResource, MockClock, MockKey
from synapse.server import HomeServer
from synapse.federation import initialize_http_replication
from synapse.api.events import SynapseEvent
from synapse.events import FrozenEvent
from synapse.storage.transactions import DestinationsTable
def make_pdu(prev_pdus=[], **kwargs):
"""Provide some default fields for making a PduTuple."""
pdu_fields = {
"is_state": False,
"unrecognized_keys": [],
"outlier": False,
"have_processed": True,
"state_key": None,
"power_level": None,
"prev_state_id": None,
"prev_state_origin": None,
"prev_events": prev_pdus,
}
pdu_fields.update(kwargs)
return SynapseEvent(prev_pdus=prev_pdus, **pdu_fields)
return FrozenEvent(pdu_fields)
class FederationTestCase(unittest.TestCase):
@ -176,7 +171,7 @@ class FederationTestCase(unittest.TestCase):
(200, "OK")
)
pdu = SynapseEvent(
pdu = make_pdu(
event_id="abc123def456",
origin="red",
user_id="@a:red",
@ -185,10 +180,9 @@ class FederationTestCase(unittest.TestCase):
origin_server_ts=123456789001,
depth=1,
content={"text": "Here is the message"},
destinations=["remote"],
)
yield self.federation.send_pdu(pdu)
yield self.federation.send_pdu(pdu, ["remote"])
self.mock_http_client.put_json.assert_called_with(
"remote",

View File

@ -16,11 +16,8 @@
from twisted.internet import defer
from tests import unittest
from synapse.api.events.room import (
MessageEvent,
)
from synapse.api.events import SynapseEvent
from synapse.api.constants import EventTypes
from synapse.events import FrozenEvent
from synapse.handlers.federation import FederationHandler
from synapse.server import HomeServer
@ -37,7 +34,7 @@ class FederationTestCase(unittest.TestCase):
self.mock_config.signing_key = [MockKey()]
self.state_handler = NonCallableMock(spec_set=[
"annotate_event_with_state",
"annotate_context_with_state",
])
self.auth = NonCallableMock(spec_set=[
@ -78,36 +75,42 @@ class FederationTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_msg(self):
pdu = SynapseEvent(
type=MessageEvent.TYPE,
room_id="foo",
content={"msgtype": u"fooo"},
origin_server_ts=0,
event_id="$a:b",
user_id="@a:b",
origin="b",
auth_events=[],
hashes={"sha256":"AcLrgtUIqqwaGoHhrEvYG1YLDIsVPYJdSRGhkp3jJp8"},
)
pdu = FrozenEvent({
"type": EventTypes.Message,
"room_id": "foo",
"content": {"msgtype": u"fooo"},
"origin_server_ts": 0,
"event_id": "$a:b",
"user_id":"@a:b",
"origin": "b",
"auth_events": [],
"hashes": {"sha256":"AcLrgtUIqqwaGoHhrEvYG1YLDIsVPYJdSRGhkp3jJp8"},
})
self.datastore.persist_event.return_value = defer.succeed(None)
self.datastore.get_room.return_value = defer.succeed(True)
self.auth.check_host_in_room.return_value = defer.succeed(True)
def annotate(ev, old_state=None):
ev.old_state_events = []
def annotate(ev, context, old_state=None):
context.current_state = {}
context.auth_events = {}
return defer.succeed(False)
self.state_handler.annotate_event_with_state.side_effect = annotate
self.state_handler.annotate_context_with_state.side_effect = annotate
yield self.handlers.federation_handler.on_receive_pdu(
"fo", pdu, False
)
self.datastore.persist_event.assert_called_once_with(
ANY, is_new_state=False, backfilled=False, current_state=None
ANY,
is_new_state=True,
backfilled=False,
current_state=None,
context=ANY,
)
self.state_handler.annotate_event_with_state.assert_called_once_with(
self.state_handler.annotate_context_with_state.assert_called_once_with(
ANY,
ANY,
old_state=None,
)

View File

@ -17,10 +17,7 @@
from twisted.internet import defer
from tests import unittest
from synapse.api.events.room import (
RoomMemberEvent,
)
from synapse.api.constants import Membership
from synapse.api.constants import EventTypes, Membership
from synapse.handlers.room import RoomMemberHandler, RoomCreationHandler
from synapse.handlers.profile import ProfileHandler
from synapse.server import HomeServer
@ -47,7 +44,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
"get_room_member",
"get_room",
"store_room",
"snapshot_room",
"get_latest_events_in_room",
]),
resource_for_federation=NonCallableMock(),
http_client=NonCallableMock(spec_set=[]),
@ -63,7 +60,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
"check_host_in_room",
]),
state_handler=NonCallableMock(spec_set=[
"annotate_event_with_state",
"annotate_context_with_state",
"get_current_state",
]),
config=self.mock_config,
@ -91,9 +88,6 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
self.handlers.profile_handler = ProfileHandler(self.hs)
self.room_member_handler = self.handlers.room_member_handler
self.snapshot = Mock()
self.datastore.snapshot_room.return_value = self.snapshot
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
@ -104,50 +98,68 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
target_user_id = "@red:blue"
content = {"membership": Membership.INVITE}
event = self.hs.get_event_factory().create_event(
etype=RoomMemberEvent.TYPE,
user_id=user_id,
state_key=target_user_id,
room_id=room_id,
membership=Membership.INVITE,
content=content,
builder = self.hs.get_event_builder_factory().new({
"type": EventTypes.Member,
"sender": user_id,
"state_key": target_user_id,
"room_id": room_id,
"content": content,
})
self.datastore.get_latest_events_in_room.return_value = (
defer.succeed([])
)
self.auth.check_host_in_room.return_value = defer.succeed(True)
def annotate(_, ctx):
ctx.current_state = {
(EventTypes.Member, "@alice:green"): self._create_member(
user_id="@alice:green",
room_id=room_id,
),
(EventTypes.Member, "@bob:red"): self._create_member(
user_id="@bob:red",
room_id=room_id,
),
}
store_id = "store_id_fooo"
self.datastore.persist_event.return_value = defer.succeed(store_id)
return defer.succeed(True)
self.datastore.get_room_member.return_value = defer.succeed(None)
self.state_handler.annotate_context_with_state.side_effect = annotate
event.old_state_events = {
(RoomMemberEvent.TYPE, "@alice:green"): self._create_member(
user_id="@alice:green",
room_id=room_id,
),
(RoomMemberEvent.TYPE, "@bob:red"): self._create_member(
user_id="@bob:red",
room_id=room_id,
),
}
def add_auth(_, ctx):
ctx.auth_events = ctx.current_state[
(EventTypes.Member, "@bob:red")
]
event.state_events = event.old_state_events
event.state_events[(RoomMemberEvent.TYPE, target_user_id)] = event
return defer.succeed(True)
self.auth.add_auth_events.side_effect = add_auth
# Actual invocation
yield self.room_member_handler.change_membership(event)
def send_invite(domain, event):
return defer.succeed(event)
self.federation.handle_new_event.assert_called_once_with(
event, self.snapshot,
self.federation.send_invite.side_effect = send_invite
room_handler = self.room_member_handler
event, context = yield room_handler._create_new_client_event(
builder
)
self.assertEquals(
set(["red", "green"]),
set(event.destinations)
yield room_handler.change_membership(event, context)
self.state_handler.annotate_context_with_state.assert_called_once_with(
builder, context
)
self.auth.add_auth_events.assert_called_once_with(
builder, context
)
self.federation.send_invite.assert_called_once_with(
"blue", event,
)
self.datastore.persist_event.assert_called_once_with(
event
event, context=context,
)
self.notifier.on_new_room_event.assert_called_once_with(
event, extra_users=[self.hs.parse_userid(target_user_id)]
@ -162,57 +174,56 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
user_id = "@bob:red"
user = self.hs.parse_userid(user_id)
event = self._create_member(
user_id=user_id,
room_id=room_id,
)
self.auth.check_host_in_room.return_value = defer.succeed(True)
store_id = "store_id_fooo"
self.datastore.persist_event.return_value = defer.succeed(store_id)
self.datastore.get_room.return_value = defer.succeed(1) # Not None.
prev_state = NonCallableMock()
prev_state.membership = Membership.INVITE
prev_state.sender = "@foo:red"
self.datastore.get_room_member.return_value = defer.succeed(prev_state)
join_signal_observer = Mock()
self.distributor.observe("user_joined_room", join_signal_observer)
event.state_events = {
(RoomMemberEvent.TYPE, "@alice:green"): self._create_member(
user_id="@alice:green",
room_id=room_id,
),
(RoomMemberEvent.TYPE, user_id): event,
}
builder = self.hs.get_event_builder_factory().new({
"type": EventTypes.Member,
"sender": user_id,
"state_key": user_id,
"room_id": room_id,
"content": {"membership": Membership.JOIN},
})
event.old_state_events = {
(RoomMemberEvent.TYPE, "@alice:green"): self._create_member(
user_id="@alice:green",
room_id=room_id,
),
}
event.state_events = event.old_state_events
event.state_events[(RoomMemberEvent.TYPE, user_id)] = event
# Actual invocation
yield self.room_member_handler.change_membership(event)
self.federation.handle_new_event.assert_called_once_with(
event, self.snapshot
self.datastore.get_latest_events_in_room.return_value = (
defer.succeed([])
)
self.assertEquals(
set(["red", "green"]),
set(event.destinations)
def annotate(_, ctx):
ctx.current_state = {
(EventTypes.Member, "@bob:red"): self._create_member(
user_id="@bob:red",
room_id=room_id,
membership=Membership.INVITE
),
}
return defer.succeed(True)
self.state_handler.annotate_context_with_state.side_effect = annotate
def add_auth(_, ctx):
ctx.auth_events = ctx.current_state[
(EventTypes.Member, "@bob:red")
]
return defer.succeed(True)
self.auth.add_auth_events.side_effect = add_auth
room_handler = self.room_member_handler
event, context = yield room_handler._create_new_client_event(
builder
)
# Actual invocation
yield room_handler.change_membership(event, context)
self.federation.handle_new_event.assert_called_once_with(
event, None, destinations=set()
)
self.datastore.persist_event.assert_called_once_with(
event
event, context=context
)
self.notifier.on_new_room_event.assert_called_once_with(
event, extra_users=[user]
@ -222,54 +233,82 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
user=user, room_id=room_id
)
def _create_member(self, user_id, room_id, membership=Membership.JOIN):
builder = self.hs.get_event_builder_factory().new({
"type": EventTypes.Member,
"sender": user_id,
"state_key": user_id,
"room_id": room_id,
"content": {"membership": membership},
})
return builder.build()
@defer.inlineCallbacks
def test_simple_leave(self):
room_id = "!foo:red"
user_id = "@bob:red"
user = self.hs.parse_userid(user_id)
event = self._create_member(
user_id=user_id,
room_id=room_id,
membership=Membership.LEAVE,
builder = self.hs.get_event_builder_factory().new({
"type": EventTypes.Member,
"sender": user_id,
"state_key": user_id,
"room_id": room_id,
"content": {"membership": Membership.LEAVE},
})
self.datastore.get_latest_events_in_room.return_value = (
defer.succeed([])
)
prev_state = NonCallableMock()
prev_state.membership = Membership.JOIN
prev_state.sender = user_id
self.datastore.get_room_member.return_value = defer.succeed(prev_state)
def annotate(_, ctx):
ctx.current_state = {
(EventTypes.Member, "@bob:red"): self._create_member(
user_id="@bob:red",
room_id=room_id,
membership=Membership.JOIN
),
}
event.state_events = {
(RoomMemberEvent.TYPE, user_id): event,
}
return defer.succeed(True)
event.old_state_events = {
(RoomMemberEvent.TYPE, user_id): self._create_member(
user_id=user_id,
room_id=room_id,
),
}
self.state_handler.annotate_context_with_state.side_effect = annotate
def add_auth(_, ctx):
ctx.auth_events = ctx.current_state[
(EventTypes.Member, "@bob:red")
]
return defer.succeed(True)
self.auth.add_auth_events.side_effect = add_auth
room_handler = self.room_member_handler
event, context = yield room_handler._create_new_client_event(
builder
)
leave_signal_observer = Mock()
self.distributor.observe("user_left_room", leave_signal_observer)
# Actual invocation
yield self.room_member_handler.change_membership(event)
yield room_handler.change_membership(event, context)
self.federation.handle_new_event.assert_called_once_with(
event, None, destinations=set(['red'])
)
self.datastore.persist_event.assert_called_once_with(
event, context=context
)
self.notifier.on_new_room_event.assert_called_once_with(
event, extra_users=[user]
)
leave_signal_observer.assert_called_with(
user=user, room_id=room_id
)
def _create_member(self, user_id, room_id, membership=Membership.JOIN):
return self.hs.get_event_factory().create_event(
etype=RoomMemberEvent.TYPE,
user_id=user_id,
state_key=user_id,
room_id=room_id,
membership=membership,
content={"membership": membership},
)
class RoomCreationTest(unittest.TestCase):
@ -292,13 +331,9 @@ class RoomCreationTest(unittest.TestCase):
notifier=NonCallableMock(spec_set=["on_new_room_event"]),
handlers=NonCallableMock(spec_set=[
"room_creation_handler",
"room_member_handler",
"federation_handler",
"message_handler",
]),
auth=NonCallableMock(spec_set=["check", "add_auth_events"]),
state_handler=NonCallableMock(spec_set=[
"annotate_event_with_state",
]),
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
@ -309,30 +344,12 @@ class RoomCreationTest(unittest.TestCase):
"handle_new_event",
])
self.datastore = hs.get_datastore()
self.handlers = hs.get_handlers()
self.notifier = hs.get_notifier()
self.state_handler = hs.get_state_handler()
self.hs = hs
self.handlers.federation_handler = self.federation
self.handlers.room_creation_handler = RoomCreationHandler(self.hs)
self.handlers.room_creation_handler = RoomCreationHandler(hs)
self.room_creation_handler = self.handlers.room_creation_handler
self.handlers.room_member_handler = NonCallableMock(spec_set=[
"change_membership"
])
self.room_member_handler = self.handlers.room_member_handler
def annotate(event):
event.state_events = {}
return defer.succeed(None)
self.state_handler.annotate_event_with_state.side_effect = annotate
def hosts(room):
return defer.succeed([])
self.datastore.get_joined_hosts_for_room.side_effect = hosts
self.message_handler = self.handlers.message_handler
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
@ -349,14 +366,37 @@ class RoomCreationTest(unittest.TestCase):
config=config,
)
self.assertTrue(self.room_member_handler.change_membership.called)
join_event = self.room_member_handler.change_membership.call_args[0][0]
self.assertTrue(self.message_handler.create_and_send_event.called)
self.assertEquals(RoomMemberEvent.TYPE, join_event.type)
self.assertEquals(room_id, join_event.room_id)
self.assertEquals(user_id, join_event.user_id)
self.assertEquals(user_id, join_event.state_key)
event_dicts = [
e[0][0]
for e in self.message_handler.create_and_send_event.call_args_list
]
self.assertTrue(self.state_handler.annotate_event_with_state.called)
self.assertTrue(len(event_dicts) > 3)
self.assertTrue(self.federation.handle_new_event.called)
self.assertDictContainsSubset(
{
"type": EventTypes.Create,
"sender": user_id,
"room_id": room_id,
},
event_dicts[0]
)
self.assertEqual(user_id, event_dicts[0]["content"]["creator"])
self.assertDictContainsSubset(
{
"type": EventTypes.Member,
"sender": user_id,
"room_id": room_id,
"state_key": user_id,
},
event_dicts[1]
)
self.assertEqual(
Membership.JOIN,
event_dicts[1]["content"]["membership"]
)

View File

@ -137,7 +137,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
if ignore_user is not None and member == ignore_user:
continue
if member.is_mine:
if hs.is_mine(member):
if localusers is not None:
localusers.add(member)
else:

View File

@ -113,9 +113,6 @@ class EventStreamPermissionsTestCase(RestTestCase):
def setUp(self):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
persistence_service.get_latest_pdus_in_context.return_value = []
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
@ -127,7 +124,6 @@ class EventStreamPermissionsTestCase(RestTestCase):
db_pool=db_pool,
http_client=None,
replication_layer=Mock(),
persistence_service=persistence_service,
clock=Mock(spec=[
"call_later",
"cancel_call_later",

View File

@ -503,7 +503,7 @@ class RoomsMemberListTestCase(RestTestCase):
@defer.inlineCallbacks
def test_get_member_list_mixed_memberships(self):
room_creator = "@some_other_guy:blue"
room_creator = "@some_other_guy:red"
room_id = yield self.create_room_as(room_creator)
room_path = "/rooms/%s/members" % room_id
yield self.invite(room=room_id, src=room_creator,

View File

@ -18,12 +18,11 @@ from tests import unittest
from twisted.internet import defer
from synapse.server import HomeServer
from synapse.api.constants import Membership
from synapse.api.events.room import (
RoomMemberEvent, MessageEvent, RoomRedactionEvent,
)
from synapse.api.constants import EventTypes, Membership
from tests.utils import SQLiteMemoryDbPool
from tests.utils import SQLiteMemoryDbPool, MockKey
from mock import Mock
class RedactionTestCase(unittest.TestCase):
@ -33,13 +32,21 @@ class RedactionTestCase(unittest.TestCase):
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
self.mock_config = Mock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer(
"test",
db_pool=db_pool,
config=self.mock_config,
resource_for_federation=Mock(),
http_client=None,
)
self.store = hs.get_datastore()
self.event_factory = hs.get_event_factory()
self.event_builder_factory = hs.get_event_builder_factory()
self.handlers = hs.get_handlers()
self.message_handler = self.handlers.message_handler
self.u_alice = hs.parse_userid("@alice:test")
self.u_bob = hs.parse_userid("@bob:test")
@ -49,35 +56,23 @@ class RedactionTestCase(unittest.TestCase):
self.depth = 1
@defer.inlineCallbacks
def inject_room_member(self, room, user, membership, prev_state=None,
def inject_room_member(self, room, user, membership, replaces_state=None,
extra_content={}):
self.depth += 1
content = {"membership": membership}
content.update(extra_content)
builder = self.event_builder_factory.new({
"type": EventTypes.Member,
"sender": user.to_string(),
"state_key": user.to_string(),
"room_id": room.to_string(),
"content": content,
})
event = self.event_factory.create_event(
etype=RoomMemberEvent.TYPE,
user_id=user.to_string(),
state_key=user.to_string(),
room_id=room.to_string(),
membership=membership,
content={"membership": membership},
depth=self.depth,
prev_events=[],
event, context = yield self.message_handler._create_new_client_event(
builder
)
event.content.update(extra_content)
if prev_state:
event.prev_state = prev_state
event.state_events = None
event.hashes = {}
event.prev_state = []
event.auth_events = []
# Have to create a join event using the eventfactory
yield self.store.persist_event(
event
)
yield self.store.persist_event(event, context)
defer.returnValue(event)
@ -85,46 +80,38 @@ class RedactionTestCase(unittest.TestCase):
def inject_message(self, room, user, body):
self.depth += 1
event = self.event_factory.create_event(
etype=MessageEvent.TYPE,
user_id=user.to_string(),
room_id=room.to_string(),
content={"body": body, "msgtype": u"message"},
depth=self.depth,
prev_events=[],
builder = self.event_builder_factory.new({
"type": EventTypes.Message,
"sender": user.to_string(),
"state_key": user.to_string(),
"room_id": room.to_string(),
"content": {"body": body, "msgtype": u"message"},
})
event, context = yield self.message_handler._create_new_client_event(
builder
)
event.state_events = None
event.hashes = {}
event.auth_events = []
yield self.store.persist_event(
event
)
yield self.store.persist_event(event, context)
defer.returnValue(event)
@defer.inlineCallbacks
def inject_redaction(self, room, event_id, user, reason):
event = self.event_factory.create_event(
etype=RoomRedactionEvent.TYPE,
user_id=user.to_string(),
room_id=room.to_string(),
content={"reason": reason},
depth=self.depth,
redacts=event_id,
prev_events=[],
builder = self.event_builder_factory.new({
"type": EventTypes.Redaction,
"sender": user.to_string(),
"state_key": user.to_string(),
"room_id": room.to_string(),
"content": {"reason": reason},
"redacts": event_id,
})
event, context = yield self.message_handler._create_new_client_event(
builder
)
event.state_events = None
event.hashes = {}
event.auth_events = []
yield self.store.persist_event(
event
)
defer.returnValue(event)
yield self.store.persist_event(event, context)
@defer.inlineCallbacks
def test_redact(self):
@ -152,14 +139,14 @@ class RedactionTestCase(unittest.TestCase):
self.assertObjectHasAttributes(
{
"type": MessageEvent.TYPE,
"type": EventTypes.Message,
"user_id": self.u_alice.to_string(),
"content": {"body": "t", "msgtype": "message"},
},
event,
)
self.assertFalse(hasattr(event, "redacted_because"))
self.assertFalse("redacted_because" in event.unsigned)
# Redact event
reason = "Because I said so"
@ -180,24 +167,26 @@ class RedactionTestCase(unittest.TestCase):
event = results[0]
self.assertEqual(msg_event.event_id, event.event_id)
self.assertTrue("redacted_because" in event.unsigned)
self.assertObjectHasAttributes(
{
"type": MessageEvent.TYPE,
"type": EventTypes.Message,
"user_id": self.u_alice.to_string(),
"content": {},
},
event,
)
self.assertTrue(hasattr(event, "redacted_because"))
self.assertObjectHasAttributes(
{
"type": RoomRedactionEvent.TYPE,
"type": EventTypes.Redaction,
"user_id": self.u_alice.to_string(),
"content": {"reason": reason},
},
event.redacted_because,
event.unsigned["redacted_because"],
)
@defer.inlineCallbacks
@ -229,7 +218,7 @@ class RedactionTestCase(unittest.TestCase):
self.assertObjectHasAttributes(
{
"type": RoomMemberEvent.TYPE,
"type": EventTypes.Member,
"user_id": self.u_bob.to_string(),
"content": {"membership": Membership.JOIN, "blue": "red"},
},
@ -257,22 +246,22 @@ class RedactionTestCase(unittest.TestCase):
event = results[0]
self.assertTrue("redacted_because" in event.unsigned)
self.assertObjectHasAttributes(
{
"type": RoomMemberEvent.TYPE,
"type": EventTypes.Member,
"user_id": self.u_bob.to_string(),
"content": {"membership": Membership.JOIN},
},
event,
)
self.assertTrue(hasattr(event, "redacted_because"))
self.assertObjectHasAttributes(
{
"type": RoomRedactionEvent.TYPE,
"type": EventTypes.Redaction,
"user_id": self.u_alice.to_string(),
"content": {"reason": reason},
},
event.redacted_because,
event.unsigned["redacted_because"],
)

View File

@ -18,9 +18,7 @@ from tests import unittest
from twisted.internet import defer
from synapse.server import HomeServer
from synapse.api.events.room import (
RoomNameEvent, RoomTopicEvent
)
from synapse.api.constants import EventTypes
from tests.utils import SQLiteMemoryDbPool
@ -131,7 +129,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
name = u"A-Room-Name"
yield self.inject_room_event(
etype=RoomNameEvent.TYPE,
etype=EventTypes.Name,
name=name,
content={"name": name},
depth=1,
@ -154,7 +152,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
topic = u"A place for things"
yield self.inject_room_event(
etype=RoomTopicEvent.TYPE,
etype=EventTypes.Topic,
topic=topic,
content={"topic": topic},
depth=1,

View File

@ -18,10 +18,11 @@ from tests import unittest
from twisted.internet import defer
from synapse.server import HomeServer
from synapse.api.constants import Membership
from synapse.api.events.room import RoomMemberEvent
from synapse.api.constants import EventTypes, Membership
from tests.utils import SQLiteMemoryDbPool
from tests.utils import SQLiteMemoryDbPool, MockKey
from mock import Mock
class RoomMemberStoreTestCase(unittest.TestCase):
@ -31,14 +32,22 @@ class RoomMemberStoreTestCase(unittest.TestCase):
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer("test",
db_pool=db_pool,
)
self.mock_config = Mock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer(
"test",
db_pool=db_pool,
config=self.mock_config,
resource_for_federation=Mock(),
http_client=None,
)
# We can't test the RoomMemberStore on its own without the other event
# storage logic
self.store = hs.get_datastore()
self.event_factory = hs.get_event_factory()
self.event_builder_factory = hs.get_event_builder_factory()
self.handlers = hs.get_handlers()
self.message_handler = self.handlers.message_handler
self.u_alice = hs.parse_userid("@alice:test")
self.u_bob = hs.parse_userid("@bob:test")
@ -49,27 +58,22 @@ class RoomMemberStoreTestCase(unittest.TestCase):
self.room = hs.parse_roomid("!abc123:test")
@defer.inlineCallbacks
def inject_room_member(self, room, user, membership):
# Have to create a join event using the eventfactory
event = self.event_factory.create_event(
etype=RoomMemberEvent.TYPE,
user_id=user.to_string(),
state_key=user.to_string(),
room_id=room.to_string(),
membership=membership,
content={"membership": membership},
depth=1,
prev_events=[],
def inject_room_member(self, room, user, membership, replaces_state=None):
builder = self.event_builder_factory.new({
"type": EventTypes.Member,
"sender": user.to_string(),
"state_key": user.to_string(),
"room_id": room.to_string(),
"content": {"membership": membership},
})
event, context = yield self.message_handler._create_new_client_event(
builder
)
event.state_events = None
event.hashes = {}
event.prev_state = {}
event.auth_events = {}
yield self.store.persist_event(event, context)
yield self.store.persist_event(
event
)
defer.returnValue(event)
@defer.inlineCallbacks
def test_one_member(self):

View File

@ -18,10 +18,11 @@ from tests import unittest
from twisted.internet import defer
from synapse.server import HomeServer
from synapse.api.constants import Membership
from synapse.api.events.room import RoomMemberEvent, MessageEvent
from synapse.api.constants import EventTypes, Membership
from tests.utils import SQLiteMemoryDbPool
from tests.utils import SQLiteMemoryDbPool, MockKey
from mock import Mock
class StreamStoreTestCase(unittest.TestCase):
@ -31,13 +32,21 @@ class StreamStoreTestCase(unittest.TestCase):
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
self.mock_config = Mock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer(
"test",
db_pool=db_pool,
config=self.mock_config,
resource_for_federation=Mock(),
http_client=None,
)
self.store = hs.get_datastore()
self.event_factory = hs.get_event_factory()
self.event_builder_factory = hs.get_event_builder_factory()
self.handlers = hs.get_handlers()
self.message_handler = self.handlers.message_handler
self.u_alice = hs.parse_userid("@alice:test")
self.u_bob = hs.parse_userid("@bob:test")
@ -48,33 +57,22 @@ class StreamStoreTestCase(unittest.TestCase):
self.depth = 1
@defer.inlineCallbacks
def inject_room_member(self, room, user, membership, replaces_state=None):
def inject_room_member(self, room, user, membership):
self.depth += 1
event = self.event_factory.create_event(
etype=RoomMemberEvent.TYPE,
user_id=user.to_string(),
state_key=user.to_string(),
room_id=room.to_string(),
membership=membership,
content={"membership": membership},
depth=self.depth,
prev_events=[],
builder = self.event_builder_factory.new({
"type": EventTypes.Member,
"sender": user.to_string(),
"state_key": user.to_string(),
"room_id": room.to_string(),
"content": {"membership": membership},
})
event, context = yield self.message_handler._create_new_client_event(
builder
)
event.state_events = None
event.hashes = {}
event.prev_state = []
event.auth_events = []
if replaces_state:
event.prev_state = [(replaces_state, "hash")]
event.replaces_state = replaces_state
# Have to create a join event using the eventfactory
yield self.store.persist_event(
event
)
yield self.store.persist_event(event, context)
defer.returnValue(event)
@ -82,23 +80,19 @@ class StreamStoreTestCase(unittest.TestCase):
def inject_message(self, room, user, body):
self.depth += 1
event = self.event_factory.create_event(
etype=MessageEvent.TYPE,
user_id=user.to_string(),
room_id=room.to_string(),
content={"body": body, "msgtype": u"message"},
depth=self.depth,
prev_events=[],
builder = self.event_builder_factory.new({
"type": EventTypes.Message,
"sender": user.to_string(),
"state_key": user.to_string(),
"room_id": room.to_string(),
"content": {"body": body, "msgtype": u"message"},
})
event, context = yield self.message_handler._create_new_client_event(
builder
)
event.state_events = None
event.hashes = {}
event.auth_events = []
# Have to create a join event using the eventfactory
yield self.store.persist_event(
event
)
yield self.store.persist_event(event, context)
@defer.inlineCallbacks
def test_event_stream_get_other(self):
@ -130,7 +124,7 @@ class StreamStoreTestCase(unittest.TestCase):
self.assertObjectHasAttributes(
{
"type": MessageEvent.TYPE,
"type": EventTypes.Message,
"user_id": self.u_alice.to_string(),
"content": {"body": "test", "msgtype": "message"},
},
@ -167,7 +161,7 @@ class StreamStoreTestCase(unittest.TestCase):
self.assertObjectHasAttributes(
{
"type": MessageEvent.TYPE,
"type": EventTypes.Message,
"user_id": self.u_alice.to_string(),
"content": {"body": "test", "msgtype": "message"},
},
@ -220,7 +214,6 @@ class StreamStoreTestCase(unittest.TestCase):
event2 = yield self.inject_room_member(
self.room1, self.u_alice, Membership.JOIN,
replaces_state=event1.event_id,
)
end = yield self.store.get_room_events_max_id()
@ -238,6 +231,6 @@ class StreamStoreTestCase(unittest.TestCase):
event = results[0]
self.assertTrue(
hasattr(event, "prev_content"),
"prev_content" in event.unsigned,
msg="No prev_content key"
)

View File

@ -23,21 +23,21 @@ mock_homeserver = BaseHomeServer(hostname="my.domain")
class UserIDTestCase(unittest.TestCase):
def test_parse(self):
user = UserID.from_string("@1234abcd:my.domain", hs=mock_homeserver)
user = UserID.from_string("@1234abcd:my.domain")
self.assertEquals("1234abcd", user.localpart)
self.assertEquals("my.domain", user.domain)
self.assertEquals(True, user.is_mine)
self.assertEquals(True, mock_homeserver.is_mine(user))
def test_build(self):
user = UserID("5678efgh", "my.domain", True)
user = UserID("5678efgh", "my.domain")
self.assertEquals(user.to_string(), "@5678efgh:my.domain")
def test_compare(self):
userA = UserID.from_string("@userA:my.domain", hs=mock_homeserver)
userAagain = UserID.from_string("@userA:my.domain", hs=mock_homeserver)
userB = UserID.from_string("@userB:my.domain", hs=mock_homeserver)
userA = UserID.from_string("@userA:my.domain")
userAagain = UserID.from_string("@userA:my.domain")
userB = UserID.from_string("@userB:my.domain")
self.assertTrue(userA == userAagain)
self.assertTrue(userA != userB)
@ -52,14 +52,14 @@ class UserIDTestCase(unittest.TestCase):
class RoomAliasTestCase(unittest.TestCase):
def test_parse(self):
room = RoomAlias.from_string("#channel:my.domain", hs=mock_homeserver)
room = RoomAlias.from_string("#channel:my.domain")
self.assertEquals("channel", room.localpart)
self.assertEquals("my.domain", room.domain)
self.assertEquals(True, room.is_mine)
self.assertEquals(True, mock_homeserver.is_mine(room))
def test_build(self):
room = RoomAlias("channel", "my.domain", True)
room = RoomAlias("channel", "my.domain")
self.assertEquals(room.to_string(), "#channel:my.domain")

View File

@ -15,21 +15,17 @@
from synapse.http.server import HttpServer
from synapse.api.errors import cs_error, CodeMessageException, StoreError
from synapse.api.constants import Membership
from synapse.api.constants import EventTypes
from synapse.storage import prepare_database
from synapse.util.logcontext import LoggingContext
from synapse.api.events.room import (
RoomMemberEvent, MessageEvent
)
from twisted.internet import defer, reactor
from twisted.enterprise.adbapi import ConnectionPool
from collections import namedtuple
from mock import patch, Mock
import json
import urllib
import urlparse
from inspect import getcallargs
@ -103,9 +99,14 @@ class MockHttpResource(HttpServer):
matcher = pattern.match(path)
if matcher:
try:
args = [
urllib.unquote(u).decode("UTF-8")
for u in matcher.groups()
]
(code, response) = yield func(
mock_request,
*matcher.groups()
*args
)
defer.returnValue((code, response))
except CodeMessageException as e:
@ -271,7 +272,7 @@ class MemoryDataStore(object):
return defer.succeed([])
def persist_event(self, event):
if event.type == RoomMemberEvent.TYPE:
if event.type == EventTypes.Member:
room_id = event.room_id
user = event.state_key
membership = event.membership